Allow setting parameters in the REPL (#1294)

This commit is contained in:
Patrick Devine 2023-11-29 09:56:42 -08:00 committed by GitHub
parent 63097607b2
commit cde31cb220
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 154 additions and 87 deletions

View file

@ -6,6 +6,7 @@ import (
"math" "math"
"os" "os"
"reflect" "reflect"
"strconv"
"strings" "strings"
"time" "time"
) )
@ -360,3 +361,63 @@ func (d *Duration) UnmarshalJSON(b []byte) (err error) {
return nil return nil
} }
// FormatParams converts specified parameter options to their correct types
func FormatParams(params map[string][]string) (map[string]interface{}, error) {
opts := Options{}
valueOpts := reflect.ValueOf(&opts).Elem() // names of the fields in the options struct
typeOpts := reflect.TypeOf(opts) // types of the fields in the options struct
// build map of json struct tags to their types
jsonOpts := make(map[string]reflect.StructField)
for _, field := range reflect.VisibleFields(typeOpts) {
jsonTag := strings.Split(field.Tag.Get("json"), ",")[0]
if jsonTag != "" {
jsonOpts[jsonTag] = field
}
}
out := make(map[string]interface{})
// iterate params and set values based on json struct tags
for key, vals := range params {
if opt, ok := jsonOpts[key]; !ok {
return nil, fmt.Errorf("unknown parameter '%s'", key)
} else {
field := valueOpts.FieldByName(opt.Name)
if field.IsValid() && field.CanSet() {
switch field.Kind() {
case reflect.Float32:
floatVal, err := strconv.ParseFloat(vals[0], 32)
if err != nil {
return nil, fmt.Errorf("invalid float value %s", vals)
}
out[key] = float32(floatVal)
case reflect.Int:
intVal, err := strconv.ParseInt(vals[0], 10, 64)
if err != nil {
return nil, fmt.Errorf("invalid int value %s", vals)
}
out[key] = intVal
case reflect.Bool:
boolVal, err := strconv.ParseBool(vals[0])
if err != nil {
return nil, fmt.Errorf("invalid bool value %s", vals)
}
out[key] = boolVal
case reflect.String:
out[key] = vals[0]
case reflect.Slice:
// TODO: only string slices are supported right now
out[key] = vals
default:
return nil, fmt.Errorf("unknown type %s for %s", field.Kind(), key)
}
}
}
}
return out, nil
}

View file

