preserve last system message from modelfile (#2289)
This commit is contained in:
parent
583950c828
commit
a896079705
2 changed files with 66 additions and 17 deletions
|
@ -156,7 +156,7 @@ type ChatHistory struct {
|
||||||
func (m *Model) ChatPrompts(msgs []api.Message) (*ChatHistory, error) {
|
func (m *Model) ChatPrompts(msgs []api.Message) (*ChatHistory, error) {
|
||||||
// build the prompt from the list of messages
|
// build the prompt from the list of messages
|
||||||
var currentImages []api.ImageData
|
var currentImages []api.ImageData
|
||||||
var lastSystem string
|
lastSystem := m.System
|
||||||
currentVars := PromptVars{
|
currentVars := PromptVars{
|
||||||
First: true,
|
First: true,
|
||||||
System: m.System,
|
System: m.System,
|
||||||
|
@ -167,7 +167,8 @@ func (m *Model) ChatPrompts(msgs []api.Message) (*ChatHistory, error) {
|
||||||
for _, msg := range msgs {
|
for _, msg := range msgs {
|
||||||
switch strings.ToLower(msg.Role) {
|
switch strings.ToLower(msg.Role) {
|
||||||
case "system":
|
case "system":
|
||||||
if currentVars.System != "" {
|
// if this is the first message it overrides the system prompt in the modelfile
|
||||||
|
if !currentVars.First && currentVars.System != "" {
|
||||||
prompts = append(prompts, currentVars)
|
prompts = append(prompts, currentVars)
|
||||||
currentVars = PromptVars{}
|
currentVars = PromptVars{}
|
||||||
}
|
}
|
||||||
|
|
|
@ -257,14 +257,16 @@ func chatHistoryEqual(a, b ChatHistory) bool {
|
||||||
func TestChat(t *testing.T) {
|
func TestChat(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
template string
|
model Model
|
||||||
msgs []api.Message
|
msgs []api.Message
|
||||||
want ChatHistory
|
want ChatHistory
|
||||||
wantErr string
|
wantErr string
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "Single Message",
|
name: "Single Message",
|
||||||
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
model: Model{
|
||||||
|
Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
||||||
|
},
|
||||||
msgs: []api.Message{
|
msgs: []api.Message{
|
||||||
{
|
{
|
||||||
Role: "system",
|
Role: "system",
|
||||||
|
@ -288,7 +290,9 @@ func TestChat(t *testing.T) {
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Message History",
|
name: "Message History",
|
||||||
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
model: Model{
|
||||||
|
Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
||||||
|
},
|
||||||
msgs: []api.Message{
|
msgs: []api.Message{
|
||||||
{
|
{
|
||||||
Role: "system",
|
Role: "system",
|
||||||
|
@ -324,7 +328,9 @@ func TestChat(t *testing.T) {
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Assistant Only",
|
name: "Assistant Only",
|
||||||
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
model: Model{
|
||||||
|
Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
||||||
|
},
|
||||||
msgs: []api.Message{
|
msgs: []api.Message{
|
||||||
{
|
{
|
||||||
Role: "assistant",
|
Role: "assistant",
|
||||||
|
@ -340,6 +346,51 @@ func TestChat(t *testing.T) {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "Last system message is preserved from modelfile",
|
||||||
|
model: Model{
|
||||||
|
Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
||||||
|
System: "You are Mojo Jojo.",
|
||||||
|
},
|
||||||
|
msgs: []api.Message{
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: "hi",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
want: ChatHistory{
|
||||||
|
Prompts: []PromptVars{
|
||||||
|
{
|
||||||
|
System: "You are Mojo Jojo.",
|
||||||
|
Prompt: "hi",
|
||||||
|
First: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
LastSystem: "You are Mojo Jojo.",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Last system message is preserved from messages",
|
||||||
|
model: Model{
|
||||||
|
Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
||||||
|
System: "You are Mojo Jojo.",
|
||||||
|
},
|
||||||
|
msgs: []api.Message{
|
||||||
|
{
|
||||||
|
Role: "system",
|
||||||
|
Content: "You are Professor Utonium.",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
want: ChatHistory{
|
||||||
|
Prompts: []PromptVars{
|
||||||
|
{
|
||||||
|
System: "You are Professor Utonium.",
|
||||||
|
First: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
LastSystem: "You are Professor Utonium.",
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "Invalid Role",
|
name: "Invalid Role",
|
||||||
msgs: []api.Message{
|
msgs: []api.Message{
|
||||||
|
@ -353,11 +404,8 @@ func TestChat(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
m := Model{
|
|
||||||
Template: tt.template,
|
|
||||||
}
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
got, err := m.ChatPrompts(tt.msgs)
|
got, err := tt.model.ChatPrompts(tt.msgs)
|
||||||
if tt.wantErr != "" {
|
if tt.wantErr != "" {
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Errorf("ChatPrompt() expected error, got nil")
|
t.Errorf("ChatPrompt() expected error, got nil")
|
||||||
|
|
Loading…
Reference in a new issue