post-response templating (#1427)
This commit is contained in:
parent
b80081022f
commit
db356c8519
3 changed files with 334 additions and 16 deletions
|
@ -18,6 +18,7 @@ import (
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"text/template"
|
"text/template"
|
||||||
|
"text/template/parse"
|
||||||
|
|
||||||
"golang.org/x/exp/slices"
|
"golang.org/x/exp/slices"
|
||||||
|
|
||||||
|
@ -57,17 +58,35 @@ type PromptVars struct {
|
||||||
First bool
|
First bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) Prompt(p PromptVars) (string, error) {
|
// extractParts extracts the parts of the template before and after the {{.Response}} node.
|
||||||
var prompt strings.Builder
|
func extractParts(tmplStr string) (pre string, post string, err error) {
|
||||||
// Use the "missingkey=zero" option to handle missing variables without panicking
|
tmpl, err := template.New("").Parse(tmplStr)
|
||||||
tmpl, err := template.New("").Option("missingkey=zero").Parse(m.Template)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
if p.System == "" {
|
var foundResponse bool
|
||||||
// use the default system message for this model if one is not specified
|
|
||||||
p.System = m.System
|
for _, node := range tmpl.Tree.Root.Nodes {
|
||||||
|
if node.Type() == parse.NodeAction && node.String() == "{{.Response}}" {
|
||||||
|
foundResponse = true
|
||||||
|
}
|
||||||
|
if !foundResponse {
|
||||||
|
pre += node.String()
|
||||||
|
} else {
|
||||||
|
post += node.String()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return pre, post, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func Prompt(promptTemplate string, p PromptVars) (string, error) {
|
||||||
|
var prompt strings.Builder
|
||||||
|
// Use the "missingkey=zero" option to handle missing variables without panicking
|
||||||
|
tmpl, err := template.New("").Option("missingkey=zero").Parse(promptTemplate)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
vars := map[string]any{
|
vars := map[string]any{
|
||||||
|
@ -82,20 +101,59 @@ func (m *Model) Prompt(p PromptVars) (string, error) {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
prompt.WriteString(sb.String())
|
prompt.WriteString(sb.String())
|
||||||
prompt.WriteString(p.Response)
|
|
||||||
|
if !strings.Contains(prompt.String(), p.Response) {
|
||||||
|
// if the response is not in the prompt template, append it to the end
|
||||||
|
prompt.WriteString(p.Response)
|
||||||
|
}
|
||||||
|
|
||||||
return prompt.String(), nil
|
return prompt.String(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PreResponsePrompt returns the prompt before the response tag
|
||||||
|
func (m *Model) PreResponsePrompt(p PromptVars) (string, error) {
|
||||||
|
if p.System == "" {
|
||||||
|
// use the default system prompt for this model if one is not specified
|
||||||
|
p.System = m.System
|
||||||
|
}
|
||||||
|
pre, _, err := extractParts(m.Template)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return Prompt(pre, p)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PostResponseTemplate returns the template after the response tag
|
||||||
|
func (m *Model) PostResponseTemplate(p PromptVars) (string, error) {
|
||||||
|
if p.System == "" {
|
||||||
|
// use the default system prompt for this model if one is not specified
|
||||||
|
p.System = m.System
|
||||||
|
}
|
||||||
|
_, post, err := extractParts(m.Template)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
if post == "" {
|
||||||
|
// if there is no post-response template, return the provided response
|
||||||
|
return p.Response, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return Prompt(post, p)
|
||||||
|
}
|
||||||
|
|
||||||
func (m *Model) ChatPrompt(msgs []api.Message) (string, []api.ImageData, error) {
|
func (m *Model) ChatPrompt(msgs []api.Message) (string, []api.ImageData, error) {
|
||||||
// build the prompt from the list of messages
|
// build the prompt from the list of messages
|
||||||
var prompt strings.Builder
|
var prompt strings.Builder
|
||||||
var currentImages []api.ImageData
|
var currentImages []api.ImageData
|
||||||
currentVars := PromptVars{
|
currentVars := PromptVars{
|
||||||
First: true,
|
First: true,
|
||||||
|
System: m.System,
|
||||||
}
|
}
|
||||||
|
|
||||||
writePrompt := func() error {
|
writePrompt := func() error {
|
||||||
p, err := m.Prompt(currentVars)
|
p, err := Prompt(m.Template, currentVars)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -133,9 +191,11 @@ func (m *Model) ChatPrompt(msgs []api.Message) (string, []api.ImageData, error)
|
||||||
|
|
||||||
// Append the last set of vars if they are non-empty
|
// Append the last set of vars if they are non-empty
|
||||||
if currentVars.Prompt != "" || currentVars.System != "" {
|
if currentVars.Prompt != "" || currentVars.System != "" {
|
||||||
if err := writePrompt(); err != nil {
|
p, err := m.PreResponsePrompt(currentVars)
|
||||||
return "", nil, err
|
if err != nil {
|
||||||
|
return "", nil, fmt.Errorf("pre-response template: %w", err)
|
||||||
}
|
}
|
||||||
|
prompt.WriteString(p)
|
||||||
}
|
}
|
||||||
|
|
||||||
return prompt.String(), currentImages, nil
|
return prompt.String(), currentImages, nil
|
||||||
|
|
|
@ -7,6 +7,232 @@ import (
|
||||||
"github.com/jmorganca/ollama/api"
|
"github.com/jmorganca/ollama/api"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func TestPrompt(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
template string
|
||||||
|
vars PromptVars
|
||||||
|
want string
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "System Prompt",
|
||||||
|
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
||||||
|
vars: PromptVars{
|
||||||
|
System: "You are a Wizard.",
|
||||||
|
Prompt: "What are the potion ingredients?",
|
||||||
|
},
|
||||||
|
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "System Prompt with Response",
|
||||||
|
template: "[INST] {{ .System }} {{ .Prompt }} [/INST] {{ .Response }}",
|
||||||
|
vars: PromptVars{
|
||||||
|
System: "You are a Wizard.",
|
||||||
|
Prompt: "What are the potion ingredients?",
|
||||||
|
Response: "I don't know.",
|
||||||
|
},
|
||||||
|
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST] I don't know.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Conditional Logic Nodes",
|
||||||
|
template: "[INST] {{if .First}}Hello!{{end}} {{ .System }} {{ .Prompt }} [/INST] {{ .Response }}",
|
||||||
|
vars: PromptVars{
|
||||||
|
First: true,
|
||||||
|
System: "You are a Wizard.",
|
||||||
|
Prompt: "What are the potion ingredients?",
|
||||||
|
Response: "I don't know.",
|
||||||
|
},
|
||||||
|
want: "[INST] Hello! You are a Wizard. What are the potion ingredients? [/INST] I don't know.",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, err := Prompt(tt.template, tt.vars)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("Prompt() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("Prompt() got = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestModel_PreResponsePrompt(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
template string
|
||||||
|
vars PromptVars
|
||||||
|
want string
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "No Response in Template",
|
||||||
|
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
||||||
|
vars: PromptVars{
|
||||||
|
System: "You are a Wizard.",
|
||||||
|
Prompt: "What are the potion ingredients?",
|
||||||
|
},
|
||||||
|
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Response in Template",
|
||||||
|
template: "[INST] {{ .System }} {{ .Prompt }} [/INST] {{ .Response }}",
|
||||||
|
vars: PromptVars{
|
||||||
|
System: "You are a Wizard.",
|
||||||
|
Prompt: "What are the potion ingredients?",
|
||||||
|
},
|
||||||
|
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST] ",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Response in Template with Trailing Formatting",
|
||||||
|
template: "<|im_start|>user\n{{ .Prompt }}<|im_end|><|im_start|>assistant\n{{ .Response }}<|im_end|>",
|
||||||
|
vars: PromptVars{
|
||||||
|
Prompt: "What are the potion ingredients?",
|
||||||
|
},
|
||||||
|
want: "<|im_start|>user\nWhat are the potion ingredients?<|im_end|><|im_start|>assistant\n",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Response in Template with Alternative Formatting",
|
||||||
|
template: "<|im_start|>user\n{{.Prompt}}<|im_end|><|im_start|>assistant\n{{.Response}}<|im_end|>",
|
||||||
|
vars: PromptVars{
|
||||||
|
Prompt: "What are the potion ingredients?",
|
||||||
|
},
|
||||||
|
want: "<|im_start|>user\nWhat are the potion ingredients?<|im_end|><|im_start|>assistant\n",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
m := Model{Template: tt.template}
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, err := m.PreResponsePrompt(tt.vars)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("PreResponsePrompt() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("PreResponsePrompt() got = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestModel_PostResponsePrompt(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
template string
|
||||||
|
vars PromptVars
|
||||||
|
want string
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "No Response in Template",
|
||||||
|
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
||||||
|
vars: PromptVars{
|
||||||
|
Response: "I don't know.",
|
||||||
|
},
|
||||||
|
want: "I don't know.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Response in Template",
|
||||||
|
template: "[INST] {{ .System }} {{ .Prompt }} [/INST] {{ .Response }}",
|
||||||
|
vars: PromptVars{
|
||||||
|
Response: "I don't know.",
|
||||||
|
},
|
||||||
|
want: "I don't know.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Response in Template with Trailing Formatting",
|
||||||
|
template: "<|im_start|>user\n{{ .Prompt }}<|im_end|><|im_start|>assistant\n{{ .Response }}<|im_end|>",
|
||||||
|
vars: PromptVars{
|
||||||
|
Response: "I don't know.",
|
||||||
|
},
|
||||||
|
want: "I don't know.<|im_end|>",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Response in Template with Alternative Formatting",
|
||||||
|
template: "<|im_start|>user\n{{.Prompt}}<|im_end|><|im_start|>assistant\n{{.Response}}<|im_end|>",
|
||||||
|
vars: PromptVars{
|
||||||
|
Response: "I don't know.",
|
||||||
|
},
|
||||||
|
want: "I don't know.<|im_end|>",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
m := Model{Template: tt.template}
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, err := m.PostResponseTemplate(tt.vars)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("PostResponseTemplate() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("PostResponseTemplate() got = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestModel_PreResponsePrompt_PostResponsePrompt(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
template string
|
||||||
|
preVars PromptVars
|
||||||
|
postVars PromptVars
|
||||||
|
want string
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Response in Template",
|
||||||
|
template: "<|im_start|>user\n{{.Prompt}}<|im_end|><|im_start|>assistant\n{{.Response}}<|im_end|>",
|
||||||
|
preVars: PromptVars{
|
||||||
|
Prompt: "What are the potion ingredients?",
|
||||||
|
},
|
||||||
|
postVars: PromptVars{
|
||||||
|
Prompt: "What are the potion ingredients?",
|
||||||
|
Response: "Sugar.",
|
||||||
|
},
|
||||||
|
want: "<|im_start|>user\nWhat are the potion ingredients?<|im_end|><|im_start|>assistant\nSugar.<|im_end|>",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "No Response in Template",
|
||||||
|
template: "<|im_start|>user\n{{.Prompt}}<|im_end|><|im_start|>assistant\n",
|
||||||
|
preVars: PromptVars{
|
||||||
|
Prompt: "What are the potion ingredients?",
|
||||||
|
},
|
||||||
|
postVars: PromptVars{
|
||||||
|
Prompt: "What are the potion ingredients?",
|
||||||
|
Response: "Spice.",
|
||||||
|
},
|
||||||
|
want: "<|im_start|>user\nWhat are the potion ingredients?<|im_end|><|im_start|>assistant\nSpice.",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
m := Model{Template: tt.template}
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
pre, err := m.PreResponsePrompt(tt.preVars)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("PreResponsePrompt() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
post, err := m.PostResponseTemplate(tt.postVars)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("PostResponseTemplate() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
result := pre + post
|
||||||
|
if result != tt.want {
|
||||||
|
t.Errorf("Prompt() got = %v, want %v", result, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestChat(t *testing.T) {
|
func TestChat(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
@ -30,6 +256,29 @@ func TestChat(t *testing.T) {
|
||||||
},
|
},
|
||||||
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]",
|
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "First Message",
|
||||||
|
template: "[INST] {{if .First}}Hello!{{end}} {{ .System }} {{ .Prompt }} [/INST]",
|
||||||
|
msgs: []api.Message{
|
||||||
|
{
|
||||||
|
Role: "system",
|
||||||
|
Content: "You are a Wizard.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: "What are the potion ingredients?",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: "eye of newt",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: "Anything else?",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
want: "[INST] Hello! You are a Wizard. What are the potion ingredients? [/INST]eye of newt[INST] Anything else? [/INST]",
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "Message History",
|
name: "Message History",
|
||||||
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
||||||
|
|
|
@ -195,6 +195,7 @@ func GenerateHandler(c *gin.Context) {
|
||||||
checkpointLoaded := time.Now()
|
checkpointLoaded := time.Now()
|
||||||
|
|
||||||
var prompt string
|
var prompt string
|
||||||
|
var promptVars PromptVars
|
||||||
switch {
|
switch {
|
||||||
case req.Raw:
|
case req.Raw:
|
||||||
prompt = req.Prompt
|
prompt = req.Prompt
|
||||||
|
@ -217,11 +218,12 @@ func GenerateHandler(c *gin.Context) {
|
||||||
prevCtx = strings.TrimPrefix(prevCtx, " ")
|
prevCtx = strings.TrimPrefix(prevCtx, " ")
|
||||||
rebuild.WriteString(prevCtx)
|
rebuild.WriteString(prevCtx)
|
||||||
}
|
}
|
||||||
p, err := model.Prompt(PromptVars{
|
promptVars = PromptVars{
|
||||||
System: req.System,
|
System: req.System,
|
||||||
Prompt: req.Prompt,
|
Prompt: req.Prompt,
|
||||||
First: len(req.Context) == 0,
|
First: len(req.Context) == 0,
|
||||||
})
|
}
|
||||||
|
p, err := model.PreResponsePrompt(promptVars)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
|
@ -264,7 +266,14 @@ func GenerateHandler(c *gin.Context) {
|
||||||
resp.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
resp.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
||||||
|
|
||||||
if !req.Raw {
|
if !req.Raw {
|
||||||
embd, err := loaded.runner.Encode(c.Request.Context(), prompt+generated.String())
|
// append the generated text to the history and template it if needed
|
||||||
|
promptVars.Response = generated.String()
|
||||||
|
result, err := model.PostResponseTemplate(promptVars)
|
||||||
|
if err != nil {
|
||||||
|
ch <- gin.H{"error": err.Error()}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
embd, err := loaded.runner.Encode(c.Request.Context(), prompt+result)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ch <- gin.H{"error": err.Error()}
|
ch <- gin.H{"error": err.Error()}
|
||||||
return
|
return
|
||||||
|
|
Loading…
Add table
Reference in a new issue