fix: set template without triple quotes

This commit is contained in:
Michael Yang 2024-01-05 15:51:33 -08:00
parent 9c2941e61b
commit 5580ae2472

View file

@ -139,8 +139,8 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
fmt.Print(readline.StartBracketedPaste)
defer fmt.Printf(readline.EndBracketedPaste)
var sb strings.Builder
var multiline MultilineState
var prompt string
for {
line, err := scanner.Readline()
@ -154,7 +154,7 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
}
scanner.Prompt.UseAlt = false
prompt = ""
sb.Reset()
continue
case err != nil:
@ -162,38 +162,41 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
}
switch {
case strings.HasPrefix(prompt, `"""`):
// if the prompt so far starts with """ then we're in multiline mode
// and we need to keep reading until we find a line that ends with """
cut, found := strings.CutSuffix(line, `"""`)
prompt += cut
if !found {
prompt += "\n"
case multiline != MultilineNone:
// check if there's a multiline terminating string
before, ok := strings.CutSuffix(line, `"""`)
sb.WriteString(before)
if !ok {
fmt.Fprintln(&sb)
continue
}
prompt = strings.TrimPrefix(prompt, `"""`)
scanner.Prompt.UseAlt = false
switch multiline {
case MultilineSystem:
opts.System = prompt
prompt = ""
opts.System = sb.String()
fmt.Println("Set system message.")
sb.Reset()
case MultilineTemplate:
opts.Template = prompt
prompt = ""
opts.Template = sb.String()
fmt.Println("Set prompt template.")
sb.Reset()
}
multiline = MultilineNone
case strings.HasPrefix(line, `"""`) && len(prompt) == 0:
scanner.Prompt.UseAlt = true
multiline = MultilinePrompt
prompt += line + "\n"
continue
scanner.Prompt.UseAlt = false
case strings.HasPrefix(line, `"""`):
line := strings.TrimPrefix(line, `"""`)
line, ok := strings.CutSuffix(line, `"""`)
sb.WriteString(line)
if !ok {
// no multiline terminating string; need more input
fmt.Fprintln(&sb)
multiline = MultilinePrompt
scanner.Prompt.UseAlt = true
break
}
case scanner.Pasting:
prompt += line + "\n"
fmt.Fprintln(&sb, line)
continue
case strings.HasPrefix(line, "/list"):
args := strings.Fields(line)
@ -251,33 +254,41 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
usageSet()
continue
}
line := strings.Join(args[2:], " ")
line = strings.TrimPrefix(line, `"""`)
if strings.HasPrefix(args[2], `"""`) {
cut, found := strings.CutSuffix(line, `"""`)
prompt += cut
if found {
if args[1] == "system" {
opts.System = prompt
fmt.Println("Set system message.")
} else {
opts.Template = prompt
fmt.Println("Set prompt template.")
}
prompt = ""
} else {
prompt = `"""` + prompt + "\n"
if args[1] == "system" {
multiline = MultilineSystem
} else {
multiline = MultilineTemplate
}
scanner.Prompt.UseAlt = true
}
} else {
opts.System = line
fmt.Println("Set system message.")
if args[1] == "system" {
multiline = MultilineSystem
} else if args[1] == "template" {
multiline = MultilineTemplate
}
line := strings.Join(args[2:], " ")
line, ok := strings.CutPrefix(line, `"""`)
if !ok {
multiline = MultilineNone
} else {
// only cut suffix if the line is multiline
line, ok = strings.CutSuffix(line, `"""`)
if ok {
multiline = MultilineNone
}
}
sb.WriteString(line)
if multiline != MultilineNone {
scanner.Prompt.UseAlt = true
continue
}
if args[1] == "system" {
opts.System = sb.String()
fmt.Println("Set system message.")
} else if args[1] == "template" {
opts.Template = sb.String()
fmt.Println("Set prompt template.")
}
sb.Reset()
continue
default:
fmt.Printf("Unknown command '/set %s'. Type /? for help\n", args[1])
}
@ -390,20 +401,20 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
}
}
if isFile {
prompt += line
} else {
if !isFile {
fmt.Printf("Unknown command '%s'. Type /? for help\n", args[0])
continue
}
sb.WriteString(line)
default:
prompt += line
sb.WriteString(line)
}
if len(prompt) > 0 && multiline == MultilineNone {
opts.Prompt = prompt
if sb.Len() > 0 && multiline == MultilineNone {
opts.Prompt = sb.String()
if multiModal {
newPrompt, images, err := extractFileData(prompt)
newPrompt, images, err := extractFileData(sb.String())
if err != nil {
return err
}
@ -419,15 +430,16 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
if len(opts.Images) == 0 {
fmt.Println("This model requires you to add a jpeg, png, or svg image.")
fmt.Println()
prompt = ""
sb.Reset()
continue
}
}
if err := generate(cmd, opts); err != nil {
return err
}
prompt = ""
sb.Reset()
}
}
}