205 lines
6.1 KiB
Go
205 lines
6.1 KiB
Go
package server
|
|
|
|
import (
|
|
"strings"
|
|
"testing"
|
|
|
|
"github.com/ollama/ollama/api"
|
|
)
|
|
|
|
func TestPrompt(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
template string
|
|
system string
|
|
prompt string
|
|
response string
|
|
generate bool
|
|
want string
|
|
}{
|
|
{
|
|
name: "simple prompt",
|
|
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
|
system: "You are a Wizard.",
|
|
prompt: "What are the potion ingredients?",
|
|
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]",
|
|
},
|
|
{
|
|
name: "implicit response",
|
|
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
|
system: "You are a Wizard.",
|
|
prompt: "What are the potion ingredients?",
|
|
response: "I don't know.",
|
|
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]I don't know.",
|
|
},
|
|
{
|
|
name: "response",
|
|
template: "[INST] {{ .System }} {{ .Prompt }} [/INST] {{ .Response }}",
|
|
system: "You are a Wizard.",
|
|
prompt: "What are the potion ingredients?",
|
|
response: "I don't know.",
|
|
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST] I don't know.",
|
|
},
|
|
{
|
|
name: "cut",
|
|
template: "<system>{{ .System }}</system><user>{{ .Prompt }}</user><assistant>{{ .Response }}</assistant>",
|
|
system: "You are a Wizard.",
|
|
prompt: "What are the potion ingredients?",
|
|
response: "I don't know.",
|
|
generate: true,
|
|
want: "<system>You are a Wizard.</system><user>What are the potion ingredients?</user><assistant>I don't know.",
|
|
},
|
|
{
|
|
name: "nocut",
|
|
template: "<system>{{ .System }}</system><user>{{ .Prompt }}</user><assistant>{{ .Response }}</assistant>",
|
|
system: "You are a Wizard.",
|
|
prompt: "What are the potion ingredients?",
|
|
response: "I don't know.",
|
|
want: "<system>You are a Wizard.</system><user>What are the potion ingredients?</user><assistant>I don't know.</assistant>",
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
got, err := Prompt(tc.template, tc.system, tc.prompt, tc.response, tc.generate)
|
|
if err != nil {
|
|
t.Errorf("error = %v", err)
|
|
}
|
|
|
|
if got != tc.want {
|
|
t.Errorf("got = %v, want %v", got, tc.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestChatPrompt(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
template string
|
|
messages []api.Message
|
|
window int
|
|
want string
|
|
}{
|
|
{
|
|
name: "simple prompt",
|
|
template: "[INST] {{ .Prompt }} [/INST]",
|
|
messages: []api.Message{
|
|
{Role: "user", Content: "Hello"},
|
|
},
|
|
window: 1024,
|
|
want: "[INST] Hello [/INST]",
|
|
},
|
|
{
|
|
name: "with system message",
|
|
template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST]",
|
|
messages: []api.Message{
|
|
{Role: "system", Content: "You are a Wizard."},
|
|
{Role: "user", Content: "Hello"},
|
|
},
|
|
window: 1024,
|
|
want: "[INST] <<SYS>>You are a Wizard.<</SYS>> Hello [/INST]",
|
|
},
|
|
{
|
|
name: "with response",
|
|
template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST] {{ .Response }}",
|
|
messages: []api.Message{
|
|
{Role: "system", Content: "You are a Wizard."},
|
|
{Role: "user", Content: "Hello"},
|
|
{Role: "assistant", Content: "I am?"},
|
|
},
|
|
window: 1024,
|
|
want: "[INST] <<SYS>>You are a Wizard.<</SYS>> Hello [/INST] I am?",
|
|
},
|
|
{
|
|
name: "with implicit response",
|
|
template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST]",
|
|
messages: []api.Message{
|
|
{Role: "system", Content: "You are a Wizard."},
|
|
{Role: "user", Content: "Hello"},
|
|
{Role: "assistant", Content: "I am?"},
|
|
},
|
|
window: 1024,
|
|
want: "[INST] <<SYS>>You are a Wizard.<</SYS>> Hello [/INST]I am?",
|
|
},
|
|
{
|
|
name: "with conversation",
|
|
template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST] {{ .Response }} ",
|
|
messages: []api.Message{
|
|
{Role: "system", Content: "You are a Wizard."},
|
|
{Role: "user", Content: "What are the potion ingredients?"},
|
|
{Role: "assistant", Content: "sugar"},
|
|
{Role: "user", Content: "Anything else?"},
|
|
},
|
|
window: 1024,
|
|
want: "[INST] <<SYS>>You are a Wizard.<</SYS>> What are the potion ingredients? [/INST] sugar [INST] Anything else? [/INST] ",
|
|
},
|
|
{
|
|
name: "with truncation",
|
|
template: "{{ .System }} {{ .Prompt }} {{ .Response }} ",
|
|
messages: []api.Message{
|
|
{Role: "system", Content: "You are a Wizard."},
|
|
{Role: "user", Content: "Hello"},
|
|
{Role: "assistant", Content: "I am?"},
|
|
{Role: "user", Content: "Why is the sky blue?"},
|
|
{Role: "assistant", Content: "The sky is blue from rayleigh scattering"},
|
|
},
|
|
window: 10,
|
|
want: "You are a Wizard. Why is the sky blue? The sky is blue from rayleigh scattering",
|
|
},
|
|
{
|
|
name: "images",
|
|
template: "{{ .System }} {{ .Prompt }}",
|
|
messages: []api.Message{
|
|
{Role: "system", Content: "You are a Wizard."},
|
|
{Role: "user", Content: "Hello", Images: []api.ImageData{[]byte("base64")}},
|
|
},
|
|
window: 1024,
|
|
want: "You are a Wizard. [img-0] Hello",
|
|
},
|
|
{
|
|
name: "images truncated",
|
|
template: "{{ .System }} {{ .Prompt }}",
|
|
messages: []api.Message{
|
|
{Role: "system", Content: "You are a Wizard."},
|
|
{Role: "user", Content: "Hello", Images: []api.ImageData{[]byte("img1"), []byte("img2")}},
|
|
},
|
|
window: 1024,
|
|
want: "You are a Wizard. [img-0] [img-1] Hello",
|
|
},
|
|
{
|
|
name: "empty list",
|
|
template: "{{ .System }} {{ .Prompt }}",
|
|
messages: []api.Message{},
|
|
window: 1024,
|
|
want: "",
|
|
},
|
|
{
|
|
name: "empty prompt",
|
|
template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST] {{ .Response }} ",
|
|
messages: []api.Message{
|
|
{Role: "user", Content: ""},
|
|
},
|
|
window: 1024,
|
|
want: "",
|
|
},
|
|
}
|
|
|
|
encode := func(s string) ([]int, error) {
|
|
words := strings.Fields(s)
|
|
return make([]int, len(words)), nil
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
got, err := ChatPrompt(tc.template, tc.messages, tc.window, encode)
|
|
if err != nil {
|
|
t.Errorf("error = %v", err)
|
|
}
|
|
|
|
if got != tc.want {
|
|
t.Errorf("got: %q, want: %q", got, tc.want)
|
|
}
|
|
})
|
|
}
|
|
}
|