diff --git a/template/template.go b/template/template.go index 0b8f2434..8d5ac51b 100644 --- a/template/template.go +++ b/template/template.go @@ -102,8 +102,21 @@ var response = parse.ActionNode{ }, } +var funcs = template.FuncMap{ + "aggregate": func(v []*api.Message, role string) string { + var aggregated []string + for _, m := range v { + if m.Role == role { + aggregated = append(aggregated, m.Content) + } + } + + return strings.Join(aggregated, "\n\n") + }, +} + func Parse(s string) (*Template, error) { - tmpl := template.New("").Option("missingkey=zero") + tmpl := template.New("").Option("missingkey=zero").Funcs(funcs) tmpl, err := tmpl.Parse(s) if err != nil { @@ -149,23 +162,21 @@ type Values struct { } func (t *Template) Execute(w io.Writer, v Values) error { - system, collated := collate(v.Messages) + collated := collate(v.Messages) if !v.forceLegacy && slices.Contains(t.Vars(), "messages") { return t.Template.Execute(w, map[string]any{ - "System": system, "Messages": collated, }) } var b bytes.Buffer - var prompt, response string + var system, prompt, response string for i, m := range collated { switch m.Role { + case "system": + system = m.Content case "user": prompt = m.Content - if i != 0 { - system = "" - } case "assistant": response = m.Content } @@ -179,6 +190,7 @@ func (t *Template) Execute(w io.Writer, v Values) error { return err } + system = "" prompt = "" response = "" } @@ -209,25 +221,14 @@ func (t *Template) Execute(w io.Writer, v Values) error { return err } -type messages []*api.Message - // collate messages based on role. consecutive messages of the same role are merged // into a single message. collate also pulls out and merges messages with Role == "system" // which are templated separately. As a side effect, it mangles message content adding image // tags ([img-%d]) as needed -func collate(msgs []api.Message) (system string, collated messages) { +func collate(msgs []api.Message) (collated []*api.Message) { var n int for i := range msgs { msg := msgs[i] - if msg.Role == "system" { - if system != "" { - system += "\n\n" - } - - system += msg.Content - continue - } - for range msg.Images { imageTag := fmt.Sprintf("[img-%d]", n) if !strings.Contains(msg.Content, "[img]") { diff --git a/template/template_test.go b/template/template_test.go index e702a186..b020eb67 100644 --- a/template/template_test.go +++ b/template/template_test.go @@ -122,6 +122,7 @@ func TestTemplate(t *testing.T) { }) t.Run("legacy", func(t *testing.T) { + t.Skip("legacy outputs are currently default outputs") var legacy bytes.Buffer if err := tmpl.Execute(&legacy, Values{Messages: tt, forceLegacy: true}); err != nil { t.Fatal(err) @@ -154,11 +155,13 @@ func TestParse(t *testing.T) { {"{{ .System }} {{ .Prompt }} {{ .Response }}", []string{"prompt", "response", "system"}}, {"{{ with .Tools }}{{ . }}{{ end }} {{ .System }} {{ .Prompt }}", []string{"prompt", "response", "system", "tools"}}, {"{{ range .Messages }}{{ .Role }} {{ .Content }}{{ end }}", []string{"content", "messages", "role"}}, - {"{{ range .Messages }}{{ if eq .Role \"system\" }}SYSTEM: {{ .Content }}{{ else if eq .Role \"user\" }}USER: {{ .Content }}{{ else if eq .Role \"assistant\" }}ASSISTANT: {{ .Content }}{{ end }}{{ end }}", []string{"content", "messages", "role"}}, + {`{{- range .Messages }} +{{- if eq .Role "system" }}SYSTEM: +{{- else if eq .Role "user" }}USER: +{{- else if eq .Role "assistant" }}ASSISTANT: +{{- end }} {{ .Content }} +{{- end }}`, []string{"content", "messages", "role"}}, {`{{- if .Messages }} -{{- if .System }}<|im_start|>system -{{ .System }}<|im_end|> -{{ end }} {{- range .Messages }}<|im_start|>{{ .Role }} {{ .Content }}<|im_end|> {{ end }}<|im_start|>assistant