From f5cbcb08e65a039f1ad0c5d543cdf154c47496e1 Mon Sep 17 00:00:00 2001 From: Bruce MacDonald Date: Fri, 28 Jul 2023 11:29:00 -0400 Subject: [PATCH] specify stop params separately --- server/images.go | 33 +++++++++++++-------------------- 1 file changed, 13 insertions(+), 20 deletions(-) diff --git a/server/images.go b/server/images.go index 933fe26c..6bd7e882 100644 --- a/server/images.go +++ b/server/images.go @@ -14,7 +14,6 @@ import ( "path" "path/filepath" "reflect" - "regexp" "strconv" "strings" "text/template" @@ -203,7 +202,7 @@ func CreateModel(name string, path string, fn func(resp api.ProgressResponse)) e } var layers []*LayerReader - params := make(map[string]string) + params := make(map[string][]string) for _, c := range commands { log.Printf("[%s] - %s\n", c.Name, c.Args) @@ -287,8 +286,8 @@ func CreateModel(name string, path string, fn func(resp api.ProgressResponse)) e layer.MediaType = mediaType layers = append(layers, layer) default: - // runtime parameters - params[c.Name] = c.Args + // runtime parameters, build a list of args for each parameter to allow multiple values to be specified (ex: multiple stop tokens) + params[c.Name] = append(params[c.Name], c.Args) } } @@ -430,7 +429,7 @@ func GetLayerWithBufferFromLayer(layer *Layer) (*LayerReader, error) { return newLayer, nil } -func paramsToReader(params map[string]string) (io.ReadSeeker, error) { +func paramsToReader(params map[string][]string) (io.ReadSeeker, error) { opts := api.DefaultOptions() typeOpts := reflect.TypeOf(opts) @@ -445,42 +444,36 @@ func paramsToReader(params map[string]string) (io.ReadSeeker, error) { valueOpts := reflect.ValueOf(&opts).Elem() // iterate params and set values based on json struct tags - for key, val := range params { + for key, vals := range params { if opt, ok := jsonOpts[key]; ok { field := valueOpts.FieldByName(opt.Name) if field.IsValid() && field.CanSet() { switch field.Kind() { case reflect.Float32: - floatVal, err := strconv.ParseFloat(val, 32) + floatVal, err := strconv.ParseFloat(vals[0], 32) if err != nil { - return nil, fmt.Errorf("invalid float value %s", val) + return nil, fmt.Errorf("invalid float value %s", vals) } field.SetFloat(floatVal) case reflect.Int: - intVal, err := strconv.ParseInt(val, 10, 0) + intVal, err := strconv.ParseInt(vals[0], 10, 0) if err != nil { - return nil, fmt.Errorf("invalid int value %s", val) + return nil, fmt.Errorf("invalid int value %s", vals) } field.SetInt(intVal) case reflect.Bool: - boolVal, err := strconv.ParseBool(val) + boolVal, err := strconv.ParseBool(vals[0]) if err != nil { - return nil, fmt.Errorf("invalid bool value %s", val) + return nil, fmt.Errorf("invalid bool value %s", vals) } field.SetBool(boolVal) case reflect.String: - field.SetString(val) + field.SetString(vals[0]) case reflect.Slice: - re := regexp.MustCompile(`"(.*?)"`) // matches everything enclosed in quotes - vals := re.FindAllStringSubmatch(val, -1) - var sliceVal []string - for _, v := range vals { - sliceVal = append(sliceVal, v[1]) // v[1] is the captured group, v[0] is the entire match - } - field.Set(reflect.ValueOf(sliceVal)) + field.Set(reflect.ValueOf(vals)) default: return nil, fmt.Errorf("unknown type %s for %s", field.Kind(), key) }