diff --git a/server/images.go b/server/images.go index 39617c66..503dd8e2 100644 --- a/server/images.go +++ b/server/images.go @@ -156,7 +156,7 @@ type ChatHistory struct { func (m *Model) ChatPrompts(msgs []api.Message) (*ChatHistory, error) { // build the prompt from the list of messages var currentImages []api.ImageData - var lastSystem string + lastSystem := m.System currentVars := PromptVars{ First: true, System: m.System, @@ -167,7 +167,8 @@ func (m *Model) ChatPrompts(msgs []api.Message) (*ChatHistory, error) { for _, msg := range msgs { switch strings.ToLower(msg.Role) { case "system": - if currentVars.System != "" { + // if this is the first message it overrides the system prompt in the modelfile + if !currentVars.First && currentVars.System != "" { prompts = append(prompts, currentVars) currentVars = PromptVars{} } diff --git a/server/images_test.go b/server/images_test.go index 08e39998..0f63a19b 100644 --- a/server/images_test.go +++ b/server/images_test.go @@ -256,15 +256,17 @@ func chatHistoryEqual(a, b ChatHistory) bool { func TestChat(t *testing.T) { tests := []struct { - name string - template string - msgs []api.Message - want ChatHistory - wantErr string + name string + model Model + msgs []api.Message + want ChatHistory + wantErr string }{ { - name: "Single Message", - template: "[INST] {{ .System }} {{ .Prompt }} [/INST]", + name: "Single Message", + model: Model{ + Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]", + }, msgs: []api.Message{ { Role: "system", @@ -287,8 +289,10 @@ func TestChat(t *testing.T) { }, }, { - name: "Message History", - template: "[INST] {{ .System }} {{ .Prompt }} [/INST]", + name: "Message History", + model: Model{ + Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]", + }, msgs: []api.Message{ { Role: "system", @@ -323,8 +327,10 @@ func TestChat(t *testing.T) { }, }, { - name: "Assistant Only", - template: "[INST] {{ .System }} {{ .Prompt }} [/INST]", + name: "Assistant Only", + model: Model{ + Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]", + }, msgs: []api.Message{ { Role: "assistant", @@ -340,6 +346,51 @@ func TestChat(t *testing.T) { }, }, }, + { + name: "Last system message is preserved from modelfile", + model: Model{ + Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]", + System: "You are Mojo Jojo.", + }, + msgs: []api.Message{ + { + Role: "user", + Content: "hi", + }, + }, + want: ChatHistory{ + Prompts: []PromptVars{ + { + System: "You are Mojo Jojo.", + Prompt: "hi", + First: true, + }, + }, + LastSystem: "You are Mojo Jojo.", + }, + }, + { + name: "Last system message is preserved from messages", + model: Model{ + Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]", + System: "You are Mojo Jojo.", + }, + msgs: []api.Message{ + { + Role: "system", + Content: "You are Professor Utonium.", + }, + }, + want: ChatHistory{ + Prompts: []PromptVars{ + { + System: "You are Professor Utonium.", + First: true, + }, + }, + LastSystem: "You are Professor Utonium.", + }, + }, { name: "Invalid Role", msgs: []api.Message{ @@ -353,11 +404,8 @@ func TestChat(t *testing.T) { } for _, tt := range tests { - m := Model{ - Template: tt.template, - } t.Run(tt.name, func(t *testing.T) { - got, err := m.ChatPrompts(tt.msgs) + got, err := tt.model.ChatPrompts(tt.msgs) if tt.wantErr != "" { if err == nil { t.Errorf("ChatPrompt() expected error, got nil")