package server import ( "strings" "testing" "github.com/ollama/ollama/api" ) func TestPrompt(t *testing.T) { tests := []struct { name string template string system string prompt string response string generate bool want string }{ { name: "simple prompt", template: "[INST] {{ .System }} {{ .Prompt }} [/INST]", system: "You are a Wizard.", prompt: "What are the potion ingredients?", want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]", }, { name: "implicit response", template: "[INST] {{ .System }} {{ .Prompt }} [/INST]", system: "You are a Wizard.", prompt: "What are the potion ingredients?", response: "I don't know.", want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]I don't know.", }, { name: "response", template: "[INST] {{ .System }} {{ .Prompt }} [/INST] {{ .Response }}", system: "You are a Wizard.", prompt: "What are the potion ingredients?", response: "I don't know.", want: "[INST] You are a Wizard. What are the potion ingredients? [/INST] I don't know.", }, { name: "cut", template: "{{ .System }}{{ .Prompt }}{{ .Response }}", system: "You are a Wizard.", prompt: "What are the potion ingredients?", response: "I don't know.", generate: true, want: "You are a Wizard.What are the potion ingredients?I don't know.", }, { name: "nocut", template: "{{ .System }}{{ .Prompt }}{{ .Response }}", system: "You are a Wizard.", prompt: "What are the potion ingredients?", response: "I don't know.", want: "You are a Wizard.What are the potion ingredients?I don't know.", }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { got, err := Prompt(tc.template, tc.system, tc.prompt, tc.response, tc.generate) if err != nil { t.Errorf("error = %v", err) } if got != tc.want { t.Errorf("got = %v, want %v", got, tc.want) } }) } } func TestChatPrompt(t *testing.T) { tests := []struct { name string template string messages []api.Message window int want string }{ { name: "simple prompt", template: "[INST] {{ .Prompt }} [/INST]", messages: []api.Message{ {Role: "user", Content: "Hello"}, }, window: 1024, want: "[INST] Hello [/INST]", }, { name: "with system message", template: "[INST] {{ if .System }}<>{{ .System }}<> {{ end }}{{ .Prompt }} [/INST]", messages: []api.Message{ {Role: "system", Content: "You are a Wizard."}, {Role: "user", Content: "Hello"}, }, window: 1024, want: "[INST] <>You are a Wizard.<> Hello [/INST]", }, { name: "with response", template: "[INST] {{ if .System }}<>{{ .System }}<> {{ end }}{{ .Prompt }} [/INST] {{ .Response }}", messages: []api.Message{ {Role: "system", Content: "You are a Wizard."}, {Role: "user", Content: "Hello"}, {Role: "assistant", Content: "I am?"}, }, window: 1024, want: "[INST] <>You are a Wizard.<> Hello [/INST] I am?", }, { name: "with implicit response", template: "[INST] {{ if .System }}<>{{ .System }}<> {{ end }}{{ .Prompt }} [/INST]", messages: []api.Message{ {Role: "system", Content: "You are a Wizard."}, {Role: "user", Content: "Hello"}, {Role: "assistant", Content: "I am?"}, }, window: 1024, want: "[INST] <>You are a Wizard.<> Hello [/INST]I am?", }, { name: "with conversation", template: "[INST] {{ if .System }}<>{{ .System }}<> {{ end }}{{ .Prompt }} [/INST] {{ .Response }} ", messages: []api.Message{ {Role: "system", Content: "You are a Wizard."}, {Role: "user", Content: "What are the potion ingredients?"}, {Role: "assistant", Content: "sugar"}, {Role: "user", Content: "Anything else?"}, }, window: 1024, want: "[INST] <>You are a Wizard.<> What are the potion ingredients? [/INST] sugar [INST] Anything else? [/INST] ", }, { name: "with truncation", template: "{{ .System }} {{ .Prompt }} {{ .Response }} ", messages: []api.Message{ {Role: "system", Content: "You are a Wizard."}, {Role: "user", Content: "Hello"}, {Role: "assistant", Content: "I am?"}, {Role: "user", Content: "Why is the sky blue?"}, {Role: "assistant", Content: "The sky is blue from rayleigh scattering"}, }, window: 10, want: "You are a Wizard. Why is the sky blue? The sky is blue from rayleigh scattering", }, { name: "images", template: "{{ .System }} {{ .Prompt }}", messages: []api.Message{ {Role: "system", Content: "You are a Wizard."}, {Role: "user", Content: "Hello", Images: []api.ImageData{[]byte("base64")}}, }, window: 1024, want: "You are a Wizard. [img-0] Hello", }, { name: "images truncated", template: "{{ .System }} {{ .Prompt }}", messages: []api.Message{ {Role: "system", Content: "You are a Wizard."}, {Role: "user", Content: "Hello", Images: []api.ImageData{[]byte("img1"), []byte("img2")}}, }, window: 1024, want: "You are a Wizard. [img-0] [img-1] Hello", }, { name: "empty list", template: "{{ .System }} {{ .Prompt }}", messages: []api.Message{}, window: 1024, want: "", }, { name: "empty prompt", template: "[INST] {{ if .System }}<>{{ .System }}<> {{ end }}{{ .Prompt }} [/INST] {{ .Response }} ", messages: []api.Message{ {Role: "user", Content: ""}, }, window: 1024, want: "", }, } encode := func(s string) ([]int, error) { words := strings.Fields(s) return make([]int, len(words)), nil } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { got, err := ChatPrompt(tc.template, tc.messages, tc.window, encode) if err != nil { t.Errorf("error = %v", err) } if got != tc.want { t.Errorf("got: %q, want: %q", got, tc.want) } }) } }