diff --git a/template/template.go b/template/template.go index 8d5ac51b..21e1614d 100644 --- a/template/template.go +++ b/template/template.go @@ -103,15 +103,16 @@ var response = parse.ActionNode{ } var funcs = template.FuncMap{ - "aggregate": func(v []*api.Message, role string) string { - var aggregated []string + // contents returns the contents of messages with an optional role filter + "contents": func(v []*api.Message, role ...string) string { + var parts []string for _, m := range v { - if m.Role == role { - aggregated = append(aggregated, m.Content) + if len(role) == 0 || role[0] == "" || m.Role == role[0] { + parts = append(parts, m.Content) } } - return strings.Join(aggregated, "\n\n") + return strings.Join(parts, "\n\n") }, } diff --git a/template/template_test.go b/template/template_test.go index 9cfa0bea..5e5f4257 100644 --- a/template/template_test.go +++ b/template/template_test.go @@ -216,7 +216,7 @@ func TestExecuteWithMessages(t *testing.T) { {"response", `[INST] {{ if .System }}{{ .System }} {{ end }}{{ .Prompt }}[/INST] {{ .Response }}`}, - {"messages", `{{- $system := aggregate $.Messages "system" -}} + {"messages", `{{- $system := contents .Messages "system" -}} {{- range $index, $_ := .Messages }} {{- if eq .Role "user" }}[INST] {{ if $system }}{{ $system }} {{- $system = "" }} @@ -243,7 +243,7 @@ func TestExecuteWithMessages(t *testing.T) { {"response", `[INST] {{ if .System }}{{ .System }} {{ end }}{{ .Prompt }}[/INST] {{ .Response }}`}, - {"messages", `{{- $system := aggregate $.Messages "system" -}} + {"messages", `{{- $system := contents .Messages "system" -}} {{- range $index, $_ := .Messages }} {{- if eq .Role "user" }}[INST] {{ if $system }}{{ $system }} {{- $system = "" }} @@ -363,3 +363,36 @@ Answer: `, }) } } + +func TestFuncs(t *testing.T) { + t.Run("contents", func(t *testing.T) { + cases := map[string]string{ + "": "A\n\nB\n\nC\n\nD\n\nE\n\nF", + "system": "A\n\nF", + "user": "B\n\nE", + "assistant": "C\n\nD", + } + + s := []*api.Message{ + {Role: "system", Content: "A"}, + {Role: "user", Content: "B"}, + {Role: "assistant", Content: "C"}, + {Role: "assistant", Content: "D"}, + {Role: "user", Content: "E"}, + {Role: "system", Content: "F"}, + } + + fn, ok := funcs["contents"].(func([]*api.Message, ...string) string) + if !ok { + t.Fatal("contents is not a function") + } + + for k, v := range cases { + t.Run(k, func(t *testing.T) { + if diff := cmp.Diff(fn(s, k), v); diff != "" { + t.Errorf("mismatch (-got +want):\n%s", diff) + } + }) + } + }) +}