diff --git a/api/types.go b/api/types.go index eda7a992..4ec7f56c 100644 --- a/api/types.go +++ b/api/types.go @@ -3,9 +3,12 @@ package api import ( "encoding/json" "fmt" + "log" "math" "os" + "reflect" "runtime" + "strings" "time" ) @@ -34,7 +37,7 @@ type GenerateRequest struct { Prompt string `json:"prompt"` Context []int `json:"context,omitempty"` - Options `json:"options"` + Options map[string]interface{} `json:"options"` } type CreateRequest struct { @@ -177,6 +180,81 @@ type Options struct { NumThread int `json:"num_thread,omitempty"` } +func (opts *Options) FromMap(m map[string]interface{}) error { + valueOpts := reflect.ValueOf(opts).Elem() // names of the fields in the options struct + typeOpts := reflect.TypeOf(opts).Elem() // types of the fields in the options struct + + // build map of json struct tags to their types + jsonOpts := make(map[string]reflect.StructField) + for _, field := range reflect.VisibleFields(typeOpts) { + jsonTag := strings.Split(field.Tag.Get("json"), ",")[0] + if jsonTag != "" { + jsonOpts[jsonTag] = field + } + } + + for key, val := range m { + if opt, ok := jsonOpts[key]; ok { + field := valueOpts.FieldByName(opt.Name) + if field.IsValid() && field.CanSet() { + switch field.Kind() { + case reflect.Int: + // when JSON unmarshals numbers, it uses float64 by default, not int + val, ok := val.(float64) + if !ok { + log.Printf("could not convert model parmeter %v to int, skipped", key) + continue + } + field.SetInt(int64(val)) + case reflect.Bool: + val, ok := val.(bool) + if !ok { + log.Printf("could not convert model parmeter %v to bool, skipped", key) + continue + } + field.SetBool(val) + case reflect.Float32: + // JSON unmarshals to float64 + val, ok := val.(float64) + if !ok { + log.Printf("could not convert model parmeter %v to float32, skipped", key) + continue + } + field.SetFloat(val) + case reflect.String: + val, ok := val.(string) + if !ok { + log.Printf("could not convert model parmeter %v to string, skipped", key) + continue + } + field.SetString(val) + case reflect.Slice: + // JSON unmarshals to []interface{}, not []string + val, ok := val.([]interface{}) + if !ok { + log.Printf("could not convert model parmeter %v to slice, skipped", key) + continue + } + // convert []interface{} to []string + slice := make([]string, len(val)) + for i, item := range val { + str, ok := item.(string) + if !ok { + log.Printf("could not convert model parmeter %v to slice of strings, skipped", key) + continue + } + slice[i] = str + } + field.Set(reflect.ValueOf(slice)) + default: + return fmt.Errorf("unknown type loading config params: %v", field.Kind()) + } + } + } + } + return nil +} + func DefaultOptions() Options { return Options{ Seed: -1, diff --git a/go.mod b/go.mod index 2df15bfd..554473cb 100644 --- a/go.mod +++ b/go.mod @@ -14,7 +14,6 @@ require ( require github.com/rivo/uniseg v0.2.0 // indirect require ( - dario.cat/mergo v1.0.0 github.com/bytedance/sonic v1.9.1 // indirect github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect github.com/chzyer/readline v1.5.1 diff --git a/go.sum b/go.sum index 0dd32085..c4097bdb 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,3 @@ -dario.cat/mergo v1.0.0 h1:AGCNq9Evsj31mOgNPcLyXc+4PNABt905YmuqPYYpBWk= -dario.cat/mergo v1.0.0/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk= github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM= github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s= github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U= diff --git a/server/images.go b/server/images.go index 17478fd2..73448a36 100644 --- a/server/images.go +++ b/server/images.go @@ -32,8 +32,8 @@ type Model struct { ModelPath string Template string System string - Digest string - Options api.Options + Digest string + Options map[string]interface{} } func (m *Model) Prompt(request api.GenerateRequest) (string, error) { @@ -135,7 +135,7 @@ func GetModel(name string) (*Model, error) { } model := &Model{ - Name: mp.GetFullTagname(), + Name: mp.GetFullTagname(), Digest: manifest.Config.Digest, } @@ -176,12 +176,10 @@ func GetModel(name string) (*Model, error) { } defer params.Close() - var opts api.Options - if err = json.NewDecoder(params).Decode(&opts); err != nil { + // parse model options parameters into a map so that we can see which fields have been specified explicitly + if err = json.NewDecoder(params).Decode(&model.Options); err != nil { return nil, err } - - model.Options = opts } } @@ -442,11 +440,13 @@ func GetLayerWithBufferFromLayer(layer *Layer) (*LayerReader, error) { return newLayer, nil } +// paramsToReader converts specified parameter options to their correct types, and returns a reader for the json func paramsToReader(params map[string][]string) (io.ReadSeeker, error) { - opts := api.DefaultOptions() - typeOpts := reflect.TypeOf(opts) + opts := api.Options{} + valueOpts := reflect.ValueOf(&opts).Elem() // names of the fields in the options struct + typeOpts := reflect.TypeOf(opts) // types of the fields in the options struct - // build map of json struct tags + // build map of json struct tags to their types jsonOpts := make(map[string]reflect.StructField) for _, field := range reflect.VisibleFields(typeOpts) { jsonTag := strings.Split(field.Tag.Get("json"), ",")[0] @@ -455,7 +455,7 @@ func paramsToReader(params map[string][]string) (io.ReadSeeker, error) { } } - valueOpts := reflect.ValueOf(&opts).Elem() + out := make(map[string]interface{}) // iterate params and set values based on json struct tags for key, vals := range params { if opt, ok := jsonOpts[key]; ok { @@ -468,25 +468,26 @@ func paramsToReader(params map[string][]string) (io.ReadSeeker, error) { return nil, fmt.Errorf("invalid float value %s", vals) } - field.SetFloat(floatVal) + out[key] = floatVal case reflect.Int: intVal, err := strconv.ParseInt(vals[0], 10, 0) if err != nil { return nil, fmt.Errorf("invalid int value %s", vals) } - field.SetInt(intVal) + out[key] = intVal case reflect.Bool: boolVal, err := strconv.ParseBool(vals[0]) if err != nil { return nil, fmt.Errorf("invalid bool value %s", vals) } - field.SetBool(boolVal) + out[key] = boolVal case reflect.String: - field.SetString(vals[0]) + out[key] = vals[0] case reflect.Slice: - field.Set(reflect.ValueOf(vals)) + // TODO: only string slices are supported right now + out[key] = vals default: return nil, fmt.Errorf("unknown type %s for %s", field.Kind(), key) } @@ -494,7 +495,7 @@ func paramsToReader(params map[string][]string) (io.ReadSeeker, error) { } } - bts, err := json.Marshal(opts) + bts, err := json.Marshal(out) if err != nil { return nil, err } diff --git a/server/routes.go b/server/routes.go index 0c94c10a..e2fc74ab 100644 --- a/server/routes.go +++ b/server/routes.go @@ -15,7 +15,6 @@ import ( "sync" "time" - "dario.cat/mergo" "github.com/gin-contrib/cors" "github.com/gin-gonic/gin" @@ -61,12 +60,13 @@ func GenerateHandler(c *gin.Context) { } opts := api.DefaultOptions() - if err := mergo.Merge(&opts, model.Options, mergo.WithOverride); err != nil { + if err := opts.FromMap(model.Options); err != nil { + log.Printf("could not load model options: %v", err) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } - - if err := mergo.Merge(&opts, req.Options, mergo.WithOverride); err != nil { + if err := opts.FromMap(req.Options); err != nil { + log.Printf("could not merge model options: %v", err) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return }