diff --git a/cmd/cmd.go b/cmd/cmd.go index 794e8780..76e3c7a9 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -35,8 +35,6 @@ import ( "github.com/jmorganca/ollama/version" ) -type ImageData []byte - func CreateHandler(cmd *cobra.Command, args []string) error { filename, _ := cmd.Flags().GetString("file") filename, err := filepath.Abs(filename) @@ -415,11 +413,10 @@ func PullHandler(cmd *cobra.Command, args []string) error { func RunGenerate(cmd *cobra.Command, args []string) error { interactive := true - opts := generateOptions{ + opts := runOptions{ Model: args[0], WordWrap: os.Getenv("TERM") == "xterm-256color", Options: map[string]interface{}{}, - Images: []ImageData{}, } format, err := cmd.Flags().GetString("format") @@ -460,18 +457,135 @@ func RunGenerate(cmd *cobra.Command, args []string) error { type generateContextKey string -type generateOptions struct { +type runOptions struct { Model string Prompt string + Messages []api.Message WordWrap bool Format string System string Template string - Images []ImageData + Images []api.ImageData Options map[string]interface{} } -func generate(cmd *cobra.Command, opts generateOptions) error { +type displayResponseState struct { + lineLength int + wordBuffer string +} + +func displayResponse(content string, wordWrap bool, state *displayResponseState) { + termWidth, _, _ := term.GetSize(int(os.Stdout.Fd())) + if wordWrap && termWidth >= 10 { + for _, ch := range content { + if state.lineLength+1 > termWidth-5 { + if len(state.wordBuffer) > termWidth-10 { + fmt.Printf("%s%c", state.wordBuffer, ch) + state.wordBuffer = "" + state.lineLength = 0 + continue + } + + // backtrack the length of the last word and clear to the end of the line + fmt.Printf("\x1b[%dD\x1b[K\n", len(state.wordBuffer)) + fmt.Printf("%s%c", state.wordBuffer, ch) + state.lineLength = len(state.wordBuffer) + 1 + } else { + fmt.Print(string(ch)) + state.lineLength += 1 + + switch ch { + case ' ': + state.wordBuffer = "" + case '\n': + state.lineLength = 0 + default: + state.wordBuffer += string(ch) + } + } + } + } else { + fmt.Printf("%s%s", state.wordBuffer, content) + if len(state.wordBuffer) > 0 { + state.wordBuffer = "" + } + } +} + +func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) { + client, err := api.ClientFromEnvironment() + if err != nil { + return nil, err + } + + p := progress.NewProgress(os.Stderr) + defer p.StopAndClear() + + spinner := progress.NewSpinner("") + p.Add("", spinner) + + cancelCtx, cancel := context.WithCancel(cmd.Context()) + defer cancel() + + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGINT) + + go func() { + <-sigChan + cancel() + }() + + var state *displayResponseState = &displayResponseState{} + var latest api.ChatResponse + var fullResponse strings.Builder + var role string + + fn := func(response api.ChatResponse) error { + p.StopAndClear() + + latest = response + + role = response.Message.Role + content := response.Message.Content + fullResponse.WriteString(content) + + displayResponse(content, opts.WordWrap, state) + + return nil + } + + req := &api.ChatRequest{ + Model: opts.Model, + Messages: opts.Messages, + Format: opts.Format, + Options: opts.Options, + } + + if err := client.Chat(cancelCtx, req, fn); err != nil { + if errors.Is(err, context.Canceled) { + return nil, nil + } + return nil, err + } + + if len(opts.Messages) > 0 { + fmt.Println() + fmt.Println() + } + + verbose, err := cmd.Flags().GetBool("verbose") + if err != nil { + return nil, err + } + + if verbose { + latest.Summary() + } + + return &api.Message{Role: role, Content: fullResponse.String()}, nil +} + +func generate(cmd *cobra.Command, opts runOptions) error { client, err := api.ClientFromEnvironment() if err != nil { return err @@ -490,11 +604,6 @@ func generate(cmd *cobra.Command, opts generateOptions) error { generateContext = []int{} } - termWidth, _, err := term.GetSize(int(os.Stdout.Fd())) - if err != nil { - opts.WordWrap = false - } - ctx, cancel := context.WithCancel(cmd.Context()) defer cancel() @@ -506,57 +615,19 @@ func generate(cmd *cobra.Command, opts generateOptions) error { cancel() }() - var currentLineLength int - var wordBuffer string + var state *displayResponseState = &displayResponseState{} fn := func(response api.GenerateResponse) error { p.StopAndClear() latest = response + content := response.Response - termWidth, _, _ = term.GetSize(int(os.Stdout.Fd())) - if opts.WordWrap && termWidth >= 10 { - for _, ch := range response.Response { - if currentLineLength+1 > termWidth-5 { - if len(wordBuffer) > termWidth-10 { - fmt.Printf("%s%c", wordBuffer, ch) - wordBuffer = "" - currentLineLength = 0 - continue - } - - // backtrack the length of the last word and clear to the end of the line - fmt.Printf("\x1b[%dD\x1b[K\n", len(wordBuffer)) - fmt.Printf("%s%c", wordBuffer, ch) - currentLineLength = len(wordBuffer) + 1 - } else { - fmt.Print(string(ch)) - currentLineLength += 1 - - switch ch { - case ' ': - wordBuffer = "" - case '\n': - currentLineLength = 0 - default: - wordBuffer += string(ch) - } - } - } - } else { - fmt.Printf("%s%s", wordBuffer, response.Response) - if len(wordBuffer) > 0 { - wordBuffer = "" - } - } + displayResponse(content, opts.WordWrap, state) return nil } - images := make([]api.ImageData, 0) - for _, i := range opts.Images { - images = append(images, api.ImageData(i)) - } request := api.GenerateRequest{ Model: opts.Model, Prompt: opts.Prompt, @@ -565,7 +636,6 @@ func generate(cmd *cobra.Command, opts generateOptions) error { System: opts.System, Template: opts.Template, Options: opts.Options, - Images: images, } if err := client.Generate(ctx, &request, fn); err != nil { diff --git a/cmd/interactive.go b/cmd/interactive.go index 40cbd60d..4ad82797 100644 --- a/cmd/interactive.go +++ b/cmd/interactive.go @@ -1,7 +1,6 @@ package cmd import ( - "context" "errors" "fmt" "io" @@ -43,16 +42,16 @@ func modelIsMultiModal(cmd *cobra.Command, name string) bool { return slices.Contains(resp.Details.Families, "clip") } -func generateInteractive(cmd *cobra.Command, opts generateOptions) error { +func generateInteractive(cmd *cobra.Command, opts runOptions) error { multiModal := modelIsMultiModal(cmd, opts.Model) // load the model - loadOpts := generateOptions{ - Model: opts.Model, - Prompt: "", - Images: []ImageData{}, + loadOpts := runOptions{ + Model: opts.Model, + Prompt: "", + Messages: []api.Message{}, } - if err := generate(cmd, loadOpts); err != nil { + if _, err := chat(cmd, loadOpts); err != nil { return err } @@ -141,6 +140,7 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error { var sb strings.Builder var multiline MultilineState + opts.Messages = make([]api.Message, 0) for { line, err := scanner.Readline() @@ -409,22 +409,26 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error { } if sb.Len() > 0 && multiline == MultilineNone { - opts.Prompt = sb.String() + newMessage := api.Message{Role: "user", Content: sb.String()} + if multiModal { - newPrompt, images, err := extractFileData(sb.String()) + msg, images, err := extractFileData(sb.String()) if err != nil { return err } - opts.Prompt = newPrompt + newMessage.Content = msg // reset the context if we find another image if len(images) > 0 { - opts.Images = images - ctx := cmd.Context() - ctx = context.WithValue(ctx, generateContextKey("context"), []int{}) - cmd.SetContext(ctx) + newMessage.Images = append(newMessage.Images, images...) + // reset the context for the new image + opts.Messages = []api.Message{} + } else { + if len(opts.Messages) > 1 { + newMessage.Images = append(newMessage.Images, opts.Messages[len(opts.Messages)-2].Images...) + } } - if len(opts.Images) == 0 { + if len(newMessage.Images) == 0 { fmt.Println("This model requires you to add a jpeg, png, or svg image.") fmt.Println() sb.Reset() @@ -432,9 +436,18 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error { } } - if err := generate(cmd, opts); err != nil { + if opts.System != "" { + opts.Messages = append(opts.Messages, api.Message{Role: "system", Content: opts.System}) + } + opts.Messages = append(opts.Messages, newMessage) + + assistant, err := chat(cmd, opts) + if err != nil { return err } + if assistant != nil { + opts.Messages = append(opts.Messages, *assistant) + } sb.Reset() } @@ -476,9 +489,9 @@ func extractFileNames(input string) []string { return re.FindAllString(input, -1) } -func extractFileData(input string) (string, []ImageData, error) { +func extractFileData(input string) (string, []api.ImageData, error) { filePaths := extractFileNames(input) - var imgs []ImageData + var imgs []api.ImageData for _, fp := range filePaths { nfp := normalizeFilePath(fp)