Merge pull request #5105 from dhiltgen/cuda_mmap
Adjust mmap logic for cuda windows for faster model load
This commit is contained in:
commit
c9c8c98bf6
3 changed files with 96 additions and 15 deletions
48
api/types.go
48
api/types.go
|
@ -168,11 +168,42 @@ type Runner struct {
|
||||||
F16KV bool `json:"f16_kv,omitempty"`
|
F16KV bool `json:"f16_kv,omitempty"`
|
||||||
LogitsAll bool `json:"logits_all,omitempty"`
|
LogitsAll bool `json:"logits_all,omitempty"`
|
||||||
VocabOnly bool `json:"vocab_only,omitempty"`
|
VocabOnly bool `json:"vocab_only,omitempty"`
|
||||||
UseMMap bool `json:"use_mmap,omitempty"`
|
UseMMap TriState `json:"use_mmap,omitempty"`
|
||||||
UseMLock bool `json:"use_mlock,omitempty"`
|
UseMLock bool `json:"use_mlock,omitempty"`
|
||||||
NumThread int `json:"num_thread,omitempty"`
|
NumThread int `json:"num_thread,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type TriState int
|
||||||
|
|
||||||
|
const (
|
||||||
|
TriStateUndefined TriState = -1
|
||||||
|
TriStateFalse TriState = 0
|
||||||
|
TriStateTrue TriState = 1
|
||||||
|
)
|
||||||
|
|
||||||
|
func (b *TriState) UnmarshalJSON(data []byte) error {
|
||||||
|
var v bool
|
||||||
|
if err := json.Unmarshal(data, &v); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if v {
|
||||||
|
*b = TriStateTrue
|
||||||
|
}
|
||||||
|
*b = TriStateFalse
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *TriState) MarshalJSON() ([]byte, error) {
|
||||||
|
if *b == TriStateUndefined {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
var v bool
|
||||||
|
if *b == TriStateTrue {
|
||||||
|
v = true
|
||||||
|
}
|
||||||
|
return json.Marshal(v)
|
||||||
|
}
|
||||||
|
|
||||||
// EmbeddingRequest is the request passed to [Client.Embeddings].
|
// EmbeddingRequest is the request passed to [Client.Embeddings].
|
||||||
type EmbeddingRequest struct {
|
type EmbeddingRequest struct {
|
||||||
// Model is the model name.
|
// Model is the model name.
|
||||||
|
@ -403,6 +434,19 @@ func (opts *Options) FromMap(m map[string]interface{}) error {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if reflect.PointerTo(field.Type()) == reflect.TypeOf((*TriState)(nil)) {
|
||||||
|
val, ok := val.(bool)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("option %q must be of type boolean", key)
|
||||||
|
}
|
||||||
|
if val {
|
||||||
|
field.SetInt(int64(TriStateTrue))
|
||||||
|
} else {
|
||||||
|
field.SetInt(int64(TriStateFalse))
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
switch field.Kind() {
|
switch field.Kind() {
|
||||||
case reflect.Int:
|
case reflect.Int:
|
||||||
switch t := val.(type) {
|
switch t := val.(type) {
|
||||||
|
@ -491,7 +535,7 @@ func DefaultOptions() Options {
|
||||||
LowVRAM: false,
|
LowVRAM: false,
|
||||||
F16KV: true,
|
F16KV: true,
|
||||||
UseMLock: false,
|
UseMLock: false,
|
||||||
UseMMap: true,
|
UseMMap: TriStateUndefined,
|
||||||
UseNUMA: false,
|
UseNUMA: false,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
|
@ -105,3 +105,39 @@ func TestDurationMarshalUnmarshal(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestUseMmapParsingFromJSON(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
req string
|
||||||
|
exp TriState
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Undefined",
|
||||||
|
req: `{ }`,
|
||||||
|
exp: TriStateUndefined,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "True",
|
||||||
|
req: `{ "use_mmap": true }`,
|
||||||
|
exp: TriStateTrue,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "False",
|
||||||
|
req: `{ "use_mmap": false }`,
|
||||||
|
exp: TriStateFalse,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
t.Run(test.name, func(t *testing.T) {
|
||||||
|
var oMap map[string]interface{}
|
||||||
|
err := json.Unmarshal([]byte(test.req), &oMap)
|
||||||
|
require.NoError(t, err)
|
||||||
|
opts := DefaultOptions()
|
||||||
|
err = opts.FromMap(oMap)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, test.exp, opts.UseMMap)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -200,7 +200,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
|
||||||
if g.Library == "metal" &&
|
if g.Library == "metal" &&
|
||||||
uint64(opts.NumGPU) > 0 &&
|
uint64(opts.NumGPU) > 0 &&
|
||||||
uint64(opts.NumGPU) < ggml.KV().BlockCount()+1 {
|
uint64(opts.NumGPU) < ggml.KV().BlockCount()+1 {
|
||||||
opts.UseMMap = false
|
opts.UseMMap = api.TriStateFalse
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -208,7 +208,8 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
|
||||||
params = append(params, "--flash-attn")
|
params = append(params, "--flash-attn")
|
||||||
}
|
}
|
||||||
|
|
||||||
if !opts.UseMMap {
|
// Windows CUDA should not use mmap for best performance
|
||||||
|
if (runtime.GOOS == "windows" && gpus[0].Library == "cuda") || opts.UseMMap == api.TriStateFalse {
|
||||||
params = append(params, "--no-mmap")
|
params = append(params, "--no-mmap")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue