From 184ad8f05795fcbd6c02ae7b4cba4121f64a8720 Mon Sep 17 00:00:00 2001 From: Bruce MacDonald Date: Thu, 27 Jul 2023 17:02:14 -0400 Subject: [PATCH] allow specifying stop conditions in modelfile --- api/types.go | 2 +- llama/llama.go | 2 +- server/images.go | 9 +++++++++ 3 files changed, 11 insertions(+), 2 deletions(-) diff --git a/api/types.go b/api/types.go index 2e3b0578..0b8603fe 100644 --- a/api/types.go +++ b/api/types.go @@ -178,7 +178,7 @@ type Options struct { MirostatTau float32 `json:"mirostat_tau,omitempty"` MirostatEta float32 `json:"mirostat_eta,omitempty"` PenalizeNewline bool `json:"penalize_newline,omitempty"` - StopConditions []string `json:"stop_conditions,omitempty"` + Stop []string `json:"stop,omitempty"` NumThread int `json:"num_thread,omitempty"` } diff --git a/llama/llama.go b/llama/llama.go index 9032bf5f..3b95cfab 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -246,7 +246,7 @@ func (llm *LLM) Predict(ctx []int, prompt string, fn func(api.GenerateResponse)) } func (llm *LLM) checkStopConditions(b bytes.Buffer) error { - for _, stopCondition := range llm.StopConditions { + for _, stopCondition := range llm.Stop { if stopCondition == b.String() { return io.EOF } else if strings.HasPrefix(stopCondition, b.String()) { diff --git a/server/images.go b/server/images.go index f6b995a9..933fe26c 100644 --- a/server/images.go +++ b/server/images.go @@ -14,6 +14,7 @@ import ( "path" "path/filepath" "reflect" + "regexp" "strconv" "strings" "text/template" @@ -472,6 +473,14 @@ func paramsToReader(params map[string]string) (io.ReadSeeker, error) { field.SetBool(boolVal) case reflect.String: field.SetString(val) + case reflect.Slice: + re := regexp.MustCompile(`"(.*?)"`) // matches everything enclosed in quotes + vals := re.FindAllStringSubmatch(val, -1) + var sliceVal []string + for _, v := range vals { + sliceVal = append(sliceVal, v[1]) // v[1] is the captured group, v[0] is the entire match + } + field.Set(reflect.ValueOf(sliceVal)) default: return nil, fmt.Errorf("unknown type %s for %s", field.Kind(), key) }