From 5580ae2472f982db5c3aae1433a02a56e0b967ec Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Fri, 5 Jan 2024 15:51:33 -0800 Subject: [PATCH] fix: set template without triple quotes --- cmd/interactive.go | 128 +++++++++++++++++++++++++-------------------- 1 file changed, 70 insertions(+), 58 deletions(-) diff --git a/cmd/interactive.go b/cmd/interactive.go index 8d567689..70602b11 100644 --- a/cmd/interactive.go +++ b/cmd/interactive.go @@ -139,8 +139,8 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error { fmt.Print(readline.StartBracketedPaste) defer fmt.Printf(readline.EndBracketedPaste) + var sb strings.Builder var multiline MultilineState - var prompt string for { line, err := scanner.Readline() @@ -154,7 +154,7 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error { } scanner.Prompt.UseAlt = false - prompt = "" + sb.Reset() continue case err != nil: @@ -162,38 +162,41 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error { } switch { - case strings.HasPrefix(prompt, `"""`): - // if the prompt so far starts with """ then we're in multiline mode - // and we need to keep reading until we find a line that ends with """ - cut, found := strings.CutSuffix(line, `"""`) - prompt += cut - - if !found { - prompt += "\n" + case multiline != MultilineNone: + // check if there's a multiline terminating string + before, ok := strings.CutSuffix(line, `"""`) + sb.WriteString(before) + if !ok { + fmt.Fprintln(&sb) continue } - prompt = strings.TrimPrefix(prompt, `"""`) - scanner.Prompt.UseAlt = false - switch multiline { case MultilineSystem: - opts.System = prompt - prompt = "" + opts.System = sb.String() fmt.Println("Set system message.") + sb.Reset() case MultilineTemplate: - opts.Template = prompt - prompt = "" + opts.Template = sb.String() fmt.Println("Set prompt template.") + sb.Reset() } + multiline = MultilineNone - case strings.HasPrefix(line, `"""`) && len(prompt) == 0: - scanner.Prompt.UseAlt = true - multiline = MultilinePrompt - prompt += line + "\n" - continue + scanner.Prompt.UseAlt = false + case strings.HasPrefix(line, `"""`): + line := strings.TrimPrefix(line, `"""`) + line, ok := strings.CutSuffix(line, `"""`) + sb.WriteString(line) + if !ok { + // no multiline terminating string; need more input + fmt.Fprintln(&sb) + multiline = MultilinePrompt + scanner.Prompt.UseAlt = true + break + } case scanner.Pasting: - prompt += line + "\n" + fmt.Fprintln(&sb, line) continue case strings.HasPrefix(line, "/list"): args := strings.Fields(line) @@ -251,33 +254,41 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error { usageSet() continue } - line := strings.Join(args[2:], " ") - line = strings.TrimPrefix(line, `"""`) - if strings.HasPrefix(args[2], `"""`) { - cut, found := strings.CutSuffix(line, `"""`) - prompt += cut - if found { - if args[1] == "system" { - opts.System = prompt - fmt.Println("Set system message.") - } else { - opts.Template = prompt - fmt.Println("Set prompt template.") - } - prompt = "" - } else { - prompt = `"""` + prompt + "\n" - if args[1] == "system" { - multiline = MultilineSystem - } else { - multiline = MultilineTemplate - } - scanner.Prompt.UseAlt = true - } - } else { - opts.System = line - fmt.Println("Set system message.") + + if args[1] == "system" { + multiline = MultilineSystem + } else if args[1] == "template" { + multiline = MultilineTemplate } + + line := strings.Join(args[2:], " ") + line, ok := strings.CutPrefix(line, `"""`) + if !ok { + multiline = MultilineNone + } else { + // only cut suffix if the line is multiline + line, ok = strings.CutSuffix(line, `"""`) + if ok { + multiline = MultilineNone + } + } + + sb.WriteString(line) + if multiline != MultilineNone { + scanner.Prompt.UseAlt = true + continue + } + + if args[1] == "system" { + opts.System = sb.String() + fmt.Println("Set system message.") + } else if args[1] == "template" { + opts.Template = sb.String() + fmt.Println("Set prompt template.") + } + + sb.Reset() + continue default: fmt.Printf("Unknown command '/set %s'. Type /? for help\n", args[1]) } @@ -390,20 +401,20 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error { } } - if isFile { - prompt += line - } else { + if !isFile { fmt.Printf("Unknown command '%s'. Type /? for help\n", args[0]) continue } + + sb.WriteString(line) default: - prompt += line + sb.WriteString(line) } - if len(prompt) > 0 && multiline == MultilineNone { - opts.Prompt = prompt + if sb.Len() > 0 && multiline == MultilineNone { + opts.Prompt = sb.String() if multiModal { - newPrompt, images, err := extractFileData(prompt) + newPrompt, images, err := extractFileData(sb.String()) if err != nil { return err } @@ -419,15 +430,16 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error { if len(opts.Images) == 0 { fmt.Println("This model requires you to add a jpeg, png, or svg image.") fmt.Println() - prompt = "" + sb.Reset() continue } } + if err := generate(cmd, opts); err != nil { return err } - prompt = "" + sb.Reset() } } }