diff --git a/api/types.go b/api/types.go index 2a36a1f6..692c4445 100644 --- a/api/types.go +++ b/api/types.go @@ -6,6 +6,7 @@ import ( "math" "os" "reflect" + "strconv" "strings" "time" ) @@ -360,3 +361,63 @@ func (d *Duration) UnmarshalJSON(b []byte) (err error) { return nil } + +// FormatParams converts specified parameter options to their correct types +func FormatParams(params map[string][]string) (map[string]interface{}, error) { + opts := Options{} + valueOpts := reflect.ValueOf(&opts).Elem() // names of the fields in the options struct + typeOpts := reflect.TypeOf(opts) // types of the fields in the options struct + + // build map of json struct tags to their types + jsonOpts := make(map[string]reflect.StructField) + for _, field := range reflect.VisibleFields(typeOpts) { + jsonTag := strings.Split(field.Tag.Get("json"), ",")[0] + if jsonTag != "" { + jsonOpts[jsonTag] = field + } + } + + out := make(map[string]interface{}) + // iterate params and set values based on json struct tags + for key, vals := range params { + if opt, ok := jsonOpts[key]; !ok { + return nil, fmt.Errorf("unknown parameter '%s'", key) + } else { + field := valueOpts.FieldByName(opt.Name) + if field.IsValid() && field.CanSet() { + switch field.Kind() { + case reflect.Float32: + floatVal, err := strconv.ParseFloat(vals[0], 32) + if err != nil { + return nil, fmt.Errorf("invalid float value %s", vals) + } + + out[key] = float32(floatVal) + case reflect.Int: + intVal, err := strconv.ParseInt(vals[0], 10, 64) + if err != nil { + return nil, fmt.Errorf("invalid int value %s", vals) + } + + out[key] = intVal + case reflect.Bool: + boolVal, err := strconv.ParseBool(vals[0]) + if err != nil { + return nil, fmt.Errorf("invalid bool value %s", vals) + } + + out[key] = boolVal + case reflect.String: + out[key] = vals[0] + case reflect.Slice: + // TODO: only string slices are supported right now + out[key] = vals + default: + return nil, fmt.Errorf("unknown type %s for %s", field.Kind(), key) + } + } + } + } + + return out, nil +} diff --git a/cmd/cmd.go b/cmd/cmd.go index 970a398b..0d5ffb69 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -412,10 +412,19 @@ func PullHandler(cmd *cobra.Command, args []string) error { } func RunGenerate(cmd *cobra.Command, args []string) error { + interactive := true + + opts := generateOptions{ + Model: args[0], + WordWrap: os.Getenv("TERM") == "xterm-256color", + Options: map[string]interface{}{}, + } + format, err := cmd.Flags().GetString("format") if err != nil { return err } + opts.Format = format prompts := args[1:] @@ -427,34 +436,38 @@ func RunGenerate(cmd *cobra.Command, args []string) error { } prompts = append([]string{string(in)}, prompts...) + opts.WordWrap = false + interactive = false } - - // output is being piped - if !term.IsTerminal(int(os.Stdout.Fd())) { - return generate(cmd, args[0], strings.Join(prompts, " "), false, format) + opts.Prompt = strings.Join(prompts, " ") + if len(prompts) > 0 { + interactive = false } - wordWrap := os.Getenv("TERM") == "xterm-256color" - nowrap, err := cmd.Flags().GetBool("nowordwrap") if err != nil { return err } - if nowrap { - wordWrap = false + opts.WordWrap = !nowrap + + if !interactive { + return generate(cmd, opts) } - // prompts are provided via stdin or args so don't enter interactive mode - if len(prompts) > 0 { - return generate(cmd, args[0], strings.Join(prompts, " "), wordWrap, format) - } - - return generateInteractive(cmd, args[0], wordWrap, format) + return generateInteractive(cmd, opts) } type generateContextKey string -func generate(cmd *cobra.Command, model, prompt string, wordWrap bool, format string) error { +type generateOptions struct { + Model string + Prompt string + WordWrap bool + Format string + Options map[string]interface{} +} + +func generate(cmd *cobra.Command, opts generateOptions) error { client, err := api.ClientFromEnvironment() if err != nil { return err @@ -475,7 +488,7 @@ func generate(cmd *cobra.Command, model, prompt string, wordWrap bool, format st termWidth, _, err := term.GetSize(int(os.Stdout.Fd())) if err != nil { - wordWrap = false + opts.WordWrap = false } cancelCtx, cancel := context.WithCancel(context.Background()) @@ -494,13 +507,19 @@ func generate(cmd *cobra.Command, model, prompt string, wordWrap bool, format st var currentLineLength int var wordBuffer string - request := api.GenerateRequest{Model: model, Prompt: prompt, Context: generateContext, Format: format} + request := api.GenerateRequest{ + Model: opts.Model, + Prompt: opts.Prompt, + Context: generateContext, + Format: opts.Format, + Options: opts.Options, + } fn := func(response api.GenerateResponse) error { p.StopAndClear() latest = response - if wordWrap { + if opts.WordWrap { for _, ch := range response.Response { if currentLineLength+1 > termWidth-5 { // backtrack the length of the last word and clear to the end of the line @@ -534,7 +553,7 @@ func generate(cmd *cobra.Command, model, prompt string, wordWrap bool, format st } return err } - if prompt != "" { + if opts.Prompt != "" { fmt.Println() fmt.Println() } @@ -562,9 +581,13 @@ func generate(cmd *cobra.Command, model, prompt string, wordWrap bool, format st return nil } -func generateInteractive(cmd *cobra.Command, model string, wordWrap bool, format string) error { +func generateInteractive(cmd *cobra.Command, opts generateOptions) error { // load the model - if err := generate(cmd, model, "", false, ""); err != nil { + loadOpts := generateOptions{ + Model: opts.Model, + Prompt: "", + } + if err := generate(cmd, loadOpts); err != nil { return err } @@ -581,6 +604,7 @@ func generateInteractive(cmd *cobra.Command, model string, wordWrap bool, format usageSet := func() { fmt.Fprintln(os.Stderr, "Available Commands:") + fmt.Fprintln(os.Stderr, " /set parameter Set a parameter") fmt.Fprintln(os.Stderr, " /set history Enable history") fmt.Fprintln(os.Stderr, " /set nohistory Disable history") fmt.Fprintln(os.Stderr, " /set wordwrap Enable wordwrap") @@ -602,6 +626,22 @@ func generateInteractive(cmd *cobra.Command, model string, wordWrap bool, format fmt.Fprintln(os.Stderr, "") } + // only list out the most common parameters + usageParameters := func() { + fmt.Fprintln(os.Stderr, "Available Parameters:") + fmt.Fprintln(os.Stderr, " /set parameter seed Random number seed") + fmt.Fprintln(os.Stderr, " /set parameter num_predict Max number of tokens to predict") + fmt.Fprintln(os.Stderr, " /set parameter top_k Pick from top k num of tokens") + fmt.Fprintln(os.Stderr, " /set parameter top_p Pick token based on sum of probabilities") + fmt.Fprintln(os.Stderr, " /set parameter num_ctx Set the context size") + fmt.Fprintln(os.Stderr, " /set parameter temperature Set creativity level") + fmt.Fprintln(os.Stderr, " /set parameter repeat_penalty How strongly to penalize repetitions") + fmt.Fprintln(os.Stderr, " /set parameter repeat_last_n Set how far back to look for repetitions") + fmt.Fprintln(os.Stderr, " /set parameter num_gpu The number of layers to send to the GPU") + fmt.Fprintln(os.Stderr, " /set parameter stop \"\", ... Set the stop parameters") + fmt.Fprintln(os.Stderr, "") + } + scanner, err := readline.New(readline.Prompt{ Prompt: ">>> ", AltPrompt: "... ", @@ -670,10 +710,10 @@ func generateInteractive(cmd *cobra.Command, model string, wordWrap bool, format case "nohistory": scanner.HistoryDisable() case "wordwrap": - wordWrap = true + opts.WordWrap = true fmt.Println("Set 'wordwrap' mode.") case "nowordwrap": - wordWrap = false + opts.WordWrap = false fmt.Println("Set 'nowordwrap' mode.") case "verbose": cmd.Flags().Set("verbose", "true") @@ -685,12 +725,28 @@ func generateInteractive(cmd *cobra.Command, model string, wordWrap bool, format if len(args) < 3 || args[2] != "json" { fmt.Println("Invalid or missing format. For 'json' mode use '/set format json'") } else { - format = args[2] + opts.Format = args[2] fmt.Printf("Set format to '%s' mode.\n", args[2]) } case "noformat": - format = "" + opts.Format = "" fmt.Println("Disabled format.") + case "parameter": + if len(args) < 4 { + usageParameters() + continue + } + var params []string + for _, p := range args[3:] { + params = append(params, p) + } + fp, err := api.FormatParams(map[string][]string{args[2]: params}) + if err != nil { + fmt.Printf("Couldn't set parameter: %q\n\n", err) + continue + } + fmt.Printf("Set parameter '%s' to '%s'\n\n", args[2], strings.Join(params, ", ")) + opts.Options[args[2]] = fp[args[2]] default: fmt.Printf("Unknown command '/set %s'. Type /? for help\n", args[1]) } @@ -705,7 +761,7 @@ func generateInteractive(cmd *cobra.Command, model string, wordWrap bool, format fmt.Println("error: couldn't connect to ollama server") return err } - resp, err := client.Show(cmd.Context(), &api.ShowRequest{Name: model}) + resp, err := client.Show(cmd.Context(), &api.ShowRequest{Name: opts.Model}) if err != nil { fmt.Println("error: couldn't get model") return err @@ -724,6 +780,14 @@ func generateInteractive(cmd *cobra.Command, model string, wordWrap bool, format if resp.Parameters == "" { fmt.Print("No parameters were specified for this model.\n\n") } else { + if len(opts.Options) > 0 { + fmt.Println("User defined parameters:") + for k, v := range opts.Options { + fmt.Printf("%-*s %v\n", 30, k, v) + } + fmt.Println() + } + fmt.Println("Model defined parameters:") fmt.Println(resp.Parameters) } case "system": @@ -767,7 +831,8 @@ func generateInteractive(cmd *cobra.Command, model string, wordWrap bool, format } if len(prompt) > 0 && prompt[0] != '/' { - if err := generate(cmd, model, prompt, wordWrap, format); err != nil { + opts.Prompt = prompt + if err := generate(cmd, opts); err != nil { return err } diff --git a/server/images.go b/server/images.go index e6350e13..74c2107c 100644 --- a/server/images.go +++ b/server/images.go @@ -14,7 +14,6 @@ import ( "net/url" "os" "path/filepath" - "reflect" "runtime" "strconv" "strings" @@ -426,7 +425,7 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars if len(params) > 0 { fn(api.ProgressResponse{Status: "creating parameters layer"}) - formattedParams, err := formatParams(params) + formattedParams, err := api.FormatParams(params) if err != nil { return err } @@ -581,64 +580,6 @@ func GetLayerWithBufferFromLayer(layer *Layer) (*LayerReader, error) { return newLayer, nil } -// formatParams converts specified parameter options to their correct types -func formatParams(params map[string][]string) (map[string]interface{}, error) { - opts := api.Options{} - valueOpts := reflect.ValueOf(&opts).Elem() // names of the fields in the options struct - typeOpts := reflect.TypeOf(opts) // types of the fields in the options struct - - // build map of json struct tags to their types - jsonOpts := make(map[string]reflect.StructField) - for _, field := range reflect.VisibleFields(typeOpts) { - jsonTag := strings.Split(field.Tag.Get("json"), ",")[0] - if jsonTag != "" { - jsonOpts[jsonTag] = field - } - } - - out := make(map[string]interface{}) - // iterate params and set values based on json struct tags - for key, vals := range params { - if opt, ok := jsonOpts[key]; ok { - field := valueOpts.FieldByName(opt.Name) - if field.IsValid() && field.CanSet() { - switch field.Kind() { - case reflect.Float32: - floatVal, err := strconv.ParseFloat(vals[0], 32) - if err != nil { - return nil, fmt.Errorf("invalid float value %s", vals) - } - - out[key] = float32(floatVal) - case reflect.Int: - intVal, err := strconv.ParseInt(vals[0], 10, 64) - if err != nil { - return nil, fmt.Errorf("invalid int value %s", vals) - } - - out[key] = intVal - case reflect.Bool: - boolVal, err := strconv.ParseBool(vals[0]) - if err != nil { - return nil, fmt.Errorf("invalid bool value %s", vals) - } - - out[key] = boolVal - case reflect.String: - out[key] = vals[0] - case reflect.Slice: - // TODO: only string slices are supported right now - out[key] = vals - default: - return nil, fmt.Errorf("unknown type %s for %s", field.Kind(), key) - } - } - } - } - - return out, nil -} - func getLayerDigests(layers []*LayerReader) ([]string, error) { var digests []string for _, l := range layers {