use model defaults for num_gqa, rope_frequency_base and rope_frequency_scale (#1983)

This commit is contained in:
Jeffrey Morgan 2024-05-09 09:06:13 -07:00 committed by GitHub
parent daa1a032f7
commit d5eec16d23
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 61 additions and 86 deletions

View file

@ -4,6 +4,7 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"log/slog"
"math" "math"
"os" "os"
"reflect" "reflect"
@ -161,7 +162,6 @@ type Runner struct {
UseNUMA bool `json:"numa,omitempty"` UseNUMA bool `json:"numa,omitempty"`
NumCtx int `json:"num_ctx,omitempty"` NumCtx int `json:"num_ctx,omitempty"`
NumBatch int `json:"num_batch,omitempty"` NumBatch int `json:"num_batch,omitempty"`
NumGQA int `json:"num_gqa,omitempty"`
NumGPU int `json:"num_gpu,omitempty"` NumGPU int `json:"num_gpu,omitempty"`
MainGPU int `json:"main_gpu,omitempty"` MainGPU int `json:"main_gpu,omitempty"`
LowVRAM bool `json:"low_vram,omitempty"` LowVRAM bool `json:"low_vram,omitempty"`
@ -171,11 +171,6 @@ type Runner struct {
UseMMap bool `json:"use_mmap,omitempty"` UseMMap bool `json:"use_mmap,omitempty"`
UseMLock bool `json:"use_mlock,omitempty"` UseMLock bool `json:"use_mlock,omitempty"`
NumThread int `json:"num_thread,omitempty"` NumThread int `json:"num_thread,omitempty"`
// Unused: RopeFrequencyBase is ignored. Instead the value in the model will be used
RopeFrequencyBase float32 `json:"rope_frequency_base,omitempty"`
// Unused: RopeFrequencyScale is ignored. Instead the value in the model will be used
RopeFrequencyScale float32 `json:"rope_frequency_scale,omitempty"`
} }
// EmbeddingRequest is the request passed to [Client.Embeddings]. // EmbeddingRequest is the request passed to [Client.Embeddings].
@ -359,8 +354,6 @@ func (m *Metrics) Summary() {
} }
} }
// ErrInvalidOpts is returned when invalid options are passed to the client.
var ErrInvalidOpts = errors.New("invalid options")
var ErrInvalidHostPort = errors.New("invalid port specified in OLLAMA_HOST") var ErrInvalidHostPort = errors.New("invalid port specified in OLLAMA_HOST")
func (opts *Options) FromMap(m map[string]interface{}) error { func (opts *Options) FromMap(m map[string]interface{}) error {
@ -376,73 +369,71 @@ func (opts *Options) FromMap(m map[string]interface{}) error {
} }
} }
invalidOpts := []string{}
for key, val := range m { for key, val := range m {
if opt, ok := jsonOpts[key]; ok { opt, ok := jsonOpts[key]
field := valueOpts.FieldByName(opt.Name) if !ok {
if field.IsValid() && field.CanSet() { slog.Warn("invalid option provided", "option", opt.Name)
if val == nil { continue
continue }
}
switch field.Kind() { field := valueOpts.FieldByName(opt.Name)
case reflect.Int: if field.IsValid() && field.CanSet() {
switch t := val.(type) { if val == nil {
case int64: continue
field.SetInt(t) }
case float64:
// when JSON unmarshals numbers, it uses float64, not int switch field.Kind() {
field.SetInt(int64(t)) case reflect.Int:
default: switch t := val.(type) {
return fmt.Errorf("option %q must be of type integer", key) case int64:
} field.SetInt(t)
case reflect.Bool: case float64:
val, ok := val.(bool) // when JSON unmarshals numbers, it uses float64, not int
if !ok { field.SetInt(int64(t))
return fmt.Errorf("option %q must be of type boolean", key) default:
} return fmt.Errorf("option %q must be of type integer", key)
field.SetBool(val) }
case reflect.Float32: case reflect.Bool:
// JSON unmarshals to float64 val, ok := val.(bool)
val, ok := val.(float64) if !ok {
if !ok { return fmt.Errorf("option %q must be of type boolean", key)
return fmt.Errorf("option %q must be of type float32", key) }
} field.SetBool(val)
field.SetFloat(val) case reflect.Float32:
case reflect.String: // JSON unmarshals to float64
val, ok := val.(string) val, ok := val.(float64)
if !ok { if !ok {
return fmt.Errorf("option %q must be of type string", key) return fmt.Errorf("option %q must be of type float32", key)
} }
field.SetString(val) field.SetFloat(val)
case reflect.Slice: case reflect.String:
// JSON unmarshals to []interface{}, not []string val, ok := val.(string)
val, ok := val.([]interface{}) if !ok {
if !ok { return fmt.Errorf("option %q must be of type string", key)
return fmt.Errorf("option %q must be of type array", key) }
} field.SetString(val)
// convert []interface{} to []string case reflect.Slice:
slice := make([]string, len(val)) // JSON unmarshals to []interface{}, not []string
for i, item := range val { val, ok := val.([]interface{})
str, ok := item.(string) if !ok {
if !ok { return fmt.Errorf("option %q must be of type array", key)
return fmt.Errorf("option %q must be of an array of strings", key) }
} // convert []interface{} to []string
slice[i] = str slice := make([]string, len(val))
} for i, item := range val {
field.Set(reflect.ValueOf(slice)) str, ok := item.(string)
default: if !ok {
return fmt.Errorf("unknown type loading config params: %v", field.Kind()) return fmt.Errorf("option %q must be of an array of strings", key)
} }
slice[i] = str
}
field.Set(reflect.ValueOf(slice))
default:
return fmt.Errorf("unknown type loading config params: %v", field.Kind())
} }
} else {
invalidOpts = append(invalidOpts, key)
} }
} }
if len(invalidOpts) > 0 {
return fmt.Errorf("%w: %v", ErrInvalidOpts, strings.Join(invalidOpts, ", "))
}
return nil return nil
} }
@ -475,8 +466,7 @@ func DefaultOptions() Options {
NumCtx: 2048, NumCtx: 2048,
NumBatch: 512, NumBatch: 512,
NumGPU: -1, // -1 here indicates that NumGPU should be set dynamically NumGPU: -1, // -1 here indicates that NumGPU should be set dynamically
NumGQA: 1, NumThread: 0, // let the runtime decide
NumThread: 0, // let the runtime decide
LowVRAM: false, LowVRAM: false,
F16KV: true, F16KV: true,
UseMLock: false, UseMLock: false,

View file

@ -313,7 +313,6 @@ curl http://localhost:11434/api/generate -d '{
"numa": false, "numa": false,
"num_ctx": 1024, "num_ctx": 1024,
"num_batch": 2, "num_batch": 2,
"num_gqa": 1,
"num_gpu": 1, "num_gpu": 1,
"main_gpu": 0, "main_gpu": 0,
"low_vram": false, "low_vram": false,
@ -321,8 +320,6 @@ curl http://localhost:11434/api/generate -d '{
"vocab_only": false, "vocab_only": false,
"use_mmap": true, "use_mmap": true,
"use_mlock": false, "use_mlock": false,
"rope_frequency_base": 1.1,
"rope_frequency_scale": 0.8,
"num_thread": 8 "num_thread": 8
} }
}' }'

View file

@ -127,10 +127,6 @@ func (s *Server) GenerateHandler(c *gin.Context) {
opts, err := modelOptions(model, req.Options) opts, err := modelOptions(model, req.Options)
if err != nil { if err != nil {
if errors.Is(err, api.ErrInvalidOpts) {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
} }
@ -370,10 +366,6 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
opts, err := modelOptions(model, req.Options) opts, err := modelOptions(model, req.Options)
if err != nil { if err != nil {
if errors.Is(err, api.ErrInvalidOpts) {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
} }
@ -1177,10 +1169,6 @@ func (s *Server) ChatHandler(c *gin.Context) {
opts, err := modelOptions(model, req.Options) opts, err := modelOptions(model, req.Options)
if err != nil { if err != nil {
if errors.Is(err, api.ErrInvalidOpts) {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
} }