@ -412,10 +412,19 @@ func PullHandler(cmd *cobra.Command, args []string) error {
} }
func RunGenerate(cmd *cobra.Command, args []string) error { func RunGenerate(cmd *cobra.Command, args []string) error {
interactive := true
opts := generateOptions{
Model: args[0],
WordWrap: os.Getenv("TERM") == "xterm-256color",
Options: map[string]interface{}{},
}
format, err := cmd.Flags().GetString("format") format, err := cmd.Flags().GetString("format")
if err != nil { if err != nil {
return err return err
} }
opts.Format = format
prompts := args[1:] prompts := args[1:]
@ -427,34 +436,38 @@ func RunGenerate(cmd *cobra.Command, args []string) error {
} }
prompts = append([]string{string(in)}, prompts...) prompts = append([]string{string(in)}, prompts...)
opts.WordWrap = false
interactive = false
} }
opts.Prompt = strings.Join(prompts, " ")
// output is being piped if len(prompts) > 0 {
if !term.IsTerminal(int(os.Stdout.Fd())) { interactive = false
return generate(cmd, args[0], strings.Join(prompts, " "), false, format)
} }
wordWrap := os.Getenv("TERM") == "xterm-256color"
nowrap, err := cmd.Flags().GetBool("nowordwrap") nowrap, err := cmd.Flags().GetBool("nowordwrap")
if err != nil { if err != nil {
return err return err
} }
if nowrap { opts.WordWrap = !nowrap
wordWrap = false
if !interactive {
return generate(cmd, opts)
} }
// prompts are provided via stdin or args so don't enter interactive mode return generateInteractive(cmd, opts)
if len(prompts) > 0 {
return generate(cmd, args[0], strings.Join(prompts, " "), wordWrap, format)
}
return generateInteractive(cmd, args[0], wordWrap, format)
} }
type generateContextKey string type generateContextKey string
func generate(cmd *cobra.Command, model, prompt string, wordWrap bool, format string) error { type generateOptions struct {
Model string
Prompt string
WordWrap bool
Format string
Options map[string]interface{}
}
func generate(cmd *cobra.Command, opts generateOptions) error {
client, err := api.ClientFromEnvironment() client, err := api.ClientFromEnvironment()
if err != nil { if err != nil {
return err return err
@ -475,7 +488,7 @@ func generate(cmd *cobra.Command, model, prompt string, wordWrap bool, format st
termWidth, _, err := term.GetSize(int(os.Stdout.Fd())) termWidth, _, err := term.GetSize(int(os.Stdout.Fd()))
if err != nil { if err != nil {
wordWrap = false opts.WordWrap = false
} }
cancelCtx, cancel := context.WithCancel(context.Background()) cancelCtx, cancel := context.WithCancel(context.Background())
@ -494,13 +507,19 @@ func generate(cmd *cobra.Command, model, prompt string, wordWrap bool, format st
var currentLineLength int var currentLineLength int
var wordBuffer string var wordBuffer string
request := api.GenerateRequest{Model: model, Prompt: prompt, Context: generateContext, Format: format} request := api.GenerateRequest{
Model: opts.Model,
Prompt: opts.Prompt,
Context: generateContext,
Format: opts.Format,
Options: opts.Options,
}
fn := func(response api.GenerateResponse) error { fn := func(response api.GenerateResponse) error {
p.StopAndClear() p.StopAndClear()
latest = response latest = response
if wordWrap { if opts.WordWrap {
for _, ch := range response.Response { for _, ch := range response.Response {
if currentLineLength+1 > termWidth-5 { if currentLineLength+1 > termWidth-5 {
// backtrack the length of the last word and clear to the end of the line // backtrack the length of the last word and clear to the end of the line
@ -534,7 +553,7 @@ func generate(cmd *cobra.Command, model, prompt string, wordWrap bool, format st
} }
return err return err
} }
if prompt != "" { if opts.Prompt != "" {
fmt.Println() fmt.Println()
fmt.Println() fmt.Println()
} }
@ -562,9 +581,13 @@ func generate(cmd *cobra.Command, model, prompt string, wordWrap bool, format st
return nil return nil
} }
func generateInteractive(cmd *cobra.Command, model string, wordWrap bool, format string) error { func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
// load the model // load the model
if err := generate(cmd, model, "", false, ""); err != nil { loadOpts := generateOptions{
Model: opts.Model,
Prompt: "",
}
if err := generate(cmd, loadOpts); err != nil {
return err return err
} }
@ -581,6 +604,7 @@ func generateInteractive(cmd *cobra.Command, model string, wordWrap bool, format
usageSet := func() { usageSet := func() {
fmt.Fprintln(os.Stderr, "Available Commands:") fmt.Fprintln(os.Stderr, "Available Commands:")
fmt.Fprintln(os.Stderr, " /set parameter Set a parameter")
fmt.Fprintln(os.Stderr, " /set history Enable history") fmt.Fprintln(os.Stderr, " /set history Enable history")
fmt.Fprintln(os.Stderr, " /set nohistory Disable history") fmt.Fprintln(os.Stderr, " /set nohistory Disable history")
fmt.Fprintln(os.Stderr, " /set wordwrap Enable wordwrap") fmt.Fprintln(os.Stderr, " /set wordwrap Enable wordwrap")
@ -602,6 +626,22 @@ func generateInteractive(cmd *cobra.Command, model string, wordWrap bool, format
fmt.Fprintln(os.Stderr, "") fmt.Fprintln(os.Stderr, "")
} }
// only list out the most common parameters
usageParameters := func() {
fmt.Fprintln(os.Stderr, "Available Parameters:")
fmt.Fprintln(os.Stderr, " /set parameter seed <int> Random number seed")
fmt.Fprintln(os.Stderr, " /set parameter num_predict <int> Max number of tokens to predict")
fmt.Fprintln(os.Stderr, " /set parameter top_k <int> Pick from top k num of tokens")
fmt.Fprintln(os.Stderr, " /set parameter top_p <float> Pick token based on sum of probabilities")
fmt.Fprintln(os.Stderr, " /set parameter num_ctx <int> Set the context size")
fmt.Fprintln(os.Stderr, " /set parameter temperature <float> Set creativity level")
fmt.Fprintln(os.Stderr, " /set parameter repeat_penalty <float> How strongly to penalize repetitions")
fmt.Fprintln(os.Stderr, " /set parameter repeat_last_n <int> Set how far back to look for repetitions")
fmt.Fprintln(os.Stderr, " /set parameter num_gpu <int> The number of layers to send to the GPU")
fmt.Fprintln(os.Stderr, " /set parameter stop \"<string>\", ... Set the stop parameters")
fmt.Fprintln(os.Stderr, "")
}
scanner, err := readline.New(readline.Prompt{ scanner, err := readline.New(readline.Prompt{
Prompt: ">>> ", Prompt: ">>> ",
AltPrompt: "... ", AltPrompt: "... ",
@ -670,10 +710,10 @@ func generateInteractive(cmd *cobra.Command, model string, wordWrap bool, format
case "nohistory": case "nohistory":
scanner.HistoryDisable() scanner.HistoryDisable()
case "wordwrap": case "wordwrap":
wordWrap = true opts.WordWrap = true
fmt.Println("Set 'wordwrap' mode.") fmt.Println("Set 'wordwrap' mode.")
case "nowordwrap": case "nowordwrap":
wordWrap = false opts.WordWrap = false
fmt.Println("Set 'nowordwrap' mode.") fmt.Println("Set 'nowordwrap' mode.")
case "verbose": case "verbose":
cmd.Flags().Set("verbose", "true") cmd.Flags().Set("verbose", "true")
@ -685,12 +725,28 @@ func generateInteractive(cmd *cobra.Command, model string, wordWrap bool, format
if len(args) < 3 || args[2] != "json" { if len(args) < 3 || args[2] != "json" {
fmt.Println("Invalid or missing format. For 'json' mode use '/set format json'") fmt.Println("Invalid or missing format. For 'json' mode use '/set format json'")
} else { } else {
format = args[2] opts.Format = args[2]
fmt.Printf("Set format to '%s' mode.\n", args[2]) fmt.Printf("Set format to '%s' mode.\n", args[2])
} }
case "noformat": case "noformat":
format = "" opts.Format = ""
fmt.Println("Disabled format.") fmt.Println("Disabled format.")
case "parameter":
if len(args) < 4 {
usageParameters()
continue
}
var params []string
for _, p := range args[3:] {
params = append(params, p)
}
fp, err := api.FormatParams(map[string][]string{args[2]: params})
if err != nil {
fmt.Printf("Couldn't set parameter: %q\n\n", err)
continue
}
fmt.Printf("Set parameter '%s' to '%s'\n\n", args[2], strings.Join(params, ", "))
opts.Options[args[2]] = fp[args[2]]
default: default:
fmt.Printf("Unknown command '/set %s'. Type /? for help\n", args[1]) fmt.Printf("Unknown command '/set %s'. Type /? for help\n", args[1])
} }
@ -705,7 +761,7 @@ func generateInteractive(cmd *cobra.Command, model string, wordWrap bool, format
fmt.Println("error: couldn't connect to ollama server") fmt.Println("error: couldn't connect to ollama server")
return err return err
} }
resp, err := client.Show(cmd.Context(), &api.ShowRequest{Name: model}) resp, err := client.Show(cmd.Context(), &api.ShowRequest{Name: opts.Model})
if err != nil { if err != nil {
fmt.Println("error: couldn't get model") fmt.Println("error: couldn't get model")
return err return err
@ -724,6 +780,14 @@ func generateInteractive(cmd *cobra.Command, model string, wordWrap bool, format
if resp.Parameters == "" { if resp.Parameters == "" {
fmt.Print("No parameters were specified for this model.\n\n") fmt.Print("No parameters were specified for this model.\n\n")
} else { } else {
if len(opts.Options) > 0 {
fmt.Println("User defined parameters:")
for k, v := range opts.Options {
fmt.Printf("%-*s %v\n", 30, k, v)
}
fmt.Println()
}
fmt.Println("Model defined parameters:")
fmt.Println(resp.Parameters) fmt.Println(resp.Parameters)
} }
case "system": case "system":
@ -767,7 +831,8 @@ func generateInteractive(cmd *cobra.Command, model string, wordWrap bool, format
} }
if len(prompt) > 0 && prompt[0] != '/' { if len(prompt) > 0 && prompt[0] != '/' {
if err := generate(cmd, model, prompt, wordWrap, format); err != nil { opts.Prompt = prompt
if err := generate(cmd, opts); err != nil {
return err return err
} }

View file

@ -14,7 +14,6 @@ import (
"net/url" "net/url"
"os" "os"
"path/filepath" "path/filepath"
"reflect"
"runtime" "runtime"
"strconv" "strconv"
"strings" "strings"
@ -426,7 +425,7 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
if len(params) > 0 { if len(params) > 0 {
fn(api.ProgressResponse{Status: "creating parameters layer"}) fn(api.ProgressResponse{Status: "creating parameters layer"})
formattedParams, err := formatParams(params) formattedParams, err := api.FormatParams(params)
if err != nil { if err != nil {
return err return err
} }
@ -581,64 +580,6 @@ func GetLayerWithBufferFromLayer(layer *Layer) (*LayerReader, error) {
return newLayer, nil return newLayer, nil
} }
// formatParams converts specified parameter options to their correct types
func formatParams(params map[string][]string) (map[string]interface{}, error) {
opts := api.Options{}
valueOpts := reflect.ValueOf(&opts).Elem() // names of the fields in the options struct
typeOpts := reflect.TypeOf(opts) // types of the fields in the options struct
// build map of json struct tags to their types
jsonOpts := make(map[string]reflect.StructField)
for _, field := range reflect.VisibleFields(typeOpts) {
jsonTag := strings.Split(field.Tag.Get("json"), ",")[0]
if jsonTag != "" {
jsonOpts[jsonTag] = field
}
}
out := make(map[string]interface{})
// iterate params and set values based on json struct tags
for key, vals := range params {
if opt, ok := jsonOpts[key]; ok {
field := valueOpts.FieldByName(opt.Name)
if field.IsValid() && field.CanSet() {
switch field.Kind() {
case reflect.Float32:
floatVal, err := strconv.ParseFloat(vals[0], 32)
if err != nil {
return nil, fmt.Errorf("invalid float value %s", vals)
}
out[key] = float32(floatVal)
case reflect.Int:
intVal, err := strconv.ParseInt(vals[0], 10, 64)
if err != nil {
return nil, fmt.Errorf("invalid int value %s", vals)
}
out[key] = intVal
case reflect.Bool:
boolVal, err := strconv.ParseBool(vals[0])
if err != nil {
return nil, fmt.Errorf("invalid bool value %s", vals)
}
out[key] = boolVal
case reflect.String:
out[key] = vals[0]
case reflect.Slice:
// TODO: only string slices are supported right now
out[key] = vals
default:
return nil, fmt.Errorf("unknown type %s for %s", field.Kind(), key)
}
}
}
}
return out, nil
}
func getLayerDigests(layers []*LayerReader) ([]string, error) { func getLayerDigests(layers []*LayerReader) ([]string, error) {
var digests []string var digests []string
for _, l := range layers { for _, l := range layers {