preserve last system message from modelfile (#2289)

This commit is contained in:
Bruce MacDonald 2024-01-31 21:45:01 -05:00 committed by GitHub
parent 583950c828
commit a896079705
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 66 additions and 17 deletions

View file

@ -156,7 +156,7 @@ type ChatHistory struct {
func (m *Model) ChatPrompts(msgs []api.Message) (*ChatHistory, error) {
// build the prompt from the list of messages
var currentImages []api.ImageData
var lastSystem string
lastSystem := m.System
currentVars := PromptVars{
First: true,
System: m.System,
@ -167,7 +167,8 @@ func (m *Model) ChatPrompts(msgs []api.Message) (*ChatHistory, error) {
for _, msg := range msgs {
switch strings.ToLower(msg.Role) {
case "system":
if currentVars.System != "" {
// if this is the first message it overrides the system prompt in the modelfile
if !currentVars.First && currentVars.System != "" {
prompts = append(prompts, currentVars)
currentVars = PromptVars{}
}

View file

@ -256,15 +256,17 @@ func chatHistoryEqual(a, b ChatHistory) bool {
func TestChat(t *testing.T) {
tests := []struct {
name string
template string
msgs []api.Message
want ChatHistory
wantErr string
name string
model Model
msgs []api.Message
want ChatHistory
wantErr string
}{
{
name: "Single Message",
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
name: "Single Message",
model: Model{
Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
},
msgs: []api.Message{
{
Role: "system",
@ -287,8 +289,10 @@ func TestChat(t *testing.T) {
},
},
{
name: "Message History",
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
name: "Message History",
model: Model{
Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
},
msgs: []api.Message{
{
Role: "system",
@ -323,8 +327,10 @@ func TestChat(t *testing.T) {
},
},
{
name: "Assistant Only",
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
name: "Assistant Only",
model: Model{
Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
},
msgs: []api.Message{
{
Role: "assistant",
@ -340,6 +346,51 @@ func TestChat(t *testing.T) {
},
},
},
{
name: "Last system message is preserved from modelfile",
model: Model{
Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
System: "You are Mojo Jojo.",
},
msgs: []api.Message{
{
Role: "user",
Content: "hi",
},
},
want: ChatHistory{
Prompts: []PromptVars{
{
System: "You are Mojo Jojo.",
Prompt: "hi",
First: true,
},
},
LastSystem: "You are Mojo Jojo.",
},
},
{
name: "Last system message is preserved from messages",
model: Model{
Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
System: "You are Mojo Jojo.",
},
msgs: []api.Message{
{
Role: "system",
Content: "You are Professor Utonium.",
},
},
want: ChatHistory{
Prompts: []PromptVars{
{
System: "You are Professor Utonium.",
First: true,
},
},
LastSystem: "You are Professor Utonium.",
},
},
{
name: "Invalid Role",
msgs: []api.Message{
@ -353,11 +404,8 @@ func TestChat(t *testing.T) {
}
for _, tt := range tests {
m := Model{
Template: tt.template,
}
t.Run(tt.name, func(t *testing.T) {
got, err := m.ChatPrompts(tt.msgs)
got, err := tt.model.ChatPrompts(tt.msgs)
if tt.wantErr != "" {
if err == nil {
t.Errorf("ChatPrompt() expected error, got nil")