specify stop params separately

This commit is contained in:
Bruce MacDonald 2023-07-28 11:29:00 -04:00
parent 184ad8f057
commit f5cbcb08e6

View file

@ -14,7 +14,6 @@ import (
"path" "path"
"path/filepath" "path/filepath"
"reflect" "reflect"
"regexp"
"strconv" "strconv"
"strings" "strings"
"text/template" "text/template"
@ -203,7 +202,7 @@ func CreateModel(name string, path string, fn func(resp api.ProgressResponse)) e
} }
var layers []*LayerReader var layers []*LayerReader
params := make(map[string]string) params := make(map[string][]string)
for _, c := range commands { for _, c := range commands {
log.Printf("[%s] - %s\n", c.Name, c.Args) log.Printf("[%s] - %s\n", c.Name, c.Args)
@ -287,8 +286,8 @@ func CreateModel(name string, path string, fn func(resp api.ProgressResponse)) e
layer.MediaType = mediaType layer.MediaType = mediaType
layers = append(layers, layer) layers = append(layers, layer)
default: default:
// runtime parameters // runtime parameters, build a list of args for each parameter to allow multiple values to be specified (ex: multiple stop tokens)
params[c.Name] = c.Args params[c.Name] = append(params[c.Name], c.Args)
} }
} }
@ -430,7 +429,7 @@ func GetLayerWithBufferFromLayer(layer *Layer) (*LayerReader, error) {
return newLayer, nil return newLayer, nil
} }
func paramsToReader(params map[string]string) (io.ReadSeeker, error) { func paramsToReader(params map[string][]string) (io.ReadSeeker, error) {
opts := api.DefaultOptions() opts := api.DefaultOptions()
typeOpts := reflect.TypeOf(opts) typeOpts := reflect.TypeOf(opts)
@ -445,42 +444,36 @@ func paramsToReader(params map[string]string) (io.ReadSeeker, error) {
valueOpts := reflect.ValueOf(&opts).Elem() valueOpts := reflect.ValueOf(&opts).Elem()
// iterate params and set values based on json struct tags // iterate params and set values based on json struct tags
for key, val := range params { for key, vals := range params {
if opt, ok := jsonOpts[key]; ok { if opt, ok := jsonOpts[key]; ok {
field := valueOpts.FieldByName(opt.Name) field := valueOpts.FieldByName(opt.Name)
if field.IsValid() && field.CanSet() { if field.IsValid() && field.CanSet() {
switch field.Kind() { switch field.Kind() {
case reflect.Float32: case reflect.Float32:
floatVal, err := strconv.ParseFloat(val, 32) floatVal, err := strconv.ParseFloat(vals[0], 32)
if err != nil { if err != nil {
return nil, fmt.Errorf("invalid float value %s", val) return nil, fmt.Errorf("invalid float value %s", vals)
} }
field.SetFloat(floatVal) field.SetFloat(floatVal)
case reflect.Int: case reflect.Int:
intVal, err := strconv.ParseInt(val, 10, 0) intVal, err := strconv.ParseInt(vals[0], 10, 0)
if err != nil { if err != nil {
return nil, fmt.Errorf("invalid int value %s", val) return nil, fmt.Errorf("invalid int value %s", vals)
} }
field.SetInt(intVal) field.SetInt(intVal)
case reflect.Bool: case reflect.Bool:
boolVal, err := strconv.ParseBool(val) boolVal, err := strconv.ParseBool(vals[0])
if err != nil { if err != nil {
return nil, fmt.Errorf("invalid bool value %s", val) return nil, fmt.Errorf("invalid bool value %s", vals)
} }
field.SetBool(boolVal) field.SetBool(boolVal)
case reflect.String: case reflect.String:
field.SetString(val) field.SetString(vals[0])
case reflect.Slice: case reflect.Slice:
re := regexp.MustCompile(`"(.*?)"`) // matches everything enclosed in quotes field.Set(reflect.ValueOf(vals))
vals := re.FindAllStringSubmatch(val, -1)
var sliceVal []string
for _, v := range vals {
sliceVal = append(sliceVal, v[1]) // v[1] is the captured group, v[0] is the entire match
}
field.Set(reflect.ValueOf(sliceVal))
default: default:
return nil, fmt.Errorf("unknown type %s for %s", field.Kind(), key) return nil, fmt.Errorf("unknown type %s for %s", field.Kind(), key)
} }