fix: only flush template in chat when current role encountered (#1426)

This commit is contained in:
Bruce MacDonald 2023-12-08 16:44:24 -05:00 committed by GitHub
parent e3f925fc1b
commit 3b0b8930d4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 92 additions and 15 deletions

View file

@ -103,16 +103,16 @@ func (m *Model) ChatPrompt(msgs []api.Message) (string, error) {
} }
for _, msg := range msgs { for _, msg := range msgs {
switch msg.Role { switch strings.ToLower(msg.Role) {
case "system": case "system":
if currentVars.Prompt != "" || currentVars.System != "" { if currentVars.System != "" {
if err := writePrompt(); err != nil { if err := writePrompt(); err != nil {
return "", err return "", err
} }
} }
currentVars.System = msg.Content currentVars.System = msg.Content
case "user": case "user":
if currentVars.Prompt != "" || currentVars.System != "" { if currentVars.Prompt != "" {
if err := writePrompt(); err != nil { if err := writePrompt(); err != nil {
return "", err return "", err
} }

View file

@ -1,21 +1,98 @@
package server package server
import ( import (
"strings"
"testing" "testing"
"github.com/jmorganca/ollama/api"
) )
func TestModelPrompt(t *testing.T) { func TestChat(t *testing.T) {
m := Model{ tests := []struct {
Template: "a{{ .Prompt }}b", name string
template string
msgs []api.Message
want string
wantErr string
}{
{
name: "Single Message",
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
msgs: []api.Message{
{
Role: "system",
Content: "You are a Wizard.",
},
{
Role: "user",
Content: "What are the potion ingredients?",
},
},
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]",
},
{
name: "Message History",
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
msgs: []api.Message{
{
Role: "system",
Content: "You are a Wizard.",
},
{
Role: "user",
Content: "What are the potion ingredients?",
},
{
Role: "assistant",
Content: "sugar",
},
{
Role: "user",
Content: "Anything else?",
},
},
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]sugar[INST] Anything else? [/INST]",
},
{
name: "Assistant Only",
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
msgs: []api.Message{
{
Role: "assistant",
Content: "everything nice",
},
},
want: "[INST] [/INST]everything nice",
},
{
name: "Invalid Role",
msgs: []api.Message{
{
Role: "not-a-role",
Content: "howdy",
},
},
wantErr: "invalid role: not-a-role",
},
} }
s, err := m.Prompt(PromptVars{
Prompt: "<h1>", for _, tt := range tests {
}) m := Model{
if err != nil { Template: tt.template,
t.Fatal(err) }
} t.Run(tt.name, func(t *testing.T) {
want := "a<h1>b" got, err := m.ChatPrompt(tt.msgs)
if s != want { if tt.wantErr != "" {
t.Errorf("got %q, want %q", s, want) if err == nil {
t.Errorf("ChatPrompt() expected error, got nil")
}
if !strings.Contains(err.Error(), tt.wantErr) {
t.Errorf("ChatPrompt() error = %v, wantErr %v", err, tt.wantErr)
}
}
if got != tt.want {
t.Errorf("ChatPrompt() got = %v, want %v", got, tt.want)
}
})
} }
} }