diff --git a/api/types.go b/api/types.go index d62079ae..7822a603 100644 --- a/api/types.go +++ b/api/types.go @@ -159,18 +159,49 @@ type Options struct { // Runner options which must be set when the model is loaded into memory type Runner struct { - UseNUMA bool `json:"numa,omitempty"` - NumCtx int `json:"num_ctx,omitempty"` - NumBatch int `json:"num_batch,omitempty"` - NumGPU int `json:"num_gpu,omitempty"` - MainGPU int `json:"main_gpu,omitempty"` - LowVRAM bool `json:"low_vram,omitempty"` - F16KV bool `json:"f16_kv,omitempty"` - LogitsAll bool `json:"logits_all,omitempty"` - VocabOnly bool `json:"vocab_only,omitempty"` - UseMMap bool `json:"use_mmap,omitempty"` - UseMLock bool `json:"use_mlock,omitempty"` - NumThread int `json:"num_thread,omitempty"` + UseNUMA bool `json:"numa,omitempty"` + NumCtx int `json:"num_ctx,omitempty"` + NumBatch int `json:"num_batch,omitempty"` + NumGPU int `json:"num_gpu,omitempty"` + MainGPU int `json:"main_gpu,omitempty"` + LowVRAM bool `json:"low_vram,omitempty"` + F16KV bool `json:"f16_kv,omitempty"` + LogitsAll bool `json:"logits_all,omitempty"` + VocabOnly bool `json:"vocab_only,omitempty"` + UseMMap TriState `json:"use_mmap,omitempty"` + UseMLock bool `json:"use_mlock,omitempty"` + NumThread int `json:"num_thread,omitempty"` +} + +type TriState int + +const ( + TriStateUndefined TriState = -1 + TriStateFalse TriState = 0 + TriStateTrue TriState = 1 +) + +func (b *TriState) UnmarshalJSON(data []byte) error { + var v bool + if err := json.Unmarshal(data, &v); err != nil { + return err + } + if v { + *b = TriStateTrue + } + *b = TriStateFalse + return nil +} + +func (b *TriState) MarshalJSON() ([]byte, error) { + if *b == TriStateUndefined { + return nil, nil + } + var v bool + if *b == TriStateTrue { + v = true + } + return json.Marshal(v) } // EmbeddingRequest is the request passed to [Client.Embeddings]. @@ -403,6 +434,19 @@ func (opts *Options) FromMap(m map[string]interface{}) error { continue } + if reflect.PointerTo(field.Type()) == reflect.TypeOf((*TriState)(nil)) { + val, ok := val.(bool) + if !ok { + return fmt.Errorf("option %q must be of type boolean", key) + } + if val { + field.SetInt(int64(TriStateTrue)) + } else { + field.SetInt(int64(TriStateFalse)) + } + continue + } + switch field.Kind() { case reflect.Int: switch t := val.(type) { @@ -491,7 +535,7 @@ func DefaultOptions() Options { LowVRAM: false, F16KV: true, UseMLock: false, - UseMMap: true, + UseMMap: TriStateUndefined, UseNUMA: false, }, } diff --git a/api/types_test.go b/api/types_test.go index 211385c7..7b4a0f83 100644 --- a/api/types_test.go +++ b/api/types_test.go @@ -105,3 +105,39 @@ func TestDurationMarshalUnmarshal(t *testing.T) { }) } } + +func TestUseMmapParsingFromJSON(t *testing.T) { + tests := []struct { + name string + req string + exp TriState + }{ + { + name: "Undefined", + req: `{ }`, + exp: TriStateUndefined, + }, + { + name: "True", + req: `{ "use_mmap": true }`, + exp: TriStateTrue, + }, + { + name: "False", + req: `{ "use_mmap": false }`, + exp: TriStateFalse, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + var oMap map[string]interface{} + err := json.Unmarshal([]byte(test.req), &oMap) + require.NoError(t, err) + opts := DefaultOptions() + err = opts.FromMap(oMap) + require.NoError(t, err) + assert.Equal(t, test.exp, opts.UseMMap) + }) + } +} diff --git a/llm/server.go b/llm/server.go index 117565ba..dd986292 100644 --- a/llm/server.go +++ b/llm/server.go @@ -200,7 +200,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr if g.Library == "metal" && uint64(opts.NumGPU) > 0 && uint64(opts.NumGPU) < ggml.KV().BlockCount()+1 { - opts.UseMMap = false + opts.UseMMap = api.TriStateFalse } } @@ -208,7 +208,8 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr params = append(params, "--flash-attn") } - if !opts.UseMMap { + // Windows CUDA should not use mmap for best performance + if (runtime.GOOS == "windows" && gpus[0].Library == "cuda") || opts.UseMMap == api.TriStateFalse { params = append(params, "--no-mmap") }