diff --git a/server/images.go b/server/images.go index 2700cb56..d2918ad8 100644 --- a/server/images.go +++ b/server/images.go @@ -103,16 +103,16 @@ func (m *Model) ChatPrompt(msgs []api.Message) (string, error) { } for _, msg := range msgs { - switch msg.Role { + switch strings.ToLower(msg.Role) { case "system": - if currentVars.Prompt != "" || currentVars.System != "" { + if currentVars.System != "" { if err := writePrompt(); err != nil { return "", err } } currentVars.System = msg.Content case "user": - if currentVars.Prompt != "" || currentVars.System != "" { + if currentVars.Prompt != "" { if err := writePrompt(); err != nil { return "", err } diff --git a/server/images_test.go b/server/images_test.go index 85e8d4bd..28cb39a3 100644 --- a/server/images_test.go +++ b/server/images_test.go @@ -1,21 +1,98 @@ package server import ( + "strings" "testing" + + "github.com/jmorganca/ollama/api" ) -func TestModelPrompt(t *testing.T) { - m := Model{ - Template: "a{{ .Prompt }}b", +func TestChat(t *testing.T) { + tests := []struct { + name string + template string + msgs []api.Message + want string + wantErr string + }{ + { + name: "Single Message", + template: "[INST] {{ .System }} {{ .Prompt }} [/INST]", + msgs: []api.Message{ + { + Role: "system", + Content: "You are a Wizard.", + }, + { + Role: "user", + Content: "What are the potion ingredients?", + }, + }, + want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]", + }, + { + name: "Message History", + template: "[INST] {{ .System }} {{ .Prompt }} [/INST]", + msgs: []api.Message{ + { + Role: "system", + Content: "You are a Wizard.", + }, + { + Role: "user", + Content: "What are the potion ingredients?", + }, + { + Role: "assistant", + Content: "sugar", + }, + { + Role: "user", + Content: "Anything else?", + }, + }, + want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]sugar[INST] Anything else? [/INST]", + }, + { + name: "Assistant Only", + template: "[INST] {{ .System }} {{ .Prompt }} [/INST]", + msgs: []api.Message{ + { + Role: "assistant", + Content: "everything nice", + }, + }, + want: "[INST] [/INST]everything nice", + }, + { + name: "Invalid Role", + msgs: []api.Message{ + { + Role: "not-a-role", + Content: "howdy", + }, + }, + wantErr: "invalid role: not-a-role", + }, } - s, err := m.Prompt(PromptVars{ - Prompt: "