Switch use_mmap to a pointer type
This uses nil as undefined for a cleaner implementation.
This commit is contained in:
parent
3518aaef33
commit
97c9e11768
3 changed files with 63 additions and 93 deletions
83
api/types.go
83
api/types.go
|
@ -168,42 +168,11 @@ type Runner struct {
|
||||||
F16KV bool `json:"f16_kv,omitempty"`
|
F16KV bool `json:"f16_kv,omitempty"`
|
||||||
LogitsAll bool `json:"logits_all,omitempty"`
|
LogitsAll bool `json:"logits_all,omitempty"`
|
||||||
VocabOnly bool `json:"vocab_only,omitempty"`
|
VocabOnly bool `json:"vocab_only,omitempty"`
|
||||||
UseMMap TriState `json:"use_mmap,omitempty"`
|
UseMMap *bool `json:"use_mmap,omitempty"`
|
||||||
UseMLock bool `json:"use_mlock,omitempty"`
|
UseMLock bool `json:"use_mlock,omitempty"`
|
||||||
NumThread int `json:"num_thread,omitempty"`
|
NumThread int `json:"num_thread,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type TriState int
|
|
||||||
|
|
||||||
const (
|
|
||||||
TriStateUndefined TriState = -1
|
|
||||||
TriStateFalse TriState = 0
|
|
||||||
TriStateTrue TriState = 1
|
|
||||||
)
|
|
||||||
|
|
||||||
func (b *TriState) UnmarshalJSON(data []byte) error {
|
|
||||||
var v bool
|
|
||||||
if err := json.Unmarshal(data, &v); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if v {
|
|
||||||
*b = TriStateTrue
|
|
||||||
}
|
|
||||||
*b = TriStateFalse
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *TriState) MarshalJSON() ([]byte, error) {
|
|
||||||
if *b == TriStateUndefined {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
var v bool
|
|
||||||
if *b == TriStateTrue {
|
|
||||||
v = true
|
|
||||||
}
|
|
||||||
return json.Marshal(v)
|
|
||||||
}
|
|
||||||
|
|
||||||
// EmbeddingRequest is the request passed to [Client.Embeddings].
|
// EmbeddingRequest is the request passed to [Client.Embeddings].
|
||||||
type EmbeddingRequest struct {
|
type EmbeddingRequest struct {
|
||||||
// Model is the model name.
|
// Model is the model name.
|
||||||
|
@ -437,19 +406,6 @@ func (opts *Options) FromMap(m map[string]interface{}) error {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if reflect.PointerTo(field.Type()) == reflect.TypeOf((*TriState)(nil)) {
|
|
||||||
val, ok := val.(bool)
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("option %q must be of type boolean", key)
|
|
||||||
}
|
|
||||||
if val {
|
|
||||||
field.SetInt(int64(TriStateTrue))
|
|
||||||
} else {
|
|
||||||
field.SetInt(int64(TriStateFalse))
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
switch field.Kind() {
|
switch field.Kind() {
|
||||||
case reflect.Int:
|
case reflect.Int:
|
||||||
switch t := val.(type) {
|
switch t := val.(type) {
|
||||||
|
@ -496,6 +452,17 @@ func (opts *Options) FromMap(m map[string]interface{}) error {
|
||||||
slice[i] = str
|
slice[i] = str
|
||||||
}
|
}
|
||||||
field.Set(reflect.ValueOf(slice))
|
field.Set(reflect.ValueOf(slice))
|
||||||
|
case reflect.Pointer:
|
||||||
|
var b bool
|
||||||
|
if field.Type() == reflect.TypeOf(&b) {
|
||||||
|
val, ok := val.(bool)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("option %q must be of type boolean", key)
|
||||||
|
}
|
||||||
|
field.Set(reflect.ValueOf(&val))
|
||||||
|
} else {
|
||||||
|
return fmt.Errorf("unknown type loading config params: %v %v", field.Kind(), field.Type())
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("unknown type loading config params: %v", field.Kind())
|
return fmt.Errorf("unknown type loading config params: %v", field.Kind())
|
||||||
}
|
}
|
||||||
|
@ -538,7 +505,7 @@ func DefaultOptions() Options {
|
||||||
LowVRAM: false,
|
LowVRAM: false,
|
||||||
F16KV: true,
|
F16KV: true,
|
||||||
UseMLock: false,
|
UseMLock: false,
|
||||||
UseMMap: TriStateUndefined,
|
UseMMap: nil,
|
||||||
UseNUMA: false,
|
UseNUMA: false,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -608,19 +575,6 @@ func FormatParams(params map[string][]string) (map[string]interface{}, error) {
|
||||||
} else {
|
} else {
|
||||||
field := valueOpts.FieldByName(opt.Name)
|
field := valueOpts.FieldByName(opt.Name)
|
||||||
if field.IsValid() && field.CanSet() {
|
if field.IsValid() && field.CanSet() {
|
||||||
if reflect.PointerTo(field.Type()) == reflect.TypeOf((*TriState)(nil)) {
|
|
||||||
boolVal, err := strconv.ParseBool(vals[0])
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("invalid bool value %s", vals)
|
|
||||||
}
|
|
||||||
if boolVal {
|
|
||||||
out[key] = TriStateTrue
|
|
||||||
} else {
|
|
||||||
out[key] = TriStateFalse
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
switch field.Kind() {
|
switch field.Kind() {
|
||||||
case reflect.Float32:
|
case reflect.Float32:
|
||||||
floatVal, err := strconv.ParseFloat(vals[0], 32)
|
floatVal, err := strconv.ParseFloat(vals[0], 32)
|
||||||
|
@ -648,6 +602,17 @@ func FormatParams(params map[string][]string) (map[string]interface{}, error) {
|
||||||
case reflect.Slice:
|
case reflect.Slice:
|
||||||
// TODO: only string slices are supported right now
|
// TODO: only string slices are supported right now
|
||||||
out[key] = vals
|
out[key] = vals
|
||||||
|
case reflect.Pointer:
|
||||||
|
var b bool
|
||||||
|
if field.Type() == reflect.TypeOf(&b) {
|
||||||
|
boolVal, err := strconv.ParseBool(vals[0])
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid bool value %s", vals)
|
||||||
|
}
|
||||||
|
out[key] = &boolVal
|
||||||
|
} else {
|
||||||
|
return nil, fmt.Errorf("unknown type %s for %s", field.Kind(), key)
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("unknown type %s for %s", field.Kind(), key)
|
return nil, fmt.Errorf("unknown type %s for %s", field.Kind(), key)
|
||||||
}
|
}
|
||||||
|
|
|
@ -108,25 +108,27 @@ func TestDurationMarshalUnmarshal(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUseMmapParsingFromJSON(t *testing.T) {
|
func TestUseMmapParsingFromJSON(t *testing.T) {
|
||||||
|
tr := true
|
||||||
|
fa := false
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
req string
|
req string
|
||||||
exp TriState
|
exp *bool
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "Undefined",
|
name: "Undefined",
|
||||||
req: `{ }`,
|
req: `{ }`,
|
||||||
exp: TriStateUndefined,
|
exp: nil,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "True",
|
name: "True",
|
||||||
req: `{ "use_mmap": true }`,
|
req: `{ "use_mmap": true }`,
|
||||||
exp: TriStateTrue,
|
exp: &tr,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "False",
|
name: "False",
|
||||||
req: `{ "use_mmap": false }`,
|
req: `{ "use_mmap": false }`,
|
||||||
exp: TriStateFalse,
|
exp: &fa,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -144,50 +146,52 @@ func TestUseMmapParsingFromJSON(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUseMmapFormatParams(t *testing.T) {
|
func TestUseMmapFormatParams(t *testing.T) {
|
||||||
|
tr := true
|
||||||
|
fa := false
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
req map[string][]string
|
req map[string][]string
|
||||||
exp TriState
|
exp *bool
|
||||||
err error
|
err error
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "True",
|
name: "True",
|
||||||
req: map[string][]string{
|
req: map[string][]string{
|
||||||
"use_mmap": []string{"true"},
|
"use_mmap": {"true"},
|
||||||
},
|
},
|
||||||
exp: TriStateTrue,
|
exp: &tr,
|
||||||
err: nil,
|
err: nil,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "False",
|
name: "False",
|
||||||
req: map[string][]string{
|
req: map[string][]string{
|
||||||
"use_mmap": []string{"false"},
|
"use_mmap": {"false"},
|
||||||
},
|
},
|
||||||
exp: TriStateFalse,
|
exp: &fa,
|
||||||
err: nil,
|
err: nil,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Numeric True",
|
name: "Numeric True",
|
||||||
req: map[string][]string{
|
req: map[string][]string{
|
||||||
"use_mmap": []string{"1"},
|
"use_mmap": {"1"},
|
||||||
},
|
},
|
||||||
exp: TriStateTrue,
|
exp: &tr,
|
||||||
err: nil,
|
err: nil,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Numeric False",
|
name: "Numeric False",
|
||||||
req: map[string][]string{
|
req: map[string][]string{
|
||||||
"use_mmap": []string{"0"},
|
"use_mmap": {"0"},
|
||||||
},
|
},
|
||||||
exp: TriStateFalse,
|
exp: &fa,
|
||||||
err: nil,
|
err: nil,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "invalid string",
|
name: "invalid string",
|
||||||
req: map[string][]string{
|
req: map[string][]string{
|
||||||
"use_mmap": []string{"foo"},
|
"use_mmap": {"foo"},
|
||||||
},
|
},
|
||||||
exp: TriStateUndefined,
|
exp: nil,
|
||||||
err: fmt.Errorf("invalid bool value [foo]"),
|
err: fmt.Errorf("invalid bool value [foo]"),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -195,11 +199,11 @@ func TestUseMmapFormatParams(t *testing.T) {
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
t.Run(test.name, func(t *testing.T) {
|
t.Run(test.name, func(t *testing.T) {
|
||||||
resp, err := FormatParams(test.req)
|
resp, err := FormatParams(test.req)
|
||||||
require.Equal(t, err, test.err)
|
require.Equal(t, test.err, err)
|
||||||
respVal, ok := resp["use_mmap"]
|
respVal, ok := resp["use_mmap"]
|
||||||
if test.exp != TriStateUndefined {
|
if test.exp != nil {
|
||||||
assert.True(t, ok, "resp: %v", resp)
|
assert.True(t, ok, "resp: %v", resp)
|
||||||
assert.Equal(t, test.exp, respVal)
|
assert.Equal(t, *test.exp, *respVal.(*bool))
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -208,7 +208,8 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
|
||||||
if g.Library == "metal" &&
|
if g.Library == "metal" &&
|
||||||
uint64(opts.NumGPU) > 0 &&
|
uint64(opts.NumGPU) > 0 &&
|
||||||
uint64(opts.NumGPU) < ggml.KV().BlockCount()+1 {
|
uint64(opts.NumGPU) < ggml.KV().BlockCount()+1 {
|
||||||
opts.UseMMap = api.TriStateFalse
|
opts.UseMMap = new(bool)
|
||||||
|
*opts.UseMMap = false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -219,10 +220,10 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
|
||||||
// Windows CUDA should not use mmap for best performance
|
// Windows CUDA should not use mmap for best performance
|
||||||
// Linux with a model larger than free space, mmap leads to thrashing
|
// Linux with a model larger than free space, mmap leads to thrashing
|
||||||
// For CPU loads we want the memory to be allocated, not FS cache
|
// For CPU loads we want the memory to be allocated, not FS cache
|
||||||
if (runtime.GOOS == "windows" && gpus[0].Library == "cuda" && opts.UseMMap == api.TriStateUndefined) ||
|
if (runtime.GOOS == "windows" && gpus[0].Library == "cuda" && opts.UseMMap == nil) ||
|
||||||
(runtime.GOOS == "linux" && systemFreeMemory < estimate.TotalSize && opts.UseMMap == api.TriStateUndefined) ||
|
(runtime.GOOS == "linux" && systemFreeMemory < estimate.TotalSize && opts.UseMMap == nil) ||
|
||||||
(gpus[0].Library == "cpu" && opts.UseMMap == api.TriStateUndefined) ||
|
(gpus[0].Library == "cpu" && opts.UseMMap == nil) ||
|
||||||
opts.UseMMap == api.TriStateFalse {
|
(opts.UseMMap != nil && !*opts.UseMMap) {
|
||||||
params = append(params, "--no-mmap")
|
params = append(params, "--no-mmap")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue