rename templates to template

This commit is contained in:
Michael Yang 2024-06-10 14:54:42 -07:00
parent 3f0b309ad4
commit 58e3fff311
29 changed files with 301 additions and 164 deletions

View file

@ -28,6 +28,7 @@ import (
"github.com/ollama/ollama/format" "github.com/ollama/ollama/format"
"github.com/ollama/ollama/llm" "github.com/ollama/ollama/llm"
"github.com/ollama/ollama/parser" "github.com/ollama/ollama/parser"
"github.com/ollama/ollama/template"
"github.com/ollama/ollama/types/errtypes" "github.com/ollama/ollama/types/errtypes"
"github.com/ollama/ollama/types/model" "github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version" "github.com/ollama/ollama/version"
@ -48,12 +49,13 @@ type Model struct {
ParentModel string ParentModel string
AdapterPaths []string AdapterPaths []string
ProjectorPaths []string ProjectorPaths []string
Template string
System string System string
License []string License []string
Digest string Digest string
Options map[string]interface{} Options map[string]interface{}
Messages []Message Messages []Message
Template *template.Template
} }
func (m *Model) IsEmbedding() bool { func (m *Model) IsEmbedding() bool {
@ -82,10 +84,10 @@ func (m *Model) String() string {
}) })
} }
if m.Template != "" { if m.Template != nil {
modelfile.Commands = append(modelfile.Commands, parser.Command{ modelfile.Commands = append(modelfile.Commands, parser.Command{
Name: "template", Name: "template",
Args: m.Template, Args: m.Template.String(),
}) })
} }
@ -191,8 +193,7 @@ func GetModel(name string) (*Model, error) {
Name: mp.GetFullTagname(), Name: mp.GetFullTagname(),
ShortName: mp.GetShortTagname(), ShortName: mp.GetShortTagname(),
Digest: digest, Digest: digest,
Template: "{{ .Prompt }}", Template: template.DefaultTemplate,
License: []string{},
} }
filename, err := GetBlobsPath(manifest.Config.Digest) filename, err := GetBlobsPath(manifest.Config.Digest)
@ -228,13 +229,17 @@ func GetModel(name string) (*Model, error) {
model.AdapterPaths = append(model.AdapterPaths, filename) model.AdapterPaths = append(model.AdapterPaths, filename)
case "application/vnd.ollama.image.projector": case "application/vnd.ollama.image.projector":
model.ProjectorPaths = append(model.ProjectorPaths, filename) model.ProjectorPaths = append(model.ProjectorPaths, filename)
case "application/vnd.ollama.image.template": case "application/vnd.ollama.image.prompt",
"application/vnd.ollama.image.template":
bts, err := os.ReadFile(filename) bts, err := os.ReadFile(filename)
if err != nil { if err != nil {
return nil, err return nil, err
} }
model.Template = string(bts) model.Template, err = template.Parse(string(bts))
if err != nil {
return nil, err
}
case "application/vnd.ollama.image.system": case "application/vnd.ollama.image.system":
bts, err := os.ReadFile(filename) bts, err := os.ReadFile(filename)
if err != nil { if err != nil {
@ -242,13 +247,6 @@ func GetModel(name string) (*Model, error) {
} }
model.System = string(bts) model.System = string(bts)
case "application/vnd.ollama.image.prompt":
bts, err := os.ReadFile(filename)
if err != nil {
return nil, err
}
model.Template = string(bts)
case "application/vnd.ollama.image.params": case "application/vnd.ollama.image.params":
params, err := os.Open(filename) params, err := os.Open(filename)
if err != nil { if err != nil {

View file

@ -16,7 +16,7 @@ import (
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/convert" "github.com/ollama/ollama/convert"
"github.com/ollama/ollama/llm" "github.com/ollama/ollama/llm"
"github.com/ollama/ollama/templates" "github.com/ollama/ollama/template"
"github.com/ollama/ollama/types/model" "github.com/ollama/ollama/types/model"
) )
@ -258,7 +258,7 @@ func parseFromFile(ctx context.Context, file *os.File, digest string, fn func(ap
func detectChatTemplate(layers []*layerGGML) ([]*layerGGML, error) { func detectChatTemplate(layers []*layerGGML) ([]*layerGGML, error) {
for _, layer := range layers { for _, layer := range layers {
if s := layer.GGML.KV().ChatTemplate(); s != "" { if s := layer.GGML.KV().ChatTemplate(); s != "" {
if t, err := templates.NamedTemplate(s); err != nil { if t, err := template.Named(s); err != nil {
slog.Debug("template detection", "error", err) slog.Debug("template detection", "error", err)
} else { } else {
tmpl, err := NewLayer(t.Reader(), "application/vnd.ollama.image.template") tmpl, err := NewLayer(t.Reader(), "application/vnd.ollama.image.template")

View file

@ -4,10 +4,11 @@ import (
"fmt" "fmt"
"log/slog" "log/slog"
"strings" "strings"
"text/template"
"text/template/parse" "text/template/parse"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/template"
) )
// isResponseNode checks if the node contains .Response // isResponseNode checks if the node contains .Response
@ -53,13 +54,8 @@ func formatTemplateForResponse(tmpl *template.Template, generate bool) {
// Prompt renders a prompt from a template. If generate is set to true, // Prompt renders a prompt from a template. If generate is set to true,
// the response and parts of the template following it are not rendered // the response and parts of the template following it are not rendered
func Prompt(tmpl, system, prompt, response string, generate bool) (string, error) { func Prompt(tmpl *template.Template, system, prompt, response string, generate bool) (string, error) {
parsed, err := template.New("").Option("missingkey=zero").Parse(tmpl) formatTemplateForResponse(tmpl, generate)
if err != nil {
return "", err
}
formatTemplateForResponse(parsed, generate)
vars := map[string]any{ vars := map[string]any{
"System": system, "System": system,
@ -68,14 +64,14 @@ func Prompt(tmpl, system, prompt, response string, generate bool) (string, error
} }
var sb strings.Builder var sb strings.Builder
if err := parsed.Execute(&sb, vars); err != nil { if err := tmpl.Execute(&sb, vars); err != nil {
return "", err return "", err
} }
return sb.String(), nil return sb.String(), nil
} }
func countTokens(tmpl string, system string, prompt string, response string, encode func(string) ([]int, error)) (int, error) { func countTokens(tmpl *template.Template, system string, prompt string, response string, encode func(string) ([]int, error)) (int, error) {
rendered, err := Prompt(tmpl, system, prompt, response, false) rendered, err := Prompt(tmpl, system, prompt, response, false)
if err != nil { if err != nil {
return 0, err return 0, err
@ -91,7 +87,7 @@ func countTokens(tmpl string, system string, prompt string, response string, enc
} }
// ChatPrompt builds up a prompt from a series of messages, truncating based on context window size // ChatPrompt builds up a prompt from a series of messages, truncating based on context window size
func ChatPrompt(tmpl string, messages []api.Message, window int, encode func(string) ([]int, error)) (string, error) { func ChatPrompt(tmpl *template.Template, messages []api.Message, window int, encode func(string) ([]int, error)) (string, error) {
type prompt struct { type prompt struct {
System string System string
Prompt string Prompt string

View file

@ -5,6 +5,7 @@ import (
"testing" "testing"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/template"
) )
func TestPrompt(t *testing.T) { func TestPrompt(t *testing.T) {
@ -61,7 +62,12 @@ func TestPrompt(t *testing.T) {
for _, tc := range tests { for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
got, err := Prompt(tc.template, tc.system, tc.prompt, tc.response, tc.generate) tmpl, err := template.Parse(tc.template)
if err != nil {
t.Fatal(err)
}
got, err := Prompt(tmpl, tc.system, tc.prompt, tc.response, tc.generate)
if err != nil { if err != nil {
t.Errorf("error = %v", err) t.Errorf("error = %v", err)
} }
@ -192,7 +198,12 @@ func TestChatPrompt(t *testing.T) {
for _, tc := range tests { for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
got, err := ChatPrompt(tc.template, tc.messages, tc.window, encode) tmpl, err := template.Parse(tc.template)
if err != nil {
t.Fatal(err)
}
got, err := ChatPrompt(tmpl, tc.messages, tc.window, encode)
if err != nil { if err != nil {
t.Errorf("error = %v", err) t.Errorf("error = %v", err)
} }

View file

@ -31,6 +31,7 @@ import (
"github.com/ollama/ollama/llm" "github.com/ollama/ollama/llm"
"github.com/ollama/ollama/openai" "github.com/ollama/ollama/openai"
"github.com/ollama/ollama/parser" "github.com/ollama/ollama/parser"
"github.com/ollama/ollama/template"
"github.com/ollama/ollama/types/errtypes" "github.com/ollama/ollama/types/errtypes"
"github.com/ollama/ollama/types/model" "github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version" "github.com/ollama/ollama/version"
@ -161,6 +162,12 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return return
} }
tmpl, err := template.Parse(req.Template)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
checkpointLoaded := time.Now() checkpointLoaded := time.Now()
var prompt string var prompt string
@ -169,7 +176,11 @@ func (s *Server) GenerateHandler(c *gin.Context) {
prompt = req.Prompt prompt = req.Prompt
case req.Prompt != "": case req.Prompt != "":
if req.Template == "" { if req.Template == "" {
req.Template = model.Template model.Template, err = template.Parse(req.Template)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
} }
if req.System == "" { if req.System == "" {
@ -187,7 +198,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
sb.WriteString(req.Prompt) sb.WriteString(req.Prompt)
p, err := Prompt(req.Template, req.System, sb.String(), "", true) p, err := Prompt(tmpl, req.System, sb.String(), "", true)
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
@ -242,7 +253,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
resp.LoadDuration = checkpointLoaded.Sub(checkpointStart) resp.LoadDuration = checkpointLoaded.Sub(checkpointStart)
if !req.Raw { if !req.Raw {
p, err := Prompt(req.Template, req.System, req.Prompt, generated.String(), false) p, err := Prompt(tmpl, req.System, req.Prompt, generated.String(), false)
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
@ -680,7 +691,10 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
} }
if req.Template != "" { if req.Template != "" {
m.Template = req.Template m.Template, err = template.Parse(req.Template)
if err != nil {
return nil, err
}
} }
msgs := make([]api.Message, 0) msgs := make([]api.Message, 0)
@ -701,7 +715,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
resp := &api.ShowResponse{ resp := &api.ShowResponse{
License: strings.Join(m.License, "\n"), License: strings.Join(m.License, "\n"),
System: m.System, System: m.System,
Template: m.Template, Template: m.Template.String(),
Details: modelDetails, Details: modelDetails,
Messages: msgs, Messages: msgs,
ModifiedAt: manifest.fi.ModTime(), ModifiedAt: manifest.fi.ModTime(),
@ -1246,7 +1260,7 @@ func (s *Server) ProcessHandler(c *gin.Context) {
} }
// ChatPrompt builds up a prompt from a series of messages for the currently `loaded` model // ChatPrompt builds up a prompt from a series of messages for the currently `loaded` model
func chatPrompt(ctx context.Context, runner *runnerRef, template string, messages []api.Message, numCtx int) (string, error) { func chatPrompt(ctx context.Context, runner *runnerRef, template *template.Template, messages []api.Message, numCtx int) (string, error) {
encode := func(s string) ([]int, error) { encode := func(s string) ([]int, error) {
return runner.llama.Tokenize(ctx, s) return runner.llama.Tokenize(ctx, s)
} }

158
template/template.go Normal file
View file

@ -0,0 +1,158 @@
package template
import (
"bytes"
"embed"
"encoding/json"
"errors"
"io"
"math"
"slices"
"strings"
"sync"
"text/template"
"text/template/parse"
"github.com/agnivade/levenshtein"
"golang.org/x/exp/maps"
)
//go:embed index.json
var indexBytes []byte
//go:embed *.gotmpl
var templatesFS embed.FS
var templatesOnce = sync.OnceValues(func() ([]*named, error) {
var templates []*named
if err := json.Unmarshal(indexBytes, &templates); err != nil {
return nil, err
}
for _, t := range templates {
bts, err := templatesFS.ReadFile(t.Name + ".gotmpl")
if err != nil {
return nil, err
}
// normalize line endings
t.Bytes = bytes.ReplaceAll(bts, []byte("\r\n"), []byte("\n"))
}
return templates, nil
})
type named struct {
Name string `json:"name"`
Template string `json:"template"`
Bytes []byte
}
func (t named) Reader() io.Reader {
return bytes.NewReader(t.Bytes)
}
func Named(s string) (*named, error) {
templates, err := templatesOnce()
if err != nil {
return nil, err
}
var template *named
score := math.MaxInt
for _, t := range templates {
if s := levenshtein.ComputeDistance(s, t.Template); s < score {
score = s
template = t
}
}
if score < 100 {
return template, nil
}
return nil, errors.New("no matching template found")
}
type Template struct {
*template.Template
raw string
}
func (t *Template) String() string {
return t.raw
}
var DefaultTemplate, _ = Parse("{{ .Prompt }}")
func Parse(s string) (*Template, error) {
t, err := template.New("").Option("missingkey=zero").Parse(s)
if err != nil {
return nil, err
}
return &Template{Template: t, raw: s}, nil
}
func (t *Template) Vars() []string {
var vars []string
for _, n := range t.Tree.Root.Nodes {
vars = append(vars, parseNode(n)...)
}
set := make(map[string]struct{})
for _, n := range vars {
set[strings.ToLower(n)] = struct{}{}
}
vars = maps.Keys(set)
slices.Sort(vars)
return vars
}
func parseNode(n parse.Node) []string {
switch n := n.(type) {
case *parse.ActionNode:
return parseNode(n.Pipe)
case *parse.IfNode:
names := parseNode(n.Pipe)
names = append(names, parseNode(n.List)...)
if n.ElseList != nil {
names = append(names, parseNode(n.ElseList)...)
}
return names
case *parse.RangeNode:
names := parseNode(n.Pipe)
names = append(names, parseNode(n.List)...)
if n.ElseList != nil {
names = append(names, parseNode(n.ElseList)...)
}
return names
case *parse.WithNode:
names := parseNode(n.Pipe)
names = append(names, parseNode(n.List)...)
if n.ElseList != nil {
names = append(names, parseNode(n.ElseList)...)
}
return names
case *parse.PipeNode:
var names []string
for _, c := range n.Cmds {
for _, a := range c.Args {
names = append(names, parseNode(a)...)
}
}
return names
case *parse.ListNode:
var names []string
for _, n := range n.Nodes {
names = append(names, parseNode(n)...)
}
return names
case *parse.FieldNode:
return n.Ident
}
return nil
}

89
template/template_test.go Normal file
View file

@ -0,0 +1,89 @@
package template
import (
"bufio"
"bytes"
"encoding/json"
"io"
"os"
"path/filepath"
"slices"
"testing"
"text/template"
"github.com/ollama/ollama/llm"
)
func TestNamed(t *testing.T) {
f, err := os.Open(filepath.Join("testdata", "templates.jsonl"))
if err != nil {
t.Fatal(err)
}
defer f.Close()
scanner := bufio.NewScanner(f)
for scanner.Scan() {
var ss map[string]string
if err := json.Unmarshal(scanner.Bytes(), &ss); err != nil {
t.Fatal(err)
}
for k, v := range ss {
t.Run(k, func(t *testing.T) {
kv := llm.KV{"tokenizer.chat_template": v}
s := kv.ChatTemplate()
r, err := Named(s)
if err != nil {
t.Fatal(err)
}
if r.Name != k {
t.Errorf("expected %q, got %q", k, r.Name)
}
var b bytes.Buffer
if _, err := io.Copy(&b, r.Reader()); err != nil {
t.Fatal(err)
}
tmpl, err := template.New(s).Parse(b.String())
if err != nil {
t.Fatal(err)
}
if tmpl.Tree.Root.String() == "" {
t.Errorf("empty %s template", k)
}
})
}
}
}
func TestParse(t *testing.T) {
cases := []struct {
template string
capabilities []string
}{
{"{{ .Prompt }}", []string{"prompt"}},
{"{{ .System }} {{ .Prompt }}", []string{"prompt", "system"}},
{"{{ .System }} {{ .Prompt }} {{ .Response }}", []string{"prompt", "response", "system"}},
{"{{ with .Tools }}{{ . }}{{ end }} {{ .System }} {{ .Prompt }}", []string{"prompt", "system", "tools"}},
{"{{ range .Messages }}{{ .Role }} {{ .Content }}{{ end }}", []string{"content", "messages", "role"}},
{"{{ range .Messages }}{{ if eq .Role \"system\" }}SYSTEM: {{ .Content }}{{ else if eq .Role \"user\" }}USER: {{ .Content }}{{ else if eq .Role \"assistant\" }}ASSISTANT: {{ .Content }}{{ end }}{{ end }}", []string{"content", "messages", "role"}},
{"{{ .Prompt }} {{ .Suffix }}", []string{"prompt", "suffix"}},
}
for _, tt := range cases {
t.Run("", func(t *testing.T) {
tmpl, err := Parse(tt.template)
if err != nil {
t.Fatal(err)
}
vars := tmpl.Vars()
if !slices.Equal(tt.capabilities, vars) {
t.Errorf("expected %v, got %v", tt.capabilities, vars)
}
})
}
}

View file

@ -1,70 +0,0 @@
package templates
import (
"bytes"
"embed"
"encoding/json"
"errors"
"io"
"math"
"sync"
"github.com/agnivade/levenshtein"
)
//go:embed index.json
var indexBytes []byte
//go:embed *.gotmpl
var templatesFS embed.FS
var templatesOnce = sync.OnceValues(func() ([]*Template, error) {
var templates []*Template
if err := json.Unmarshal(indexBytes, &templates); err != nil {
return nil, err
}
for _, t := range templates {
bts, err := templatesFS.ReadFile(t.Name + ".gotmpl")
if err != nil {
return nil, err
}
// normalize line endings
t.Bytes = bytes.ReplaceAll(bts, []byte("\r\n"), []byte("\n"))
}
return templates, nil
})
type Template struct {
Name string `json:"name"`
Template string `json:"template"`
Bytes []byte
}
func (t Template) Reader() io.Reader {
return bytes.NewReader(t.Bytes)
}
func NamedTemplate(s string) (*Template, error) {
templates, err := templatesOnce()
if err != nil {
return nil, err
}
var template *Template
score := math.MaxInt
for _, t := range templates {
if s := levenshtein.ComputeDistance(s, t.Template); s < score {
score = s
template = t
}
}
if score < 100 {
return template, nil
}
return nil, errors.New("no matching template found")
}

View file

@ -1,59 +0,0 @@
package templates
import (
"bufio"
"bytes"
"encoding/json"
"io"
"os"
"path/filepath"
"testing"
"text/template"
"github.com/ollama/ollama/llm"
)
func TestKVChatTemplate(t *testing.T) {
f, err := os.Open(filepath.Join("testdata", "templates.jsonl"))
if err != nil {
t.Fatal(err)
}
defer f.Close()
scanner := bufio.NewScanner(f)
for scanner.Scan() {
var ss map[string]string
if err := json.Unmarshal(scanner.Bytes(), &ss); err != nil {
t.Fatal(err)
}
for k, v := range ss {
t.Run(k, func(t *testing.T) {
kv := llm.KV{"tokenizer.chat_template": v}
s := kv.ChatTemplate()
r, err := NamedTemplate(s)
if err != nil {
t.Fatal(err)
}
if r.Name != k {
t.Errorf("expected %q, got %q", k, r.Name)
}
var b bytes.Buffer
if _, err := io.Copy(&b, r.Reader()); err != nil {
t.Fatal(err)
}
tmpl, err := template.New(s).Parse(b.String())
if err != nil {
t.Fatal(err)
}
if tmpl.Tree.Root.String() == "" {
t.Errorf("empty %s template", k)
}
})
}
}
}