From bf704423c5854a9d0875afb0d80af5bb484177d3 Mon Sep 17 00:00:00 2001 From: Patrick Devine Date: Mon, 4 Dec 2023 16:35:29 -0800 Subject: [PATCH] revert cli to use /api/generate (#1383) --- cmd/cmd.go | 236 +++++++++++++++++++++++++++-------------------------- 1 file changed, 120 insertions(+), 116 deletions(-) diff --git a/cmd/cmd.go b/cmd/cmd.go index d3c3b777..df0d90c8 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -159,54 +159,7 @@ func RunHandler(cmd *cobra.Command, args []string) error { return err } - interactive := true - - opts := runOptions{ - Model: name, - WordWrap: os.Getenv("TERM") == "xterm-256color", - Options: map[string]interface{}{}, - } - - format, err := cmd.Flags().GetString("format") - if err != nil { - return err - } - opts.Format = format - - prompts := args[1:] - - // prepend stdin to the prompt if provided - if !term.IsTerminal(int(os.Stdin.Fd())) { - in, err := io.ReadAll(os.Stdin) - if err != nil { - return err - } - - prompts = append([]string{string(in)}, prompts...) - opts.WordWrap = false - interactive = false - } - msg := api.Message{ - Role: "user", - Content: strings.Join(prompts, " "), - } - opts.Messages = append(opts.Messages, msg) - if len(prompts) > 0 { - interactive = false - } - - nowrap, err := cmd.Flags().GetBool("nowordwrap") - if err != nil { - return err - } - opts.WordWrap = !nowrap - - if !interactive { - _, err := chat(cmd, opts) - return err - } - - return chatInteractive(cmd, opts) + return RunGenerate(cmd, args) } func PushHandler(cmd *cobra.Command, args []string) error { @@ -458,26 +411,83 @@ func PullHandler(cmd *cobra.Command, args []string) error { return nil } -type runOptions struct { +func RunGenerate(cmd *cobra.Command, args []string) error { + interactive := true + + opts := generateOptions{ + Model: args[0], + WordWrap: os.Getenv("TERM") == "xterm-256color", + Options: map[string]interface{}{}, + } + + format, err := cmd.Flags().GetString("format") + if err != nil { + return err + } + opts.Format = format + + prompts := args[1:] + + // prepend stdin to the prompt if provided + if !term.IsTerminal(int(os.Stdin.Fd())) { + in, err := io.ReadAll(os.Stdin) + if err != nil { + return err + } + + prompts = append([]string{string(in)}, prompts...) + opts.WordWrap = false + interactive = false + } + opts.Prompt = strings.Join(prompts, " ") + if len(prompts) > 0 { + interactive = false + } + + nowrap, err := cmd.Flags().GetBool("nowordwrap") + if err != nil { + return err + } + opts.WordWrap = !nowrap + + if !interactive { + return generate(cmd, opts) + } + + return generateInteractive(cmd, opts) +} + +type generateContextKey string + +type generateOptions struct { Model string - Messages []api.Message + Prompt string WordWrap bool Format string + System string Template string Options map[string]interface{} } -func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) { +func generate(cmd *cobra.Command, opts generateOptions) error { client, err := api.ClientFromEnvironment() if err != nil { - return nil, err + return err } p := progress.NewProgress(os.Stderr) defer p.StopAndClear() + spinner := progress.NewSpinner("") p.Add("", spinner) + var latest api.GenerateResponse + + generateContext, ok := cmd.Context().Value(generateContextKey("context")).([]int) + if !ok { + generateContext = []int{} + } + termWidth, _, err := term.GetSize(int(os.Stdout.Fd())) if err != nil { opts.WordWrap = false @@ -496,24 +506,24 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) { var currentLineLength int var wordBuffer string - var latest api.ChatResponse - var fullResponse strings.Builder - var role string - fn := func(response api.ChatResponse) error { + request := api.GenerateRequest{ + Model: opts.Model, + Prompt: opts.Prompt, + Context: generateContext, + Format: opts.Format, + System: opts.System, + Template: opts.Template, + Options: opts.Options, + } + fn := func(response api.GenerateResponse) error { p.StopAndClear() + latest = response - if response.Message == nil { - // warm-up response or done - return nil - } - role = response.Message.Role - content := response.Message.Content - fullResponse.WriteString(content) termWidth, _, _ = term.GetSize(int(os.Stdout.Fd())) if opts.WordWrap && termWidth >= 10 { - for _, ch := range content { + for _, ch := range response.Response { if currentLineLength+1 > termWidth-5 { if len(wordBuffer) > termWidth-10 { fmt.Printf("%s%c", wordBuffer, ch) @@ -541,7 +551,7 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) { } } } else { - fmt.Printf("%s%s", wordBuffer, content) + fmt.Printf("%s%s", wordBuffer, response.Response) if len(wordBuffer) > 0 { wordBuffer = "" } @@ -550,35 +560,35 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) { return nil } - req := &api.ChatRequest{ - Model: opts.Model, - Messages: opts.Messages, - Format: opts.Format, - Template: opts.Template, - Options: opts.Options, - } - if err := client.Chat(cancelCtx, req, fn); err != nil { + if err := client.Generate(cancelCtx, &request, fn); err != nil { if errors.Is(err, context.Canceled) { - return nil, nil + return nil } - return nil, err + return err + } + if opts.Prompt != "" { + fmt.Println() + fmt.Println() } - if len(opts.Messages) > 0 { - fmt.Println() - fmt.Println() + if !latest.Done { + return nil } verbose, err := cmd.Flags().GetBool("verbose") if err != nil { - return nil, err + return err } if verbose { latest.Summary() } - return &api.Message{Role: role, Content: fullResponse.String()}, nil + ctx := cmd.Context() + ctx = context.WithValue(ctx, generateContextKey("context"), latest.Context) + cmd.SetContext(ctx) + + return nil } type MultilineState int @@ -590,10 +600,13 @@ const ( MultilineTemplate ) -func chatInteractive(cmd *cobra.Command, opts runOptions) error { +func generateInteractive(cmd *cobra.Command, opts generateOptions) error { // load the model - loadOpts := runOptions{Model: opts.Model} - if _, err := chat(cmd, loadOpts); err != nil { + loadOpts := generateOptions{ + Model: opts.Model, + Prompt: "", + } + if err := generate(cmd, loadOpts); err != nil { return err } @@ -664,9 +677,7 @@ func chatInteractive(cmd *cobra.Command, opts runOptions) error { defer fmt.Printf(readline.EndBracketedPaste) var multiline MultilineState - var content string - var systemContent string - opts.Messages = make([]api.Message, 0) + var prompt string for { line, err := scanner.Readline() @@ -680,7 +691,7 @@ func chatInteractive(cmd *cobra.Command, opts runOptions) error { } scanner.Prompt.UseAlt = false - content = "" + prompt = "" continue case err != nil: @@ -688,37 +699,37 @@ func chatInteractive(cmd *cobra.Command, opts runOptions) error { } switch { - case strings.HasPrefix(content, `"""`): + case strings.HasPrefix(prompt, `"""`): // if the prompt so far starts with """ then we're in multiline mode // and we need to keep reading until we find a line that ends with """ cut, found := strings.CutSuffix(line, `"""`) - content += cut + "\n" + prompt += cut + "\n" if !found { continue } - content = strings.TrimPrefix(content, `"""`) + prompt = strings.TrimPrefix(prompt, `"""`) scanner.Prompt.UseAlt = false switch multiline { case MultilineSystem: - systemContent = content - content = "" + opts.System = prompt + prompt = "" fmt.Println("Set system template.\n") case MultilineTemplate: - opts.Template = content - content = "" + opts.Template = prompt + prompt = "" fmt.Println("Set model template.\n") } multiline = MultilineNone - case strings.HasPrefix(line, `"""`) && len(content) == 0: + case strings.HasPrefix(line, `"""`) && len(prompt) == 0: scanner.Prompt.UseAlt = true multiline = MultilinePrompt - content += line + "\n" + prompt += line + "\n" continue case scanner.Pasting: - content += line + "\n" + prompt += line + "\n" continue case strings.HasPrefix(line, "/list"): args := strings.Fields(line) @@ -780,17 +791,17 @@ func chatInteractive(cmd *cobra.Command, opts runOptions) error { line = strings.TrimPrefix(line, `"""`) if strings.HasPrefix(args[2], `"""`) { cut, found := strings.CutSuffix(line, `"""`) - content += cut + "\n" + prompt += cut + "\n" if found { - systemContent = content + opts.System = prompt if args[1] == "system" { fmt.Println("Set system template.\n") } else { fmt.Println("Set prompt template.\n") } - content = "" + prompt = "" } else { - content = `"""` + content + prompt = `"""` + prompt if args[1] == "system" { multiline = MultilineSystem } else { @@ -799,7 +810,7 @@ func chatInteractive(cmd *cobra.Command, opts runOptions) error { scanner.Prompt.UseAlt = true } } else { - systemContent = line + opts.System = line fmt.Println("Set system template.\n") } default: @@ -847,8 +858,8 @@ func chatInteractive(cmd *cobra.Command, opts runOptions) error { } case "system": switch { - case systemContent != "": - fmt.Println(systemContent + "\n") + case opts.System != "": + fmt.Println(opts.System + "\n") case resp.System != "": fmt.Println(resp.System + "\n") default: @@ -888,23 +899,16 @@ func chatInteractive(cmd *cobra.Command, opts runOptions) error { fmt.Printf("Unknown command '%s'. Type /? for help\n", args[0]) continue default: - content += line + prompt += line } - if len(content) > 0 && multiline == MultilineNone { - if systemContent != "" { - opts.Messages = append(opts.Messages, api.Message{Role: "system", Content: systemContent}) - } - opts.Messages = append(opts.Messages, api.Message{Role: "user", Content: content}) - assistant, err := chat(cmd, opts) - if err != nil { + if len(prompt) > 0 && multiline == MultilineNone { + opts.Prompt = prompt + if err := generate(cmd, opts); err != nil { return err } - if assistant != nil { - opts.Messages = append(opts.Messages, *assistant) - } - content = "" + prompt = "" } } }