allow specifying zero values in modelfile
This commit is contained in:
commit
8b1e791820
5 changed files with 101 additions and 25 deletions
80
api/types.go
80
api/types.go
|
@ -3,9 +3,12 @@ package api
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log"
|
||||||
"math"
|
"math"
|
||||||
"os"
|
"os"
|
||||||
|
"reflect"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -34,7 +37,7 @@ type GenerateRequest struct {
|
||||||
Prompt string `json:"prompt"`
|
Prompt string `json:"prompt"`
|
||||||
Context []int `json:"context,omitempty"`
|
Context []int `json:"context,omitempty"`
|
||||||
|
|
||||||
Options `json:"options"`
|
Options map[string]interface{} `json:"options"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type CreateRequest struct {
|
type CreateRequest struct {
|
||||||
|
@ -177,6 +180,81 @@ type Options struct {
|
||||||
NumThread int `json:"num_thread,omitempty"`
|
NumThread int `json:"num_thread,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (opts *Options) FromMap(m map[string]interface{}) error {
|
||||||
|
valueOpts := reflect.ValueOf(opts).Elem() // names of the fields in the options struct
|
||||||
|
typeOpts := reflect.TypeOf(opts).Elem() // types of the fields in the options struct
|
||||||
|
|
||||||
|
// build map of json struct tags to their types
|
||||||
|
jsonOpts := make(map[string]reflect.StructField)
|
||||||
|
for _, field := range reflect.VisibleFields(typeOpts) {
|
||||||
|
jsonTag := strings.Split(field.Tag.Get("json"), ",")[0]
|
||||||
|
if jsonTag != "" {
|
||||||
|
jsonOpts[jsonTag] = field
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for key, val := range m {
|
||||||
|
if opt, ok := jsonOpts[key]; ok {
|
||||||
|
field := valueOpts.FieldByName(opt.Name)
|
||||||
|
if field.IsValid() && field.CanSet() {
|
||||||
|
switch field.Kind() {
|
||||||
|
case reflect.Int:
|
||||||
|
// when JSON unmarshals numbers, it uses float64 by default, not int
|
||||||
|
val, ok := val.(float64)
|
||||||
|
if !ok {
|
||||||
|
log.Printf("could not convert model parmeter %v to int, skipped", key)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
field.SetInt(int64(val))
|
||||||
|
case reflect.Bool:
|
||||||
|
val, ok := val.(bool)
|
||||||
|
if !ok {
|
||||||
|
log.Printf("could not convert model parmeter %v to bool, skipped", key)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
field.SetBool(val)
|
||||||
|
case reflect.Float32:
|
||||||
|
// JSON unmarshals to float64
|
||||||
|
val, ok := val.(float64)
|
||||||
|
if !ok {
|
||||||
|
log.Printf("could not convert model parmeter %v to float32, skipped", key)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
field.SetFloat(val)
|
||||||
|
case reflect.String:
|
||||||
|
val, ok := val.(string)
|
||||||
|
if !ok {
|
||||||
|
log.Printf("could not convert model parmeter %v to string, skipped", key)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
field.SetString(val)
|
||||||
|
case reflect.Slice:
|
||||||
|
// JSON unmarshals to []interface{}, not []string
|
||||||
|
val, ok := val.([]interface{})
|
||||||
|
if !ok {
|
||||||
|
log.Printf("could not convert model parmeter %v to slice, skipped", key)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// convert []interface{} to []string
|
||||||
|
slice := make([]string, len(val))
|
||||||
|
for i, item := range val {
|
||||||
|
str, ok := item.(string)
|
||||||
|
if !ok {
|
||||||
|
log.Printf("could not convert model parmeter %v to slice of strings, skipped", key)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
slice[i] = str
|
||||||
|
}
|
||||||
|
field.Set(reflect.ValueOf(slice))
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unknown type loading config params: %v", field.Kind())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func DefaultOptions() Options {
|
func DefaultOptions() Options {
|
||||||
return Options{
|
return Options{
|
||||||
Seed: -1,
|
Seed: -1,
|
||||||
|
|
1
go.mod
1
go.mod
|
@ -14,7 +14,6 @@ require (
|
||||||
require github.com/rivo/uniseg v0.2.0 // indirect
|
require github.com/rivo/uniseg v0.2.0 // indirect
|
||||||
|
|
||||||
require (
|
require (
|
||||||
dario.cat/mergo v1.0.0
|
|
||||||
github.com/bytedance/sonic v1.9.1 // indirect
|
github.com/bytedance/sonic v1.9.1 // indirect
|
||||||
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
|
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
|
||||||
github.com/chzyer/readline v1.5.1
|
github.com/chzyer/readline v1.5.1
|
||||||
|
|
2
go.sum
2
go.sum
|
@ -1,5 +1,3 @@
|
||||||
dario.cat/mergo v1.0.0 h1:AGCNq9Evsj31mOgNPcLyXc+4PNABt905YmuqPYYpBWk=
|
|
||||||
dario.cat/mergo v1.0.0/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk=
|
|
||||||
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
|
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
|
||||||
github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
|
github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
|
||||||
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
|
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
|
||||||
|
|
|
@ -33,7 +33,7 @@ type Model struct {
|
||||||
Template string
|
Template string
|
||||||
System string
|
System string
|
||||||
Digest string
|
Digest string
|
||||||
Options api.Options
|
Options map[string]interface{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) Prompt(request api.GenerateRequest) (string, error) {
|
func (m *Model) Prompt(request api.GenerateRequest) (string, error) {
|
||||||
|
@ -176,12 +176,10 @@ func GetModel(name string) (*Model, error) {
|
||||||
}
|
}
|
||||||
defer params.Close()
|
defer params.Close()
|
||||||
|
|
||||||
var opts api.Options
|
// parse model options parameters into a map so that we can see which fields have been specified explicitly
|
||||||
if err = json.NewDecoder(params).Decode(&opts); err != nil {
|
if err = json.NewDecoder(params).Decode(&model.Options); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
model.Options = opts
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -442,11 +440,13 @@ func GetLayerWithBufferFromLayer(layer *Layer) (*LayerReader, error) {
|
||||||
return newLayer, nil
|
return newLayer, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// paramsToReader converts specified parameter options to their correct types, and returns a reader for the json
|
||||||
func paramsToReader(params map[string][]string) (io.ReadSeeker, error) {
|
func paramsToReader(params map[string][]string) (io.ReadSeeker, error) {
|
||||||
opts := api.DefaultOptions()
|
opts := api.Options{}
|
||||||
typeOpts := reflect.TypeOf(opts)
|
valueOpts := reflect.ValueOf(&opts).Elem() // names of the fields in the options struct
|
||||||
|
typeOpts := reflect.TypeOf(opts) // types of the fields in the options struct
|
||||||
|
|
||||||
// build map of json struct tags
|
// build map of json struct tags to their types
|
||||||
jsonOpts := make(map[string]reflect.StructField)
|
jsonOpts := make(map[string]reflect.StructField)
|
||||||
for _, field := range reflect.VisibleFields(typeOpts) {
|
for _, field := range reflect.VisibleFields(typeOpts) {
|
||||||
jsonTag := strings.Split(field.Tag.Get("json"), ",")[0]
|
jsonTag := strings.Split(field.Tag.Get("json"), ",")[0]
|
||||||
|
@ -455,7 +455,7 @@ func paramsToReader(params map[string][]string) (io.ReadSeeker, error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
valueOpts := reflect.ValueOf(&opts).Elem()
|
out := make(map[string]interface{})
|
||||||
// iterate params and set values based on json struct tags
|
// iterate params and set values based on json struct tags
|
||||||
for key, vals := range params {
|
for key, vals := range params {
|
||||||
if opt, ok := jsonOpts[key]; ok {
|
if opt, ok := jsonOpts[key]; ok {
|
||||||
|
@ -468,25 +468,26 @@ func paramsToReader(params map[string][]string) (io.ReadSeeker, error) {
|
||||||
return nil, fmt.Errorf("invalid float value %s", vals)
|
return nil, fmt.Errorf("invalid float value %s", vals)
|
||||||
}
|
}
|
||||||
|
|
||||||
field.SetFloat(floatVal)
|
out[key] = floatVal
|
||||||
case reflect.Int:
|
case reflect.Int:
|
||||||
intVal, err := strconv.ParseInt(vals[0], 10, 0)
|
intVal, err := strconv.ParseInt(vals[0], 10, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("invalid int value %s", vals)
|
return nil, fmt.Errorf("invalid int value %s", vals)
|
||||||
}
|
}
|
||||||
|
|
||||||
field.SetInt(intVal)
|
out[key] = intVal
|
||||||
case reflect.Bool:
|
case reflect.Bool:
|
||||||
boolVal, err := strconv.ParseBool(vals[0])
|
boolVal, err := strconv.ParseBool(vals[0])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("invalid bool value %s", vals)
|
return nil, fmt.Errorf("invalid bool value %s", vals)
|
||||||
}
|
}
|
||||||
|
|
||||||
field.SetBool(boolVal)
|
out[key] = boolVal
|
||||||
case reflect.String:
|
case reflect.String:
|
||||||
field.SetString(vals[0])
|
out[key] = vals[0]
|
||||||
case reflect.Slice:
|
case reflect.Slice:
|
||||||
field.Set(reflect.ValueOf(vals))
|
// TODO: only string slices are supported right now
|
||||||
|
out[key] = vals
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("unknown type %s for %s", field.Kind(), key)
|
return nil, fmt.Errorf("unknown type %s for %s", field.Kind(), key)
|
||||||
}
|
}
|
||||||
|
@ -494,7 +495,7 @@ func paramsToReader(params map[string][]string) (io.ReadSeeker, error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bts, err := json.Marshal(opts)
|
bts, err := json.Marshal(out)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,7 +15,6 @@ import (
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"dario.cat/mergo"
|
|
||||||
"github.com/gin-contrib/cors"
|
"github.com/gin-contrib/cors"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
|
||||||
|
@ -61,12 +60,13 @@ func GenerateHandler(c *gin.Context) {
|
||||||
}
|
}
|
||||||
|
|
||||||
opts := api.DefaultOptions()
|
opts := api.DefaultOptions()
|
||||||
if err := mergo.Merge(&opts, model.Options, mergo.WithOverride); err != nil {
|
if err := opts.FromMap(model.Options); err != nil {
|
||||||
|
log.Printf("could not load model options: %v", err)
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if err := opts.FromMap(req.Options); err != nil {
|
||||||
if err := mergo.Merge(&opts, req.Options, mergo.WithOverride); err != nil {
|
log.Printf("could not merge model options: %v", err)
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue