specify stop params separately

This commit is contained in:
Bruce MacDonald 2023-07-28 11:29:00 -04:00
parent 184ad8f057
commit f5cbcb08e6

View file

@ -14,7 +14,6 @@ import (
"path"
"path/filepath"
"reflect"
"regexp"
"strconv"
"strings"
"text/template"
@ -203,7 +202,7 @@ func CreateModel(name string, path string, fn func(resp api.ProgressResponse)) e
}
var layers []*LayerReader
params := make(map[string]string)
params := make(map[string][]string)
for _, c := range commands {
log.Printf("[%s] - %s\n", c.Name, c.Args)
@ -287,8 +286,8 @@ func CreateModel(name string, path string, fn func(resp api.ProgressResponse)) e
layer.MediaType = mediaType
layers = append(layers, layer)
default:
// runtime parameters
params[c.Name] = c.Args
// runtime parameters, build a list of args for each parameter to allow multiple values to be specified (ex: multiple stop tokens)
params[c.Name] = append(params[c.Name], c.Args)
}
}
@ -430,7 +429,7 @@ func GetLayerWithBufferFromLayer(layer *Layer) (*LayerReader, error) {
return newLayer, nil
}
func paramsToReader(params map[string]string) (io.ReadSeeker, error) {
func paramsToReader(params map[string][]string) (io.ReadSeeker, error) {
opts := api.DefaultOptions()
typeOpts := reflect.TypeOf(opts)
@ -445,42 +444,36 @@ func paramsToReader(params map[string]string) (io.ReadSeeker, error) {
valueOpts := reflect.ValueOf(&opts).Elem()
// iterate params and set values based on json struct tags
for key, val := range params {
for key, vals := range params {
if opt, ok := jsonOpts[key]; ok {
field := valueOpts.FieldByName(opt.Name)
if field.IsValid() && field.CanSet() {
switch field.Kind() {
case reflect.Float32:
floatVal, err := strconv.ParseFloat(val, 32)
floatVal, err := strconv.ParseFloat(vals[0], 32)
if err != nil {
return nil, fmt.Errorf("invalid float value %s", val)
return nil, fmt.Errorf("invalid float value %s", vals)
}
field.SetFloat(floatVal)
case reflect.Int:
intVal, err := strconv.ParseInt(val, 10, 0)
intVal, err := strconv.ParseInt(vals[0], 10, 0)
if err != nil {
return nil, fmt.Errorf("invalid int value %s", val)
return nil, fmt.Errorf("invalid int value %s", vals)
}
field.SetInt(intVal)
case reflect.Bool:
boolVal, err := strconv.ParseBool(val)
boolVal, err := strconv.ParseBool(vals[0])
if err != nil {
return nil, fmt.Errorf("invalid bool value %s", val)
return nil, fmt.Errorf("invalid bool value %s", vals)
}
field.SetBool(boolVal)
case reflect.String:
field.SetString(val)
field.SetString(vals[0])
case reflect.Slice:
re := regexp.MustCompile(`"(.*?)"`) // matches everything enclosed in quotes
vals := re.FindAllStringSubmatch(val, -1)
var sliceVal []string
for _, v := range vals {
sliceVal = append(sliceVal, v[1]) // v[1] is the captured group, v[0] is the entire match
}
field.Set(reflect.ValueOf(sliceVal))
field.Set(reflect.ValueOf(vals))
default:
return nil, fmt.Errorf("unknown type %s for %s", field.Kind(), key)
}