fix system prompt (#5662)

* fix system prompt

* execute template when hitting previous roles

* fix tests

---------

Co-authored-by: jmorganca <jmorganca@gmail.com>
This commit is contained in:
Michael Yang 2024-07-12 21:04:44 -07:00 committed by GitHub
parent 23ebbaa46e
commit 22c5451fc2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 51 additions and 30 deletions

View file

@ -4,7 +4,6 @@ import (
"bytes" "bytes"
"context" "context"
"log/slog" "log/slog"
"slices"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/llm" "github.com/ollama/ollama/llm"
@ -17,26 +16,18 @@ type tokenizeFunc func(context.Context, string) ([]int, error)
// chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the // chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the
// latest message and 2) system messages // latest message and 2) system messages
func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message) (prompt string, images []llm.ImageData, _ error) { func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message) (prompt string, images []llm.ImageData, _ error) {
// pull out any system messages which should always be included in the prompt
var system []api.Message var system []api.Message
msgs = slices.DeleteFunc(msgs, func(m api.Message) bool {
if m.Role == "system" {
system = append(system, m)
return true
}
return false
})
if len(system) == 0 && m.System != "" {
// add model system prompt since it wasn't provided
system = append(system, api.Message{Role: "system", Content: m.System})
}
// always include the last message // always include the last message
n := len(msgs) - 1 n := len(msgs) - 1
// in reverse, find all messages that fit into context window // in reverse, find all messages that fit into context window
for i := n - 1; i >= 0; i-- { for i := n - 1; i >= 0; i-- {
system = make([]api.Message, 0)
for j := range i {
if msgs[j].Role == "system" {
system = append(system, msgs[j])
}
}
var b bytes.Buffer var b bytes.Buffer
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...)}); err != nil { if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...)}); err != nil {
return "", nil, err return "", nil, err

View file

@ -6,6 +6,7 @@ import (
"strings" "strings"
"testing" "testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/template" "github.com/ollama/ollama/template"
) )
@ -164,6 +165,19 @@ func TestChatPrompt(t *testing.T) {
prompt: "You are the Test Who Lived. You're a test, Harry! I-I'm a what? A test. And a thumping good one at that, I'd wager. ", prompt: "You are the Test Who Lived. You're a test, Harry! I-I'm a what? A test. And a thumping good one at that, I'd wager. ",
}, },
}, },
{
name: "out of order system",
limit: 2048,
msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry!"},
{Role: "assistant", Content: "I-I'm a what?"},
{Role: "system", Content: "You are the Test Who Lived."},
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
},
expect: expect{
prompt: "You're a test, Harry! I-I'm a what? You are the Test Who Lived. A test. And a thumping good one at that, I'd wager. ",
},
},
} }
tmpl, err := template.Parse(` tmpl, err := template.Parse(`
@ -187,6 +201,10 @@ func TestChatPrompt(t *testing.T) {
t.Errorf("expected %q, got %q", tt.prompt, prompt) t.Errorf("expected %q, got %q", tt.prompt, prompt)
} }
if diff := cmp.Diff(prompt, tt.prompt); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
if len(images) != len(tt.images) { if len(images) != len(tt.images) {
t.Fatalf("expected %d images, got %d", len(tt.images), len(images)) t.Fatalf("expected %d images, got %d", len(tt.images), len(images))
} }

View file

@ -149,27 +149,19 @@ type Values struct {
} }
func (t *Template) Execute(w io.Writer, v Values) error { func (t *Template) Execute(w io.Writer, v Values) error {
system, collated := collate(v.Messages) system, messages := collate(v.Messages)
if !v.forceLegacy && slices.Contains(t.Vars(), "messages") { if !v.forceLegacy && slices.Contains(t.Vars(), "messages") {
return t.Template.Execute(w, map[string]any{ return t.Template.Execute(w, map[string]any{
"System": system, "System": system,
"Messages": collated, "Messages": messages,
}) })
} }
system = ""
var b bytes.Buffer var b bytes.Buffer
var prompt, response string var prompt, response string
for i, m := range collated { for _, m := range messages {
switch m.Role { execute := func () error {
case "system":
system = m.Content
case "user":
prompt = m.Content
case "assistant":
response = m.Content
}
if i != len(collated)-1 && prompt != "" && response != "" {
if err := t.Template.Execute(&b, map[string]any{ if err := t.Template.Execute(&b, map[string]any{
"System": system, "System": system,
"Prompt": prompt, "Prompt": prompt,
@ -181,6 +173,26 @@ func (t *Template) Execute(w io.Writer, v Values) error {
system = "" system = ""
prompt = "" prompt = ""
response = "" response = ""
return nil
}
switch m.Role {
case "system":
if prompt != "" || response != "" {
if err := execute(); err != nil {
return err
}
}
system = m.Content
case "user":
if response != "" {
if err := execute(); err != nil {
return err
}
}
prompt = m.Content
case "assistant":
response = m.Content
} }
} }
@ -199,7 +211,7 @@ func (t *Template) Execute(w io.Writer, v Values) error {
tree := parse.Tree{Root: nodes.(*parse.ListNode)} tree := parse.Tree{Root: nodes.(*parse.ListNode)}
if err := template.Must(template.New("").AddParseTree("", &tree)).Execute(&b, map[string]any{ if err := template.Must(template.New("").AddParseTree("", &tree)).Execute(&b, map[string]any{
"System": "", "System": system,
"Prompt": prompt, "Prompt": prompt,
}); err != nil { }); err != nil {
return err return err