read runner parameter options from map

- read runner options from map to see what was specified explicitly and overwrite zero values
This commit is contained in:
Bruce MacDonald 2023-08-01 13:36:31 -04:00
parent daa0d1de7a
commit 1c5a8770ee
5 changed files with 102 additions and 38 deletions

View file

@ -3,9 +3,12 @@ package api
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"log"
"math" "math"
"os" "os"
"reflect"
"runtime" "runtime"
"strings"
"time" "time"
) )
@ -34,7 +37,7 @@ type GenerateRequest struct {
Prompt string `json:"prompt"` Prompt string `json:"prompt"`
Context []int `json:"context,omitempty"` Context []int `json:"context,omitempty"`
Options `json:"options"` Options map[string]interface{} `json:"options"`
} }
type CreateRequest struct { type CreateRequest struct {
@ -177,6 +180,81 @@ type Options struct {
NumThread int `json:"num_thread,omitempty"` NumThread int `json:"num_thread,omitempty"`
} }
func (opts *Options) FromMap(m map[string]interface{}) error {
valueOpts := reflect.ValueOf(opts).Elem() // names of the fields in the options struct
typeOpts := reflect.TypeOf(opts).Elem() // types of the fields in the options struct
// build map of json struct tags to their types
jsonOpts := make(map[string]reflect.StructField)
for _, field := range reflect.VisibleFields(typeOpts) {
jsonTag := strings.Split(field.Tag.Get("json"), ",")[0]
if jsonTag != "" {
jsonOpts[jsonTag] = field
}
}
for key, val := range m {
if opt, ok := jsonOpts[key]; ok {
field := valueOpts.FieldByName(opt.Name)
if field.IsValid() && field.CanSet() {
switch field.Kind() {
case reflect.Int:
// when JSON unmarshals numbers, it uses float64 by default, not int
val, ok := val.(float64)
if !ok {
log.Printf("could not convert model parmeter %v to int, skipped", key)
continue
}
field.SetInt(int64(val))
case reflect.Bool:
val, ok := val.(bool)
if !ok {
log.Printf("could not convert model parmeter %v to bool, skipped", key)
continue
}
field.SetBool(val)
case reflect.Float32:
// JSON unmarshals to float64
val, ok := val.(float64)
if !ok {
log.Printf("could not convert model parmeter %v to float32, skipped", key)
continue
}
field.SetFloat(val)
case reflect.String:
val, ok := val.(string)
if !ok {
log.Printf("could not convert model parmeter %v to string, skipped", key)
continue
}
field.SetString(val)
case reflect.Slice:
// JSON unmarshals to []interface{}, not []string
val, ok := val.([]interface{})
if !ok {
log.Printf("could not convert model parmeter %v to slice, skipped", key)
continue
}
// convert []interface{} to []string
slice := make([]string, len(val))
for i, item := range val {
str, ok := item.(string)
if !ok {
log.Printf("could not convert model parmeter %v to slice of strings, skipped", key)
continue
}
slice[i] = str
}
field.Set(reflect.ValueOf(slice))
default:
return fmt.Errorf("unknown type loading config params: %v", field.Kind())
}
}
}
}
return nil
}
func DefaultOptions() Options { func DefaultOptions() Options {
return Options{ return Options{
Seed: -1, Seed: -1,

6
go.mod
View file

@ -11,13 +11,9 @@ require (
github.com/spf13/cobra v1.7.0 github.com/spf13/cobra v1.7.0
) )
require ( require github.com/rivo/uniseg v0.2.0 // indirect
github.com/mitchellh/mapstructure v1.5.0 // indirect
github.com/rivo/uniseg v0.2.0 // indirect
)
require ( require (
dario.cat/mergo v1.0.0
github.com/bytedance/sonic v1.9.1 // indirect github.com/bytedance/sonic v1.9.1 // indirect
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
github.com/chzyer/readline v1.5.1 github.com/chzyer/readline v1.5.1

4
go.sum
View file

@ -1,5 +1,3 @@
dario.cat/mergo v1.0.0 h1:AGCNq9Evsj31mOgNPcLyXc+4PNABt905YmuqPYYpBWk=
dario.cat/mergo v1.0.0/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk=
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM= github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s= github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U= github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
@ -73,8 +71,6 @@ github.com/mattn/go-runewidth v0.0.14 h1:+xnbZSEeDbOIg5/mE6JF0w6n9duR1l3/WmbinWV
github.com/mattn/go-runewidth v0.0.14/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= github.com/mattn/go-runewidth v0.0.14/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db h1:62I3jR2EmQ4l5rM/4FEfDWcRD+abF5XlKShorW5LRoQ= github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db h1:62I3jR2EmQ4l5rM/4FEfDWcRD+abF5XlKShorW5LRoQ=
github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db/go.mod h1:l0dey0ia/Uv7NcFFVbCLtqEBQbrT4OCwCSKTEv6enCw= github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db/go.mod h1:l0dey0ia/Uv7NcFFVbCLtqEBQbrT4OCwCSKTEv6enCw=
github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=

View file

@ -19,7 +19,6 @@ import (
"github.com/jmorganca/ollama/api" "github.com/jmorganca/ollama/api"
"github.com/jmorganca/ollama/parser" "github.com/jmorganca/ollama/parser"
"github.com/mitchellh/mapstructure"
) )
type RegistryOptions struct { type RegistryOptions struct {
@ -34,7 +33,7 @@ type Model struct {
Template string Template string
System string System string
Digest string Digest string
Options api.Options Options map[string]interface{}
} }
func (m *Model) Prompt(request api.GenerateRequest) (string, error) { func (m *Model) Prompt(request api.GenerateRequest) (string, error) {
@ -178,14 +177,7 @@ func GetModel(name string) (*Model, error) {
defer params.Close() defer params.Close()
// parse model options parameters into a map so that we can see which fields have been specified explicitly // parse model options parameters into a map so that we can see which fields have been specified explicitly
// TODO: once there are no modelfiles in the wild that do not have default options populated this can be removed if err = json.NewDecoder(params).Decode(&model.Options); err != nil {
var opts map[string]interface{}
if err = json.NewDecoder(params).Decode(&opts); err != nil {
return nil, err
}
// update the default options on the model with the options that have been specified
if err := mapstructure.Decode(opts, &model.Options); err != nil {
return nil, err return nil, err
} }
} }
@ -448,11 +440,13 @@ func GetLayerWithBufferFromLayer(layer *Layer) (*LayerReader, error) {
return newLayer, nil return newLayer, nil
} }
// paramsToReader converts specified parameter options to their correct types, and returns a reader for the json
func paramsToReader(params map[string][]string) (io.ReadSeeker, error) { func paramsToReader(params map[string][]string) (io.ReadSeeker, error) {
opts := api.DefaultOptions() opts := api.Options{}
typeOpts := reflect.TypeOf(opts) valueOpts := reflect.ValueOf(&opts).Elem() // names of the fields in the options struct
typeOpts := reflect.TypeOf(opts) // types of the fields in the options struct
// build map of json struct tags // build map of json struct tags to their types
jsonOpts := make(map[string]reflect.StructField) jsonOpts := make(map[string]reflect.StructField)
for _, field := range reflect.VisibleFields(typeOpts) { for _, field := range reflect.VisibleFields(typeOpts) {
jsonTag := strings.Split(field.Tag.Get("json"), ",")[0] jsonTag := strings.Split(field.Tag.Get("json"), ",")[0]
@ -461,7 +455,7 @@ func paramsToReader(params map[string][]string) (io.ReadSeeker, error) {
} }
} }
valueOpts := reflect.ValueOf(&opts).Elem() out := make(map[string]interface{})
// iterate params and set values based on json struct tags // iterate params and set values based on json struct tags
for key, vals := range params { for key, vals := range params {
if opt, ok := jsonOpts[key]; ok { if opt, ok := jsonOpts[key]; ok {
@ -474,25 +468,26 @@ func paramsToReader(params map[string][]string) (io.ReadSeeker, error) {
return nil, fmt.Errorf("invalid float value %s", vals) return nil, fmt.Errorf("invalid float value %s", vals)
} }
field.SetFloat(floatVal) out[key] = floatVal
case reflect.Int: case reflect.Int:
intVal, err := strconv.ParseInt(vals[0], 10, 0) intVal, err := strconv.ParseInt(vals[0], 10, 0)
if err != nil { if err != nil {
return nil, fmt.Errorf("invalid int value %s", vals) return nil, fmt.Errorf("invalid int value %s", vals)
} }
field.SetInt(intVal) out[key] = intVal
case reflect.Bool: case reflect.Bool:
boolVal, err := strconv.ParseBool(vals[0]) boolVal, err := strconv.ParseBool(vals[0])
if err != nil { if err != nil {
return nil, fmt.Errorf("invalid bool value %s", vals) return nil, fmt.Errorf("invalid bool value %s", vals)
} }
field.SetBool(boolVal) out[key] = boolVal
case reflect.String: case reflect.String:
field.SetString(vals[0]) out[key] = vals[0]
case reflect.Slice: case reflect.Slice:
field.Set(reflect.ValueOf(vals)) // TODO: only string slices are supported right now
out[key] = vals
default: default:
return nil, fmt.Errorf("unknown type %s for %s", field.Kind(), key) return nil, fmt.Errorf("unknown type %s for %s", field.Kind(), key)
} }
@ -500,12 +495,6 @@ func paramsToReader(params map[string][]string) (io.ReadSeeker, error) {
} }
} }
// convert opts to map so that zero fields are not omitted
out := make(map[string]interface{})
if err := mapstructure.Decode(opts, &out); err != nil {
return nil, err
}
bts, err := json.Marshal(out) bts, err := json.Marshal(out)
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -15,7 +15,6 @@ import (
"sync" "sync"
"time" "time"
"dario.cat/mergo"
"github.com/gin-contrib/cors" "github.com/gin-contrib/cors"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@ -59,8 +58,14 @@ func GenerateHandler(c *gin.Context) {
loaded.llm = nil loaded.llm = nil
} }
opts := model.Options opts := api.DefaultOptions()
if err := mergo.Merge(&opts, req.Options, mergo.WithOverride); err != nil { if err := opts.FromMap(model.Options); err != nil {
log.Printf("could not load model options: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if err := opts.FromMap(req.Options); err != nil {
log.Printf("could not merge model options: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
} }