ollama/template/template.go

447 lines
9.1 KiB
Go
Raw Normal View History

2024-06-10 14:54:42 -07:00
package template
import (
"bytes"
"embed"
"encoding/json"
"errors"
2024-06-17 10:38:55 -07:00
"fmt"
2024-06-10 14:54:42 -07:00
"io"
"math"
"slices"
"strings"
"sync"
"text/template"
"text/template/parse"
"github.com/agnivade/levenshtein"
2024-06-17 10:38:55 -07:00
"github.com/ollama/ollama/api"
2024-06-10 14:54:42 -07:00
"golang.org/x/exp/maps"
)
//go:embed index.json
var indexBytes []byte
//go:embed *.gotmpl
//go:embed *.json
2024-06-10 14:54:42 -07:00
var templatesFS embed.FS
var templatesOnce = sync.OnceValues(func() ([]*named, error) {
var templates []*named
if err := json.Unmarshal(indexBytes, &templates); err != nil {
return nil, err
}
for _, t := range templates {
bts, err := templatesFS.ReadFile(t.Name + ".gotmpl")
if err != nil {
return nil, err
}
// normalize line endings
t.Bytes = bytes.ReplaceAll(bts, []byte("\r\n"), []byte("\n"))
params, err := templatesFS.ReadFile(t.Name + ".json")
if err != nil {
continue
}
if err := json.Unmarshal(params, &t.Parameters); err != nil {
return nil, err
}
2024-06-10 14:54:42 -07:00
}
return templates, nil
})
type named struct {
Name string `json:"name"`
Template string `json:"template"`
Bytes []byte
Parameters *struct {
Stop []string `json:"stop"`
}
2024-06-10 14:54:42 -07:00
}
func (t named) Reader() io.Reader {
return bytes.NewReader(t.Bytes)
}
func Named(s string) (*named, error) {
templates, err := templatesOnce()
if err != nil {
return nil, err
}
var template *named
score := math.MaxInt
for _, t := range templates {
if s := levenshtein.ComputeDistance(s, t.Template); s < score {
score = s
template = t
}
}
if score < 100 {
return template, nil
}
return nil, errors.New("no matching template found")
}
2024-06-17 10:38:55 -07:00
var DefaultTemplate, _ = Parse("{{ .Prompt }}")
2024-06-10 14:54:42 -07:00
type Template struct {
*template.Template
raw string
}
2024-06-20 11:00:08 -07:00
// response is a template node that can be added to templates that don't already have one
2024-06-17 10:38:55 -07:00
var response = parse.ActionNode{
NodeType: parse.NodeAction,
Pipe: &parse.PipeNode{
NodeType: parse.NodePipe,
Cmds: []*parse.CommandNode{
{
NodeType: parse.NodeCommand,
Args: []parse.Node{
&parse.FieldNode{
NodeType: parse.NodeField,
Ident: []string{"Response"},
},
},
},
},
},
2024-06-10 14:54:42 -07:00
}
2024-06-20 13:45:47 -07:00
var funcs = template.FuncMap{
"json": func(v any) string {
b, _ := json.Marshal(v)
return string(b)
},
}
2024-06-20 11:00:08 -07:00
func Parse(s string) (*Template, error) {
2024-06-20 13:45:47 -07:00
tmpl := template.New("").Option("missingkey=zero").Funcs(funcs)
2024-06-17 10:38:55 -07:00
tmpl, err := tmpl.Parse(s)
2024-06-10 14:54:42 -07:00
if err != nil {
return nil, err
}
2024-06-17 10:38:55 -07:00
t := Template{Template: tmpl, raw: s}
if vars := t.Vars(); !slices.Contains(vars, "messages") && !slices.Contains(vars, "response") {
// touch up the template and append {{ .Response }}
tmpl.Tree.Root.Nodes = append(tmpl.Tree.Root.Nodes, &response)
}
return &t, nil
}
func (t *Template) String() string {
return t.raw
2024-06-10 14:54:42 -07:00
}
func (t *Template) Vars() []string {
var vars []string
2024-06-17 10:38:55 -07:00
for _, tt := range t.Templates() {
for _, n := range tt.Root.Nodes {
2024-06-20 13:45:47 -07:00
vars = append(vars, Identifiers(n)...)
2024-06-17 10:38:55 -07:00
}
2024-06-10 14:54:42 -07:00
}
set := make(map[string]struct{})
for _, n := range vars {
set[strings.ToLower(n)] = struct{}{}
}
vars = maps.Keys(set)
slices.Sort(vars)
return vars
}
2024-06-17 10:38:55 -07:00
type Values struct {
Messages []api.Message
api.Tools
Prompt string
Suffix string
// forceLegacy is a flag used to test compatibility with legacy templates
forceLegacy bool
2024-06-17 10:38:55 -07:00
}
2024-06-20 13:45:47 -07:00
func (t *Template) Subtree(fn func(parse.Node) bool) *template.Template {
var walk func(parse.Node) parse.Node
walk = func(n parse.Node) parse.Node {
if fn(n) {
return n
}
switch t := n.(type) {
case *parse.ListNode:
for _, c := range t.Nodes {
if n := walk(c); n != nil {
return n
}
}
case *parse.BranchNode:
for _, n := range []*parse.ListNode{t.List, t.ElseList} {
if n != nil {
if n := walk(n); n != nil {
return n
}
}
}
case *parse.IfNode:
return walk(&t.BranchNode)
case *parse.WithNode:
return walk(&t.BranchNode)
case *parse.RangeNode:
return walk(&t.BranchNode)
}
return nil
}
if n := walk(t.Tree.Root); n != nil {
return (&template.Template{
Tree: &parse.Tree{
Root: &parse.ListNode{
Nodes: []parse.Node{n},
},
},
}).Funcs(funcs)
}
return nil
}
2024-06-17 10:38:55 -07:00
func (t *Template) Execute(w io.Writer, v Values) error {
system, messages := collate(v.Messages)
if v.Prompt != "" && v.Suffix != "" {
return t.Template.Execute(w, map[string]any{
"Prompt": v.Prompt,
"Suffix": v.Suffix,
"Response": "",
})
} else if !v.forceLegacy && slices.Contains(t.Vars(), "messages") {
2024-06-17 10:38:55 -07:00
return t.Template.Execute(w, map[string]any{
"System": system,
"Messages": messages,
2024-06-20 13:45:47 -07:00
"Tools": v.Tools,
2024-07-17 10:39:22 -07:00
"Response": "",
2024-06-17 10:38:55 -07:00
})
}
system = ""
2024-06-17 10:38:55 -07:00
var b bytes.Buffer
var prompt, response string
for _, m := range messages {
2024-06-20 13:45:47 -07:00
execute := func() error {
2024-06-17 10:38:55 -07:00
if err := t.Template.Execute(&b, map[string]any{
"System": system,
2024-06-17 10:38:55 -07:00
"Prompt": prompt,
"Response": response,
}); err != nil {
return err
}
system = ""
2024-06-17 10:38:55 -07:00
prompt = ""
response = ""
return nil
}
switch m.Role {
case "system":
if prompt != "" || response != "" {
if err := execute(); err != nil {
return err
}
}
system = m.Content
case "user":
if response != "" {
if err := execute(); err != nil {
return err
}
}
prompt = m.Content
case "assistant":
response = m.Content
2024-06-17 10:38:55 -07:00
}
}
var cut bool
nodes := deleteNode(t.Template.Root.Copy(), func(n parse.Node) bool {
2024-06-20 13:45:47 -07:00
if field, ok := n.(*parse.FieldNode); ok && slices.Contains(field.Ident, "Response") {
cut = true
return false
2024-06-17 10:38:55 -07:00
}
return cut
})
tree := parse.Tree{Root: nodes.(*parse.ListNode)}
if err := template.Must(template.New("").AddParseTree("", &tree)).Execute(&b, map[string]any{
2024-07-17 10:39:22 -07:00
"System": system,
"Prompt": prompt,
"Response": response,
2024-06-17 10:38:55 -07:00
}); err != nil {
return err
}
_, err := io.Copy(w, &b)
return err
}
2024-06-20 11:00:08 -07:00
// collate messages based on role. consecutive messages of the same role are merged
// into a single message. collate also collects and returns all system messages.
// collate mutates message content adding image tags ([img-%d]) as needed
func collate(msgs []api.Message) (string, []*api.Message) {
2024-06-17 10:38:55 -07:00
var n int
var system []string
var collated []*api.Message
2024-06-17 10:38:55 -07:00
for i := range msgs {
msg := msgs[i]
for range msg.Images {
imageTag := fmt.Sprintf("[img-%d]", n)
if !strings.Contains(msg.Content, "[img]") {
msg.Content = strings.TrimSpace("[img] " + msg.Content)
}
msg.Content = strings.Replace(msg.Content, "[img]", imageTag, 1)
n++
}
if msg.Role == "system" {
system = append(system, msg.Content)
}
2024-06-17 10:38:55 -07:00
if len(collated) > 0 && collated[len(collated)-1].Role == msg.Role {
collated[len(collated)-1].Content += "\n\n" + msg.Content
} else {
collated = append(collated, &msg)
}
}
return strings.Join(system, "\n\n"), collated
2024-06-17 10:38:55 -07:00
}
2024-06-20 13:45:47 -07:00
// Identifiers walks the node tree returning any identifiers it finds along the way
func Identifiers(n parse.Node) []string {
2024-06-10 14:54:42 -07:00
switch n := n.(type) {
2024-06-20 13:45:47 -07:00
case *parse.ListNode:
var names []string
for _, n := range n.Nodes {
names = append(names, Identifiers(n)...)
2024-06-10 14:54:42 -07:00
}
2024-06-20 13:45:47 -07:00
2024-06-10 14:54:42 -07:00
return names
2024-06-20 13:45:47 -07:00
case *parse.TemplateNode:
return Identifiers(n.Pipe)
case *parse.ActionNode:
return Identifiers(n.Pipe)
case *parse.BranchNode:
names := Identifiers(n.Pipe)
for _, n := range []*parse.ListNode{n.List, n.ElseList} {
if n != nil {
names = append(names, Identifiers(n)...)
}
2024-06-10 14:54:42 -07:00
}
return names
2024-06-20 13:45:47 -07:00
case *parse.IfNode:
return Identifiers(&n.BranchNode)
case *parse.RangeNode:
return Identifiers(&n.BranchNode)
2024-06-10 14:54:42 -07:00
case *parse.WithNode:
2024-06-20 13:45:47 -07:00
return Identifiers(&n.BranchNode)
2024-06-10 14:54:42 -07:00
case *parse.PipeNode:
var names []string
for _, c := range n.Cmds {
for _, a := range c.Args {
2024-06-20 13:45:47 -07:00
names = append(names, Identifiers(a)...)
2024-06-10 14:54:42 -07:00
}
}
return names
case *parse.FieldNode:
return n.Ident
2024-06-20 13:45:47 -07:00
case *parse.VariableNode:
return n.Ident
2024-06-10 14:54:42 -07:00
}
return nil
}
// deleteNode walks the node list and deletes nodes that match the predicate
// this is currently to remove the {{ .Response }} node from templates
func deleteNode(n parse.Node, fn func(parse.Node) bool) parse.Node {
var walk func(n parse.Node) parse.Node
walk = func(n parse.Node) parse.Node {
if fn(n) {
return nil
}
switch t := n.(type) {
case *parse.ListNode:
var nodes []parse.Node
for _, c := range t.Nodes {
if n := walk(c); n != nil {
nodes = append(nodes, n)
}
}
t.Nodes = nodes
return t
case *parse.IfNode:
t.BranchNode = *(walk(&t.BranchNode).(*parse.BranchNode))
case *parse.WithNode:
t.BranchNode = *(walk(&t.BranchNode).(*parse.BranchNode))
case *parse.RangeNode:
t.BranchNode = *(walk(&t.BranchNode).(*parse.BranchNode))
case *parse.BranchNode:
t.List = walk(t.List).(*parse.ListNode)
if t.ElseList != nil {
t.ElseList = walk(t.ElseList).(*parse.ListNode)
}
case *parse.ActionNode:
n := walk(t.Pipe)
if n == nil {
return nil
}
t.Pipe = n.(*parse.PipeNode)
case *parse.PipeNode:
var commands []*parse.CommandNode
for _, c := range t.Cmds {
var args []parse.Node
for _, a := range c.Args {
if n := walk(a); n != nil {
args = append(args, n)
}
}
if len(args) == 0 {
return nil
}
c.Args = args
commands = append(commands, c)
}
if len(commands) == 0 {
return nil
}
t.Cmds = commands
}
return n
}
return walk(n)
}