This commit is contained in:
Michael Yang 2024-07-03 13:49:14 -07:00
parent ac7a842e55
commit 326363b3a7
2 changed files with 7 additions and 117 deletions

View file

@ -102,25 +102,8 @@ var response = parse.ActionNode{
}, },
} }
var funcs = template.FuncMap{
"toJson": func(v any) string {
b, err := json.Marshal(v)
if err != nil {
return ""
}
return string(b)
},
"add": func(a, b int) int {
return a + b
},
"sub": func(a, b int) int {
return a - b
},
}
func Parse(s string) (*Template, error) { func Parse(s string) (*Template, error) {
tmpl := template.New("").Option("missingkey=zero").Funcs(funcs) tmpl := template.New("").Option("missingkey=zero")
tmpl, err := tmpl.Parse(s) tmpl, err := tmpl.Parse(s)
if err != nil { if err != nil {

View file

@ -8,7 +8,6 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"slices" "slices"
"strconv"
"testing" "testing"
"text/template" "text/template"
@ -16,98 +15,6 @@ import (
"github.com/ollama/ollama/llm" "github.com/ollama/ollama/llm"
) )
func TestFuncs(t *testing.T) {
t.Run("toJson", func(t *testing.T) {
cases := []struct {
input any
expected string
}{
{nil, "null"},
{true, "true"},
{false, "false"},
{0, "0"},
{1, "1"},
{1.0, "1"},
{1.1, "1.1"},
{"", `""`},
{"hello", `"hello"`},
{[]int{1, 2, 3}, "[1,2,3]"},
{[]string{"a", "b", "c"}, `["a","b","c"]`},
{map[string]int{"a": 1, "b": 2}, `{"a":1,"b":2}`},
{map[string]string{"a": "b", "c": "d"}, `{"a":"b","c":"d"}`},
}
for _, tt := range cases {
t.Run(tt.expected, func(t *testing.T) {
toJson, ok := funcs["toJson"].(func(any) string)
if !ok {
t.Fatal("toJson is not a function")
}
if s := toJson(tt.input); s != tt.expected {
t.Errorf("expected %q, got %q", tt.expected, s)
}
})
}
})
t.Run("add", func(t *testing.T) {
cases := []struct {
a, b int
expected int
}{
{0, 0, 0},
{0, 1, 1},
{1, 0, 1},
{1, 1, 2},
{1, -1, 0},
{-1, 1, 0},
{-1, -1, -2},
}
for _, tt := range cases {
t.Run(strconv.Itoa(tt.expected), func(t *testing.T) {
add, ok := funcs["add"].(func(int, int) int)
if !ok {
t.Fatal("add is not a function")
}
if n := add(tt.a, tt.b); n != tt.expected {
t.Errorf("expected %d, got %d", tt.expected, n)
}
})
}
})
t.Run("sub", func(t *testing.T) {
cases := []struct {
a, b int
expected int
}{
{0, 0, 0},
{0, 1, -1},
{1, 0, 1},
{1, 1, 0},
{1, -1, 2},
{-1, 1, -2},
{-1, -1, 0},
}
for _, tt := range cases {
t.Run(strconv.Itoa(tt.expected), func(t *testing.T) {
sub, ok := funcs["sub"].(func(int, int) int)
if !ok {
t.Fatal("sub is not a function")
}
if n := sub(tt.a, tt.b); n != tt.expected {
t.Errorf("expected %d, got %d", tt.expected, n)
}
})
}
})
}
func TestNamed(t *testing.T) { func TestNamed(t *testing.T) {
f, err := os.Open(filepath.Join("testdata", "templates.jsonl")) f, err := os.Open(filepath.Join("testdata", "templates.jsonl"))
if err != nil { if err != nil {
@ -197,8 +104,8 @@ func TestExecuteWithMessages(t *testing.T) {
[]template{ []template{
{"no response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] `}, {"no response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] `},
{"response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`}, {"response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`},
{"messages", `{{- range .Messages }} {"messages", `{{- range $index, $_ := .Messages }}
{{- if eq .Role "user" }}[INST] {{ if and (eq (index $.Messages (sub (len $.Messages) 1)) .) $.System }}{{ $.System }}{{ "\n\n" }} {{- if eq .Role "user" }}[INST] {{ if and (eq (len (slice $.Messages $index)) 1) $.System }}{{ $.System }}{{ "\n\n" }}
{{- end }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }} {{- end }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}
{{- end }} {{- end }}
{{- end }}`}, {{- end }}`},
@ -218,8 +125,8 @@ func TestExecuteWithMessages(t *testing.T) {
{"no response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] `}, {"no response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] `},
{"response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`}, {"response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`},
{"messages", ` {"messages", `
{{- range .Messages }} {{- range $index, $_ := .Messages }}
{{- if eq .Role "user" }}[INST] {{ if and (eq (index $.Messages (sub (len $.Messages) 1)) .) $.System }}{{ $.System }}{{ "\n\n" }} {{- if eq .Role "user" }}[INST] {{ if and (eq (len (slice $.Messages $index)) 1) $.System }}{{ $.System }}{{ "\n\n" }}
{{- end }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }} {{- end }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}
{{- end }} {{- end }}
{{- end }}`}, {{- end }}`},
@ -248,8 +155,8 @@ What is your name?[/INST] `,
{{ .Response }}<|im_end|> {{ .Response }}<|im_end|>
`}, `},
{"messages", ` {"messages", `
{{- range .Messages }} {{- range $index, $_ := .Messages }}
{{- if and (eq .Role "user") (eq (index $.Messages (sub (len $.Messages) 1)) .) $.System }}<|im_start|>system {{- if and (eq .Role "user") (eq (len (slice $.Messages $index)) 1) $.System }}<|im_start|>system
{{ $.System }}<|im_end|>{{ "\n" }} {{ $.System }}<|im_end|>{{ "\n" }}
{{- end }}<|im_start|>{{ .Role }} {{- end }}<|im_start|>{{ .Role }}
{{ .Content }}<|im_end|>{{ "\n" }} {{ .Content }}<|im_end|>{{ "\n" }}