From 2909dce89435fdf7227233483a64558ac593e9ec Mon Sep 17 00:00:00 2001 From: Patrick Devine Date: Thu, 4 Jan 2024 15:20:26 -0800 Subject: [PATCH] split up interactive generation --- cmd/cmd.go | 500 ------------------------------------------- cmd/interactive.go | 515 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 515 insertions(+), 500 deletions(-) create mode 100644 cmd/interactive.go diff --git a/cmd/cmd.go b/cmd/cmd.go index 1bcb3a23..179ba9d1 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -17,7 +17,6 @@ import ( "os/exec" "os/signal" "path/filepath" - "regexp" "runtime" "strings" "syscall" @@ -26,14 +25,12 @@ import ( "github.com/olekukonko/tablewriter" "github.com/spf13/cobra" "golang.org/x/crypto/ssh" - "golang.org/x/exp/slices" "golang.org/x/term" "github.com/jmorganca/ollama/api" "github.com/jmorganca/ollama/format" "github.com/jmorganca/ollama/parser" "github.com/jmorganca/ollama/progress" - "github.com/jmorganca/ollama/readline" "github.com/jmorganca/ollama/server" "github.com/jmorganca/ollama/version" ) @@ -621,459 +618,6 @@ func generate(cmd *cobra.Command, opts generateOptions) error { return nil } -type MultilineState int - -const ( - MultilineNone MultilineState = iota - MultilinePrompt - MultilineSystem - MultilineTemplate -) - -func modelIsMultiModal(cmd *cobra.Command, name string) bool { - // get model details - client, err := api.ClientFromEnvironment() - if err != nil { - fmt.Println("error: couldn't connect to ollama server") - return false - } - - req := api.ShowRequest{Name: name} - resp, err := client.Show(cmd.Context(), &req) - if err != nil { - return false - } - - return slices.Contains(resp.Details.Families, "clip") -} - -func generateInteractive(cmd *cobra.Command, opts generateOptions) error { - multiModal := modelIsMultiModal(cmd, opts.Model) - - // load the model - loadOpts := generateOptions{ - Model: opts.Model, - Prompt: "", - Images: []ImageData{}, - } - if err := generate(cmd, loadOpts); err != nil { - return err - } - - usage := func() { - fmt.Fprintln(os.Stderr, "Available Commands:") - fmt.Fprintln(os.Stderr, " /set Set session variables") - fmt.Fprintln(os.Stderr, " /show Show model information") - fmt.Fprintln(os.Stderr, " /bye Exit") - fmt.Fprintln(os.Stderr, " /?, /help Help for a command") - fmt.Fprintln(os.Stderr, " /? shortcuts Help for keyboard shortcuts") - fmt.Fprintln(os.Stderr, "") - fmt.Fprintln(os.Stderr, "Use \"\"\" to begin a multi-line message.") - fmt.Fprintln(os.Stderr, "") - } - - usageSet := func() { - fmt.Fprintln(os.Stderr, "Available Commands:") - fmt.Fprintln(os.Stderr, " /set parameter ... Set a parameter") - fmt.Fprintln(os.Stderr, " /set system Set system message") - fmt.Fprintln(os.Stderr, " /set template Set prompt template") - fmt.Fprintln(os.Stderr, " /set history Enable history") - fmt.Fprintln(os.Stderr, " /set nohistory Disable history") - fmt.Fprintln(os.Stderr, " /set wordwrap Enable wordwrap") - fmt.Fprintln(os.Stderr, " /set nowordwrap Disable wordwrap") - fmt.Fprintln(os.Stderr, " /set format json Enable JSON mode") - fmt.Fprintln(os.Stderr, " /set noformat Disable formatting") - fmt.Fprintln(os.Stderr, " /set verbose Show LLM stats") - fmt.Fprintln(os.Stderr, " /set quiet Disable LLM stats") - fmt.Fprintln(os.Stderr, "") - } - - usageShortcuts := func() { - fmt.Fprintln(os.Stderr, "Available keyboard shortcuts:") - fmt.Fprintln(os.Stderr, " Ctrl + a Move to the beginning of the line (Home)") - fmt.Fprintln(os.Stderr, " Ctrl + e Move to the end of the line (End)") - fmt.Fprintln(os.Stderr, " Alt + b Move back (left) one word") - fmt.Fprintln(os.Stderr, " Alt + f Move forward (right) one word") - fmt.Fprintln(os.Stderr, " Ctrl + k Delete the sentence after the cursor") - fmt.Fprintln(os.Stderr, " Ctrl + u Delete the sentence before the cursor") - fmt.Fprintln(os.Stderr, "") - fmt.Fprintln(os.Stderr, " Ctrl + l Clear the screen") - fmt.Fprintln(os.Stderr, " Ctrl + c Stop the model from responding") - fmt.Fprintln(os.Stderr, " Ctrl + d Exit ollama (/bye)") - fmt.Fprintln(os.Stderr, "") - } - - usageShow := func() { - fmt.Fprintln(os.Stderr, "Available Commands:") - fmt.Fprintln(os.Stderr, " /show license Show model license") - fmt.Fprintln(os.Stderr, " /show modelfile Show Modelfile for this model") - fmt.Fprintln(os.Stderr, " /show parameters Show parameters for this model") - fmt.Fprintln(os.Stderr, " /show system Show system message") - fmt.Fprintln(os.Stderr, " /show template Show prompt template") - fmt.Fprintln(os.Stderr, "") - } - - // only list out the most common parameters - usageParameters := func() { - fmt.Fprintln(os.Stderr, "Available Parameters:") - fmt.Fprintln(os.Stderr, " /set parameter seed Random number seed") - fmt.Fprintln(os.Stderr, " /set parameter num_predict Max number of tokens to predict") - fmt.Fprintln(os.Stderr, " /set parameter top_k Pick from top k num of tokens") - fmt.Fprintln(os.Stderr, " /set parameter top_p Pick token based on sum of probabilities") - fmt.Fprintln(os.Stderr, " /set parameter num_ctx Set the context size") - fmt.Fprintln(os.Stderr, " /set parameter temperature Set creativity level") - fmt.Fprintln(os.Stderr, " /set parameter repeat_penalty How strongly to penalize repetitions") - fmt.Fprintln(os.Stderr, " /set parameter repeat_last_n Set how far back to look for repetitions") - fmt.Fprintln(os.Stderr, " /set parameter num_gpu The number of layers to send to the GPU") - fmt.Fprintln(os.Stderr, " /set parameter stop \"\", ... Set the stop parameters") - fmt.Fprintln(os.Stderr, "") - } - - scanner, err := readline.New(readline.Prompt{ - Prompt: ">>> ", - AltPrompt: "... ", - Placeholder: "Send a message (/? for help)", - AltPlaceholder: `Use """ to end multi-line input`, - }) - if err != nil { - return err - } - - fmt.Print(readline.StartBracketedPaste) - defer fmt.Printf(readline.EndBracketedPaste) - - var multiline MultilineState - var prompt string - - for { - line, err := scanner.Readline() - switch { - case errors.Is(err, io.EOF): - fmt.Println() - return nil - case errors.Is(err, readline.ErrInterrupt): - if line == "" { - fmt.Println("\nUse Ctrl + d or /bye to exit.") - } - - scanner.Prompt.UseAlt = false - prompt = "" - - continue - case err != nil: - return err - } - - switch { - case strings.HasPrefix(prompt, `"""`): - // if the prompt so far starts with """ then we're in multiline mode - // and we need to keep reading until we find a line that ends with """ - cut, found := strings.CutSuffix(line, `"""`) - prompt += cut - - if !found { - prompt += "\n" - continue - } - - prompt = strings.TrimPrefix(prompt, `"""`) - scanner.Prompt.UseAlt = false - - switch multiline { - case MultilineSystem: - opts.System = prompt - prompt = "" - fmt.Println("Set system message.") - case MultilineTemplate: - opts.Template = prompt - prompt = "" - fmt.Println("Set prompt template.") - } - multiline = MultilineNone - case strings.HasPrefix(line, `"""`) && len(prompt) == 0: - scanner.Prompt.UseAlt = true - multiline = MultilinePrompt - prompt += line + "\n" - continue - case scanner.Pasting: - prompt += line + "\n" - continue - case strings.HasPrefix(line, "/list"): - args := strings.Fields(line) - if err := ListHandler(cmd, args[1:]); err != nil { - return err - } - case strings.HasPrefix(line, "/set"): - args := strings.Fields(line) - if len(args) > 1 { - switch args[1] { - case "history": - scanner.HistoryEnable() - case "nohistory": - scanner.HistoryDisable() - case "wordwrap": - opts.WordWrap = true - fmt.Println("Set 'wordwrap' mode.") - case "nowordwrap": - opts.WordWrap = false - fmt.Println("Set 'nowordwrap' mode.") - case "verbose": - cmd.Flags().Set("verbose", "true") - fmt.Println("Set 'verbose' mode.") - case "quiet": - cmd.Flags().Set("verbose", "false") - fmt.Println("Set 'quiet' mode.") - case "format": - if len(args) < 3 || args[2] != "json" { - fmt.Println("Invalid or missing format. For 'json' mode use '/set format json'") - } else { - opts.Format = args[2] - fmt.Printf("Set format to '%s' mode.\n", args[2]) - } - case "noformat": - opts.Format = "" - fmt.Println("Disabled format.") - case "parameter": - if len(args) < 4 { - usageParameters() - continue - } - var params []string - for _, p := range args[3:] { - params = append(params, p) - } - fp, err := api.FormatParams(map[string][]string{args[2]: params}) - if err != nil { - fmt.Printf("Couldn't set parameter: %q\n\n", err) - continue - } - fmt.Printf("Set parameter '%s' to '%s'\n\n", args[2], strings.Join(params, ", ")) - opts.Options[args[2]] = fp[args[2]] - case "system", "template": - if len(args) < 3 { - usageSet() - continue - } - line := strings.Join(args[2:], " ") - line = strings.TrimPrefix(line, `"""`) - if strings.HasPrefix(args[2], `"""`) { - cut, found := strings.CutSuffix(line, `"""`) - prompt += cut - if found { - if args[1] == "system" { - opts.System = prompt - fmt.Println("Set system message.") - } else { - opts.Template = prompt - fmt.Println("Set prompt template.") - } - prompt = "" - } else { - prompt = `"""` + prompt + "\n" - if args[1] == "system" { - multiline = MultilineSystem - } else { - multiline = MultilineTemplate - } - scanner.Prompt.UseAlt = true - } - } else { - opts.System = line - fmt.Println("Set system message.") - } - default: - fmt.Printf("Unknown command '/set %s'. Type /? for help\n", args[1]) - } - } else { - usageSet() - } - case strings.HasPrefix(line, "/show"): - args := strings.Fields(line) - if len(args) > 1 { - client, err := api.ClientFromEnvironment() - if err != nil { - fmt.Println("error: couldn't connect to ollama server") - return err - } - resp, err := client.Show(cmd.Context(), &api.ShowRequest{Name: opts.Model}) - if err != nil { - fmt.Println("error: couldn't get model") - return err - } - - switch args[1] { - case "license": - if resp.License == "" { - fmt.Print("No license was specified for this model.\n\n") - } else { - fmt.Println(resp.License) - } - case "modelfile": - fmt.Println(resp.Modelfile) - case "parameters": - if resp.Parameters == "" { - fmt.Print("No parameters were specified for this model.\n\n") - } else { - if len(opts.Options) > 0 { - fmt.Println("User defined parameters:") - for k, v := range opts.Options { - fmt.Printf("%-*s %v\n", 30, k, v) - } - fmt.Println() - } - fmt.Println("Model defined parameters:") - fmt.Println(resp.Parameters) - } - case "system": - switch { - case opts.System != "": - fmt.Println(opts.System + "\n") - case resp.System != "": - fmt.Println(resp.System + "\n") - default: - fmt.Print("No system message was specified for this model.\n\n") - } - case "template": - switch { - case opts.Template != "": - fmt.Println(opts.Template + "\n") - case resp.Template != "": - fmt.Println(resp.Template) - default: - fmt.Print("No prompt template was specified for this model.\n\n") - } - default: - fmt.Printf("Unknown command '/show %s'. Type /? for help\n", args[1]) - } - } else { - usageShow() - } - case strings.HasPrefix(line, "/help"), strings.HasPrefix(line, "/?"): - args := strings.Fields(line) - if len(args) > 1 { - switch args[1] { - case "set", "/set": - usageSet() - case "show", "/show": - usageShow() - case "shortcut", "shortcuts": - usageShortcuts() - } - } else { - usage() - } - case line == "/exit", line == "/bye": - return nil - case strings.HasPrefix(line, "/"): - args := strings.Fields(line) - isFile := false - - if multiModal { - for _, f := range extractFileNames(line) { - if strings.HasPrefix(f, args[0]) { - isFile = true - break - } - } - } - - if isFile { - prompt += line - } else { - fmt.Printf("Unknown command '%s'. Type /? for help\n", args[0]) - continue - } - default: - prompt += line - } - - if len(prompt) > 0 && multiline == MultilineNone { - opts.Prompt = prompt - if multiModal { - newPrompt, images, err := extractFileData(prompt) - if err != nil { - return err - } - opts.Prompt = newPrompt - - // reset the context if we find another image - if len(images) > 0 { - opts.Images = images - ctx := cmd.Context() - ctx = context.WithValue(ctx, generateContextKey("context"), []int{}) - cmd.SetContext(ctx) - } - if len(opts.Images) == 0 { - fmt.Println("This model requires you to add a jpeg, png, or svg image.") - fmt.Println() - prompt = "" - continue - } - } - if err := generate(cmd, opts); err != nil { - return err - } - - prompt = "" - } - } -} - -func normalizeFilePath(fp string) string { - // Define a map of escaped characters and their replacements - replacements := map[string]string{ - "\\ ": " ", // Escaped space - "\\(": "(", // Escaped left parenthesis - "\\)": ")", // Escaped right parenthesis - "\\[": "[", // Escaped left square bracket - "\\]": "]", // Escaped right square bracket - "\\{": "{", // Escaped left curly brace - "\\}": "}", // Escaped right curly brace - "\\$": "$", // Escaped dollar sign - "\\&": "&", // Escaped ampersand - "\\;": ";", // Escaped semicolon - "\\'": "'", // Escaped single quote - "\\\\": "\\", // Escaped backslash - "\\*": "*", // Escaped asterisk - "\\?": "?", // Escaped question mark - } - - for escaped, actual := range replacements { - fp = strings.ReplaceAll(fp, escaped, actual) - } - return fp -} - -func extractFileNames(input string) []string { - // Regex to match file paths starting with / or ./ and include escaped spaces (\ or %20) - // and followed by more characters and a file extension - regexPattern := `(?:\./|/)[\S\\ ]+?\.(?i:jpg|jpeg|png|svg)\b` - re := regexp.MustCompile(regexPattern) - - return re.FindAllString(input, -1) -} - -func extractFileData(input string) (string, []ImageData, error) { - filePaths := extractFileNames(input) - var imgs []ImageData - - for _, fp := range filePaths { - nfp := normalizeFilePath(fp) - data, err := getImageData(nfp) - if err != nil { - if os.IsNotExist(err) { - continue - } - fmt.Printf("Couldn't process image: %q\n", err) - return "", imgs, err - } - fmt.Printf("Added image '%s'\n", nfp) - input = strings.ReplaceAll(input, fp, "") - imgs = append(imgs, data) - } - return input, imgs, nil -} - func RunServer(cmd *cobra.Command, _ []string) error { host, port, err := net.SplitHostPort(os.Getenv("OLLAMA_HOST")) if err != nil { @@ -1095,50 +639,6 @@ func RunServer(cmd *cobra.Command, _ []string) error { return server.Serve(ln) } -func getImageData(filePath string) ([]byte, error) { - file, err := os.Open(filePath) - if err != nil { - return nil, err - } - defer file.Close() - - buf := make([]byte, 512) - _, err = file.Read(buf) - if err != nil { - return nil, err - } - - contentType := http.DetectContentType(buf) - allowedTypes := []string{"image/jpeg", "image/jpg", "image/svg+xml", "image/png"} - if !slices.Contains(allowedTypes, contentType) { - return nil, fmt.Errorf("invalid image type: %s", contentType) - } - - info, err := file.Stat() - if err != nil { - return nil, err - } - - // Check if the file size exceeds 100MB - var maxSize int64 = 100 * 1024 * 1024 // 100MB in bytes - if info.Size() > maxSize { - return nil, fmt.Errorf("file size exceeds maximum limit (100MB)") - } - - buf = make([]byte, info.Size()) - _, err = file.Seek(0, 0) - if err != nil { - return nil, err - } - - _, err = io.ReadFull(file, buf) - if err != nil { - return nil, err - } - - return buf, nil -} - func initializeKeypair() error { home, err := os.UserHomeDir() if err != nil { diff --git a/cmd/interactive.go b/cmd/interactive.go new file mode 100644 index 00000000..62e7c7e5 --- /dev/null +++ b/cmd/interactive.go @@ -0,0 +1,515 @@ +package cmd + +import ( + "context" + "errors" + "fmt" + "io" + "net/http" + "os" + "regexp" + "strings" + + "github.com/spf13/cobra" + "golang.org/x/exp/slices" + + "github.com/jmorganca/ollama/api" + "github.com/jmorganca/ollama/readline" +) + +type MultilineState int + +const ( + MultilineNone MultilineState = iota + MultilinePrompt + MultilineSystem + MultilineTemplate +) + +func modelIsMultiModal(cmd *cobra.Command, name string) bool { + // get model details + client, err := api.ClientFromEnvironment() + if err != nil { + fmt.Println("error: couldn't connect to ollama server") + return false + } + + req := api.ShowRequest{Name: name} + resp, err := client.Show(cmd.Context(), &req) + if err != nil { + return false + } + + return slices.Contains(resp.Details.Families, "clip") +} + +func generateInteractive(cmd *cobra.Command, opts generateOptions) error { + multiModal := modelIsMultiModal(cmd, opts.Model) + + // load the model + loadOpts := generateOptions{ + Model: opts.Model, + Prompt: "", + Images: []ImageData{}, + } + if err := generate(cmd, loadOpts); err != nil { + return err + } + + usage := func() { + fmt.Fprintln(os.Stderr, "Available Commands:") + fmt.Fprintln(os.Stderr, " /set Set session variables") + fmt.Fprintln(os.Stderr, " /show Show model information") + fmt.Fprintln(os.Stderr, " /bye Exit") + fmt.Fprintln(os.Stderr, " /?, /help Help for a command") + fmt.Fprintln(os.Stderr, " /? shortcuts Help for keyboard shortcuts") + fmt.Fprintln(os.Stderr, "") + fmt.Fprintln(os.Stderr, "Use \"\"\" to begin a multi-line message.") + fmt.Fprintln(os.Stderr, "") + } + + usageSet := func() { + fmt.Fprintln(os.Stderr, "Available Commands:") + fmt.Fprintln(os.Stderr, " /set parameter ... Set a parameter") + fmt.Fprintln(os.Stderr, " /set system Set system message") + fmt.Fprintln(os.Stderr, " /set template Set prompt template") + fmt.Fprintln(os.Stderr, " /set history Enable history") + fmt.Fprintln(os.Stderr, " /set nohistory Disable history") + fmt.Fprintln(os.Stderr, " /set wordwrap Enable wordwrap") + fmt.Fprintln(os.Stderr, " /set nowordwrap Disable wordwrap") + fmt.Fprintln(os.Stderr, " /set format json Enable JSON mode") + fmt.Fprintln(os.Stderr, " /set noformat Disable formatting") + fmt.Fprintln(os.Stderr, " /set verbose Show LLM stats") + fmt.Fprintln(os.Stderr, " /set quiet Disable LLM stats") + fmt.Fprintln(os.Stderr, "") + } + + usageShortcuts := func() { + fmt.Fprintln(os.Stderr, "Available keyboard shortcuts:") + fmt.Fprintln(os.Stderr, " Ctrl + a Move to the beginning of the line (Home)") + fmt.Fprintln(os.Stderr, " Ctrl + e Move to the end of the line (End)") + fmt.Fprintln(os.Stderr, " Alt + b Move back (left) one word") + fmt.Fprintln(os.Stderr, " Alt + f Move forward (right) one word") + fmt.Fprintln(os.Stderr, " Ctrl + k Delete the sentence after the cursor") + fmt.Fprintln(os.Stderr, " Ctrl + u Delete the sentence before the cursor") + fmt.Fprintln(os.Stderr, "") + fmt.Fprintln(os.Stderr, " Ctrl + l Clear the screen") + fmt.Fprintln(os.Stderr, " Ctrl + c Stop the model from responding") + fmt.Fprintln(os.Stderr, " Ctrl + d Exit ollama (/bye)") + fmt.Fprintln(os.Stderr, "") + } + + usageShow := func() { + fmt.Fprintln(os.Stderr, "Available Commands:") + fmt.Fprintln(os.Stderr, " /show license Show model license") + fmt.Fprintln(os.Stderr, " /show modelfile Show Modelfile for this model") + fmt.Fprintln(os.Stderr, " /show parameters Show parameters for this model") + fmt.Fprintln(os.Stderr, " /show system Show system message") + fmt.Fprintln(os.Stderr, " /show template Show prompt template") + fmt.Fprintln(os.Stderr, "") + } + + // only list out the most common parameters + usageParameters := func() { + fmt.Fprintln(os.Stderr, "Available Parameters:") + fmt.Fprintln(os.Stderr, " /set parameter seed Random number seed") + fmt.Fprintln(os.Stderr, " /set parameter num_predict Max number of tokens to predict") + fmt.Fprintln(os.Stderr, " /set parameter top_k Pick from top k num of tokens") + fmt.Fprintln(os.Stderr, " /set parameter top_p Pick token based on sum of probabilities") + fmt.Fprintln(os.Stderr, " /set parameter num_ctx Set the context size") + fmt.Fprintln(os.Stderr, " /set parameter temperature Set creativity level") + fmt.Fprintln(os.Stderr, " /set parameter repeat_penalty How strongly to penalize repetitions") + fmt.Fprintln(os.Stderr, " /set parameter repeat_last_n Set how far back to look for repetitions") + fmt.Fprintln(os.Stderr, " /set parameter num_gpu The number of layers to send to the GPU") + fmt.Fprintln(os.Stderr, " /set parameter stop \"\", ... Set the stop parameters") + fmt.Fprintln(os.Stderr, "") + } + + scanner, err := readline.New(readline.Prompt{ + Prompt: ">>> ", + AltPrompt: "... ", + Placeholder: "Send a message (/? for help)", + AltPlaceholder: `Use """ to end multi-line input`, + }) + if err != nil { + return err + } + + fmt.Print(readline.StartBracketedPaste) + defer fmt.Printf(readline.EndBracketedPaste) + + var multiline MultilineState + var prompt string + + for { + line, err := scanner.Readline() + switch { + case errors.Is(err, io.EOF): + fmt.Println() + return nil + case errors.Is(err, readline.ErrInterrupt): + if line == "" { + fmt.Println("\nUse Ctrl + d or /bye to exit.") + } + + scanner.Prompt.UseAlt = false + prompt = "" + + continue + case err != nil: + return err + } + + switch { + case strings.HasPrefix(prompt, `"""`): + // if the prompt so far starts with """ then we're in multiline mode + // and we need to keep reading until we find a line that ends with """ + cut, found := strings.CutSuffix(line, `"""`) + prompt += cut + + if !found { + prompt += "\n" + continue + } + + prompt = strings.TrimPrefix(prompt, `"""`) + scanner.Prompt.UseAlt = false + + switch multiline { + case MultilineSystem: + opts.System = prompt + prompt = "" + fmt.Println("Set system message.") + case MultilineTemplate: + opts.Template = prompt + prompt = "" + fmt.Println("Set prompt template.") + } + multiline = MultilineNone + case strings.HasPrefix(line, `"""`) && len(prompt) == 0: + scanner.Prompt.UseAlt = true + multiline = MultilinePrompt + prompt += line + "\n" + continue + case scanner.Pasting: + prompt += line + "\n" + continue + case strings.HasPrefix(line, "/list"): + args := strings.Fields(line) + if err := ListHandler(cmd, args[1:]); err != nil { + return err + } + case strings.HasPrefix(line, "/set"): + args := strings.Fields(line) + if len(args) > 1 { + switch args[1] { + case "history": + scanner.HistoryEnable() + case "nohistory": + scanner.HistoryDisable() + case "wordwrap": + opts.WordWrap = true + fmt.Println("Set 'wordwrap' mode.") + case "nowordwrap": + opts.WordWrap = false + fmt.Println("Set 'nowordwrap' mode.") + case "verbose": + cmd.Flags().Set("verbose", "true") + fmt.Println("Set 'verbose' mode.") + case "quiet": + cmd.Flags().Set("verbose", "false") + fmt.Println("Set 'quiet' mode.") + case "format": + if len(args) < 3 || args[2] != "json" { + fmt.Println("Invalid or missing format. For 'json' mode use '/set format json'") + } else { + opts.Format = args[2] + fmt.Printf("Set format to '%s' mode.\n", args[2]) + } + case "noformat": + opts.Format = "" + fmt.Println("Disabled format.") + case "parameter": + if len(args) < 4 { + usageParameters() + continue + } + var params []string + for _, p := range args[3:] { + params = append(params, p) + } + fp, err := api.FormatParams(map[string][]string{args[2]: params}) + if err != nil { + fmt.Printf("Couldn't set parameter: %q\n\n", err) + continue + } + fmt.Printf("Set parameter '%s' to '%s'\n\n", args[2], strings.Join(params, ", ")) + opts.Options[args[2]] = fp[args[2]] + case "system", "template": + if len(args) < 3 { + usageSet() + continue + } + line := strings.Join(args[2:], " ") + line = strings.TrimPrefix(line, `"""`) + if strings.HasPrefix(args[2], `"""`) { + cut, found := strings.CutSuffix(line, `"""`) + prompt += cut + if found { + if args[1] == "system" { + opts.System = prompt + fmt.Println("Set system message.") + } else { + opts.Template = prompt + fmt.Println("Set prompt template.") + } + prompt = "" + } else { + prompt = `"""` + prompt + "\n" + if args[1] == "system" { + multiline = MultilineSystem + } else { + multiline = MultilineTemplate + } + scanner.Prompt.UseAlt = true + } + } else { + opts.System = line + fmt.Println("Set system message.") + } + default: + fmt.Printf("Unknown command '/set %s'. Type /? for help\n", args[1]) + } + } else { + usageSet() + } + case strings.HasPrefix(line, "/show"): + args := strings.Fields(line) + if len(args) > 1 { + client, err := api.ClientFromEnvironment() + if err != nil { + fmt.Println("error: couldn't connect to ollama server") + return err + } + resp, err := client.Show(cmd.Context(), &api.ShowRequest{Name: opts.Model}) + if err != nil { + fmt.Println("error: couldn't get model") + return err + } + + switch args[1] { + case "license": + if resp.License == "" { + fmt.Print("No license was specified for this model.\n\n") + } else { + fmt.Println(resp.License) + } + case "modelfile": + fmt.Println(resp.Modelfile) + case "parameters": + if resp.Parameters == "" { + fmt.Print("No parameters were specified for this model.\n\n") + } else { + if len(opts.Options) > 0 { + fmt.Println("User defined parameters:") + for k, v := range opts.Options { + fmt.Printf("%-*s %v\n", 30, k, v) + } + fmt.Println() + } + fmt.Println("Model defined parameters:") + fmt.Println(resp.Parameters) + } + case "system": + switch { + case opts.System != "": + fmt.Println(opts.System + "\n") + case resp.System != "": + fmt.Println(resp.System + "\n") + default: + fmt.Print("No system message was specified for this model.\n\n") + } + case "template": + switch { + case opts.Template != "": + fmt.Println(opts.Template + "\n") + case resp.Template != "": + fmt.Println(resp.Template) + default: + fmt.Print("No prompt template was specified for this model.\n\n") + } + default: + fmt.Printf("Unknown command '/show %s'. Type /? for help\n", args[1]) + } + } else { + usageShow() + } + case strings.HasPrefix(line, "/help"), strings.HasPrefix(line, "/?"): + args := strings.Fields(line) + if len(args) > 1 { + switch args[1] { + case "set", "/set": + usageSet() + case "show", "/show": + usageShow() + case "shortcut", "shortcuts": + usageShortcuts() + } + } else { + usage() + } + case line == "/exit", line == "/bye": + return nil + case strings.HasPrefix(line, "/"): + args := strings.Fields(line) + isFile := false + + if multiModal { + for _, f := range extractFileNames(line) { + if strings.HasPrefix(f, args[0]) { + isFile = true + break + } + } + } + + if isFile { + prompt += line + } else { + fmt.Printf("Unknown command '%s'. Type /? for help\n", args[0]) + continue + } + default: + prompt += line + } + + if len(prompt) > 0 && multiline == MultilineNone { + opts.Prompt = prompt + if multiModal { + newPrompt, images, err := extractFileData(prompt) + if err != nil { + return err + } + opts.Prompt = newPrompt + + // reset the context if we find another image + if len(images) > 0 { + opts.Images = images + ctx := cmd.Context() + ctx = context.WithValue(ctx, generateContextKey("context"), []int{}) + cmd.SetContext(ctx) + } + if len(opts.Images) == 0 { + fmt.Println("This model requires you to add a jpeg, png, or svg image.") + fmt.Println() + prompt = "" + continue + } + } + if err := generate(cmd, opts); err != nil { + return err + } + + prompt = "" + } + } +} + +func normalizeFilePath(fp string) string { + // Define a map of escaped characters and their replacements + replacements := map[string]string{ + "\\ ": " ", // Escaped space + "\\(": "(", // Escaped left parenthesis + "\\)": ")", // Escaped right parenthesis + "\\[": "[", // Escaped left square bracket + "\\]": "]", // Escaped right square bracket + "\\{": "{", // Escaped left curly brace + "\\}": "}", // Escaped right curly brace + "\\$": "$", // Escaped dollar sign + "\\&": "&", // Escaped ampersand + "\\;": ";", // Escaped semicolon + "\\'": "'", // Escaped single quote + "\\\\": "\\", // Escaped backslash + "\\*": "*", // Escaped asterisk + "\\?": "?", // Escaped question mark + } + + for escaped, actual := range replacements { + fp = strings.ReplaceAll(fp, escaped, actual) + } + return fp +} + +func extractFileNames(input string) []string { + // Regex to match file paths starting with / or ./ and include escaped spaces (\ or %20) + // and followed by more characters and a file extension + regexPattern := `(?:\./|/)[\S\\ ]+?\.(?i:jpg|jpeg|png|svg)\b` + re := regexp.MustCompile(regexPattern) + + return re.FindAllString(input, -1) +} + +func extractFileData(input string) (string, []ImageData, error) { + filePaths := extractFileNames(input) + var imgs []ImageData + + for _, fp := range filePaths { + nfp := normalizeFilePath(fp) + data, err := getImageData(nfp) + if err != nil { + if os.IsNotExist(err) { + continue + } + fmt.Printf("Couldn't process image: %q\n", err) + return "", imgs, err + } + fmt.Printf("Added image '%s'\n", nfp) + input = strings.ReplaceAll(input, fp, "") + imgs = append(imgs, data) + } + return input, imgs, nil +} + +func getImageData(filePath string) ([]byte, error) { + file, err := os.Open(filePath) + if err != nil { + return nil, err + } + defer file.Close() + + buf := make([]byte, 512) + _, err = file.Read(buf) + if err != nil { + return nil, err + } + + contentType := http.DetectContentType(buf) + allowedTypes := []string{"image/jpeg", "image/jpg", "image/svg+xml", "image/png"} + if !slices.Contains(allowedTypes, contentType) { + return nil, fmt.Errorf("invalid image type: %s", contentType) + } + + info, err := file.Stat() + if err != nil { + return nil, err + } + + // Check if the file size exceeds 100MB + var maxSize int64 = 100 * 1024 * 1024 // 100MB in bytes + if info.Size() > maxSize { + return nil, fmt.Errorf("file size exceeds maximum limit (100MB)") + } + + buf = make([]byte, info.Size()) + _, err = file.Seek(0, 0) + if err != nil { + return nil, err + } + + _, err = io.ReadFull(file, buf) + if err != nil { + return nil, err + } + + return buf, nil +}