do no automatically aggregate system messages

This commit is contained in:
Michael Yang 2024-07-11 13:10:13 -07:00
parent 791650ddef
commit e64f9ebb44
2 changed files with 27 additions and 23 deletions

View file

@ -102,8 +102,21 @@ var response = parse.ActionNode{
}, },
} }
var funcs = template.FuncMap{
"aggregate": func(v []*api.Message, role string) string {
var aggregated []string
for _, m := range v {
if m.Role == role {
aggregated = append(aggregated, m.Content)
}
}
return strings.Join(aggregated, "\n\n")
},
}
func Parse(s string) (*Template, error) { func Parse(s string) (*Template, error) {
tmpl := template.New("").Option("missingkey=zero") tmpl := template.New("").Option("missingkey=zero").Funcs(funcs)
tmpl, err := tmpl.Parse(s) tmpl, err := tmpl.Parse(s)
if err != nil { if err != nil {
@ -149,23 +162,21 @@ type Values struct {
} }
func (t *Template) Execute(w io.Writer, v Values) error { func (t *Template) Execute(w io.Writer, v Values) error {
system, collated := collate(v.Messages) collated := collate(v.Messages)
if !v.forceLegacy && slices.Contains(t.Vars(), "messages") { if !v.forceLegacy && slices.Contains(t.Vars(), "messages") {
return t.Template.Execute(w, map[string]any{ return t.Template.Execute(w, map[string]any{
"System": system,
"Messages": collated, "Messages": collated,
}) })
} }
var b bytes.Buffer var b bytes.Buffer
var prompt, response string var system, prompt, response string
for i, m := range collated { for i, m := range collated {
switch m.Role { switch m.Role {
case "system":
system = m.Content
case "user": case "user":
prompt = m.Content prompt = m.Content
if i != 0 {
system = ""
}
case "assistant": case "assistant":
response = m.Content response = m.Content
} }
@ -179,6 +190,7 @@ func (t *Template) Execute(w io.Writer, v Values) error {
return err return err
} }
system = ""
prompt = "" prompt = ""
response = "" response = ""
} }
@ -209,25 +221,14 @@ func (t *Template) Execute(w io.Writer, v Values) error {
return err return err
} }
type messages []*api.Message
// collate messages based on role. consecutive messages of the same role are merged // collate messages based on role. consecutive messages of the same role are merged
// into a single message. collate also pulls out and merges messages with Role == "system" // into a single message. collate also pulls out and merges messages with Role == "system"
// which are templated separately. As a side effect, it mangles message content adding image // which are templated separately. As a side effect, it mangles message content adding image
// tags ([img-%d]) as needed // tags ([img-%d]) as needed
func collate(msgs []api.Message) (system string, collated messages) { func collate(msgs []api.Message) (collated []*api.Message) {
var n int var n int
for i := range msgs { for i := range msgs {
msg := msgs[i] msg := msgs[i]
if msg.Role == "system" {
if system != "" {
system += "\n\n"
}
system += msg.Content
continue
}
for range msg.Images { for range msg.Images {
imageTag := fmt.Sprintf("[img-%d]", n) imageTag := fmt.Sprintf("[img-%d]", n)
if !strings.Contains(msg.Content, "[img]") { if !strings.Contains(msg.Content, "[img]") {

View file

@ -122,6 +122,7 @@ func TestTemplate(t *testing.T) {
}) })
t.Run("legacy", func(t *testing.T) { t.Run("legacy", func(t *testing.T) {
t.Skip("legacy outputs are currently default outputs")
var legacy bytes.Buffer var legacy bytes.Buffer
if err := tmpl.Execute(&legacy, Values{Messages: tt, forceLegacy: true}); err != nil { if err := tmpl.Execute(&legacy, Values{Messages: tt, forceLegacy: true}); err != nil {
t.Fatal(err) t.Fatal(err)
@ -154,11 +155,13 @@ func TestParse(t *testing.T) {
{"{{ .System }} {{ .Prompt }} {{ .Response }}", []string{"prompt", "response", "system"}}, {"{{ .System }} {{ .Prompt }} {{ .Response }}", []string{"prompt", "response", "system"}},
{"{{ with .Tools }}{{ . }}{{ end }} {{ .System }} {{ .Prompt }}", []string{"prompt", "response", "system", "tools"}}, {"{{ with .Tools }}{{ . }}{{ end }} {{ .System }} {{ .Prompt }}", []string{"prompt", "response", "system", "tools"}},
{"{{ range .Messages }}{{ .Role }} {{ .Content }}{{ end }}", []string{"content", "messages", "role"}}, {"{{ range .Messages }}{{ .Role }} {{ .Content }}{{ end }}", []string{"content", "messages", "role"}},
{"{{ range .Messages }}{{ if eq .Role \"system\" }}SYSTEM: {{ .Content }}{{ else if eq .Role \"user\" }}USER: {{ .Content }}{{ else if eq .Role \"assistant\" }}ASSISTANT: {{ .Content }}{{ end }}{{ end }}", []string{"content", "messages", "role"}}, {`{{- range .Messages }}
{{- if eq .Role "system" }}SYSTEM:
{{- else if eq .Role "user" }}USER:
{{- else if eq .Role "assistant" }}ASSISTANT:
{{- end }} {{ .Content }}
{{- end }}`, []string{"content", "messages", "role"}},
{`{{- if .Messages }} {`{{- if .Messages }}
{{- if .System }}<|im_start|>system
{{ .System }}<|im_end|>
{{ end }}
{{- range .Messages }}<|im_start|>{{ .Role }} {{- range .Messages }}<|im_start|>{{ .Role }}
{{ .Content }}<|im_end|> {{ .Content }}<|im_end|>
{{ end }}<|im_start|>assistant {{ end }}<|im_start|>assistant