diff --git a/server/images.go b/server/images.go index 2700cb56..d2918ad8 100644 --- a/server/images.go +++ b/server/images.go @@ -103,16 +103,16 @@ func (m *Model) ChatPrompt(msgs []api.Message) (string, error) { } for _, msg := range msgs { - switch msg.Role { + switch strings.ToLower(msg.Role) { case "system": - if currentVars.Prompt != "" || currentVars.System != "" { + if currentVars.System != "" { if err := writePrompt(); err != nil { return "", err } } currentVars.System = msg.Content case "user": - if currentVars.Prompt != "" || currentVars.System != "" { + if currentVars.Prompt != "" { if err := writePrompt(); err != nil { return "", err } diff --git a/server/images_test.go b/server/images_test.go index 85e8d4bd..28cb39a3 100644 --- a/server/images_test.go +++ b/server/images_test.go @@ -1,21 +1,98 @@ package server import ( + "strings" "testing" + + "github.com/jmorganca/ollama/api" ) -func TestModelPrompt(t *testing.T) { - m := Model{ - Template: "a{{ .Prompt }}b", +func TestChat(t *testing.T) { + tests := []struct { + name string + template string + msgs []api.Message + want string + wantErr string + }{ + { + name: "Single Message", + template: "[INST] {{ .System }} {{ .Prompt }} [/INST]", + msgs: []api.Message{ + { + Role: "system", + Content: "You are a Wizard.", + }, + { + Role: "user", + Content: "What are the potion ingredients?", + }, + }, + want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]", + }, + { + name: "Message History", + template: "[INST] {{ .System }} {{ .Prompt }} [/INST]", + msgs: []api.Message{ + { + Role: "system", + Content: "You are a Wizard.", + }, + { + Role: "user", + Content: "What are the potion ingredients?", + }, + { + Role: "assistant", + Content: "sugar", + }, + { + Role: "user", + Content: "Anything else?", + }, + }, + want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]sugar[INST] Anything else? [/INST]", + }, + { + name: "Assistant Only", + template: "[INST] {{ .System }} {{ .Prompt }} [/INST]", + msgs: []api.Message{ + { + Role: "assistant", + Content: "everything nice", + }, + }, + want: "[INST] [/INST]everything nice", + }, + { + name: "Invalid Role", + msgs: []api.Message{ + { + Role: "not-a-role", + Content: "howdy", + }, + }, + wantErr: "invalid role: not-a-role", + }, } - s, err := m.Prompt(PromptVars{ - Prompt: "

", - }) - if err != nil { - t.Fatal(err) - } - want := "a

b" - if s != want { - t.Errorf("got %q, want %q", s, want) + + for _, tt := range tests { + m := Model{ + Template: tt.template, + } + t.Run(tt.name, func(t *testing.T) { + got, err := m.ChatPrompt(tt.msgs) + if tt.wantErr != "" { + if err == nil { + t.Errorf("ChatPrompt() expected error, got nil") + } + if !strings.Contains(err.Error(), tt.wantErr) { + t.Errorf("ChatPrompt() error = %v, wantErr %v", err, tt.wantErr) + } + } + if got != tt.want { + t.Errorf("ChatPrompt() got = %v, want %v", got, tt.want) + } + }) } }