template: preprocess message and collect system
This commit is contained in:
parent
179737feb7
commit
36c87c433b
2 changed files with 23 additions and 67 deletions
|
@ -102,22 +102,8 @@ var response = parse.ActionNode{
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
var funcs = template.FuncMap{
|
|
||||||
// contents returns the contents of messages with an optional role filter
|
|
||||||
"contents": func(v []*api.Message, role ...string) string {
|
|
||||||
var parts []string
|
|
||||||
for _, m := range v {
|
|
||||||
if len(role) == 0 || role[0] == "" || m.Role == role[0] {
|
|
||||||
parts = append(parts, m.Content)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return strings.Join(parts, "\n\n")
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
func Parse(s string) (*Template, error) {
|
func Parse(s string) (*Template, error) {
|
||||||
tmpl := template.New("").Option("missingkey=zero").Funcs(funcs)
|
tmpl := template.New("").Option("missingkey=zero")
|
||||||
|
|
||||||
tmpl, err := tmpl.Parse(s)
|
tmpl, err := tmpl.Parse(s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -163,15 +149,16 @@ type Values struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *Template) Execute(w io.Writer, v Values) error {
|
func (t *Template) Execute(w io.Writer, v Values) error {
|
||||||
collated := collate(v.Messages)
|
system, collated := collate(v.Messages)
|
||||||
if !v.forceLegacy && slices.Contains(t.Vars(), "messages") {
|
if !v.forceLegacy && slices.Contains(t.Vars(), "messages") {
|
||||||
return t.Template.Execute(w, map[string]any{
|
return t.Template.Execute(w, map[string]any{
|
||||||
|
"System": system,
|
||||||
"Messages": collated,
|
"Messages": collated,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
var b bytes.Buffer
|
var b bytes.Buffer
|
||||||
var system, prompt, response string
|
var prompt, response string
|
||||||
for i, m := range collated {
|
for i, m := range collated {
|
||||||
switch m.Role {
|
switch m.Role {
|
||||||
case "system":
|
case "system":
|
||||||
|
@ -223,11 +210,13 @@ func (t *Template) Execute(w io.Writer, v Values) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// collate messages based on role. consecutive messages of the same role are merged
|
// collate messages based on role. consecutive messages of the same role are merged
|
||||||
// into a single message. collate also pulls out and merges messages with Role == "system"
|
// into a single message. collate also collects and returns all system messages.
|
||||||
// which are templated separately. As a side effect, it mangles message content adding image
|
// collate mutates message content adding image tags ([img-%d]) as needed
|
||||||
// tags ([img-%d]) as needed
|
func collate(msgs []api.Message) (string, []*api.Message) {
|
||||||
func collate(msgs []api.Message) (collated []*api.Message) {
|
|
||||||
var n int
|
var n int
|
||||||
|
|
||||||
|
var system []string
|
||||||
|
var collated []*api.Message
|
||||||
for i := range msgs {
|
for i := range msgs {
|
||||||
msg := msgs[i]
|
msg := msgs[i]
|
||||||
for range msg.Images {
|
for range msg.Images {
|
||||||
|
@ -240,6 +229,10 @@ func collate(msgs []api.Message) (collated []*api.Message) {
|
||||||
n++
|
n++
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if msg.Role == "system" {
|
||||||
|
system = append(system, msg.Content)
|
||||||
|
}
|
||||||
|
|
||||||
if len(collated) > 0 && collated[len(collated)-1].Role == msg.Role {
|
if len(collated) > 0 && collated[len(collated)-1].Role == msg.Role {
|
||||||
collated[len(collated)-1].Content += "\n\n" + msg.Content
|
collated[len(collated)-1].Content += "\n\n" + msg.Content
|
||||||
} else {
|
} else {
|
||||||
|
@ -247,7 +240,7 @@ func collate(msgs []api.Message) (collated []*api.Message) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return
|
return strings.Join(system, "\n\n"), collated
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseNode(n parse.Node) []string {
|
func parseNode(n parse.Node) []string {
|
||||||
|
|
|
@ -216,13 +216,11 @@ func TestExecuteWithMessages(t *testing.T) {
|
||||||
{"response", `[INST] {{ if .System }}{{ .System }}
|
{"response", `[INST] {{ if .System }}{{ .System }}
|
||||||
|
|
||||||
{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`},
|
{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`},
|
||||||
{"messages", `{{- $system := contents .Messages "system" -}}
|
{"messages", `[INST] {{ if .System }}{{ .System }}
|
||||||
{{- range $index, $_ := .Messages }}
|
|
||||||
{{- if eq .Role "user" }}[INST] {{ if $system }}{{ $system }}
|
|
||||||
{{- $system = "" }}
|
|
||||||
|
|
||||||
{{ end }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}
|
{{ end }}
|
||||||
{{- end }}
|
{{- range .Messages }}
|
||||||
|
{{- if eq .Role "user" }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}[INST] {{ end }}
|
||||||
{{- end }}`},
|
{{- end }}`},
|
||||||
},
|
},
|
||||||
Values{
|
Values{
|
||||||
|
@ -243,13 +241,11 @@ func TestExecuteWithMessages(t *testing.T) {
|
||||||
{"response", `[INST] {{ if .System }}{{ .System }}
|
{"response", `[INST] {{ if .System }}{{ .System }}
|
||||||
|
|
||||||
{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`},
|
{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`},
|
||||||
{"messages", `{{- $system := contents .Messages "system" -}}
|
{"messages", `[INST] {{ if .System }}{{ .System }}
|
||||||
{{- range $index, $_ := .Messages }}
|
|
||||||
{{- if eq .Role "user" }}[INST] {{ if $system }}{{ $system }}
|
|
||||||
{{- $system = "" }}
|
|
||||||
|
|
||||||
{{ end }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}
|
{{ end }}
|
||||||
{{- end }}
|
{{- range .Messages }}
|
||||||
|
{{- if eq .Role "user" }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}[INST] {{ end }}
|
||||||
{{- end }}`},
|
{{- end }}`},
|
||||||
},
|
},
|
||||||
Values{
|
Values{
|
||||||
|
@ -363,36 +359,3 @@ Answer: `,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFuncs(t *testing.T) {
|
|
||||||
t.Run("contents", func(t *testing.T) {
|
|
||||||
cases := map[string]string{
|
|
||||||
"": "A\n\nB\n\nC\n\nD\n\nE\n\nF",
|
|
||||||
"system": "A\n\nF",
|
|
||||||
"user": "B\n\nE",
|
|
||||||
"assistant": "C\n\nD",
|
|
||||||
}
|
|
||||||
|
|
||||||
s := []*api.Message{
|
|
||||||
{Role: "system", Content: "A"},
|
|
||||||
{Role: "user", Content: "B"},
|
|
||||||
{Role: "assistant", Content: "C"},
|
|
||||||
{Role: "assistant", Content: "D"},
|
|
||||||
{Role: "user", Content: "E"},
|
|
||||||
{Role: "system", Content: "F"},
|
|
||||||
}
|
|
||||||
|
|
||||||
fn, ok := funcs["contents"].(func([]*api.Message, ...string) string)
|
|
||||||
if !ok {
|
|
||||||
t.Fatal("contents is not a function")
|
|
||||||
}
|
|
||||||
|
|
||||||
for k, v := range cases {
|
|
||||||
t.Run(k, func(t *testing.T) {
|
|
||||||
if diff := cmp.Diff(fn(s, k), v); diff != "" {
|
|
||||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
Loading…
Reference in a new issue