post-response templating (#1427)

This commit is contained in:
Bruce MacDonald 2023-12-22 17:07:05 -05:00 committed by GitHub
parent b80081022f
commit db356c8519
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 334 additions and 16 deletions

View file

@ -18,6 +18,7 @@ import (
"strconv" "strconv"
"strings" "strings"
"text/template" "text/template"
"text/template/parse"
"golang.org/x/exp/slices" "golang.org/x/exp/slices"
@ -57,17 +58,35 @@ type PromptVars struct {
First bool First bool
} }
func (m *Model) Prompt(p PromptVars) (string, error) { // extractParts extracts the parts of the template before and after the {{.Response}} node.
var prompt strings.Builder func extractParts(tmplStr string) (pre string, post string, err error) {
// Use the "missingkey=zero" option to handle missing variables without panicking tmpl, err := template.New("").Parse(tmplStr)
tmpl, err := template.New("").Option("missingkey=zero").Parse(m.Template)
if err != nil { if err != nil {
return "", err return "", "", err
} }
if p.System == "" { var foundResponse bool
// use the default system message for this model if one is not specified
p.System = m.System for _, node := range tmpl.Tree.Root.Nodes {
if node.Type() == parse.NodeAction && node.String() == "{{.Response}}" {
foundResponse = true
}
if !foundResponse {
pre += node.String()
} else {
post += node.String()
}
}
return pre, post, nil
}
func Prompt(promptTemplate string, p PromptVars) (string, error) {
var prompt strings.Builder
// Use the "missingkey=zero" option to handle missing variables without panicking
tmpl, err := template.New("").Option("missingkey=zero").Parse(promptTemplate)
if err != nil {
return "", err
} }
vars := map[string]any{ vars := map[string]any{
@ -82,20 +101,59 @@ func (m *Model) Prompt(p PromptVars) (string, error) {
return "", err return "", err
} }
prompt.WriteString(sb.String()) prompt.WriteString(sb.String())
if !strings.Contains(prompt.String(), p.Response) {
// if the response is not in the prompt template, append it to the end
prompt.WriteString(p.Response) prompt.WriteString(p.Response)
}
return prompt.String(), nil return prompt.String(), nil
} }
// PreResponsePrompt returns the prompt before the response tag
func (m *Model) PreResponsePrompt(p PromptVars) (string, error) {
if p.System == "" {
// use the default system prompt for this model if one is not specified
p.System = m.System
}
pre, _, err := extractParts(m.Template)
if err != nil {
return "", err
}
return Prompt(pre, p)
}
// PostResponseTemplate returns the template after the response tag
func (m *Model) PostResponseTemplate(p PromptVars) (string, error) {
if p.System == "" {
// use the default system prompt for this model if one is not specified
p.System = m.System
}
_, post, err := extractParts(m.Template)
if err != nil {
return "", err
}
if post == "" {
// if there is no post-response template, return the provided response
return p.Response, nil
}
return Prompt(post, p)
}
func (m *Model) ChatPrompt(msgs []api.Message) (string, []api.ImageData, error) { func (m *Model) ChatPrompt(msgs []api.Message) (string, []api.ImageData, error) {
// build the prompt from the list of messages // build the prompt from the list of messages
var prompt strings.Builder var prompt strings.Builder
var currentImages []api.ImageData var currentImages []api.ImageData
currentVars := PromptVars{ currentVars := PromptVars{
First: true, First: true,
System: m.System,
} }
writePrompt := func() error { writePrompt := func() error {
p, err := m.Prompt(currentVars) p, err := Prompt(m.Template, currentVars)
if err != nil { if err != nil {
return err return err
} }
@ -133,9 +191,11 @@ func (m *Model) ChatPrompt(msgs []api.Message) (string, []api.ImageData, error)
// Append the last set of vars if they are non-empty // Append the last set of vars if they are non-empty
if currentVars.Prompt != "" || currentVars.System != "" { if currentVars.Prompt != "" || currentVars.System != "" {
if err := writePrompt(); err != nil { p, err := m.PreResponsePrompt(currentVars)
return "", nil, err if err != nil {
return "", nil, fmt.Errorf("pre-response template: %w", err)
} }
prompt.WriteString(p)
} }
return prompt.String(), currentImages, nil return prompt.String(), currentImages, nil

View file

@ -7,6 +7,232 @@ import (
"github.com/jmorganca/ollama/api" "github.com/jmorganca/ollama/api"
) )
func TestPrompt(t *testing.T) {
tests := []struct {
name string
template string
vars PromptVars
want string
wantErr bool
}{
{
name: "System Prompt",
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
vars: PromptVars{
System: "You are a Wizard.",
Prompt: "What are the potion ingredients?",
},
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]",
},
{
name: "System Prompt with Response",
template: "[INST] {{ .System }} {{ .Prompt }} [/INST] {{ .Response }}",
vars: PromptVars{
System: "You are a Wizard.",
Prompt: "What are the potion ingredients?",
Response: "I don't know.",
},
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST] I don't know.",
},
{
name: "Conditional Logic Nodes",
template: "[INST] {{if .First}}Hello!{{end}} {{ .System }} {{ .Prompt }} [/INST] {{ .Response }}",
vars: PromptVars{
First: true,
System: "You are a Wizard.",
Prompt: "What are the potion ingredients?",
Response: "I don't know.",
},
want: "[INST] Hello! You are a Wizard. What are the potion ingredients? [/INST] I don't know.",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := Prompt(tt.template, tt.vars)
if (err != nil) != tt.wantErr {
t.Errorf("Prompt() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("Prompt() got = %v, want %v", got, tt.want)
}
})
}
}
func TestModel_PreResponsePrompt(t *testing.T) {
tests := []struct {
name string
template string
vars PromptVars
want string
wantErr bool
}{
{
name: "No Response in Template",
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
vars: PromptVars{
System: "You are a Wizard.",
Prompt: "What are the potion ingredients?",
},
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]",
},
{
name: "Response in Template",
template: "[INST] {{ .System }} {{ .Prompt }} [/INST] {{ .Response }}",
vars: PromptVars{
System: "You are a Wizard.",
Prompt: "What are the potion ingredients?",
},
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST] ",
},
{
name: "Response in Template with Trailing Formatting",
template: "<|im_start|>user\n{{ .Prompt }}<|im_end|><|im_start|>assistant\n{{ .Response }}<|im_end|>",
vars: PromptVars{
Prompt: "What are the potion ingredients?",
},
want: "<|im_start|>user\nWhat are the potion ingredients?<|im_end|><|im_start|>assistant\n",
},
{
name: "Response in Template with Alternative Formatting",
template: "<|im_start|>user\n{{.Prompt}}<|im_end|><|im_start|>assistant\n{{.Response}}<|im_end|>",
vars: PromptVars{
Prompt: "What are the potion ingredients?",
},
want: "<|im_start|>user\nWhat are the potion ingredients?<|im_end|><|im_start|>assistant\n",
},
}
for _, tt := range tests {
m := Model{Template: tt.template}
t.Run(tt.name, func(t *testing.T) {
got, err := m.PreResponsePrompt(tt.vars)
if (err != nil) != tt.wantErr {
t.Errorf("PreResponsePrompt() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("PreResponsePrompt() got = %v, want %v", got, tt.want)
}
})
}
}
func TestModel_PostResponsePrompt(t *testing.T) {
tests := []struct {
name string
template string
vars PromptVars
want string
wantErr bool
}{
{
name: "No Response in Template",
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
vars: PromptVars{
Response: "I don't know.",
},
want: "I don't know.",
},
{
name: "Response in Template",
template: "[INST] {{ .System }} {{ .Prompt }} [/INST] {{ .Response }}",
vars: PromptVars{
Response: "I don't know.",
},
want: "I don't know.",
},
{
name: "Response in Template with Trailing Formatting",
template: "<|im_start|>user\n{{ .Prompt }}<|im_end|><|im_start|>assistant\n{{ .Response }}<|im_end|>",
vars: PromptVars{
Response: "I don't know.",
},
want: "I don't know.<|im_end|>",
},
{
name: "Response in Template with Alternative Formatting",
template: "<|im_start|>user\n{{.Prompt}}<|im_end|><|im_start|>assistant\n{{.Response}}<|im_end|>",
vars: PromptVars{
Response: "I don't know.",
},
want: "I don't know.<|im_end|>",
},
}
for _, tt := range tests {
m := Model{Template: tt.template}
t.Run(tt.name, func(t *testing.T) {
got, err := m.PostResponseTemplate(tt.vars)
if (err != nil) != tt.wantErr {
t.Errorf("PostResponseTemplate() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("PostResponseTemplate() got = %v, want %v", got, tt.want)
}
})
}
}
func TestModel_PreResponsePrompt_PostResponsePrompt(t *testing.T) {
tests := []struct {
name string
template string
preVars PromptVars
postVars PromptVars
want string
wantErr bool
}{
{
name: "Response in Template",
template: "<|im_start|>user\n{{.Prompt}}<|im_end|><|im_start|>assistant\n{{.Response}}<|im_end|>",
preVars: PromptVars{
Prompt: "What are the potion ingredients?",
},
postVars: PromptVars{
Prompt: "What are the potion ingredients?",
Response: "Sugar.",
},
want: "<|im_start|>user\nWhat are the potion ingredients?<|im_end|><|im_start|>assistant\nSugar.<|im_end|>",
},
{
name: "No Response in Template",
template: "<|im_start|>user\n{{.Prompt}}<|im_end|><|im_start|>assistant\n",
preVars: PromptVars{
Prompt: "What are the potion ingredients?",
},
postVars: PromptVars{
Prompt: "What are the potion ingredients?",
Response: "Spice.",
},
want: "<|im_start|>user\nWhat are the potion ingredients?<|im_end|><|im_start|>assistant\nSpice.",
},
}
for _, tt := range tests {
m := Model{Template: tt.template}
t.Run(tt.name, func(t *testing.T) {
pre, err := m.PreResponsePrompt(tt.preVars)
if (err != nil) != tt.wantErr {
t.Errorf("PreResponsePrompt() error = %v, wantErr %v", err, tt.wantErr)
return
}
post, err := m.PostResponseTemplate(tt.postVars)
if err != nil {
t.Errorf("PostResponseTemplate() error = %v, wantErr %v", err, tt.wantErr)
return
}
result := pre + post
if result != tt.want {
t.Errorf("Prompt() got = %v, want %v", result, tt.want)
}
})
}
}
func TestChat(t *testing.T) { func TestChat(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
@ -30,6 +256,29 @@ func TestChat(t *testing.T) {
}, },
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]", want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]",
}, },
{
name: "First Message",
template: "[INST] {{if .First}}Hello!{{end}} {{ .System }} {{ .Prompt }} [/INST]",
msgs: []api.Message{
{
Role: "system",
Content: "You are a Wizard.",
},
{
Role: "user",
Content: "What are the potion ingredients?",
},
{
Role: "assistant",
Content: "eye of newt",
},
{
Role: "user",
Content: "Anything else?",
},
},
want: "[INST] Hello! You are a Wizard. What are the potion ingredients? [/INST]eye of newt[INST] Anything else? [/INST]",
},
{ {
name: "Message History", name: "Message History",
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]", template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",

View file

@ -195,6 +195,7 @@ func GenerateHandler(c *gin.Context) {
checkpointLoaded := time.Now() checkpointLoaded := time.Now()
var prompt string var prompt string
var promptVars PromptVars
switch { switch {
case req.Raw: case req.Raw:
prompt = req.Prompt prompt = req.Prompt
@ -217,11 +218,12 @@ func GenerateHandler(c *gin.Context) {
prevCtx = strings.TrimPrefix(prevCtx, " ") prevCtx = strings.TrimPrefix(prevCtx, " ")
rebuild.WriteString(prevCtx) rebuild.WriteString(prevCtx)
} }
p, err := model.Prompt(PromptVars{ promptVars = PromptVars{
System: req.System, System: req.System,
Prompt: req.Prompt, Prompt: req.Prompt,
First: len(req.Context) == 0, First: len(req.Context) == 0,
}) }
p, err := model.PreResponsePrompt(promptVars)
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
@ -264,7 +266,14 @@ func GenerateHandler(c *gin.Context) {
resp.LoadDuration = checkpointLoaded.Sub(checkpointStart) resp.LoadDuration = checkpointLoaded.Sub(checkpointStart)
if !req.Raw { if !req.Raw {
embd, err := loaded.runner.Encode(c.Request.Context(), prompt+generated.String()) // append the generated text to the history and template it if needed
promptVars.Response = generated.String()
result, err := model.PostResponseTemplate(promptVars)
if err != nil {
ch <- gin.H{"error": err.Error()}
return
}
embd, err := loaded.runner.Encode(c.Request.Context(), prompt+result)
if err != nil { if err != nil {
ch <- gin.H{"error": err.Error()} ch <- gin.H{"error": err.Error()}
return return