Allow setting parameters in the REPL (#1294)
This commit is contained in:
parent
63097607b2
commit
cde31cb220
3 changed files with 154 additions and 87 deletions
61
api/types.go
61
api/types.go
|
@ -6,6 +6,7 @@ import (
|
||||||
"math"
|
"math"
|
||||||
"os"
|
"os"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
@ -360,3 +361,63 @@ func (d *Duration) UnmarshalJSON(b []byte) (err error) {
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// FormatParams converts specified parameter options to their correct types
|
||||||
|
func FormatParams(params map[string][]string) (map[string]interface{}, error) {
|
||||||
|
opts := Options{}
|
||||||
|
valueOpts := reflect.ValueOf(&opts).Elem() // names of the fields in the options struct
|
||||||
|
typeOpts := reflect.TypeOf(opts) // types of the fields in the options struct
|
||||||
|
|
||||||
|
// build map of json struct tags to their types
|
||||||
|
jsonOpts := make(map[string]reflect.StructField)
|
||||||
|
for _, field := range reflect.VisibleFields(typeOpts) {
|
||||||
|
jsonTag := strings.Split(field.Tag.Get("json"), ",")[0]
|
||||||
|
if jsonTag != "" {
|
||||||
|
jsonOpts[jsonTag] = field
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
out := make(map[string]interface{})
|
||||||
|
// iterate params and set values based on json struct tags
|
||||||
|
for key, vals := range params {
|
||||||
|
if opt, ok := jsonOpts[key]; !ok {
|
||||||
|
return nil, fmt.Errorf("unknown parameter '%s'", key)
|
||||||
|
} else {
|
||||||
|
field := valueOpts.FieldByName(opt.Name)
|
||||||
|
if field.IsValid() && field.CanSet() {
|
||||||
|
switch field.Kind() {
|
||||||
|
case reflect.Float32:
|
||||||
|
floatVal, err := strconv.ParseFloat(vals[0], 32)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid float value %s", vals)
|
||||||
|
}
|
||||||
|
|
||||||
|
out[key] = float32(floatVal)
|
||||||
|
case reflect.Int:
|
||||||
|
intVal, err := strconv.ParseInt(vals[0], 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid int value %s", vals)
|
||||||
|
}
|
||||||
|
|
||||||
|
out[key] = intVal
|
||||||
|
case reflect.Bool:
|
||||||
|
boolVal, err := strconv.ParseBool(vals[0])
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid bool value %s", vals)
|
||||||
|
}
|
||||||
|
|
||||||
|
out[key] = boolVal
|
||||||
|
case reflect.String:
|
||||||
|
out[key] = vals[0]
|
||||||
|
case reflect.Slice:
|
||||||
|
// TODO: only string slices are supported right now
|
||||||
|
out[key] = vals
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unknown type %s for %s", field.Kind(), key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
119
cmd/cmd.go
119
cmd/cmd.go
|
@ -412,10 +412,19 @@ func PullHandler(cmd *cobra.Command, args []string) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func RunGenerate(cmd *cobra.Command, args []string) error {
|
func RunGenerate(cmd *cobra.Command, args []string) error {
|
||||||
|
interactive := true
|
||||||
|
|
||||||
|
opts := generateOptions{
|
||||||
|
Model: args[0],
|
||||||
|
WordWrap: os.Getenv("TERM") == "xterm-256color",
|
||||||
|
Options: map[string]interface{}{},
|
||||||
|
}
|
||||||
|
|
||||||
format, err := cmd.Flags().GetString("format")
|
format, err := cmd.Flags().GetString("format")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
opts.Format = format
|
||||||
|
|
||||||
prompts := args[1:]
|
prompts := args[1:]
|
||||||
|
|
||||||
|
@ -427,34 +436,38 @@ func RunGenerate(cmd *cobra.Command, args []string) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
prompts = append([]string{string(in)}, prompts...)
|
prompts = append([]string{string(in)}, prompts...)
|
||||||
|
opts.WordWrap = false
|
||||||
|
interactive = false
|
||||||
}
|
}
|
||||||
|
opts.Prompt = strings.Join(prompts, " ")
|
||||||
// output is being piped
|
if len(prompts) > 0 {
|
||||||
if !term.IsTerminal(int(os.Stdout.Fd())) {
|
interactive = false
|
||||||
return generate(cmd, args[0], strings.Join(prompts, " "), false, format)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
wordWrap := os.Getenv("TERM") == "xterm-256color"
|
|
||||||
|
|
||||||
nowrap, err := cmd.Flags().GetBool("nowordwrap")
|
nowrap, err := cmd.Flags().GetBool("nowordwrap")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if nowrap {
|
opts.WordWrap = !nowrap
|
||||||
wordWrap = false
|
|
||||||
|
if !interactive {
|
||||||
|
return generate(cmd, opts)
|
||||||
}
|
}
|
||||||
|
|
||||||
// prompts are provided via stdin or args so don't enter interactive mode
|
return generateInteractive(cmd, opts)
|
||||||
if len(prompts) > 0 {
|
|
||||||
return generate(cmd, args[0], strings.Join(prompts, " "), wordWrap, format)
|
|
||||||
}
|
|
||||||
|
|
||||||
return generateInteractive(cmd, args[0], wordWrap, format)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type generateContextKey string
|
type generateContextKey string
|
||||||
|
|
||||||
func generate(cmd *cobra.Command, model, prompt string, wordWrap bool, format string) error {
|
type generateOptions struct {
|
||||||
|
Model string
|
||||||
|
Prompt string
|
||||||
|
WordWrap bool
|
||||||
|
Format string
|
||||||
|
Options map[string]interface{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func generate(cmd *cobra.Command, opts generateOptions) error {
|
||||||
client, err := api.ClientFromEnvironment()
|
client, err := api.ClientFromEnvironment()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -475,7 +488,7 @@ func generate(cmd *cobra.Command, model, prompt string, wordWrap bool, format st
|
||||||
|
|
||||||
termWidth, _, err := term.GetSize(int(os.Stdout.Fd()))
|
termWidth, _, err := term.GetSize(int(os.Stdout.Fd()))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
wordWrap = false
|
opts.WordWrap = false
|
||||||
}
|
}
|
||||||
|
|
||||||
cancelCtx, cancel := context.WithCancel(context.Background())
|
cancelCtx, cancel := context.WithCancel(context.Background())
|
||||||
|
@ -494,13 +507,19 @@ func generate(cmd *cobra.Command, model, prompt string, wordWrap bool, format st
|
||||||
var currentLineLength int
|
var currentLineLength int
|
||||||
var wordBuffer string
|
var wordBuffer string
|
||||||
|
|
||||||
request := api.GenerateRequest{Model: model, Prompt: prompt, Context: generateContext, Format: format}
|
request := api.GenerateRequest{
|
||||||
|
Model: opts.Model,
|
||||||
|
Prompt: opts.Prompt,
|
||||||
|
Context: generateContext,
|
||||||
|
Format: opts.Format,
|
||||||
|
Options: opts.Options,
|
||||||
|
}
|
||||||
fn := func(response api.GenerateResponse) error {
|
fn := func(response api.GenerateResponse) error {
|
||||||
p.StopAndClear()
|
p.StopAndClear()
|
||||||
|
|
||||||
latest = response
|
latest = response
|
||||||
|
|
||||||
if wordWrap {
|
if opts.WordWrap {
|
||||||
for _, ch := range response.Response {
|
for _, ch := range response.Response {
|
||||||
if currentLineLength+1 > termWidth-5 {
|
if currentLineLength+1 > termWidth-5 {
|
||||||
// backtrack the length of the last word and clear to the end of the line
|
// backtrack the length of the last word and clear to the end of the line
|
||||||
|
@ -534,7 +553,7 @@ func generate(cmd *cobra.Command, model, prompt string, wordWrap bool, format st
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if prompt != "" {
|
if opts.Prompt != "" {
|
||||||
fmt.Println()
|
fmt.Println()
|
||||||
fmt.Println()
|
fmt.Println()
|
||||||
}
|
}
|
||||||
|
@ -562,9 +581,13 @@ func generate(cmd *cobra.Command, model, prompt string, wordWrap bool, format st
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func generateInteractive(cmd *cobra.Command, model string, wordWrap bool, format string) error {
|
func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
|
||||||
// load the model
|
// load the model
|
||||||
if err := generate(cmd, model, "", false, ""); err != nil {
|
loadOpts := generateOptions{
|
||||||
|
Model: opts.Model,
|
||||||
|
Prompt: "",
|
||||||
|
}
|
||||||
|
if err := generate(cmd, loadOpts); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -581,6 +604,7 @@ func generateInteractive(cmd *cobra.Command, model string, wordWrap bool, format
|
||||||
|
|
||||||
usageSet := func() {
|
usageSet := func() {
|
||||||
fmt.Fprintln(os.Stderr, "Available Commands:")
|
fmt.Fprintln(os.Stderr, "Available Commands:")
|
||||||
|
fmt.Fprintln(os.Stderr, " /set parameter Set a parameter")
|
||||||
fmt.Fprintln(os.Stderr, " /set history Enable history")
|
fmt.Fprintln(os.Stderr, " /set history Enable history")
|
||||||
fmt.Fprintln(os.Stderr, " /set nohistory Disable history")
|
fmt.Fprintln(os.Stderr, " /set nohistory Disable history")
|
||||||
fmt.Fprintln(os.Stderr, " /set wordwrap Enable wordwrap")
|
fmt.Fprintln(os.Stderr, " /set wordwrap Enable wordwrap")
|
||||||
|
@ -602,6 +626,22 @@ func generateInteractive(cmd *cobra.Command, model string, wordWrap bool, format
|
||||||
fmt.Fprintln(os.Stderr, "")
|
fmt.Fprintln(os.Stderr, "")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// only list out the most common parameters
|
||||||
|
usageParameters := func() {
|
||||||
|
fmt.Fprintln(os.Stderr, "Available Parameters:")
|
||||||
|
fmt.Fprintln(os.Stderr, " /set parameter seed <int> Random number seed")
|
||||||
|
fmt.Fprintln(os.Stderr, " /set parameter num_predict <int> Max number of tokens to predict")
|
||||||
|
fmt.Fprintln(os.Stderr, " /set parameter top_k <int> Pick from top k num of tokens")
|
||||||
|
fmt.Fprintln(os.Stderr, " /set parameter top_p <float> Pick token based on sum of probabilities")
|
||||||
|
fmt.Fprintln(os.Stderr, " /set parameter num_ctx <int> Set the context size")
|
||||||
|
fmt.Fprintln(os.Stderr, " /set parameter temperature <float> Set creativity level")
|
||||||
|
fmt.Fprintln(os.Stderr, " /set parameter repeat_penalty <float> How strongly to penalize repetitions")
|
||||||
|
fmt.Fprintln(os.Stderr, " /set parameter repeat_last_n <int> Set how far back to look for repetitions")
|
||||||
|
fmt.Fprintln(os.Stderr, " /set parameter num_gpu <int> The number of layers to send to the GPU")
|
||||||
|
fmt.Fprintln(os.Stderr, " /set parameter stop \"<string>\", ... Set the stop parameters")
|
||||||
|
fmt.Fprintln(os.Stderr, "")
|
||||||
|
}
|
||||||
|
|
||||||
scanner, err := readline.New(readline.Prompt{
|
scanner, err := readline.New(readline.Prompt{
|
||||||
Prompt: ">>> ",
|
Prompt: ">>> ",
|
||||||
AltPrompt: "... ",
|
AltPrompt: "... ",
|
||||||
|
@ -670,10 +710,10 @@ func generateInteractive(cmd *cobra.Command, model string, wordWrap bool, format
|
||||||
case "nohistory":
|
case "nohistory":
|
||||||
scanner.HistoryDisable()
|
scanner.HistoryDisable()
|
||||||
case "wordwrap":
|
case "wordwrap":
|
||||||
wordWrap = true
|
opts.WordWrap = true
|
||||||
fmt.Println("Set 'wordwrap' mode.")
|
fmt.Println("Set 'wordwrap' mode.")
|
||||||
case "nowordwrap":
|
case "nowordwrap":
|
||||||
wordWrap = false
|
opts.WordWrap = false
|
||||||
fmt.Println("Set 'nowordwrap' mode.")
|
fmt.Println("Set 'nowordwrap' mode.")
|
||||||
case "verbose":
|
case "verbose":
|
||||||
cmd.Flags().Set("verbose", "true")
|
cmd.Flags().Set("verbose", "true")
|
||||||
|
@ -685,12 +725,28 @@ func generateInteractive(cmd *cobra.Command, model string, wordWrap bool, format
|
||||||
if len(args) < 3 || args[2] != "json" {
|
if len(args) < 3 || args[2] != "json" {
|
||||||
fmt.Println("Invalid or missing format. For 'json' mode use '/set format json'")
|
fmt.Println("Invalid or missing format. For 'json' mode use '/set format json'")
|
||||||
} else {
|
} else {
|
||||||
format = args[2]
|
opts.Format = args[2]
|
||||||
fmt.Printf("Set format to '%s' mode.\n", args[2])
|
fmt.Printf("Set format to '%s' mode.\n", args[2])
|
||||||
}
|
}
|
||||||
case "noformat":
|
case "noformat":
|
||||||
format = ""
|
opts.Format = ""
|
||||||
fmt.Println("Disabled format.")
|
fmt.Println("Disabled format.")
|
||||||
|
case "parameter":
|
||||||
|
if len(args) < 4 {
|
||||||
|
usageParameters()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
var params []string
|
||||||
|
for _, p := range args[3:] {
|
||||||
|
params = append(params, p)
|
||||||
|
}
|
||||||
|
fp, err := api.FormatParams(map[string][]string{args[2]: params})
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Couldn't set parameter: %q\n\n", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
fmt.Printf("Set parameter '%s' to '%s'\n\n", args[2], strings.Join(params, ", "))
|
||||||
|
opts.Options[args[2]] = fp[args[2]]
|
||||||
default:
|
default:
|
||||||
fmt.Printf("Unknown command '/set %s'. Type /? for help\n", args[1])
|
fmt.Printf("Unknown command '/set %s'. Type /? for help\n", args[1])
|
||||||
}
|
}
|
||||||
|
@ -705,7 +761,7 @@ func generateInteractive(cmd *cobra.Command, model string, wordWrap bool, format
|
||||||
fmt.Println("error: couldn't connect to ollama server")
|
fmt.Println("error: couldn't connect to ollama server")
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
resp, err := client.Show(cmd.Context(), &api.ShowRequest{Name: model})
|
resp, err := client.Show(cmd.Context(), &api.ShowRequest{Name: opts.Model})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println("error: couldn't get model")
|
fmt.Println("error: couldn't get model")
|
||||||
return err
|
return err
|
||||||
|
@ -724,6 +780,14 @@ func generateInteractive(cmd *cobra.Command, model string, wordWrap bool, format
|
||||||
if resp.Parameters == "" {
|
if resp.Parameters == "" {
|
||||||
fmt.Print("No parameters were specified for this model.\n\n")
|
fmt.Print("No parameters were specified for this model.\n\n")
|
||||||
} else {
|
} else {
|
||||||
|
if len(opts.Options) > 0 {
|
||||||
|
fmt.Println("User defined parameters:")
|
||||||
|
for k, v := range opts.Options {
|
||||||
|
fmt.Printf("%-*s %v\n", 30, k, v)
|
||||||
|
}
|
||||||
|
fmt.Println()
|
||||||
|
}
|
||||||
|
fmt.Println("Model defined parameters:")
|
||||||
fmt.Println(resp.Parameters)
|
fmt.Println(resp.Parameters)
|
||||||
}
|
}
|
||||||
case "system":
|
case "system":
|
||||||
|
@ -767,7 +831,8 @@ func generateInteractive(cmd *cobra.Command, model string, wordWrap bool, format
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(prompt) > 0 && prompt[0] != '/' {
|
if len(prompt) > 0 && prompt[0] != '/' {
|
||||||
if err := generate(cmd, model, prompt, wordWrap, format); err != nil {
|
opts.Prompt = prompt
|
||||||
|
if err := generate(cmd, opts); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -14,7 +14,6 @@ import (
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"reflect"
|
|
||||||
"runtime"
|
"runtime"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
@ -426,7 +425,7 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
|
||||||
if len(params) > 0 {
|
if len(params) > 0 {
|
||||||
fn(api.ProgressResponse{Status: "creating parameters layer"})
|
fn(api.ProgressResponse{Status: "creating parameters layer"})
|
||||||
|
|
||||||
formattedParams, err := formatParams(params)
|
formattedParams, err := api.FormatParams(params)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -581,64 +580,6 @@ func GetLayerWithBufferFromLayer(layer *Layer) (*LayerReader, error) {
|
||||||
return newLayer, nil
|
return newLayer, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// formatParams converts specified parameter options to their correct types
|
|
||||||
func formatParams(params map[string][]string) (map[string]interface{}, error) {
|
|
||||||
opts := api.Options{}
|
|
||||||
valueOpts := reflect.ValueOf(&opts).Elem() // names of the fields in the options struct
|
|
||||||
typeOpts := reflect.TypeOf(opts) // types of the fields in the options struct
|
|
||||||
|
|
||||||
// build map of json struct tags to their types
|
|
||||||
jsonOpts := make(map[string]reflect.StructField)
|
|
||||||
for _, field := range reflect.VisibleFields(typeOpts) {
|
|
||||||
jsonTag := strings.Split(field.Tag.Get("json"), ",")[0]
|
|
||||||
if jsonTag != "" {
|
|
||||||
jsonOpts[jsonTag] = field
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
out := make(map[string]interface{})
|
|
||||||
// iterate params and set values based on json struct tags
|
|
||||||
for key, vals := range params {
|
|
||||||
if opt, ok := jsonOpts[key]; ok {
|
|
||||||
field := valueOpts.FieldByName(opt.Name)
|
|
||||||
if field.IsValid() && field.CanSet() {
|
|
||||||
switch field.Kind() {
|
|
||||||
case reflect.Float32:
|
|
||||||
floatVal, err := strconv.ParseFloat(vals[0], 32)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("invalid float value %s", vals)
|
|
||||||
}
|
|
||||||
|
|
||||||
out[key] = float32(floatVal)
|
|
||||||
case reflect.Int:
|
|
||||||
intVal, err := strconv.ParseInt(vals[0], 10, 64)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("invalid int value %s", vals)
|
|
||||||
}
|
|
||||||
|
|
||||||
out[key] = intVal
|
|
||||||
case reflect.Bool:
|
|
||||||
boolVal, err := strconv.ParseBool(vals[0])
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("invalid bool value %s", vals)
|
|
||||||
}
|
|
||||||
|
|
||||||
out[key] = boolVal
|
|
||||||
case reflect.String:
|
|
||||||
out[key] = vals[0]
|
|
||||||
case reflect.Slice:
|
|
||||||
// TODO: only string slices are supported right now
|
|
||||||
out[key] = vals
|
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("unknown type %s for %s", field.Kind(), key)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return out, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func getLayerDigests(layers []*LayerReader) ([]string, error) {
|
func getLayerDigests(layers []*LayerReader) ([]string, error) {
|
||||||
var digests []string
|
var digests []string
|
||||||
for _, l := range layers {
|
for _, l := range layers {
|
||||||
|
|
Loading…
Reference in a new issue