Merge pull request #1614 from jmorganca/mxyng/fix-set-template

fix: set template without triple quotes
This commit is contained in:
Michael Yang 2024-01-09 09:36:24 -08:00 committed by GitHub
commit 62023177f6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -139,8 +139,8 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
fmt.Print(readline.StartBracketedPaste) fmt.Print(readline.StartBracketedPaste)
defer fmt.Printf(readline.EndBracketedPaste) defer fmt.Printf(readline.EndBracketedPaste)
var sb strings.Builder
var multiline MultilineState var multiline MultilineState
var prompt string
for { for {
line, err := scanner.Readline() line, err := scanner.Readline()
@ -154,7 +154,7 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
} }
scanner.Prompt.UseAlt = false scanner.Prompt.UseAlt = false
prompt = "" sb.Reset()
continue continue
case err != nil: case err != nil:
@ -162,38 +162,41 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
} }
switch { switch {
case strings.HasPrefix(prompt, `"""`): case multiline != MultilineNone:
// if the prompt so far starts with """ then we're in multiline mode // check if there's a multiline terminating string
// and we need to keep reading until we find a line that ends with """ before, ok := strings.CutSuffix(line, `"""`)
cut, found := strings.CutSuffix(line, `"""`) sb.WriteString(before)
prompt += cut if !ok {
fmt.Fprintln(&sb)
if !found {
prompt += "\n"
continue continue
} }
prompt = strings.TrimPrefix(prompt, `"""`)
scanner.Prompt.UseAlt = false
switch multiline { switch multiline {
case MultilineSystem: case MultilineSystem:
opts.System = prompt opts.System = sb.String()
prompt = ""
fmt.Println("Set system message.") fmt.Println("Set system message.")
sb.Reset()
case MultilineTemplate: case MultilineTemplate:
opts.Template = prompt opts.Template = sb.String()
prompt = ""
fmt.Println("Set prompt template.") fmt.Println("Set prompt template.")
sb.Reset()
} }
multiline = MultilineNone multiline = MultilineNone
case strings.HasPrefix(line, `"""`) && len(prompt) == 0: scanner.Prompt.UseAlt = false
scanner.Prompt.UseAlt = true case strings.HasPrefix(line, `"""`):
line := strings.TrimPrefix(line, `"""`)
line, ok := strings.CutSuffix(line, `"""`)
sb.WriteString(line)
if !ok {
// no multiline terminating string; need more input
fmt.Fprintln(&sb)
multiline = MultilinePrompt multiline = MultilinePrompt
prompt += line + "\n" scanner.Prompt.UseAlt = true
continue break
}
case scanner.Pasting: case scanner.Pasting:
prompt += line + "\n" fmt.Fprintln(&sb, line)
continue continue
case strings.HasPrefix(line, "/list"): case strings.HasPrefix(line, "/list"):
args := strings.Fields(line) args := strings.Fields(line)
@ -251,33 +254,41 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
usageSet() usageSet()
continue continue
} }
line := strings.Join(args[2:], " ")
line = strings.TrimPrefix(line, `"""`)
if strings.HasPrefix(args[2], `"""`) {
cut, found := strings.CutSuffix(line, `"""`)
prompt += cut
if found {
if args[1] == "system" {
opts.System = prompt
fmt.Println("Set system message.")
} else {
opts.Template = prompt
fmt.Println("Set prompt template.")
}
prompt = ""
} else {
prompt = `"""` + prompt + "\n"
if args[1] == "system" { if args[1] == "system" {
multiline = MultilineSystem multiline = MultilineSystem
} else { } else if args[1] == "template" {
multiline = MultilineTemplate multiline = MultilineTemplate
} }
scanner.Prompt.UseAlt = true
} line := strings.Join(args[2:], " ")
line, ok := strings.CutPrefix(line, `"""`)
if !ok {
multiline = MultilineNone
} else { } else {
opts.System = line // only cut suffix if the line is multiline
fmt.Println("Set system message.") line, ok = strings.CutSuffix(line, `"""`)
if ok {
multiline = MultilineNone
} }
}
sb.WriteString(line)
if multiline != MultilineNone {
scanner.Prompt.UseAlt = true
continue
}
if args[1] == "system" {
opts.System = sb.String()
fmt.Println("Set system message.")
} else if args[1] == "template" {
opts.Template = sb.String()
fmt.Println("Set prompt template.")
}
sb.Reset()
continue
default: default:
fmt.Printf("Unknown command '/set %s'. Type /? for help\n", args[1]) fmt.Printf("Unknown command '/set %s'. Type /? for help\n", args[1])
} }
@ -390,20 +401,20 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
} }
} }
if isFile { if !isFile {
prompt += line
} else {
fmt.Printf("Unknown command '%s'. Type /? for help\n", args[0]) fmt.Printf("Unknown command '%s'. Type /? for help\n", args[0])
continue continue
} }
sb.WriteString(line)
default: default:
prompt += line sb.WriteString(line)
} }
if len(prompt) > 0 && multiline == MultilineNone { if sb.Len() > 0 && multiline == MultilineNone {
opts.Prompt = prompt opts.Prompt = sb.String()
if multiModal { if multiModal {
newPrompt, images, err := extractFileData(prompt) newPrompt, images, err := extractFileData(sb.String())
if err != nil { if err != nil {
return err return err
} }
@ -419,15 +430,16 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
if len(opts.Images) == 0 { if len(opts.Images) == 0 {
fmt.Println("This model requires you to add a jpeg, png, or svg image.") fmt.Println("This model requires you to add a jpeg, png, or svg image.")
fmt.Println() fmt.Println()
prompt = "" sb.Reset()
continue continue
} }
} }
if err := generate(cmd, opts); err != nil { if err := generate(cmd, opts); err != nil {
return err return err
} }
prompt = "" sb.Reset()
} }
} }
} }