specify stop params separately
This commit is contained in:
parent
184ad8f057
commit
f5cbcb08e6
1 changed files with 13 additions and 20 deletions
|
@ -14,7 +14,6 @@ import (
|
||||||
"path"
|
"path"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"reflect"
|
"reflect"
|
||||||
"regexp"
|
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"text/template"
|
"text/template"
|
||||||
|
@ -203,7 +202,7 @@ func CreateModel(name string, path string, fn func(resp api.ProgressResponse)) e
|
||||||
}
|
}
|
||||||
|
|
||||||
var layers []*LayerReader
|
var layers []*LayerReader
|
||||||
params := make(map[string]string)
|
params := make(map[string][]string)
|
||||||
|
|
||||||
for _, c := range commands {
|
for _, c := range commands {
|
||||||
log.Printf("[%s] - %s\n", c.Name, c.Args)
|
log.Printf("[%s] - %s\n", c.Name, c.Args)
|
||||||
|
@ -287,8 +286,8 @@ func CreateModel(name string, path string, fn func(resp api.ProgressResponse)) e
|
||||||
layer.MediaType = mediaType
|
layer.MediaType = mediaType
|
||||||
layers = append(layers, layer)
|
layers = append(layers, layer)
|
||||||
default:
|
default:
|
||||||
// runtime parameters
|
// runtime parameters, build a list of args for each parameter to allow multiple values to be specified (ex: multiple stop tokens)
|
||||||
params[c.Name] = c.Args
|
params[c.Name] = append(params[c.Name], c.Args)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -430,7 +429,7 @@ func GetLayerWithBufferFromLayer(layer *Layer) (*LayerReader, error) {
|
||||||
return newLayer, nil
|
return newLayer, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func paramsToReader(params map[string]string) (io.ReadSeeker, error) {
|
func paramsToReader(params map[string][]string) (io.ReadSeeker, error) {
|
||||||
opts := api.DefaultOptions()
|
opts := api.DefaultOptions()
|
||||||
typeOpts := reflect.TypeOf(opts)
|
typeOpts := reflect.TypeOf(opts)
|
||||||
|
|
||||||
|
@ -445,42 +444,36 @@ func paramsToReader(params map[string]string) (io.ReadSeeker, error) {
|
||||||
|
|
||||||
valueOpts := reflect.ValueOf(&opts).Elem()
|
valueOpts := reflect.ValueOf(&opts).Elem()
|
||||||
// iterate params and set values based on json struct tags
|
// iterate params and set values based on json struct tags
|
||||||
for key, val := range params {
|
for key, vals := range params {
|
||||||
if opt, ok := jsonOpts[key]; ok {
|
if opt, ok := jsonOpts[key]; ok {
|
||||||
field := valueOpts.FieldByName(opt.Name)
|
field := valueOpts.FieldByName(opt.Name)
|
||||||
if field.IsValid() && field.CanSet() {
|
if field.IsValid() && field.CanSet() {
|
||||||
switch field.Kind() {
|
switch field.Kind() {
|
||||||
case reflect.Float32:
|
case reflect.Float32:
|
||||||
floatVal, err := strconv.ParseFloat(val, 32)
|
floatVal, err := strconv.ParseFloat(vals[0], 32)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("invalid float value %s", val)
|
return nil, fmt.Errorf("invalid float value %s", vals)
|
||||||
}
|
}
|
||||||
|
|
||||||
field.SetFloat(floatVal)
|
field.SetFloat(floatVal)
|
||||||
case reflect.Int:
|
case reflect.Int:
|
||||||
intVal, err := strconv.ParseInt(val, 10, 0)
|
intVal, err := strconv.ParseInt(vals[0], 10, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("invalid int value %s", val)
|
return nil, fmt.Errorf("invalid int value %s", vals)
|
||||||
}
|
}
|
||||||
|
|
||||||
field.SetInt(intVal)
|
field.SetInt(intVal)
|
||||||
case reflect.Bool:
|
case reflect.Bool:
|
||||||
boolVal, err := strconv.ParseBool(val)
|
boolVal, err := strconv.ParseBool(vals[0])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("invalid bool value %s", val)
|
return nil, fmt.Errorf("invalid bool value %s", vals)
|
||||||
}
|
}
|
||||||
|
|
||||||
field.SetBool(boolVal)
|
field.SetBool(boolVal)
|
||||||
case reflect.String:
|
case reflect.String:
|
||||||
field.SetString(val)
|
field.SetString(vals[0])
|
||||||
case reflect.Slice:
|
case reflect.Slice:
|
||||||
re := regexp.MustCompile(`"(.*?)"`) // matches everything enclosed in quotes
|
field.Set(reflect.ValueOf(vals))
|
||||||
vals := re.FindAllStringSubmatch(val, -1)
|
|
||||||
var sliceVal []string
|
|
||||||
for _, v := range vals {
|
|
||||||
sliceVal = append(sliceVal, v[1]) // v[1] is the captured group, v[0] is the entire match
|
|
||||||
}
|
|
||||||
field.Set(reflect.ValueOf(sliceVal))
|
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("unknown type %s for %s", field.Kind(), key)
|
return nil, fmt.Errorf("unknown type %s for %s", field.Kind(), key)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue