allow specifying stop conditions in modelfile

This commit is contained in:
Bruce MacDonald 2023-07-27 17:02:14 -04:00
parent 822a0e36eb
commit 184ad8f057
3 changed files with 11 additions and 2 deletions

View file

@ -178,7 +178,7 @@ type Options struct {
MirostatTau float32 `json:"mirostat_tau,omitempty"`
MirostatEta float32 `json:"mirostat_eta,omitempty"`
PenalizeNewline bool `json:"penalize_newline,omitempty"`
StopConditions []string `json:"stop_conditions,omitempty"`
Stop []string `json:"stop,omitempty"`
NumThread int `json:"num_thread,omitempty"`
}

View file

@ -246,7 +246,7 @@ func (llm *LLM) Predict(ctx []int, prompt string, fn func(api.GenerateResponse))
}
func (llm *LLM) checkStopConditions(b bytes.Buffer) error {
for _, stopCondition := range llm.StopConditions {
for _, stopCondition := range llm.Stop {
if stopCondition == b.String() {
return io.EOF
} else if strings.HasPrefix(stopCondition, b.String()) {

View file

@ -14,6 +14,7 @@ import (
"path"
"path/filepath"
"reflect"
"regexp"
"strconv"
"strings"
"text/template"
@ -472,6 +473,14 @@ func paramsToReader(params map[string]string) (io.ReadSeeker, error) {
field.SetBool(boolVal)
case reflect.String:
field.SetString(val)
case reflect.Slice:
re := regexp.MustCompile(`"(.*?)"`) // matches everything enclosed in quotes
vals := re.FindAllStringSubmatch(val, -1)
var sliceVal []string
for _, v := range vals {
sliceVal = append(sliceVal, v[1]) // v[1] is the captured group, v[0] is the entire match
}
field.Set(reflect.ValueOf(sliceVal))
default:
return nil, fmt.Errorf("unknown type %s for %s", field.Kind(), key)
}