rename aggregate to contents

This commit is contained in:
Michael Yang 2024-07-11 16:06:57 -07:00
parent 57ec6901eb
commit 5056bb9c01
2 changed files with 41 additions and 7 deletions

View file

@ -103,15 +103,16 @@ var response = parse.ActionNode{
}
var funcs = template.FuncMap{
"aggregate": func(v []*api.Message, role string) string {
var aggregated []string
// contents returns the contents of messages with an optional role filter
"contents": func(v []*api.Message, role ...string) string {
var parts []string
for _, m := range v {
if m.Role == role {
aggregated = append(aggregated, m.Content)
if len(role) == 0 || role[0] == "" || m.Role == role[0] {
parts = append(parts, m.Content)
}
}
return strings.Join(aggregated, "\n\n")
return strings.Join(parts, "\n\n")
},
}

View file

@ -216,7 +216,7 @@ func TestExecuteWithMessages(t *testing.T) {
{"response", `[INST] {{ if .System }}{{ .System }}
{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`},
{"messages", `{{- $system := aggregate $.Messages "system" -}}
{"messages", `{{- $system := contents .Messages "system" -}}
{{- range $index, $_ := .Messages }}
{{- if eq .Role "user" }}[INST] {{ if $system }}{{ $system }}
{{- $system = "" }}
@ -243,7 +243,7 @@ func TestExecuteWithMessages(t *testing.T) {
{"response", `[INST] {{ if .System }}{{ .System }}
{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`},
{"messages", `{{- $system := aggregate $.Messages "system" -}}
{"messages", `{{- $system := contents .Messages "system" -}}
{{- range $index, $_ := .Messages }}
{{- if eq .Role "user" }}[INST] {{ if $system }}{{ $system }}
{{- $system = "" }}
@ -363,3 +363,36 @@ Answer: `,
})
}
}
func TestFuncs(t *testing.T) {
t.Run("contents", func(t *testing.T) {
cases := map[string]string{
"": "A\n\nB\n\nC\n\nD\n\nE\n\nF",
"system": "A\n\nF",
"user": "B\n\nE",
"assistant": "C\n\nD",
}
s := []*api.Message{
{Role: "system", Content: "A"},
{Role: "user", Content: "B"},
{Role: "assistant", Content: "C"},
{Role: "assistant", Content: "D"},
{Role: "user", Content: "E"},
{Role: "system", Content: "F"},
}
fn, ok := funcs["contents"].(func([]*api.Message, ...string) string)
if !ok {
t.Fatal("contents is not a function")
}
for k, v := range cases {
t.Run(k, func(t *testing.T) {
if diff := cmp.Diff(fn(s, k), v); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
}
})
}