preserve last system message from modelfile (#2289)

This commit is contained in:
Bruce MacDonald 2024-01-31 21:45:01 -05:00 committed by GitHub
parent 583950c828
commit a896079705
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 66 additions and 17 deletions

View file

@ -156,7 +156,7 @@ type ChatHistory struct {
func (m *Model) ChatPrompts(msgs []api.Message) (*ChatHistory, error) { func (m *Model) ChatPrompts(msgs []api.Message) (*ChatHistory, error) {
// build the prompt from the list of messages // build the prompt from the list of messages
var currentImages []api.ImageData var currentImages []api.ImageData
var lastSystem string lastSystem := m.System
currentVars := PromptVars{ currentVars := PromptVars{
First: true, First: true,
System: m.System, System: m.System,
@ -167,7 +167,8 @@ func (m *Model) ChatPrompts(msgs []api.Message) (*ChatHistory, error) {
for _, msg := range msgs { for _, msg := range msgs {
switch strings.ToLower(msg.Role) { switch strings.ToLower(msg.Role) {
case "system": case "system":
if currentVars.System != "" { // if this is the first message it overrides the system prompt in the modelfile
if !currentVars.First && currentVars.System != "" {
prompts = append(prompts, currentVars) prompts = append(prompts, currentVars)
currentVars = PromptVars{} currentVars = PromptVars{}
} }

View file

@ -257,14 +257,16 @@ func chatHistoryEqual(a, b ChatHistory) bool {
func TestChat(t *testing.T) { func TestChat(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
template string model Model
msgs []api.Message msgs []api.Message
want ChatHistory want ChatHistory
wantErr string wantErr string
}{ }{
{ {
name: "Single Message", name: "Single Message",
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]", model: Model{
Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
},
msgs: []api.Message{ msgs: []api.Message{
{ {
Role: "system", Role: "system",
@ -288,7 +290,9 @@ func TestChat(t *testing.T) {
}, },
{ {
name: "Message History", name: "Message History",
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]", model: Model{
Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
},
msgs: []api.Message{ msgs: []api.Message{
{ {
Role: "system", Role: "system",
@ -324,7 +328,9 @@ func TestChat(t *testing.T) {
}, },
{ {
name: "Assistant Only", name: "Assistant Only",
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]", model: Model{
Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
},
msgs: []api.Message{ msgs: []api.Message{
{ {
Role: "assistant", Role: "assistant",
@ -340,6 +346,51 @@ func TestChat(t *testing.T) {
}, },
}, },
}, },
{
name: "Last system message is preserved from modelfile",
model: Model{
Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
System: "You are Mojo Jojo.",
},
msgs: []api.Message{
{
Role: "user",
Content: "hi",
},
},
want: ChatHistory{
Prompts: []PromptVars{
{
System: "You are Mojo Jojo.",
Prompt: "hi",
First: true,
},
},
LastSystem: "You are Mojo Jojo.",
},
},
{
name: "Last system message is preserved from messages",
model: Model{
Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
System: "You are Mojo Jojo.",
},
msgs: []api.Message{
{
Role: "system",
Content: "You are Professor Utonium.",
},
},
want: ChatHistory{
Prompts: []PromptVars{
{
System: "You are Professor Utonium.",
First: true,
},
},
LastSystem: "You are Professor Utonium.",
},
},
{ {
name: "Invalid Role", name: "Invalid Role",
msgs: []api.Message{ msgs: []api.Message{
@ -353,11 +404,8 @@ func TestChat(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
m := Model{
Template: tt.template,
}
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
got, err := m.ChatPrompts(tt.msgs) got, err := tt.model.ChatPrompts(tt.msgs)
if tt.wantErr != "" { if tt.wantErr != "" {
if err == nil { if err == nil {
t.Errorf("ChatPrompt() expected error, got nil") t.Errorf("ChatPrompt() expected error, got nil")