From db356c85199f62e0ba7f7811b06777fb366f4634 Mon Sep 17 00:00:00 2001 From: Bruce MacDonald Date: Fri, 22 Dec 2023 17:07:05 -0500 Subject: [PATCH] post-response templating (#1427) --- server/images.go | 86 ++++++++++++--- server/images_test.go | 249 ++++++++++++++++++++++++++++++++++++++++++ server/routes.go | 15 ++- 3 files changed, 334 insertions(+), 16 deletions(-) diff --git a/server/images.go b/server/images.go index 006a91d3..85b96395 100644 --- a/server/images.go +++ b/server/images.go @@ -18,6 +18,7 @@ import ( "strconv" "strings" "text/template" + "text/template/parse" "golang.org/x/exp/slices" @@ -57,17 +58,35 @@ type PromptVars struct { First bool } -func (m *Model) Prompt(p PromptVars) (string, error) { - var prompt strings.Builder - // Use the "missingkey=zero" option to handle missing variables without panicking - tmpl, err := template.New("").Option("missingkey=zero").Parse(m.Template) +// extractParts extracts the parts of the template before and after the {{.Response}} node. +func extractParts(tmplStr string) (pre string, post string, err error) { + tmpl, err := template.New("").Parse(tmplStr) if err != nil { - return "", err + return "", "", err } - if p.System == "" { - // use the default system message for this model if one is not specified - p.System = m.System + var foundResponse bool + + for _, node := range tmpl.Tree.Root.Nodes { + if node.Type() == parse.NodeAction && node.String() == "{{.Response}}" { + foundResponse = true + } + if !foundResponse { + pre += node.String() + } else { + post += node.String() + } + } + + return pre, post, nil +} + +func Prompt(promptTemplate string, p PromptVars) (string, error) { + var prompt strings.Builder + // Use the "missingkey=zero" option to handle missing variables without panicking + tmpl, err := template.New("").Option("missingkey=zero").Parse(promptTemplate) + if err != nil { + return "", err } vars := map[string]any{ @@ -82,20 +101,59 @@ func (m *Model) Prompt(p PromptVars) (string, error) { return "", err } prompt.WriteString(sb.String()) - prompt.WriteString(p.Response) + + if !strings.Contains(prompt.String(), p.Response) { + // if the response is not in the prompt template, append it to the end + prompt.WriteString(p.Response) + } + return prompt.String(), nil } +// PreResponsePrompt returns the prompt before the response tag +func (m *Model) PreResponsePrompt(p PromptVars) (string, error) { + if p.System == "" { + // use the default system prompt for this model if one is not specified + p.System = m.System + } + pre, _, err := extractParts(m.Template) + if err != nil { + return "", err + } + + return Prompt(pre, p) +} + +// PostResponseTemplate returns the template after the response tag +func (m *Model) PostResponseTemplate(p PromptVars) (string, error) { + if p.System == "" { + // use the default system prompt for this model if one is not specified + p.System = m.System + } + _, post, err := extractParts(m.Template) + if err != nil { + return "", err + } + + if post == "" { + // if there is no post-response template, return the provided response + return p.Response, nil + } + + return Prompt(post, p) +} + func (m *Model) ChatPrompt(msgs []api.Message) (string, []api.ImageData, error) { // build the prompt from the list of messages var prompt strings.Builder var currentImages []api.ImageData currentVars := PromptVars{ - First: true, + First: true, + System: m.System, } writePrompt := func() error { - p, err := m.Prompt(currentVars) + p, err := Prompt(m.Template, currentVars) if err != nil { return err } @@ -133,9 +191,11 @@ func (m *Model) ChatPrompt(msgs []api.Message) (string, []api.ImageData, error) // Append the last set of vars if they are non-empty if currentVars.Prompt != "" || currentVars.System != "" { - if err := writePrompt(); err != nil { - return "", nil, err + p, err := m.PreResponsePrompt(currentVars) + if err != nil { + return "", nil, fmt.Errorf("pre-response template: %w", err) } + prompt.WriteString(p) } return prompt.String(), currentImages, nil diff --git a/server/images_test.go b/server/images_test.go index 17c433cd..55408e7a 100644 --- a/server/images_test.go +++ b/server/images_test.go @@ -7,6 +7,232 @@ import ( "github.com/jmorganca/ollama/api" ) +func TestPrompt(t *testing.T) { + tests := []struct { + name string + template string + vars PromptVars + want string + wantErr bool + }{ + { + name: "System Prompt", + template: "[INST] {{ .System }} {{ .Prompt }} [/INST]", + vars: PromptVars{ + System: "You are a Wizard.", + Prompt: "What are the potion ingredients?", + }, + want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]", + }, + { + name: "System Prompt with Response", + template: "[INST] {{ .System }} {{ .Prompt }} [/INST] {{ .Response }}", + vars: PromptVars{ + System: "You are a Wizard.", + Prompt: "What are the potion ingredients?", + Response: "I don't know.", + }, + want: "[INST] You are a Wizard. What are the potion ingredients? [/INST] I don't know.", + }, + { + name: "Conditional Logic Nodes", + template: "[INST] {{if .First}}Hello!{{end}} {{ .System }} {{ .Prompt }} [/INST] {{ .Response }}", + vars: PromptVars{ + First: true, + System: "You are a Wizard.", + Prompt: "What are the potion ingredients?", + Response: "I don't know.", + }, + want: "[INST] Hello! You are a Wizard. What are the potion ingredients? [/INST] I don't know.", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := Prompt(tt.template, tt.vars) + if (err != nil) != tt.wantErr { + t.Errorf("Prompt() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("Prompt() got = %v, want %v", got, tt.want) + } + }) + } +} + +func TestModel_PreResponsePrompt(t *testing.T) { + tests := []struct { + name string + template string + vars PromptVars + want string + wantErr bool + }{ + { + name: "No Response in Template", + template: "[INST] {{ .System }} {{ .Prompt }} [/INST]", + vars: PromptVars{ + System: "You are a Wizard.", + Prompt: "What are the potion ingredients?", + }, + want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]", + }, + { + name: "Response in Template", + template: "[INST] {{ .System }} {{ .Prompt }} [/INST] {{ .Response }}", + vars: PromptVars{ + System: "You are a Wizard.", + Prompt: "What are the potion ingredients?", + }, + want: "[INST] You are a Wizard. What are the potion ingredients? [/INST] ", + }, + { + name: "Response in Template with Trailing Formatting", + template: "<|im_start|>user\n{{ .Prompt }}<|im_end|><|im_start|>assistant\n{{ .Response }}<|im_end|>", + vars: PromptVars{ + Prompt: "What are the potion ingredients?", + }, + want: "<|im_start|>user\nWhat are the potion ingredients?<|im_end|><|im_start|>assistant\n", + }, + { + name: "Response in Template with Alternative Formatting", + template: "<|im_start|>user\n{{.Prompt}}<|im_end|><|im_start|>assistant\n{{.Response}}<|im_end|>", + vars: PromptVars{ + Prompt: "What are the potion ingredients?", + }, + want: "<|im_start|>user\nWhat are the potion ingredients?<|im_end|><|im_start|>assistant\n", + }, + } + + for _, tt := range tests { + m := Model{Template: tt.template} + t.Run(tt.name, func(t *testing.T) { + got, err := m.PreResponsePrompt(tt.vars) + if (err != nil) != tt.wantErr { + t.Errorf("PreResponsePrompt() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("PreResponsePrompt() got = %v, want %v", got, tt.want) + } + }) + } +} + +func TestModel_PostResponsePrompt(t *testing.T) { + tests := []struct { + name string + template string + vars PromptVars + want string + wantErr bool + }{ + { + name: "No Response in Template", + template: "[INST] {{ .System }} {{ .Prompt }} [/INST]", + vars: PromptVars{ + Response: "I don't know.", + }, + want: "I don't know.", + }, + { + name: "Response in Template", + template: "[INST] {{ .System }} {{ .Prompt }} [/INST] {{ .Response }}", + vars: PromptVars{ + Response: "I don't know.", + }, + want: "I don't know.", + }, + { + name: "Response in Template with Trailing Formatting", + template: "<|im_start|>user\n{{ .Prompt }}<|im_end|><|im_start|>assistant\n{{ .Response }}<|im_end|>", + vars: PromptVars{ + Response: "I don't know.", + }, + want: "I don't know.<|im_end|>", + }, + { + name: "Response in Template with Alternative Formatting", + template: "<|im_start|>user\n{{.Prompt}}<|im_end|><|im_start|>assistant\n{{.Response}}<|im_end|>", + vars: PromptVars{ + Response: "I don't know.", + }, + want: "I don't know.<|im_end|>", + }, + } + + for _, tt := range tests { + m := Model{Template: tt.template} + t.Run(tt.name, func(t *testing.T) { + got, err := m.PostResponseTemplate(tt.vars) + if (err != nil) != tt.wantErr { + t.Errorf("PostResponseTemplate() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("PostResponseTemplate() got = %v, want %v", got, tt.want) + } + }) + } +} + +func TestModel_PreResponsePrompt_PostResponsePrompt(t *testing.T) { + tests := []struct { + name string + template string + preVars PromptVars + postVars PromptVars + want string + wantErr bool + }{ + { + name: "Response in Template", + template: "<|im_start|>user\n{{.Prompt}}<|im_end|><|im_start|>assistant\n{{.Response}}<|im_end|>", + preVars: PromptVars{ + Prompt: "What are the potion ingredients?", + }, + postVars: PromptVars{ + Prompt: "What are the potion ingredients?", + Response: "Sugar.", + }, + want: "<|im_start|>user\nWhat are the potion ingredients?<|im_end|><|im_start|>assistant\nSugar.<|im_end|>", + }, + { + name: "No Response in Template", + template: "<|im_start|>user\n{{.Prompt}}<|im_end|><|im_start|>assistant\n", + preVars: PromptVars{ + Prompt: "What are the potion ingredients?", + }, + postVars: PromptVars{ + Prompt: "What are the potion ingredients?", + Response: "Spice.", + }, + want: "<|im_start|>user\nWhat are the potion ingredients?<|im_end|><|im_start|>assistant\nSpice.", + }, + } + + for _, tt := range tests { + m := Model{Template: tt.template} + t.Run(tt.name, func(t *testing.T) { + pre, err := m.PreResponsePrompt(tt.preVars) + if (err != nil) != tt.wantErr { + t.Errorf("PreResponsePrompt() error = %v, wantErr %v", err, tt.wantErr) + return + } + post, err := m.PostResponseTemplate(tt.postVars) + if err != nil { + t.Errorf("PostResponseTemplate() error = %v, wantErr %v", err, tt.wantErr) + return + } + result := pre + post + if result != tt.want { + t.Errorf("Prompt() got = %v, want %v", result, tt.want) + } + }) + } +} + func TestChat(t *testing.T) { tests := []struct { name string @@ -30,6 +256,29 @@ func TestChat(t *testing.T) { }, want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]", }, + { + name: "First Message", + template: "[INST] {{if .First}}Hello!{{end}} {{ .System }} {{ .Prompt }} [/INST]", + msgs: []api.Message{ + { + Role: "system", + Content: "You are a Wizard.", + }, + { + Role: "user", + Content: "What are the potion ingredients?", + }, + { + Role: "assistant", + Content: "eye of newt", + }, + { + Role: "user", + Content: "Anything else?", + }, + }, + want: "[INST] Hello! You are a Wizard. What are the potion ingredients? [/INST]eye of newt[INST] Anything else? [/INST]", + }, { name: "Message History", template: "[INST] {{ .System }} {{ .Prompt }} [/INST]", diff --git a/server/routes.go b/server/routes.go index 75e67a72..123f964b 100644 --- a/server/routes.go +++ b/server/routes.go @@ -195,6 +195,7 @@ func GenerateHandler(c *gin.Context) { checkpointLoaded := time.Now() var prompt string + var promptVars PromptVars switch { case req.Raw: prompt = req.Prompt @@ -217,11 +218,12 @@ func GenerateHandler(c *gin.Context) { prevCtx = strings.TrimPrefix(prevCtx, " ") rebuild.WriteString(prevCtx) } - p, err := model.Prompt(PromptVars{ + promptVars = PromptVars{ System: req.System, Prompt: req.Prompt, First: len(req.Context) == 0, - }) + } + p, err := model.PreResponsePrompt(promptVars) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return @@ -264,7 +266,14 @@ func GenerateHandler(c *gin.Context) { resp.LoadDuration = checkpointLoaded.Sub(checkpointStart) if !req.Raw { - embd, err := loaded.runner.Encode(c.Request.Context(), prompt+generated.String()) + // append the generated text to the history and template it if needed + promptVars.Response = generated.String() + result, err := model.PostResponseTemplate(promptVars) + if err != nil { + ch <- gin.H{"error": err.Error()} + return + } + embd, err := loaded.runner.Encode(c.Request.Context(), prompt+result) if err != nil { ch <- gin.H{"error": err.Error()} return