fix system prompt (#5662)
* fix system prompt * execute template when hitting previous roles * fix tests --------- Co-authored-by: jmorganca <jmorganca@gmail.com>
This commit is contained in:
parent
23ebbaa46e
commit
22c5451fc2
3 changed files with 51 additions and 30 deletions
|
@ -4,7 +4,6 @@ import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"slices"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/llm"
|
"github.com/ollama/ollama/llm"
|
||||||
|
@ -17,26 +16,18 @@ type tokenizeFunc func(context.Context, string) ([]int, error)
|
||||||
// chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the
|
// chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the
|
||||||
// latest message and 2) system messages
|
// latest message and 2) system messages
|
||||||
func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message) (prompt string, images []llm.ImageData, _ error) {
|
func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message) (prompt string, images []llm.ImageData, _ error) {
|
||||||
// pull out any system messages which should always be included in the prompt
|
|
||||||
var system []api.Message
|
var system []api.Message
|
||||||
msgs = slices.DeleteFunc(msgs, func(m api.Message) bool {
|
|
||||||
if m.Role == "system" {
|
|
||||||
system = append(system, m)
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
|
||||||
})
|
|
||||||
|
|
||||||
if len(system) == 0 && m.System != "" {
|
|
||||||
// add model system prompt since it wasn't provided
|
|
||||||
system = append(system, api.Message{Role: "system", Content: m.System})
|
|
||||||
}
|
|
||||||
|
|
||||||
// always include the last message
|
// always include the last message
|
||||||
n := len(msgs) - 1
|
n := len(msgs) - 1
|
||||||
// in reverse, find all messages that fit into context window
|
// in reverse, find all messages that fit into context window
|
||||||
for i := n - 1; i >= 0; i-- {
|
for i := n - 1; i >= 0; i-- {
|
||||||
|
system = make([]api.Message, 0)
|
||||||
|
for j := range i {
|
||||||
|
if msgs[j].Role == "system" {
|
||||||
|
system = append(system, msgs[j])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
var b bytes.Buffer
|
var b bytes.Buffer
|
||||||
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...)}); err != nil {
|
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...)}); err != nil {
|
||||||
return "", nil, err
|
return "", nil, err
|
||||||
|
|
|
@ -6,6 +6,7 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/template"
|
"github.com/ollama/ollama/template"
|
||||||
)
|
)
|
||||||
|
@ -164,6 +165,19 @@ func TestChatPrompt(t *testing.T) {
|
||||||
prompt: "You are the Test Who Lived. You're a test, Harry! I-I'm a what? A test. And a thumping good one at that, I'd wager. ",
|
prompt: "You are the Test Who Lived. You're a test, Harry! I-I'm a what? A test. And a thumping good one at that, I'd wager. ",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "out of order system",
|
||||||
|
limit: 2048,
|
||||||
|
msgs: []api.Message{
|
||||||
|
{Role: "user", Content: "You're a test, Harry!"},
|
||||||
|
{Role: "assistant", Content: "I-I'm a what?"},
|
||||||
|
{Role: "system", Content: "You are the Test Who Lived."},
|
||||||
|
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
|
||||||
|
},
|
||||||
|
expect: expect{
|
||||||
|
prompt: "You're a test, Harry! I-I'm a what? You are the Test Who Lived. A test. And a thumping good one at that, I'd wager. ",
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
tmpl, err := template.Parse(`
|
tmpl, err := template.Parse(`
|
||||||
|
@ -187,6 +201,10 @@ func TestChatPrompt(t *testing.T) {
|
||||||
t.Errorf("expected %q, got %q", tt.prompt, prompt)
|
t.Errorf("expected %q, got %q", tt.prompt, prompt)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(prompt, tt.prompt); diff != "" {
|
||||||
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
if len(images) != len(tt.images) {
|
if len(images) != len(tt.images) {
|
||||||
t.Fatalf("expected %d images, got %d", len(tt.images), len(images))
|
t.Fatalf("expected %d images, got %d", len(tt.images), len(images))
|
||||||
}
|
}
|
||||||
|
|
|
@ -149,27 +149,19 @@ type Values struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *Template) Execute(w io.Writer, v Values) error {
|
func (t *Template) Execute(w io.Writer, v Values) error {
|
||||||
system, collated := collate(v.Messages)
|
system, messages := collate(v.Messages)
|
||||||
if !v.forceLegacy && slices.Contains(t.Vars(), "messages") {
|
if !v.forceLegacy && slices.Contains(t.Vars(), "messages") {
|
||||||
return t.Template.Execute(w, map[string]any{
|
return t.Template.Execute(w, map[string]any{
|
||||||
"System": system,
|
"System": system,
|
||||||
"Messages": collated,
|
"Messages": messages,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
system = ""
|
||||||
var b bytes.Buffer
|
var b bytes.Buffer
|
||||||
var prompt, response string
|
var prompt, response string
|
||||||
for i, m := range collated {
|
for _, m := range messages {
|
||||||
switch m.Role {
|
execute := func () error {
|
||||||
case "system":
|
|
||||||
system = m.Content
|
|
||||||
case "user":
|
|
||||||
prompt = m.Content
|
|
||||||
case "assistant":
|
|
||||||
response = m.Content
|
|
||||||
}
|
|
||||||
|
|
||||||
if i != len(collated)-1 && prompt != "" && response != "" {
|
|
||||||
if err := t.Template.Execute(&b, map[string]any{
|
if err := t.Template.Execute(&b, map[string]any{
|
||||||
"System": system,
|
"System": system,
|
||||||
"Prompt": prompt,
|
"Prompt": prompt,
|
||||||
|
@ -181,6 +173,26 @@ func (t *Template) Execute(w io.Writer, v Values) error {
|
||||||
system = ""
|
system = ""
|
||||||
prompt = ""
|
prompt = ""
|
||||||
response = ""
|
response = ""
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
switch m.Role {
|
||||||
|
case "system":
|
||||||
|
if prompt != "" || response != "" {
|
||||||
|
if err := execute(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
system = m.Content
|
||||||
|
case "user":
|
||||||
|
if response != "" {
|
||||||
|
if err := execute(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
prompt = m.Content
|
||||||
|
case "assistant":
|
||||||
|
response = m.Content
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -199,7 +211,7 @@ func (t *Template) Execute(w io.Writer, v Values) error {
|
||||||
|
|
||||||
tree := parse.Tree{Root: nodes.(*parse.ListNode)}
|
tree := parse.Tree{Root: nodes.(*parse.ListNode)}
|
||||||
if err := template.Must(template.New("").AddParseTree("", &tree)).Execute(&b, map[string]any{
|
if err := template.Must(template.New("").AddParseTree("", &tree)).Execute(&b, map[string]any{
|
||||||
"System": "",
|
"System": system,
|
||||||
"Prompt": prompt,
|
"Prompt": prompt,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
Loading…
Reference in a new issue