From e64f9ebb44b584d94094274f62acd90a5195dd89 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Thu, 11 Jul 2024 13:10:13 -0700 Subject: [PATCH 1/3] do no automatically aggregate system messages --- template/template.go | 39 ++++++++++++++++++++------------------- template/template_test.go | 11 +++++++---- 2 files changed, 27 insertions(+), 23 deletions(-) diff --git a/template/template.go b/template/template.go index 0b8f2434..8d5ac51b 100644 --- a/template/template.go +++ b/template/template.go @@ -102,8 +102,21 @@ var response = parse.ActionNode{ }, } +var funcs = template.FuncMap{ + "aggregate": func(v []*api.Message, role string) string { + var aggregated []string + for _, m := range v { + if m.Role == role { + aggregated = append(aggregated, m.Content) + } + } + + return strings.Join(aggregated, "\n\n") + }, +} + func Parse(s string) (*Template, error) { - tmpl := template.New("").Option("missingkey=zero") + tmpl := template.New("").Option("missingkey=zero").Funcs(funcs) tmpl, err := tmpl.Parse(s) if err != nil { @@ -149,23 +162,21 @@ type Values struct { } func (t *Template) Execute(w io.Writer, v Values) error { - system, collated := collate(v.Messages) + collated := collate(v.Messages) if !v.forceLegacy && slices.Contains(t.Vars(), "messages") { return t.Template.Execute(w, map[string]any{ - "System": system, "Messages": collated, }) } var b bytes.Buffer - var prompt, response string + var system, prompt, response string for i, m := range collated { switch m.Role { + case "system": + system = m.Content case "user": prompt = m.Content - if i != 0 { - system = "" - } case "assistant": response = m.Content } @@ -179,6 +190,7 @@ func (t *Template) Execute(w io.Writer, v Values) error { return err } + system = "" prompt = "" response = "" } @@ -209,25 +221,14 @@ func (t *Template) Execute(w io.Writer, v Values) error { return err } -type messages []*api.Message - // collate messages based on role. consecutive messages of the same role are merged // into a single message. collate also pulls out and merges messages with Role == "system" // which are templated separately. As a side effect, it mangles message content adding image // tags ([img-%d]) as needed -func collate(msgs []api.Message) (system string, collated messages) { +func collate(msgs []api.Message) (collated []*api.Message) { var n int for i := range msgs { msg := msgs[i] - if msg.Role == "system" { - if system != "" { - system += "\n\n" - } - - system += msg.Content - continue - } - for range msg.Images { imageTag := fmt.Sprintf("[img-%d]", n) if !strings.Contains(msg.Content, "[img]") { diff --git a/template/template_test.go b/template/template_test.go index e702a186..b020eb67 100644 --- a/template/template_test.go +++ b/template/template_test.go @@ -122,6 +122,7 @@ func TestTemplate(t *testing.T) { }) t.Run("legacy", func(t *testing.T) { + t.Skip("legacy outputs are currently default outputs") var legacy bytes.Buffer if err := tmpl.Execute(&legacy, Values{Messages: tt, forceLegacy: true}); err != nil { t.Fatal(err) @@ -154,11 +155,13 @@ func TestParse(t *testing.T) { {"{{ .System }} {{ .Prompt }} {{ .Response }}", []string{"prompt", "response", "system"}}, {"{{ with .Tools }}{{ . }}{{ end }} {{ .System }} {{ .Prompt }}", []string{"prompt", "response", "system", "tools"}}, {"{{ range .Messages }}{{ .Role }} {{ .Content }}{{ end }}", []string{"content", "messages", "role"}}, - {"{{ range .Messages }}{{ if eq .Role \"system\" }}SYSTEM: {{ .Content }}{{ else if eq .Role \"user\" }}USER: {{ .Content }}{{ else if eq .Role \"assistant\" }}ASSISTANT: {{ .Content }}{{ end }}{{ end }}", []string{"content", "messages", "role"}}, + {`{{- range .Messages }} +{{- if eq .Role "system" }}SYSTEM: +{{- else if eq .Role "user" }}USER: +{{- else if eq .Role "assistant" }}ASSISTANT: +{{- end }} {{ .Content }} +{{- end }}`, []string{"content", "messages", "role"}}, {`{{- if .Messages }} -{{- if .System }}<|im_start|>system -{{ .System }}<|im_end|> -{{ end }} {{- range .Messages }}<|im_start|>{{ .Role }} {{ .Content }}<|im_end|> {{ end }}<|im_start|>assistant From 57ec6901eb59cca9d0c29adca3f0fd4b95c1c989 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Thu, 11 Jul 2024 13:11:40 -0700 Subject: [PATCH 2/3] revert embedded templates to use prompt/response This reverts commit 19753c18c01183b4c974e36e89b0c7cbdcc3c38a. for compat. messages will be added at a later date --- server/routes_create_test.go | 4 +- template/alfred.gotmpl | 9 +-- template/alpaca.gotmpl | 13 ---- template/chatml.gotmpl | 9 --- template/chatqa.gotmpl | 12 ---- template/codellama-70b-instruct.gotmpl | 15 +---- template/falcon-instruct.gotmpl | 10 ---- template/gemma-instruct.gotmpl | 12 ---- template/granite-instruct.gotmpl | 14 ----- template/llama2-chat.gotmpl | 18 ++---- template/llama3-instruct.gotmpl | 14 +---- template/magicoder.gotmpl | 13 ---- template/mistral-instruct.gotmpl | 13 +--- template/openchat.gotmpl | 12 +--- template/phi-3.gotmpl | 9 --- template/solar-instruct.gotmpl | 14 ----- template/starcoder2-instruct.gotmpl | 15 ----- template/template_test.go | 59 ++++++++++++------- .../system-user-assistant-user | 4 +- .../llama2-chat.gotmpl/user-assistant-user | 4 +- .../system-user-assistant-user | 5 +- template/vicuna.gotmpl | 11 ---- template/zephyr.gotmpl | 9 --- 23 files changed, 63 insertions(+), 235 deletions(-) diff --git a/server/routes_create_test.go b/server/routes_create_test.go index 40477937..04174b92 100644 --- a/server/routes_create_test.go +++ b/server/routes_create_test.go @@ -546,8 +546,8 @@ func TestCreateDetectTemplate(t *testing.T) { checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{ filepath.Join(p, "blobs", "sha256-553c4a3f747b3d22a4946875f1cc8ed011c2930d83f864a0c7265f9ec0a20413"), - filepath.Join(p, "blobs", "sha256-68b0323b2f21572bc09ba07554b16b379a5713ee48ef8c25a7661a1f71cfce77"), - filepath.Join(p, "blobs", "sha256-eb72fb7c550ee1f1dec4039bd65382acecf5f7536a30fb7ccace39a8d0cb590b"), + filepath.Join(p, "blobs", "sha256-c608dc615584cd20d9d830363dabf8a4783ae5d34245c3d8c115edb3bc7b28e4"), + filepath.Join(p, "blobs", "sha256-f836ee110db21567f826332e4cedd746c06d10664fd5a9ea3659e3683a944510"), }) }) diff --git a/template/alfred.gotmpl b/template/alfred.gotmpl index 71bc6706..cecb9d2c 100644 --- a/template/alfred.gotmpl +++ b/template/alfred.gotmpl @@ -1,8 +1 @@ -{{- if .Messages }} -{{- if .System }}{{ .System }} -{{- end }} -{{- range .Messages }}{{ .Content }} -{{- end }} -{{- else -}} -{{ if .System }}{{ .System }}{{ end }}{{ if .Prompt }}{{ .Prompt }}{{ end }}{{ .Response }} -{{- end -}} \ No newline at end of file +{{ if .System }}{{ .System }}{{ end }}{{ if .Prompt }}{{ .Prompt }}{{ end }}{{ .Response }} \ No newline at end of file diff --git a/template/alpaca.gotmpl b/template/alpaca.gotmpl index e9becb3d..ec7a8edc 100644 --- a/template/alpaca.gotmpl +++ b/template/alpaca.gotmpl @@ -1,15 +1,3 @@ -{{- if .Messages }} -{{- if .System }}{{ .System }} - -{{ end }} -{{- range .Messages }} -{{- if eq .Role "user" }}### Instruction: -{{- else if eq .Role "assistant" }}### Response: -{{- end }} -{{ .Content }} - -{{ end }}### Response: -{{ else -}} {{ if .System }}{{ .System }} {{ end }}{{ if .Prompt }}### Instruction: @@ -18,4 +6,3 @@ {{ end }}### Response: {{ .Response }} -{{ end -}} \ No newline at end of file diff --git a/template/chatml.gotmpl b/template/chatml.gotmpl index eb8ab0dc..fb672601 100644 --- a/template/chatml.gotmpl +++ b/template/chatml.gotmpl @@ -1,15 +1,6 @@ -{{- if .Messages }} -{{- if .System }}<|im_start|>system -{{ .System }}<|im_end|> -{{ end }} -{{- range .Messages }}<|im_start|>{{ .Role }} -{{ .Content }}<|im_end|> -{{ end }}<|im_start|>assistant -{{ else -}} {{ if .System }}<|im_start|>system {{ .System }}<|im_end|> {{ end }}{{ if .Prompt }}<|im_start|>user {{ .Prompt }}<|im_end|> {{ end }}<|im_start|>assistant {{ .Response }}<|im_end|> -{{ end -}} \ No newline at end of file diff --git a/template/chatqa.gotmpl b/template/chatqa.gotmpl index 41c6ced5..91679a72 100644 --- a/template/chatqa.gotmpl +++ b/template/chatqa.gotmpl @@ -1,18 +1,6 @@ -{{- if .Messages }} -{{- if .System }}System: {{ .System }} - -{{ end }} -{{- range .Messages }} -{{- if eq .Role "user" }}User: -{{- else if eq .Role "assistant" }}Assistant: -{{- end }} {{ .Content }} - -{{ end }}Assistant: -{{- else -}} {{ if .System }}System: {{ .System }} {{ end }}{{ if .Prompt }}User: {{ .Prompt }} {{ end }}Assistant: {{ .Response }} -{{ end -}} \ No newline at end of file diff --git a/template/codellama-70b-instruct.gotmpl b/template/codellama-70b-instruct.gotmpl index 0a313d38..e5856042 100644 --- a/template/codellama-70b-instruct.gotmpl +++ b/template/codellama-70b-instruct.gotmpl @@ -1,19 +1,10 @@ -{{- if .Messages }} -{{- if .System }}Source: system - - {{ .System }} {{ end }} -{{- range .Messages }}Source: {{ .Role }} - - {{ .Content }} {{ end }}Source: assistant -Destination: user - - {{ else -}} {{ if .System }}Source: system {{ .System }} {{ end }}Source: user {{ .Prompt }} Source: assistant +{{- if not .Response }} Destination: user +{{- end }} - {{ .Response }} -{{- end -}} \ No newline at end of file + {{ .Response }} \ No newline at end of file diff --git a/template/falcon-instruct.gotmpl b/template/falcon-instruct.gotmpl index 3a403007..0a5fe48e 100644 --- a/template/falcon-instruct.gotmpl +++ b/template/falcon-instruct.gotmpl @@ -1,15 +1,5 @@ -{{- if .Messages }} -{{- if .System }}System: {{ .System }} -{{ end }} -{{- range .Messages }} -{{- if eq .Role "user" }}User: -{{ else if eq .Role "assistant" }}Falcon: -{{ end }}{{ .Content }} -{{ end }}Falcon: -{{ else -}} {{ if .System }}System: {{ .System }} {{ end }}{{ if .Prompt }}User: {{ .Prompt }} {{ end }}Falcon: {{ .Response }} -{{ end -}} \ No newline at end of file diff --git a/template/gemma-instruct.gotmpl b/template/gemma-instruct.gotmpl index 6d778a70..3c3a8425 100644 --- a/template/gemma-instruct.gotmpl +++ b/template/gemma-instruct.gotmpl @@ -1,17 +1,5 @@ -{{- if .Messages }} -{{- range $index, $_ := .Messages }} -{{- if eq .Role "user" }}user -{{- if and $.System (eq $index 0) }} -{{ $.System }} -{{- end }} -{{- else if eq .Role "assistant" }}model -{{- end }} -{{ .Content }} -{{ end }}model -{{ else -}} user {{ if .System }}{{ .System }} {{ end }}{{ .Prompt }} model {{ .Response }} -{{ end -}} \ No newline at end of file diff --git a/template/granite-instruct.gotmpl b/template/granite-instruct.gotmpl index 4a85a97b..56690fce 100644 --- a/template/granite-instruct.gotmpl +++ b/template/granite-instruct.gotmpl @@ -1,16 +1,3 @@ -{{- if .Messages }} -{{- if .System }}System: -{{ .System }} - -{{ end }} -{{- range .Messages }} -{{- if eq .Role "user" }}Question: -{{- else if eq .Role "assistant" }}Answer: -{{- end }} -{{ .Content }} - -{{ end }}Answer: -{{ else -}} {{ if .System }}System: {{ .System }} @@ -20,4 +7,3 @@ {{ end }}Answer: {{ .Response }} -{{ end -}} \ No newline at end of file diff --git a/template/llama2-chat.gotmpl b/template/llama2-chat.gotmpl index 1816fefd..013b414e 100644 --- a/template/llama2-chat.gotmpl +++ b/template/llama2-chat.gotmpl @@ -1,16 +1,6 @@ -{{- if .Messages }} -{{- range $index, $_ := .Messages }} -{{- if eq .Role "user" }}[INST] {{ if eq $index 0 }}<> -{{- if $.System }} -{{ $.System }} +[INST] <> +{{- if .System }} +{{ .System }} {{ end }}<> -{{ end }}{{ .Content }} -{{- else }} [/INST] {{ .Content }} -{{- end }} -{{- end }} [/INST] -{{- else -}} -[INST] <>{{ if .System }}{{ .System }}{{ end }}<> - -{{ .Prompt }} [/INST] {{ .Response }} -{{- end -}} \ No newline at end of file +{{ .Prompt }} [/INST] {{ .Response }} \ No newline at end of file diff --git a/template/llama3-instruct.gotmpl b/template/llama3-instruct.gotmpl index 7947b8da..36d0218b 100644 --- a/template/llama3-instruct.gotmpl +++ b/template/llama3-instruct.gotmpl @@ -1,19 +1,7 @@ -{{- if .Messages }} -{{- if .System }}<|start_header_id|>system<|end_header_id|> - -{{ .System }}<|eot_id|> -{{- end }} -{{- range .Messages }}<|start_header_id|>{{ .Role }}<|end_header_id|> - -{{ .Content }}<|eot_id|> -{{- end }}<|start_header_id|>assistant<|end_header_id|> - -{{ else -}} {{ if .System }}<|start_header_id|>system<|end_header_id|> {{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|> {{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|> -{{ .Response }}<|eot_id|> -{{- end -}} \ No newline at end of file +{{ .Response }}<|eot_id|> \ No newline at end of file diff --git a/template/magicoder.gotmpl b/template/magicoder.gotmpl index 9227b666..52abc01a 100644 --- a/template/magicoder.gotmpl +++ b/template/magicoder.gotmpl @@ -1,15 +1,3 @@ -{{- if .Messages }} -{{- if .System }}{{ .System }} - -{{ end }} -{{- range .Messages }} -{{- if eq .Role "user" }}@@ Instruction -{{- else if eq .Role "assistant" }}@@ Response -{{- end }} -{{ .Content }} - -{{ end }}@@ Response -{{ else -}} {{ if .System }}{{ .System }} {{ end }}{{ if .Prompt }}@@ Instruction @@ -18,4 +6,3 @@ {{ end }}@@ Response {{ .Response }} -{{ end -}} \ No newline at end of file diff --git a/template/mistral-instruct.gotmpl b/template/mistral-instruct.gotmpl index 1d746dfd..e489bd4c 100644 --- a/template/mistral-instruct.gotmpl +++ b/template/mistral-instruct.gotmpl @@ -1,10 +1,3 @@ -{{- if .Messages }} -{{- range $index, $_ := .Messages }} -{{- if eq .Role "user" }}[INST] {{ if and $.System (eq (len (slice $.Messages $index)) 1) }}{{ $.System }} -{{ end }}{{ .Content }} -{{- else if eq .Role "assistant" }}[/INST] {{ .Content }} -{{- end }} -{{- end }}[/INST] -{{- else -}} -[INST] {{ if .System }}{{ .System }} {{ end }}{{ .Prompt }}[/INST] {{ .Response }} -{{- end -}} \ No newline at end of file +[INST] {{ if .System }}{{ .System }} + +{{ end }}{{ .Prompt }}[/INST] {{ .Response }} \ No newline at end of file diff --git a/template/openchat.gotmpl b/template/openchat.gotmpl index 649f0509..9c183834 100644 --- a/template/openchat.gotmpl +++ b/template/openchat.gotmpl @@ -1,11 +1 @@ -{{- if .Messages }} -{{- if .System }}GPT4 Correct System: {{ .System }}<|end_of_turn|> -{{- end }} -{{- range .Messages }}GPT4 Correct -{{- if eq .Role "user" }} User: -{{- else if eq .Role "assistant" }} Assistant: -{{- end }} {{ .Content }}<|end_of_turn|> -{{- end }}GPT4 Correct Assistant: -{{- else -}} -{{ if .System }}GPT4 Correct System: {{ .System }}<|end_of_turn|>{{ end }}GPT4 Correct User: {{ .Prompt }}<|end_of_turn|>GPT4 Correct Assistant: {{ .Response }}<|end_of_turn|> -{{- end -}} \ No newline at end of file +{{ if .System }}GPT4 Correct System: {{ .System }}<|end_of_turn|>{{ end }}GPT4 Correct User: {{ .Prompt }}<|end_of_turn|>GPT4 Correct Assistant: {{ .Response }}<|end_of_turn|> \ No newline at end of file diff --git a/template/phi-3.gotmpl b/template/phi-3.gotmpl index 4ca56e95..6c3610dd 100644 --- a/template/phi-3.gotmpl +++ b/template/phi-3.gotmpl @@ -1,15 +1,6 @@ -{{- if .Messages }} -{{- if .System }}<|system|> -{{ .System }}<|end|> -{{ end }} -{{- range .Messages }}<|{{ .Role }}|> -{{ .Content }}<|end|> -{{ end }}<|assistant|> -{{ else -}} {{ if .System }}<|system|> {{ .System }}<|end|> {{ end }}{{ if .Prompt }}<|user|> {{ .Prompt }}<|end|> {{ end }}<|assistant|> {{ .Response }}<|end|> -{{ end -}} \ No newline at end of file diff --git a/template/solar-instruct.gotmpl b/template/solar-instruct.gotmpl index 8a8331ca..1c14960d 100644 --- a/template/solar-instruct.gotmpl +++ b/template/solar-instruct.gotmpl @@ -1,16 +1,3 @@ -{{- if .Messages }} -{{- if .System }}### System: -{{ .System }} - -{{ end }} -{{- range .Messages }} -{{- if eq .Role "user" }}### User: -{{ .Content }} -{{ else if eq .Role "assistant" }}### Assistant: -{{ .Content }} -{{ end }} -{{ end }}### Assistant: -{{ else -}} {{ if .System }}### System: {{ .System }} @@ -20,4 +7,3 @@ {{ end }}### Assistant: {{ .Response }} -{{ end -}} \ No newline at end of file diff --git a/template/starcoder2-instruct.gotmpl b/template/starcoder2-instruct.gotmpl index 17c6ad75..6c93a7ab 100644 --- a/template/starcoder2-instruct.gotmpl +++ b/template/starcoder2-instruct.gotmpl @@ -1,17 +1,3 @@ -{{- if .Messages }} -{{- if .System }}{{ .System }} - -{{ end }} -{{- range .Messages }} -{{- if eq .Role "user" }}### Instruction -{{ .Content }} - -{{ else if eq .Role "assistant" }}### Response -{{ .Content }}<|endoftext|> - -{{ end }} -{{- end }}### Response -{{ else -}} {{ if .System }}{{ .System }} {{ end }}{{ if .Prompt }}### Instruction @@ -20,4 +6,3 @@ {{ end }}### Response {{ .Response }}<|endoftext|> -{{ end -}} \ No newline at end of file diff --git a/template/template_test.go b/template/template_test.go index b020eb67..9cfa0bea 100644 --- a/template/template_test.go +++ b/template/template_test.go @@ -116,7 +116,14 @@ func TestTemplate(t *testing.T) { t.Fatal(err) } - if diff := cmp.Diff(actual.Bytes(), expect); diff != "" { + bts := actual.Bytes() + + if slices.Contains([]string{"chatqa.gotmpl", "llama2-chat.gotmpl", "mistral-instruct.gotmpl", "openchat.gotmpl", "vicuna.gotmpl"}, match) && bts[len(bts)-1] == ' ' { + t.Log("removing trailing space from output") + bts = bts[:len(bts)-1] + } + + if diff := cmp.Diff(bts, expect); diff != "" { t.Errorf("mismatch (-got +want):\n%s", diff) } }) @@ -203,11 +210,18 @@ func TestExecuteWithMessages(t *testing.T) { { "mistral", []template{ - {"no response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] `}, - {"response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`}, - {"messages", `{{- range $index, $_ := .Messages }} -{{- if eq .Role "user" }}[INST] {{ if and (eq $index 0) $.System }}{{ $.System }}{{ "\n\n" }} -{{- end }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }} + {"no response", `[INST] {{ if .System }}{{ .System }} + +{{ end }}{{ .Prompt }}[/INST] `}, + {"response", `[INST] {{ if .System }}{{ .System }} + +{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`}, + {"messages", `{{- $system := aggregate $.Messages "system" -}} +{{- range $index, $_ := .Messages }} +{{- if eq .Role "user" }}[INST] {{ if $system }}{{ $system }} +{{- $system = "" }} + +{{ end }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }} {{- end }} {{- end }}`}, }, @@ -223,12 +237,18 @@ func TestExecuteWithMessages(t *testing.T) { { "mistral system", []template{ - {"no response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] `}, - {"response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`}, - {"messages", ` + {"no response", `[INST] {{ if .System }}{{ .System }} + +{{ end }}{{ .Prompt }}[/INST] `}, + {"response", `[INST] {{ if .System }}{{ .System }} + +{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`}, + {"messages", `{{- $system := aggregate $.Messages "system" -}} {{- range $index, $_ := .Messages }} -{{- if eq .Role "user" }}[INST] {{ if and (eq $index 0) $.System }}{{ $.System }}{{ "\n\n" }} -{{- end }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }} +{{- if eq .Role "user" }}[INST] {{ if $system }}{{ $system }} +{{- $system = "" }} + +{{ end }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }} {{- end }} {{- end }}`}, }, @@ -256,12 +276,9 @@ Hello friend![/INST] Hello human![INST] What is your name?[/INST] `, {{ .Response }}<|im_end|> `}, {"messages", ` -{{- range $index, $_ := .Messages }} -{{- if and (eq .Role "user") (eq $index 0) $.System }}<|im_start|>system -{{ $.System }}<|im_end|>{{ "\n" }} -{{- end }}<|im_start|>{{ .Role }} -{{ .Content }}<|im_end|>{{ "\n" }} -{{- end }}<|im_start|>assistant +{{- range $index, $_ := .Messages }}<|im_start|>{{ .Role }} +{{ .Content }}<|im_end|> +{{ end }}<|im_start|>assistant `}, }, Values{ @@ -294,9 +311,11 @@ What is your name?<|im_end|> `}, {"messages", ` {{- range .Messages }} -{{- if eq .Role "user" }}Question: {{ .Content }}{{ "\n\n" }} -{{- else if eq .Role "assistant" }}Answer: {{ .Content }}{{ "\n\n" }} -{{- end }} +{{- if eq .Role "user" }}Question: {{ .Content }} + +{{ else if eq .Role "assistant" }}Answer: {{ .Content }} + +{{ end }} {{- end }}Answer: `}, }, Values{ diff --git a/template/testdata/llama2-chat.gotmpl/system-user-assistant-user b/template/testdata/llama2-chat.gotmpl/system-user-assistant-user index fc2679bf..9db81cb4 100644 --- a/template/testdata/llama2-chat.gotmpl/system-user-assistant-user +++ b/template/testdata/llama2-chat.gotmpl/system-user-assistant-user @@ -2,4 +2,6 @@ You are a helpful assistant. <> -Hello, how are you? [/INST] I'm doing great. How can I help you today?[INST] I'd like to show off how chat templating works! [/INST] \ No newline at end of file +Hello, how are you? [/INST] I'm doing great. How can I help you today?[INST] <><> + +I'd like to show off how chat templating works! [/INST] \ No newline at end of file diff --git a/template/testdata/llama2-chat.gotmpl/user-assistant-user b/template/testdata/llama2-chat.gotmpl/user-assistant-user index 42b4c529..ca58954f 100644 --- a/template/testdata/llama2-chat.gotmpl/user-assistant-user +++ b/template/testdata/llama2-chat.gotmpl/user-assistant-user @@ -1,3 +1,5 @@ [INST] <><> -Hello, how are you? [/INST] I'm doing great. How can I help you today?[INST] I'd like to show off how chat templating works! [/INST] \ No newline at end of file +Hello, how are you? [/INST] I'm doing great. How can I help you today?[INST] <><> + +I'd like to show off how chat templating works! [/INST] \ No newline at end of file diff --git a/template/testdata/mistral-instruct.gotmpl/system-user-assistant-user b/template/testdata/mistral-instruct.gotmpl/system-user-assistant-user index b6b4bf93..2f1edaec 100644 --- a/template/testdata/mistral-instruct.gotmpl/system-user-assistant-user +++ b/template/testdata/mistral-instruct.gotmpl/system-user-assistant-user @@ -1,2 +1,3 @@ -[INST] Hello, how are you?[/INST] I'm doing great. How can I help you today?[INST] You are a helpful assistant. -I'd like to show off how chat templating works![/INST] \ No newline at end of file +[INST] You are a helpful assistant. + +Hello, how are you?[/INST] I'm doing great. How can I help you today?[INST] I'd like to show off how chat templating works![/INST] \ No newline at end of file diff --git a/template/vicuna.gotmpl b/template/vicuna.gotmpl index 01465b99..515b2fe9 100644 --- a/template/vicuna.gotmpl +++ b/template/vicuna.gotmpl @@ -1,15 +1,4 @@ -{{- if .Messages }} -{{- if .System }}{{ .System }} - -{{ end }} -{{- range .Messages }} -{{- if eq .Role "user" }}USER: {{ .Content }} -{{ else if eq .Role "assistant" }}ASSISTANT: {{ .Content }} -{{ end }} -{{- end }}ASSISTANT: -{{- else -}} {{ if .System }}{{ .System }} {{ end }}{{ if .Prompt }}USER: {{ .Prompt }} {{ end }}ASSISTANT: {{ .Response }} -{{ end -}} \ No newline at end of file diff --git a/template/zephyr.gotmpl b/template/zephyr.gotmpl index 3ca1d1a1..1f889f26 100644 --- a/template/zephyr.gotmpl +++ b/template/zephyr.gotmpl @@ -1,15 +1,6 @@ -{{- if .Messages }} -{{- if .System }}<|system|> -{{ .System }} -{{ end }} -{{- range .Messages }}<|{{ .Role }}|> -{{ .Content }} -{{ end }}<|assistant|> -{{ else -}} {{ if .System }}<|system|> {{ .System }} {{ end }}{{ if .Prompt }}<|user|> {{ .Prompt }} {{ end }}<|assistant|> {{ .Response }} -{{ end -}} \ No newline at end of file From 5056bb9c010f06316b0ff280b879b9c36a7c995c Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Thu, 11 Jul 2024 16:06:57 -0700 Subject: [PATCH 3/3] rename aggregate to contents --- template/template.go | 11 ++++++----- template/template_test.go | 37 +++++++++++++++++++++++++++++++++++-- 2 files changed, 41 insertions(+), 7 deletions(-) diff --git a/template/template.go b/template/template.go index 8d5ac51b..21e1614d 100644 --- a/template/template.go +++ b/template/template.go @@ -103,15 +103,16 @@ var response = parse.ActionNode{ } var funcs = template.FuncMap{ - "aggregate": func(v []*api.Message, role string) string { - var aggregated []string + // contents returns the contents of messages with an optional role filter + "contents": func(v []*api.Message, role ...string) string { + var parts []string for _, m := range v { - if m.Role == role { - aggregated = append(aggregated, m.Content) + if len(role) == 0 || role[0] == "" || m.Role == role[0] { + parts = append(parts, m.Content) } } - return strings.Join(aggregated, "\n\n") + return strings.Join(parts, "\n\n") }, } diff --git a/template/template_test.go b/template/template_test.go index 9cfa0bea..5e5f4257 100644 --- a/template/template_test.go +++ b/template/template_test.go @@ -216,7 +216,7 @@ func TestExecuteWithMessages(t *testing.T) { {"response", `[INST] {{ if .System }}{{ .System }} {{ end }}{{ .Prompt }}[/INST] {{ .Response }}`}, - {"messages", `{{- $system := aggregate $.Messages "system" -}} + {"messages", `{{- $system := contents .Messages "system" -}} {{- range $index, $_ := .Messages }} {{- if eq .Role "user" }}[INST] {{ if $system }}{{ $system }} {{- $system = "" }} @@ -243,7 +243,7 @@ func TestExecuteWithMessages(t *testing.T) { {"response", `[INST] {{ if .System }}{{ .System }} {{ end }}{{ .Prompt }}[/INST] {{ .Response }}`}, - {"messages", `{{- $system := aggregate $.Messages "system" -}} + {"messages", `{{- $system := contents .Messages "system" -}} {{- range $index, $_ := .Messages }} {{- if eq .Role "user" }}[INST] {{ if $system }}{{ $system }} {{- $system = "" }} @@ -363,3 +363,36 @@ Answer: `, }) } } + +func TestFuncs(t *testing.T) { + t.Run("contents", func(t *testing.T) { + cases := map[string]string{ + "": "A\n\nB\n\nC\n\nD\n\nE\n\nF", + "system": "A\n\nF", + "user": "B\n\nE", + "assistant": "C\n\nD", + } + + s := []*api.Message{ + {Role: "system", Content: "A"}, + {Role: "user", Content: "B"}, + {Role: "assistant", Content: "C"}, + {Role: "assistant", Content: "D"}, + {Role: "user", Content: "E"}, + {Role: "system", Content: "F"}, + } + + fn, ok := funcs["contents"].(func([]*api.Message, ...string) string) + if !ok { + t.Fatal("contents is not a function") + } + + for k, v := range cases { + t.Run(k, func(t *testing.T) { + if diff := cmp.Diff(fn(s, k), v); diff != "" { + t.Errorf("mismatch (-got +want):\n%s", diff) + } + }) + } + }) +}