allow setting the system and template for prompts in the repl (#1335)

This commit is contained in:
Patrick Devine 2023-12-01 09:28:35 -08:00 committed by GitHub
parent 0409c1fa59
commit 6681d37861
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -464,6 +464,8 @@ type generateOptions struct {
Prompt string
WordWrap bool
Format string
System string
Template string
Options map[string]interface{}
}
@ -510,6 +512,8 @@ func generate(cmd *cobra.Command, opts generateOptions) error {
Prompt: opts.Prompt,
Context: generateContext,
Format: opts.Format,
System: opts.System,
Template: opts.Template,
Options: opts.Options,
}
fn := func(response api.GenerateResponse) error {
@ -576,6 +580,15 @@ func generate(cmd *cobra.Command, opts generateOptions) error {
return nil
}
type MultilineState int
const (
MultilineNone MultilineState = iota
MultilinePrompt
MultilineSystem
MultilineTemplate
)
func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
// load the model
loadOpts := generateOptions{
@ -599,7 +612,9 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
usageSet := func() {
fmt.Fprintln(os.Stderr, "Available Commands:")
fmt.Fprintln(os.Stderr, " /set parameter Set a parameter")
fmt.Fprintln(os.Stderr, " /set parameter ... Set a parameter")
fmt.Fprintln(os.Stderr, " /set system <string> Set system prompt")
fmt.Fprintln(os.Stderr, " /set template <string> Set prompt template")
fmt.Fprintln(os.Stderr, " /set history Enable history")
fmt.Fprintln(os.Stderr, " /set nohistory Disable history")
fmt.Fprintln(os.Stderr, " /set wordwrap Enable wordwrap")
@ -650,6 +665,7 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
fmt.Print(readline.StartBracketedPaste)
defer fmt.Printf(readline.EndBracketedPaste)
var multiline MultilineState
var prompt string
for {
@ -684,8 +700,21 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
prompt = strings.TrimPrefix(prompt, `"""`)
scanner.Prompt.UseAlt = false
switch multiline {
case MultilineSystem:
opts.System = prompt
prompt = ""
fmt.Println("Set system template.\n")
case MultilineTemplate:
opts.Template = prompt
prompt = ""
fmt.Println("Set model template.\n")
}
multiline = MultilineNone
case strings.HasPrefix(line, `"""`) && len(prompt) == 0:
scanner.Prompt.UseAlt = true
multiline = MultilinePrompt
prompt += line + "\n"
continue
case scanner.Pasting:
@ -742,6 +771,37 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
}
fmt.Printf("Set parameter '%s' to '%s'\n\n", args[2], strings.Join(params, ", "))
opts.Options[args[2]] = fp[args[2]]
case "system", "template":
if len(args) < 3 {
usageSet()
continue
}
line := strings.Join(args[2:], " ")
line = strings.TrimPrefix(line, `"""`)
if strings.HasPrefix(args[2], `"""`) {
cut, found := strings.CutSuffix(line, `"""`)
prompt += cut + "\n"
if found {
opts.System = prompt
if args[1] == "system" {
fmt.Println("Set system template.\n")
} else {
fmt.Println("Set prompt template.\n")
}
prompt = ""
} else {
prompt = `"""` + prompt
if args[1] == "system" {
multiline = MultilineSystem
} else {
multiline = MultilineTemplate
}
scanner.Prompt.UseAlt = true
}
} else {
opts.System = line
fmt.Println("Set system template.\n")
}
default:
fmt.Printf("Unknown command '/set %s'. Type /? for help\n", args[1])
}
@ -786,16 +846,22 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
fmt.Println(resp.Parameters)
}
case "system":
if resp.System == "" {
switch {
case opts.System != "":
fmt.Println(opts.System + "\n")
case resp.System != "":
fmt.Println(resp.System + "\n")
default:
fmt.Print("No system prompt was specified for this model.\n\n")
} else {
fmt.Println(resp.System)
}
case "template":
if resp.Template == "" {
fmt.Print("No prompt template was specified for this model.\n\n")
} else {
switch {
case opts.Template != "":
fmt.Println(opts.Template + "\n")
case resp.Template != "":
fmt.Println(resp.Template)
default:
fmt.Print("No prompt template was specified for this model.\n\n")
}
default:
fmt.Printf("Unknown command '/show %s'. Type /? for help\n", args[1])
@ -825,7 +891,7 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
prompt += line
}
if len(prompt) > 0 && prompt[0] != '/' {
if len(prompt) > 0 && multiline == MultilineNone {
opts.Prompt = prompt
if err := generate(cmd, opts); err != nil {
return err