2024-06-10 21:54:42 +00:00
|
|
|
package template
|
|
|
|
|
|
|
|
import (
|
|
|
|
"bytes"
|
|
|
|
"embed"
|
|
|
|
"encoding/json"
|
|
|
|
"errors"
|
2024-06-17 17:38:55 +00:00
|
|
|
"fmt"
|
2024-06-10 21:54:42 +00:00
|
|
|
"io"
|
|
|
|
"math"
|
|
|
|
"slices"
|
|
|
|
"strings"
|
|
|
|
"sync"
|
|
|
|
"text/template"
|
|
|
|
"text/template/parse"
|
|
|
|
|
|
|
|
"github.com/agnivade/levenshtein"
|
2024-06-17 17:38:55 +00:00
|
|
|
"github.com/ollama/ollama/api"
|
2024-06-10 21:54:42 +00:00
|
|
|
"golang.org/x/exp/maps"
|
|
|
|
)
|
|
|
|
|
|
|
|
//go:embed index.json
|
|
|
|
var indexBytes []byte
|
|
|
|
|
|
|
|
//go:embed *.gotmpl
|
|
|
|
var templatesFS embed.FS
|
|
|
|
|
|
|
|
var templatesOnce = sync.OnceValues(func() ([]*named, error) {
|
|
|
|
var templates []*named
|
|
|
|
if err := json.Unmarshal(indexBytes, &templates); err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
|
|
|
for _, t := range templates {
|
|
|
|
bts, err := templatesFS.ReadFile(t.Name + ".gotmpl")
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
|
|
|
// normalize line endings
|
|
|
|
t.Bytes = bytes.ReplaceAll(bts, []byte("\r\n"), []byte("\n"))
|
|
|
|
}
|
|
|
|
|
|
|
|
return templates, nil
|
|
|
|
})
|
|
|
|
|
|
|
|
type named struct {
|
|
|
|
Name string `json:"name"`
|
|
|
|
Template string `json:"template"`
|
|
|
|
Bytes []byte
|
|
|
|
}
|
|
|
|
|
|
|
|
func (t named) Reader() io.Reader {
|
|
|
|
return bytes.NewReader(t.Bytes)
|
|
|
|
}
|
|
|
|
|
|
|
|
func Named(s string) (*named, error) {
|
|
|
|
templates, err := templatesOnce()
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
|
|
|
var template *named
|
|
|
|
score := math.MaxInt
|
|
|
|
for _, t := range templates {
|
|
|
|
if s := levenshtein.ComputeDistance(s, t.Template); s < score {
|
|
|
|
score = s
|
|
|
|
template = t
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
if score < 100 {
|
|
|
|
return template, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
return nil, errors.New("no matching template found")
|
|
|
|
}
|
|
|
|
|
2024-06-17 17:38:55 +00:00
|
|
|
var DefaultTemplate, _ = Parse("{{ .Prompt }}")
|
|
|
|
|
2024-06-10 21:54:42 +00:00
|
|
|
type Template struct {
|
|
|
|
*template.Template
|
|
|
|
raw string
|
|
|
|
}
|
|
|
|
|
2024-06-20 18:00:08 +00:00
|
|
|
// response is a template node that can be added to templates that don't already have one
|
2024-06-17 17:38:55 +00:00
|
|
|
var response = parse.ActionNode{
|
|
|
|
NodeType: parse.NodeAction,
|
|
|
|
Pipe: &parse.PipeNode{
|
|
|
|
NodeType: parse.NodePipe,
|
|
|
|
Cmds: []*parse.CommandNode{
|
|
|
|
{
|
|
|
|
NodeType: parse.NodeCommand,
|
|
|
|
Args: []parse.Node{
|
|
|
|
&parse.FieldNode{
|
|
|
|
NodeType: parse.NodeField,
|
|
|
|
Ident: []string{"Response"},
|
|
|
|
},
|
|
|
|
},
|
|
|
|
},
|
|
|
|
},
|
|
|
|
},
|
2024-06-10 21:54:42 +00:00
|
|
|
}
|
|
|
|
|
2024-06-20 18:00:08 +00:00
|
|
|
func Parse(s string) (*Template, error) {
|
2024-07-12 18:48:06 +00:00
|
|
|
tmpl := template.New("").Option("missingkey=zero")
|
2024-06-17 17:38:55 +00:00
|
|
|
|
|
|
|
tmpl, err := tmpl.Parse(s)
|
2024-06-10 21:54:42 +00:00
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
2024-06-17 17:38:55 +00:00
|
|
|
t := Template{Template: tmpl, raw: s}
|
|
|
|
if vars := t.Vars(); !slices.Contains(vars, "messages") && !slices.Contains(vars, "response") {
|
|
|
|
// touch up the template and append {{ .Response }}
|
|
|
|
tmpl.Tree.Root.Nodes = append(tmpl.Tree.Root.Nodes, &response)
|
|
|
|
}
|
|
|
|
|
|
|
|
return &t, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func (t *Template) String() string {
|
|
|
|
return t.raw
|
2024-06-10 21:54:42 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
func (t *Template) Vars() []string {
|
|
|
|
var vars []string
|
2024-06-17 17:38:55 +00:00
|
|
|
for _, tt := range t.Templates() {
|
|
|
|
for _, n := range tt.Root.Nodes {
|
|
|
|
vars = append(vars, parseNode(n)...)
|
|
|
|
}
|
2024-06-10 21:54:42 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
set := make(map[string]struct{})
|
|
|
|
for _, n := range vars {
|
|
|
|
set[strings.ToLower(n)] = struct{}{}
|
|
|
|
}
|
|
|
|
|
|
|
|
vars = maps.Keys(set)
|
|
|
|
slices.Sort(vars)
|
|
|
|
return vars
|
|
|
|
}
|
|
|
|
|
2024-06-17 17:38:55 +00:00
|
|
|
type Values struct {
|
|
|
|
Messages []api.Message
|
2024-07-10 18:00:07 +00:00
|
|
|
|
|
|
|
// forceLegacy is a flag used to test compatibility with legacy templates
|
|
|
|
forceLegacy bool
|
2024-06-17 17:38:55 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
func (t *Template) Execute(w io.Writer, v Values) error {
|
2024-07-12 18:48:06 +00:00
|
|
|
system, collated := collate(v.Messages)
|
2024-07-10 18:00:07 +00:00
|
|
|
if !v.forceLegacy && slices.Contains(t.Vars(), "messages") {
|
2024-06-17 17:38:55 +00:00
|
|
|
return t.Template.Execute(w, map[string]any{
|
2024-07-12 18:48:06 +00:00
|
|
|
"System": system,
|
2024-06-17 17:38:55 +00:00
|
|
|
"Messages": collated,
|
|
|
|
})
|
|
|
|
}
|
|
|
|
|
|
|
|
var b bytes.Buffer
|
2024-07-12 18:48:06 +00:00
|
|
|
var prompt, response string
|
2024-06-17 17:38:55 +00:00
|
|
|
for i, m := range collated {
|
2024-07-10 18:00:07 +00:00
|
|
|
switch m.Role {
|
2024-07-11 20:10:13 +00:00
|
|
|
case "system":
|
|
|
|
system = m.Content
|
2024-07-10 18:00:07 +00:00
|
|
|
case "user":
|
2024-06-17 17:38:55 +00:00
|
|
|
prompt = m.Content
|
2024-07-10 18:00:07 +00:00
|
|
|
case "assistant":
|
2024-06-17 17:38:55 +00:00
|
|
|
response = m.Content
|
|
|
|
}
|
|
|
|
|
|
|
|
if i != len(collated)-1 && prompt != "" && response != "" {
|
|
|
|
if err := t.Template.Execute(&b, map[string]any{
|
2024-07-10 18:00:07 +00:00
|
|
|
"System": system,
|
2024-06-17 17:38:55 +00:00
|
|
|
"Prompt": prompt,
|
|
|
|
"Response": response,
|
|
|
|
}); err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
2024-07-11 20:10:13 +00:00
|
|
|
system = ""
|
2024-06-17 17:38:55 +00:00
|
|
|
prompt = ""
|
|
|
|
response = ""
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
var cut bool
|
2024-07-10 18:00:07 +00:00
|
|
|
nodes := deleteNode(t.Template.Root.Copy(), func(n parse.Node) bool {
|
|
|
|
switch t := n.(type) {
|
|
|
|
case *parse.ActionNode:
|
|
|
|
case *parse.FieldNode:
|
|
|
|
if slices.Contains(t.Ident, "Response") {
|
|
|
|
cut = true
|
|
|
|
}
|
2024-06-17 17:38:55 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
return cut
|
|
|
|
})
|
|
|
|
|
2024-07-10 18:00:07 +00:00
|
|
|
tree := parse.Tree{Root: nodes.(*parse.ListNode)}
|
|
|
|
if err := template.Must(template.New("").AddParseTree("", &tree)).Execute(&b, map[string]any{
|
|
|
|
"System": "",
|
2024-06-17 17:38:55 +00:00
|
|
|
"Prompt": prompt,
|
|
|
|
}); err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
_, err := io.Copy(w, &b)
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
2024-06-20 18:00:08 +00:00
|
|
|
// collate messages based on role. consecutive messages of the same role are merged
|
2024-07-12 18:48:06 +00:00
|
|
|
// into a single message. collate also collects and returns all system messages.
|
|
|
|
// collate mutates message content adding image tags ([img-%d]) as needed
|
|
|
|
func collate(msgs []api.Message) (string, []*api.Message) {
|
2024-06-17 17:38:55 +00:00
|
|
|
var n int
|
2024-07-12 18:48:06 +00:00
|
|
|
|
|
|
|
var system []string
|
|
|
|
var collated []*api.Message
|
2024-06-17 17:38:55 +00:00
|
|
|
for i := range msgs {
|
|
|
|
msg := msgs[i]
|
|
|
|
for range msg.Images {
|
|
|
|
imageTag := fmt.Sprintf("[img-%d]", n)
|
|
|
|
if !strings.Contains(msg.Content, "[img]") {
|
|
|
|
msg.Content = strings.TrimSpace("[img] " + msg.Content)
|
|
|
|
}
|
|
|
|
|
|
|
|
msg.Content = strings.Replace(msg.Content, "[img]", imageTag, 1)
|
|
|
|
n++
|
|
|
|
}
|
|
|
|
|
2024-07-12 18:48:06 +00:00
|
|
|
if msg.Role == "system" {
|
|
|
|
system = append(system, msg.Content)
|
|
|
|
}
|
|
|
|
|
2024-06-17 17:38:55 +00:00
|
|
|
if len(collated) > 0 && collated[len(collated)-1].Role == msg.Role {
|
|
|
|
collated[len(collated)-1].Content += "\n\n" + msg.Content
|
|
|
|
} else {
|
|
|
|
collated = append(collated, &msg)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2024-07-12 18:48:06 +00:00
|
|
|
return strings.Join(system, "\n\n"), collated
|
2024-06-17 17:38:55 +00:00
|
|
|
}
|
|
|
|
|
2024-06-10 21:54:42 +00:00
|
|
|
func parseNode(n parse.Node) []string {
|
|
|
|
switch n := n.(type) {
|
|
|
|
case *parse.ActionNode:
|
|
|
|
return parseNode(n.Pipe)
|
|
|
|
case *parse.IfNode:
|
|
|
|
names := parseNode(n.Pipe)
|
|
|
|
names = append(names, parseNode(n.List)...)
|
|
|
|
if n.ElseList != nil {
|
|
|
|
names = append(names, parseNode(n.ElseList)...)
|
|
|
|
}
|
|
|
|
return names
|
|
|
|
case *parse.RangeNode:
|
|
|
|
names := parseNode(n.Pipe)
|
|
|
|
names = append(names, parseNode(n.List)...)
|
|
|
|
if n.ElseList != nil {
|
|
|
|
names = append(names, parseNode(n.ElseList)...)
|
|
|
|
}
|
|
|
|
return names
|
|
|
|
case *parse.WithNode:
|
|
|
|
names := parseNode(n.Pipe)
|
|
|
|
names = append(names, parseNode(n.List)...)
|
|
|
|
if n.ElseList != nil {
|
|
|
|
names = append(names, parseNode(n.ElseList)...)
|
|
|
|
}
|
|
|
|
return names
|
|
|
|
case *parse.PipeNode:
|
|
|
|
var names []string
|
|
|
|
for _, c := range n.Cmds {
|
|
|
|
for _, a := range c.Args {
|
|
|
|
names = append(names, parseNode(a)...)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return names
|
|
|
|
case *parse.ListNode:
|
|
|
|
var names []string
|
|
|
|
for _, n := range n.Nodes {
|
|
|
|
names = append(names, parseNode(n)...)
|
|
|
|
}
|
|
|
|
|
|
|
|
return names
|
|
|
|
case *parse.FieldNode:
|
|
|
|
return n.Ident
|
2024-06-17 17:38:55 +00:00
|
|
|
case *parse.TemplateNode:
|
|
|
|
return parseNode(n.Pipe)
|
2024-06-10 21:54:42 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
return nil
|
|
|
|
}
|
2024-07-10 18:00:07 +00:00
|
|
|
|
|
|
|
// deleteNode walks the node list and deletes nodes that match the predicate
|
|
|
|
// this is currently to remove the {{ .Response }} node from templates
|
|
|
|
func deleteNode(n parse.Node, fn func(parse.Node) bool) parse.Node {
|
|
|
|
var walk func(n parse.Node) parse.Node
|
|
|
|
walk = func(n parse.Node) parse.Node {
|
|
|
|
if fn(n) {
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
switch t := n.(type) {
|
|
|
|
case *parse.ListNode:
|
|
|
|
var nodes []parse.Node
|
|
|
|
for _, c := range t.Nodes {
|
|
|
|
if n := walk(c); n != nil {
|
|
|
|
nodes = append(nodes, n)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
t.Nodes = nodes
|
|
|
|
return t
|
|
|
|
case *parse.IfNode:
|
|
|
|
t.BranchNode = *(walk(&t.BranchNode).(*parse.BranchNode))
|
|
|
|
case *parse.WithNode:
|
|
|
|
t.BranchNode = *(walk(&t.BranchNode).(*parse.BranchNode))
|
|
|
|
case *parse.RangeNode:
|
|
|
|
t.BranchNode = *(walk(&t.BranchNode).(*parse.BranchNode))
|
|
|
|
case *parse.BranchNode:
|
|
|
|
t.List = walk(t.List).(*parse.ListNode)
|
|
|
|
if t.ElseList != nil {
|
|
|
|
t.ElseList = walk(t.ElseList).(*parse.ListNode)
|
|
|
|
}
|
|
|
|
case *parse.ActionNode:
|
|
|
|
n := walk(t.Pipe)
|
|
|
|
if n == nil {
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
t.Pipe = n.(*parse.PipeNode)
|
|
|
|
case *parse.PipeNode:
|
|
|
|
var commands []*parse.CommandNode
|
|
|
|
for _, c := range t.Cmds {
|
|
|
|
var args []parse.Node
|
|
|
|
for _, a := range c.Args {
|
|
|
|
if n := walk(a); n != nil {
|
|
|
|
args = append(args, n)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
if len(args) == 0 {
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
c.Args = args
|
|
|
|
commands = append(commands, c)
|
|
|
|
}
|
|
|
|
|
|
|
|
if len(commands) == 0 {
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
t.Cmds = commands
|
|
|
|
}
|
|
|
|
|
|
|
|
return n
|
|
|
|
}
|
|
|
|
|
|
|
|
return walk(n)
|
|
|
|
}
|