diff --git a/server/prompt_test.go b/server/prompt_test.go index d4cee98c..1435b143 100644 --- a/server/prompt_test.go +++ b/server/prompt_test.go @@ -161,7 +161,7 @@ func TestChatPrompt(t *testing.T) { {Role: "user", Content: "A test. And a thumping good one at that, I'd wager."}, }, expect: expect{ - prompt: "You're a test, Harry! I-I'm a what? You are the Test Who Lived. A test. And a thumping good one at that, I'd wager. ", + prompt: "You are the Test Who Lived. You're a test, Harry! I-I'm a what? A test. And a thumping good one at that, I'd wager. ", }, }, } diff --git a/server/routes_create_test.go b/server/routes_create_test.go index 269a0ba1..40477937 100644 --- a/server/routes_create_test.go +++ b/server/routes_create_test.go @@ -546,8 +546,8 @@ func TestCreateDetectTemplate(t *testing.T) { checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{ filepath.Join(p, "blobs", "sha256-553c4a3f747b3d22a4946875f1cc8ed011c2930d83f864a0c7265f9ec0a20413"), - filepath.Join(p, "blobs", "sha256-9512c372dfc7d84d6065b8dd2b601aeed8cc1a78e7a7aa784a42fff37f5524b7"), - filepath.Join(p, "blobs", "sha256-b8b78cb8c6eefd14c06f1af042e6161255bf87bbf2dd14fce57cdac893db8139"), + filepath.Join(p, "blobs", "sha256-68b0323b2f21572bc09ba07554b16b379a5713ee48ef8c25a7661a1f71cfce77"), + filepath.Join(p, "blobs", "sha256-eb72fb7c550ee1f1dec4039bd65382acecf5f7536a30fb7ccace39a8d0cb590b"), }) }) diff --git a/template/template.go b/template/template.go index b133b97e..0b8f2434 100644 --- a/template/template.go +++ b/template/template.go @@ -143,11 +143,14 @@ func (t *Template) Vars() []string { type Values struct { Messages []api.Message + + // forceLegacy is a flag used to test compatibility with legacy templates + forceLegacy bool } func (t *Template) Execute(w io.Writer, v Values) error { system, collated := collate(v.Messages) - if slices.Contains(t.Vars(), "messages") { + if !v.forceLegacy && slices.Contains(t.Vars(), "messages") { return t.Template.Execute(w, map[string]any{ "System": system, "Messages": collated, @@ -157,15 +160,19 @@ func (t *Template) Execute(w io.Writer, v Values) error { var b bytes.Buffer var prompt, response string for i, m := range collated { - if m.Role == "user" { + switch m.Role { + case "user": prompt = m.Content - } else { + if i != 0 { + system = "" + } + case "assistant": response = m.Content } if i != len(collated)-1 && prompt != "" && response != "" { if err := t.Template.Execute(&b, map[string]any{ - "System": "", + "System": system, "Prompt": prompt, "Response": response, }); err != nil { @@ -178,18 +185,21 @@ func (t *Template) Execute(w io.Writer, v Values) error { } var cut bool - tree := t.Template.Copy() - // for the last message, cut everything after "{{ .Response }}" - tree.Root.Nodes = slices.DeleteFunc(tree.Root.Nodes, func(n parse.Node) bool { - if slices.Contains(parseNode(n), "Response") { - cut = true + nodes := deleteNode(t.Template.Root.Copy(), func(n parse.Node) bool { + switch t := n.(type) { + case *parse.ActionNode: + case *parse.FieldNode: + if slices.Contains(t.Ident, "Response") { + cut = true + } } return cut }) - if err := template.Must(template.New("").AddParseTree("", tree)).Execute(&b, map[string]any{ - "System": system, + tree := parse.Tree{Root: nodes.(*parse.ListNode)} + if err := template.Must(template.New("").AddParseTree("", &tree)).Execute(&b, map[string]any{ + "System": "", "Prompt": prompt, }); err != nil { return err @@ -286,3 +296,72 @@ func parseNode(n parse.Node) []string { return nil } + +// deleteNode walks the node list and deletes nodes that match the predicate +// this is currently to remove the {{ .Response }} node from templates +func deleteNode(n parse.Node, fn func(parse.Node) bool) parse.Node { + var walk func(n parse.Node) parse.Node + walk = func(n parse.Node) parse.Node { + if fn(n) { + return nil + } + + switch t := n.(type) { + case *parse.ListNode: + var nodes []parse.Node + for _, c := range t.Nodes { + if n := walk(c); n != nil { + nodes = append(nodes, n) + } + } + + t.Nodes = nodes + return t + case *parse.IfNode: + t.BranchNode = *(walk(&t.BranchNode).(*parse.BranchNode)) + case *parse.WithNode: + t.BranchNode = *(walk(&t.BranchNode).(*parse.BranchNode)) + case *parse.RangeNode: + t.BranchNode = *(walk(&t.BranchNode).(*parse.BranchNode)) + case *parse.BranchNode: + t.List = walk(t.List).(*parse.ListNode) + if t.ElseList != nil { + t.ElseList = walk(t.ElseList).(*parse.ListNode) + } + case *parse.ActionNode: + n := walk(t.Pipe) + if n == nil { + return nil + } + + t.Pipe = n.(*parse.PipeNode) + case *parse.PipeNode: + var commands []*parse.CommandNode + for _, c := range t.Cmds { + var args []parse.Node + for _, a := range c.Args { + if n := walk(a); n != nil { + args = append(args, n) + } + } + + if len(args) == 0 { + return nil + } + + c.Args = args + commands = append(commands, c) + } + + if len(commands) == 0 { + return nil + } + + t.Cmds = commands + } + + return n + } + + return walk(n) +} diff --git a/template/template_test.go b/template/template_test.go index 428cdc77..e702a186 100644 --- a/template/template_test.go +++ b/template/template_test.go @@ -105,8 +105,8 @@ func TestTemplate(t *testing.T) { } for n, tt := range cases { + var actual bytes.Buffer t.Run(n, func(t *testing.T) { - var actual bytes.Buffer if err := tmpl.Execute(&actual, Values{Messages: tt}); err != nil { t.Fatal(err) } @@ -120,6 +120,25 @@ func TestTemplate(t *testing.T) { t.Errorf("mismatch (-got +want):\n%s", diff) } }) + + t.Run("legacy", func(t *testing.T) { + var legacy bytes.Buffer + if err := tmpl.Execute(&legacy, Values{Messages: tt, forceLegacy: true}); err != nil { + t.Fatal(err) + } + + legacyBytes := legacy.Bytes() + if slices.Contains([]string{"chatqa.gotmpl", "openchat.gotmpl", "vicuna.gotmpl"}, match) && legacyBytes[len(legacyBytes)-1] == ' ' { + t.Log("removing trailing space from legacy output") + legacyBytes = legacyBytes[:len(legacyBytes)-1] + } else if slices.Contains([]string{"codellama-70b-instruct.gotmpl", "llama2-chat.gotmpl", "mistral-instruct.gotmpl"}, match) { + t.Skip("legacy outputs cannot be compared to messages outputs") + } + + if diff := cmp.Diff(legacyBytes, actual.Bytes()); diff != "" { + t.Errorf("mismatch (-got +want):\n%s", diff) + } + }) } }) } @@ -136,6 +155,21 @@ func TestParse(t *testing.T) { {"{{ with .Tools }}{{ . }}{{ end }} {{ .System }} {{ .Prompt }}", []string{"prompt", "response", "system", "tools"}}, {"{{ range .Messages }}{{ .Role }} {{ .Content }}{{ end }}", []string{"content", "messages", "role"}}, {"{{ range .Messages }}{{ if eq .Role \"system\" }}SYSTEM: {{ .Content }}{{ else if eq .Role \"user\" }}USER: {{ .Content }}{{ else if eq .Role \"assistant\" }}ASSISTANT: {{ .Content }}{{ end }}{{ end }}", []string{"content", "messages", "role"}}, + {`{{- if .Messages }} +{{- if .System }}<|im_start|>system +{{ .System }}<|im_end|> +{{ end }} +{{- range .Messages }}<|im_start|>{{ .Role }} +{{ .Content }}<|im_end|> +{{ end }}<|im_start|>assistant +{{ else -}} +{{ if .System }}<|im_start|>system +{{ .System }}<|im_end|> +{{ end }}{{ if .Prompt }}<|im_start|>user +{{ .Prompt }}<|im_end|> +{{ end }}<|im_start|>assistant +{{ .Response }}<|im_end|> +{{- end -}}`, []string{"content", "messages", "prompt", "response", "role", "system"}}, } for _, tt := range cases { @@ -145,9 +179,8 @@ func TestParse(t *testing.T) { t.Fatal(err) } - vars := tmpl.Vars() - if !slices.Equal(tt.vars, vars) { - t.Errorf("expected %v, got %v", tt.vars, vars) + if diff := cmp.Diff(tmpl.Vars(), tt.vars); diff != "" { + t.Errorf("mismatch (-got +want):\n%s", diff) } }) } @@ -170,7 +203,7 @@ func TestExecuteWithMessages(t *testing.T) { {"no response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] `}, {"response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`}, {"messages", `{{- range $index, $_ := .Messages }} -{{- if eq .Role "user" }}[INST] {{ if and (eq (len (slice $.Messages $index)) 1) $.System }}{{ $.System }}{{ "\n\n" }} +{{- if eq .Role "user" }}[INST] {{ if and (eq $index 0) $.System }}{{ $.System }}{{ "\n\n" }} {{- end }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }} {{- end }} {{- end }}`}, @@ -191,7 +224,7 @@ func TestExecuteWithMessages(t *testing.T) { {"response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`}, {"messages", ` {{- range $index, $_ := .Messages }} -{{- if eq .Role "user" }}[INST] {{ if and (eq (len (slice $.Messages $index)) 1) $.System }}{{ $.System }}{{ "\n\n" }} +{{- if eq .Role "user" }}[INST] {{ if and (eq $index 0) $.System }}{{ $.System }}{{ "\n\n" }} {{- end }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }} {{- end }} {{- end }}`}, @@ -204,9 +237,9 @@ func TestExecuteWithMessages(t *testing.T) { {Role: "user", Content: "What is your name?"}, }, }, - `[INST] Hello friend![/INST] Hello human![INST] You are a helpful assistant! + `[INST] You are a helpful assistant! -What is your name?[/INST] `, +Hello friend![/INST] Hello human![INST] What is your name?[/INST] `, }, { "chatml", @@ -221,7 +254,7 @@ What is your name?[/INST] `, `}, {"messages", ` {{- range $index, $_ := .Messages }} -{{- if and (eq .Role "user") (eq (len (slice $.Messages $index)) 1) $.System }}<|im_start|>system +{{- if and (eq .Role "user") (eq $index 0) $.System }}<|im_start|>system {{ $.System }}<|im_end|>{{ "\n" }} {{- end }}<|im_start|>{{ .Role }} {{ .Content }}<|im_end|>{{ "\n" }} @@ -236,12 +269,12 @@ What is your name?[/INST] `, {Role: "user", Content: "What is your name?"}, }, }, - `<|im_start|>user + `<|im_start|>system +You are a helpful assistant!<|im_end|> +<|im_start|>user Hello friend!<|im_end|> <|im_start|>assistant Hello human!<|im_end|> -<|im_start|>system -You are a helpful assistant!<|im_end|> <|im_start|>user What is your name?<|im_end|> <|im_start|>assistant @@ -300,8 +333,8 @@ Answer: `, t.Fatal(err) } - if b.String() != tt.expected { - t.Errorf("expected\n%s,\ngot\n%s", tt.expected, b.String()) + if diff := cmp.Diff(b.String(), tt.expected); diff != "" { + t.Errorf("mismatch (-got +want):\n%s", diff) } }) }