diff --git a/llm/server.go b/llm/server.go
index 54fad92c..08dc04d5 100644
--- a/llm/server.go
+++ b/llm/server.go
@@ -679,7 +679,7 @@ type CompletionRequest struct {
Prompt string
Format string
Images []ImageData
- Options api.Options
+ Options *api.Options
}
type CompletionResponse struct {
diff --git a/server/images.go b/server/images.go
index a62991f1..688d5dca 100644
--- a/server/images.go
+++ b/server/images.go
@@ -34,6 +34,8 @@ import (
"github.com/ollama/ollama/version"
)
+var errCapabilityCompletion = errors.New("completion")
+
type Capability string
const CapabilityCompletion = Capability("completion")
@@ -62,7 +64,10 @@ type Model struct {
Template *template.Template
}
-func (m *Model) Has(caps ...Capability) bool {
+// CheckCapabilities checks if the model has the specified capabilities returning an error describing
+// any missing or unknown capabilities
+func (m *Model) CheckCapabilities(caps ...Capability) error {
+ var errs []error
for _, cap := range caps {
switch cap {
case CapabilityCompletion:
@@ -81,15 +86,19 @@ func (m *Model) Has(caps ...Capability) bool {
}
if _, ok := ggml.KV()[fmt.Sprintf("%s.pooling_type", ggml.KV().Architecture())]; ok {
- return false
+ errs = append(errs, errCapabilityCompletion)
}
default:
slog.Error("unknown capability", "capability", cap)
- return false
+ return fmt.Errorf("unknown capability: %s", cap)
}
}
- return true
+ if err := errors.Join(errs...); err != nil {
+ return fmt.Errorf("missing capabilities: %w", errors.Join(errs...))
+ }
+
+ return nil
}
func (m *Model) String() string {
diff --git a/server/prompt.go b/server/prompt.go
index bfc319a5..51d691a9 100644
--- a/server/prompt.go
+++ b/server/prompt.go
@@ -1,217 +1,83 @@
package server
import (
- "fmt"
+ "bytes"
+ "context"
"log/slog"
- "strings"
-
- "text/template/parse"
+ "slices"
"github.com/ollama/ollama/api"
+ "github.com/ollama/ollama/llm"
"github.com/ollama/ollama/template"
)
-// isResponseNode checks if the node contains .Response
-func isResponseNode(node *parse.ActionNode) bool {
- for _, cmd := range node.Pipe.Cmds {
- for _, arg := range cmd.Args {
- if fieldNode, ok := arg.(*parse.FieldNode); ok && len(fieldNode.Ident) > 0 {
- if fieldNode.Ident[0] == "Response" {
- return true
- }
- }
+type tokenizeFunc func(context.Context, string) ([]int, error)
+
+// chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn.
+// chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the
+// latest message and 2) system messages
+func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message) (prompt string, images []llm.ImageData, _ error) {
+ // pull out any system messages which should always be included in the prompt
+ var system []api.Message
+ msgs = slices.DeleteFunc(msgs, func(m api.Message) bool {
+ if m.Role == "system" {
+ system = append(system, m)
+ return true
}
- }
- return false
-}
-// formatTemplateForResponse formats the template AST to:
-// 1. remove all nodes after the first .Response (if generate=true)
-// 2. add a .Response node to the end if it doesn't exist
-// TODO(jmorganca): this should recursively cut the template before the first .Response
-func formatTemplateForResponse(tmpl *template.Template, generate bool) {
- var found bool
- for i, node := range tmpl.Tree.Root.Nodes {
- if actionNode, ok := node.(*parse.ActionNode); ok {
- if isResponseNode(actionNode) {
- found = true
- if generate {
- tmpl.Tree.Root.Nodes = tmpl.Tree.Root.Nodes[:i+1]
- break
- }
- }
+ return false
+ })
+
+ if len(system) == 0 && m.System != "" {
+ // add model system prompt since it wasn't provided
+ system = append(system, api.Message{Role: "system", Content: m.System})
+ }
+
+ // always include the last message
+ n := len(msgs) - 1
+ // in reverse, find all messages that fit into context window
+ for i := n - 1; i >= 0; i-- {
+ var b bytes.Buffer
+ if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...)}); err != nil {
+ return "", nil, err
}
- }
- if !found {
- // add the response node if it doesn't exist
- responseFieldNode := &parse.FieldNode{NodeType: parse.NodeField, Ident: []string{"Response"}}
- responsePipeNode := &parse.PipeNode{NodeType: parse.NodePipe, Cmds: []*parse.CommandNode{{NodeType: parse.NodeCommand, Args: []parse.Node{responseFieldNode}}}}
- responseActionNode := &parse.ActionNode{NodeType: parse.NodeAction, Pipe: responsePipeNode}
- tmpl.Tree.Root.Nodes = append(tmpl.Tree.Root.Nodes, responseActionNode)
- }
-}
-
-// Prompt renders a prompt from a template. If generate is set to true,
-// the response and parts of the template following it are not rendered
-func Prompt(tmpl *template.Template, system, prompt, response string, generate bool) (string, error) {
- formatTemplateForResponse(tmpl, generate)
-
- vars := map[string]any{
- "System": system,
- "Prompt": prompt,
- "Response": response,
- }
-
- var sb strings.Builder
- if err := tmpl.Execute(&sb, vars); err != nil {
- return "", err
- }
-
- return sb.String(), nil
-}
-
-func countTokens(tmpl *template.Template, system string, prompt string, response string, encode func(string) ([]int, error)) (int, error) {
- rendered, err := Prompt(tmpl, system, prompt, response, false)
- if err != nil {
- return 0, err
- }
-
- tokens, err := encode(rendered)
- if err != nil {
- slog.Error("failed to encode prompt", "err", err)
- return 0, err
- }
-
- return len(tokens), err
-}
-
-// ChatPrompt builds up a prompt from a series of messages, truncating based on context window size
-func ChatPrompt(tmpl *template.Template, messages []api.Message, window int, encode func(string) ([]int, error)) (string, error) {
- type prompt struct {
- System string
- Prompt string
- Response string
-
- images []int
- tokens int
- }
-
- var p prompt
-
- // iterate through messages to build up {system,user,response} prompts
- var imgId int
- var prompts []prompt
- for _, msg := range messages {
- switch strings.ToLower(msg.Role) {
- case "system":
- if p.System != "" || p.Prompt != "" || p.Response != "" {
- prompts = append(prompts, p)
- p = prompt{}
- }
-
- p.System = msg.Content
- case "user":
- if p.Prompt != "" || p.Response != "" {
- prompts = append(prompts, p)
- p = prompt{}
- }
-
- var sb strings.Builder
- for range msg.Images {
- fmt.Fprintf(&sb, "[img-%d] ", imgId)
- p.images = append(p.images, imgId)
- imgId += 1
- }
-
- sb.WriteString(msg.Content)
- p.Prompt = sb.String()
- case "assistant":
- if p.Response != "" {
- prompts = append(prompts, p)
- p = prompt{}
- }
-
- p.Response = msg.Content
- default:
- return "", fmt.Errorf("invalid role: %s, role must be one of [system, user, assistant]", msg.Role)
- }
- }
-
- // add final prompt
- if p.System != "" || p.Prompt != "" || p.Response != "" {
- prompts = append(prompts, p)
- }
-
- // calculate token lengths for each prompt, estimating 768 tokens per images
- for i, p := range prompts {
- tokens, err := countTokens(tmpl, p.System, p.Prompt, p.Response, encode)
+ s, err := tokenize(ctx, b.String())
if err != nil {
- return "", err
+ return "", nil, err
}
- prompts[i].tokens = tokens + len(prompts[i].images)*768
- }
-
- // truncate images and prompts starting from the beginning of the list
- // until either one prompt remains or the total tokens fits the context window
- // TODO (jmorganca): this doesn't account for the context window room required for the response
- for {
- var required int
- for _, p := range prompts {
- required += p.tokens
+ c := len(s)
+ if m.ProjectorPaths != nil {
+ for _, m := range msgs[i:] {
+ // images are represented as 768 sized embeddings
+ // TODO: get embedding length from project metadata
+ c += 768 * len(m.Images)
+ }
}
- required += 1 // for bos token
-
- if required <= window {
- slog.Debug("prompt now fits in context window", "required", required, "window", window)
+ if c > opts.NumCtx {
+ slog.Debug("truncating input messages which exceed context length", "truncated", len(msgs[i:]))
break
+ } else {
+ n = i
}
-
- prompt := &prompts[0]
-
- if len(prompt.images) > 1 {
- img := prompt.images[0]
- slog.Debug("prompt longer than context window, removing image", "id", img, "required", required, "window", window)
- prompt.images = prompt.images[1:]
- prompt.Prompt = strings.Replace(prompt.Prompt, fmt.Sprintf(" [img-%d]", img), "", 1)
- prompt.tokens -= 768
- continue
- }
-
- if len(prompts) > 1 {
- slog.Debug("required tokens longer than context window, removing first prompt", "prompt", prompts[0].tokens, "required", required, "window", window)
- system := prompt.System
- prompts = prompts[1:]
-
- if system != "" && prompts[0].System == "" {
- prompts[0].System = system
-
- tokens, err := countTokens(tmpl, prompts[0].System, prompts[0].Prompt, prompts[0].Response, encode)
- if err != nil {
- return "", err
- }
-
- prompts[0].tokens = tokens + len(prompts[0].images)*768
- }
-
- continue
- }
-
- // stop truncating if there's only one prompt left
- break
}
- var sb strings.Builder
- for i, p := range prompts {
- // last prompt should leave the response unrendered (for completion)
- rendered, err := Prompt(tmpl, p.System, p.Prompt, p.Response, i == len(prompts)-1)
- if err != nil {
- return "", err
- }
- sb.WriteString(rendered)
+ // truncate any messages that do not fit into the context window
+ var b bytes.Buffer
+ if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[n:]...)}); err != nil {
+ return "", nil, err
}
- return sb.String(), nil
+ for _, m := range msgs[n:] {
+ for _, i := range m.Images {
+ images = append(images, llm.ImageData{
+ ID: len(images),
+ Data: i,
+ })
+ }
+ }
+
+ return b.String(), images, nil
}
diff --git a/server/prompt_test.go b/server/prompt_test.go
index 7df58d0b..d4cee98c 100644
--- a/server/prompt_test.go
+++ b/server/prompt_test.go
@@ -1,6 +1,8 @@
package server
import (
+ "bytes"
+ "context"
"strings"
"testing"
@@ -8,208 +10,195 @@ import (
"github.com/ollama/ollama/template"
)
-func TestPrompt(t *testing.T) {
- tests := []struct {
- name string
- template string
- system string
- prompt string
- response string
- generate bool
- want string
- }{
- {
- name: "simple prompt",
- template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
- system: "You are a Wizard.",
- prompt: "What are the potion ingredients?",
- want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]",
- },
- {
- name: "implicit response",
- template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
- system: "You are a Wizard.",
- prompt: "What are the potion ingredients?",
- response: "I don't know.",
- want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]I don't know.",
- },
- {
- name: "response",
- template: "[INST] {{ .System }} {{ .Prompt }} [/INST] {{ .Response }}",
- system: "You are a Wizard.",
- prompt: "What are the potion ingredients?",
- response: "I don't know.",
- want: "[INST] You are a Wizard. What are the potion ingredients? [/INST] I don't know.",
- },
- {
- name: "cut",
- template: "{{ .System }}{{ .Prompt }}{{ .Response }}",
- system: "You are a Wizard.",
- prompt: "What are the potion ingredients?",
- response: "I don't know.",
- generate: true,
- want: "You are a Wizard.What are the potion ingredients?I don't know.",
- },
- {
- name: "nocut",
- template: "{{ .System }}{{ .Prompt }}{{ .Response }}",
- system: "You are a Wizard.",
- prompt: "What are the potion ingredients?",
- response: "I don't know.",
- want: "You are a Wizard.What are the potion ingredients?I don't know.",
- },
+func tokenize(_ context.Context, s string) (tokens []int, err error) {
+ for range strings.Fields(s) {
+ tokens = append(tokens, len(tokens))
}
- for _, tc := range tests {
- t.Run(tc.name, func(t *testing.T) {
- tmpl, err := template.Parse(tc.template)
- if err != nil {
- t.Fatal(err)
- }
-
- got, err := Prompt(tmpl, tc.system, tc.prompt, tc.response, tc.generate)
- if err != nil {
- t.Errorf("error = %v", err)
- }
-
- if got != tc.want {
- t.Errorf("got = %v, want %v", got, tc.want)
- }
- })
- }
+ return
}
func TestChatPrompt(t *testing.T) {
- tests := []struct {
- name string
- template string
- messages []api.Message
- window int
- want string
+ type expect struct {
+ prompt string
+ images [][]byte
+ }
+
+ cases := []struct {
+ name string
+ limit int
+ msgs []api.Message
+ expect
}{
{
- name: "simple prompt",
- template: "[INST] {{ .Prompt }} [/INST]",
- messages: []api.Message{
- {Role: "user", Content: "Hello"},
+ name: "messages",
+ limit: 64,
+ msgs: []api.Message{
+ {Role: "user", Content: "You're a test, Harry!"},
+ {Role: "assistant", Content: "I-I'm a what?"},
+ {Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
+ },
+ expect: expect{
+ prompt: "You're a test, Harry! I-I'm a what? A test. And a thumping good one at that, I'd wager. ",
},
- window: 1024,
- want: "[INST] Hello [/INST]",
},
{
- name: "with system message",
- template: "[INST] {{ if .System }}<>{{ .System }}<> {{ end }}{{ .Prompt }} [/INST]",
- messages: []api.Message{
- {Role: "system", Content: "You are a Wizard."},
- {Role: "user", Content: "Hello"},
+ name: "truncate messages",
+ limit: 1,
+ msgs: []api.Message{
+ {Role: "user", Content: "You're a test, Harry!"},
+ {Role: "assistant", Content: "I-I'm a what?"},
+ {Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
+ },
+ expect: expect{
+ prompt: "A test. And a thumping good one at that, I'd wager. ",
},
- window: 1024,
- want: "[INST] <>You are a Wizard.<> Hello [/INST]",
},
{
- name: "with response",
- template: "[INST] {{ if .System }}<>{{ .System }}<> {{ end }}{{ .Prompt }} [/INST] {{ .Response }}",
- messages: []api.Message{
- {Role: "system", Content: "You are a Wizard."},
- {Role: "user", Content: "Hello"},
- {Role: "assistant", Content: "I am?"},
+ name: "truncate messages with image",
+ limit: 64,
+ msgs: []api.Message{
+ {Role: "user", Content: "You're a test, Harry!"},
+ {Role: "assistant", Content: "I-I'm a what?"},
+ {Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{[]byte("something")}},
+ },
+ expect: expect{
+ prompt: "[img-0] A test. And a thumping good one at that, I'd wager. ",
+ images: [][]byte{
+ []byte("something"),
+ },
},
- window: 1024,
- want: "[INST] <>You are a Wizard.<> Hello [/INST] I am?",
},
{
- name: "with implicit response",
- template: "[INST] {{ if .System }}<>{{ .System }}<> {{ end }}{{ .Prompt }} [/INST]",
- messages: []api.Message{
- {Role: "system", Content: "You are a Wizard."},
- {Role: "user", Content: "Hello"},
- {Role: "assistant", Content: "I am?"},
+ name: "truncate messages with images",
+ limit: 64,
+ msgs: []api.Message{
+ {Role: "user", Content: "You're a test, Harry!", Images: []api.ImageData{[]byte("something")}},
+ {Role: "assistant", Content: "I-I'm a what?"},
+ {Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{[]byte("somethingelse")}},
+ },
+ expect: expect{
+ prompt: "[img-0] A test. And a thumping good one at that, I'd wager. ",
+ images: [][]byte{
+ []byte("somethingelse"),
+ },
},
- window: 1024,
- want: "[INST] <>You are a Wizard.<> Hello [/INST]I am?",
},
{
- name: "with conversation",
- template: "[INST] {{ if .System }}<>{{ .System }}<> {{ end }}{{ .Prompt }} [/INST] {{ .Response }} ",
- messages: []api.Message{
- {Role: "system", Content: "You are a Wizard."},
- {Role: "user", Content: "What are the potion ingredients?"},
- {Role: "assistant", Content: "sugar"},
- {Role: "user", Content: "Anything else?"},
+ name: "messages with images",
+ limit: 2048,
+ msgs: []api.Message{
+ {Role: "user", Content: "You're a test, Harry!", Images: []api.ImageData{[]byte("something")}},
+ {Role: "assistant", Content: "I-I'm a what?"},
+ {Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{[]byte("somethingelse")}},
+ },
+ expect: expect{
+ prompt: "[img-0] You're a test, Harry! I-I'm a what? [img-1] A test. And a thumping good one at that, I'd wager. ",
+ images: [][]byte{
+ []byte("something"),
+ []byte("somethingelse"),
+ },
},
- window: 1024,
- want: "[INST] <>You are a Wizard.<> What are the potion ingredients? [/INST] sugar [INST] Anything else? [/INST] ",
},
{
- name: "with truncation",
- template: "{{ .System }} {{ .Prompt }} {{ .Response }} ",
- messages: []api.Message{
- {Role: "system", Content: "You are a Wizard."},
- {Role: "user", Content: "Hello"},
- {Role: "assistant", Content: "I am?"},
- {Role: "user", Content: "Why is the sky blue?"},
- {Role: "assistant", Content: "The sky is blue from rayleigh scattering"},
+ name: "message with image tag",
+ limit: 2048,
+ msgs: []api.Message{
+ {Role: "user", Content: "You're a test, Harry! [img]", Images: []api.ImageData{[]byte("something")}},
+ {Role: "assistant", Content: "I-I'm a what?"},
+ {Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{[]byte("somethingelse")}},
+ },
+ expect: expect{
+ prompt: "You're a test, Harry! [img-0] I-I'm a what? [img-1] A test. And a thumping good one at that, I'd wager. ",
+ images: [][]byte{
+ []byte("something"),
+ []byte("somethingelse"),
+ },
},
- window: 10,
- want: "You are a Wizard. Why is the sky blue? The sky is blue from rayleigh scattering",
},
{
- name: "images",
- template: "{{ .System }} {{ .Prompt }}",
- messages: []api.Message{
- {Role: "system", Content: "You are a Wizard."},
- {Role: "user", Content: "Hello", Images: []api.ImageData{[]byte("base64")}},
+ name: "messages with interleaved images",
+ limit: 2048,
+ msgs: []api.Message{
+ {Role: "user", Content: "You're a test, Harry!"},
+ {Role: "user", Images: []api.ImageData{[]byte("something")}},
+ {Role: "user", Images: []api.ImageData{[]byte("somethingelse")}},
+ {Role: "assistant", Content: "I-I'm a what?"},
+ {Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
+ },
+ expect: expect{
+ prompt: "You're a test, Harry!\n\n[img-0]\n\n[img-1] I-I'm a what? A test. And a thumping good one at that, I'd wager. ",
+ images: [][]byte{
+ []byte("something"),
+ []byte("somethingelse"),
+ },
},
- window: 1024,
- want: "You are a Wizard. [img-0] Hello",
},
{
- name: "images truncated",
- template: "{{ .System }} {{ .Prompt }}",
- messages: []api.Message{
- {Role: "system", Content: "You are a Wizard."},
- {Role: "user", Content: "Hello", Images: []api.ImageData{[]byte("img1"), []byte("img2")}},
+ name: "truncate message with interleaved images",
+ limit: 1024,
+ msgs: []api.Message{
+ {Role: "user", Content: "You're a test, Harry!"},
+ {Role: "user", Images: []api.ImageData{[]byte("something")}},
+ {Role: "user", Images: []api.ImageData{[]byte("somethingelse")}},
+ {Role: "assistant", Content: "I-I'm a what?"},
+ {Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
+ },
+ expect: expect{
+ prompt: "[img-0] I-I'm a what? A test. And a thumping good one at that, I'd wager. ",
+ images: [][]byte{
+ []byte("somethingelse"),
+ },
},
- window: 1024,
- want: "You are a Wizard. [img-0] [img-1] Hello",
},
{
- name: "empty list",
- template: "{{ .System }} {{ .Prompt }}",
- messages: []api.Message{},
- window: 1024,
- want: "",
- },
- {
- name: "empty prompt",
- template: "[INST] {{ if .System }}<>{{ .System }}<> {{ end }}{{ .Prompt }} [/INST] {{ .Response }} ",
- messages: []api.Message{
- {Role: "user", Content: ""},
+ name: "message with system prompt",
+ limit: 2048,
+ msgs: []api.Message{
+ {Role: "system", Content: "You are the Test Who Lived."},
+ {Role: "user", Content: "You're a test, Harry!"},
+ {Role: "assistant", Content: "I-I'm a what?"},
+ {Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
+ },
+ expect: expect{
+ prompt: "You're a test, Harry! I-I'm a what? You are the Test Who Lived. A test. And a thumping good one at that, I'd wager. ",
},
- window: 1024,
- want: "",
},
}
- encode := func(s string) ([]int, error) {
- words := strings.Fields(s)
- return make([]int, len(words)), nil
+ tmpl, err := template.Parse(`
+{{- if .System }}{{ .System }} {{ end }}
+{{- if .Prompt }}{{ .Prompt }} {{ end }}
+{{- if .Response }}{{ .Response }} {{ end }}`)
+ if err != nil {
+ t.Fatal(err)
}
- for _, tc := range tests {
- t.Run(tc.name, func(t *testing.T) {
- tmpl, err := template.Parse(tc.template)
+ for _, tt := range cases {
+ t.Run(tt.name, func(t *testing.T) {
+ model := Model{Template: tmpl, ProjectorPaths: []string{"vision"}}
+ opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}}
+ prompt, images, err := chatPrompt(context.TODO(), &model, tokenize, &opts, tt.msgs)
if err != nil {
t.Fatal(err)
}
- got, err := ChatPrompt(tmpl, tc.messages, tc.window, encode)
- if err != nil {
- t.Errorf("error = %v", err)
+ if tt.prompt != prompt {
+ t.Errorf("expected %q, got %q", tt.prompt, prompt)
}
- if got != tc.want {
- t.Errorf("got: %q, want: %q", got, tc.want)
+ if len(images) != len(tt.images) {
+ t.Fatalf("expected %d images, got %d", len(tt.images), len(images))
+ }
+
+ for i := range images {
+ if images[i].ID != i {
+ t.Errorf("expected ID %d, got %d", i, images[i].ID)
+ }
+
+ if !bytes.Equal(images[i].Data, tt.images[i]) {
+ t.Errorf("expected %q, got %q", tt.images[i], images[i])
+ }
}
})
}
diff --git a/server/routes.go b/server/routes.go
index ac6b713a..4059c7c5 100644
--- a/server/routes.go
+++ b/server/routes.go
@@ -1,13 +1,13 @@
package server
import (
+ "bytes"
"cmp"
"context"
"encoding/json"
"errors"
"fmt"
"io"
- "io/fs"
"log/slog"
"net"
"net/http"
@@ -54,6 +54,8 @@ func init() {
gin.SetMode(mode)
}
+var errRequired = errors.New("is required")
+
func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) {
opts := api.DefaultOptions()
if err := opts.FromMap(model.Options); err != nil {
@@ -67,163 +69,140 @@ func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options
return opts, nil
}
-func isSupportedImageType(image []byte) bool {
- contentType := http.DetectContentType(image)
- allowedTypes := []string{"image/jpeg", "image/jpg", "image/png"}
- return slices.Contains(allowedTypes, contentType)
+// scheduleRunner schedules a runner after validating inputs such as capabilities and model options.
+// It returns the allocated runner, model instance, and consolidated options if successful and error otherwise.
+func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capability, requestOpts map[string]any, keepAlive *api.Duration) (llm.LlamaServer, *Model, *api.Options, error) {
+ if name == "" {
+ return nil, nil, nil, fmt.Errorf("model %w", errRequired)
+ }
+
+ model, err := GetModel(name)
+ if err != nil {
+ return nil, nil, nil, err
+ }
+
+ if err := model.CheckCapabilities(caps...); err != nil {
+ return nil, nil, nil, fmt.Errorf("%s %w", name, err)
+ }
+
+ opts, err := modelOptions(model, requestOpts)
+ if err != nil {
+ return nil, nil, nil, err
+ }
+
+ runnerCh, errCh := s.sched.GetRunner(ctx, model, opts, keepAlive)
+ var runner *runnerRef
+ select {
+ case runner = <-runnerCh:
+ case err = <-errCh:
+ return nil, nil, nil, err
+ }
+
+ return runner.llama, model, &opts, nil
}
func (s *Server) GenerateHandler(c *gin.Context) {
- checkpointStart := time.Now()
var req api.GenerateRequest
- err := c.ShouldBindJSON(&req)
-
- switch {
- case errors.Is(err, io.EOF):
+ if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
- case err != nil:
+ } else if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
- // validate the request
- switch {
- case req.Model == "":
- c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
+ if req.Format != "" && req.Format != "json" {
+ c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be empty or \"json\""})
return
- case len(req.Format) > 0 && req.Format != "json":
- c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be json"})
- return
- case req.Raw && (req.Template != "" || req.System != "" || len(req.Context) > 0):
+ } else if req.Raw && (req.Template != "" || req.System != "" || len(req.Context) > 0) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "raw mode does not support template, system, or context"})
return
}
- for _, img := range req.Images {
- if !isSupportedImageType(img) {
- c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "unsupported image format"})
- return
- }
- }
-
- model, err := GetModel(req.Model)
- if err != nil {
- var pErr *fs.PathError
- if errors.As(err, &pErr) {
- c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
- return
- }
- c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+ caps := []Capability{CapabilityCompletion}
+ r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
+ if errors.Is(err, errCapabilityCompletion) {
+ c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support generate", req.Model)})
+ return
+ } else if err != nil {
+ handleScheduleError(c, req.Model, err)
return
}
- if !model.Has(CapabilityCompletion) {
- c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%s does not support generate", req.Model)})
- return
- }
-
- opts, err := modelOptions(model, req.Options)
- if err != nil {
- c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
- return
- }
-
- rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, req.KeepAlive)
- var runner *runnerRef
- select {
- case runner = <-rCh:
- case err = <-eCh:
- handleErrorResponse(c, err)
- return
- }
-
- // an empty request loads the model
- // note: for a short while template was used in lieu
- // of `raw` mode so we need to check for it too
- if req.Prompt == "" && req.Template == "" && req.System == "" {
+ if req.Prompt == "" {
c.JSON(http.StatusOK, api.GenerateResponse{
- CreatedAt: time.Now().UTC(),
Model: req.Model,
+ CreatedAt: time.Now().UTC(),
Done: true,
DoneReason: "load",
})
return
}
- tmpl, err := template.Parse(req.Template)
- if err != nil {
- c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
- return
+ images := make([]llm.ImageData, len(req.Images))
+ for i := range req.Images {
+ images[i] = llm.ImageData{ID: i, Data: req.Images[i]}
}
- checkpointLoaded := time.Now()
-
- var prompt string
- switch {
- case req.Raw:
- prompt = req.Prompt
- case req.Prompt != "":
- if req.Template == "" {
- tmpl = model.Template
+ prompt := req.Prompt
+ if !req.Raw {
+ var msgs []api.Message
+ if req.System != "" {
+ msgs = append(msgs, api.Message{Role: "system", Content: req.System})
+ } else if m.System != "" {
+ msgs = append(msgs, api.Message{Role: "system", Content: m.System})
}
- if req.System == "" {
- req.System = model.System
+ for _, i := range images {
+ msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]", i.ID)})
}
- slog.Debug("generate handler", "prompt", req.Prompt)
- slog.Debug("generate handler", "template", req.Template)
- slog.Debug("generate handler", "system", req.System)
+ msgs = append(msgs, api.Message{Role: "user", Content: req.Prompt})
- var sb strings.Builder
- for i := range req.Images {
- fmt.Fprintf(&sb, "[img-%d] ", i)
+ tmpl := m.Template
+ if req.Template != "" {
+ tmpl, err = template.Parse(req.Template)
+ if err != nil {
+ c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+ return
+ }
}
- sb.WriteString(req.Prompt)
-
- p, err := Prompt(tmpl, req.System, sb.String(), "", true)
- if err != nil {
- c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
- return
- }
-
- sb.Reset()
+ var b bytes.Buffer
if req.Context != nil {
- prev, err := runner.llama.Detokenize(c.Request.Context(), req.Context)
+ s, err := r.Detokenize(c.Request.Context(), req.Context)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
- sb.WriteString(prev)
+ b.WriteString(s)
}
- sb.WriteString(p)
+ if err := tmpl.Execute(&b, template.Values{Messages: msgs}); err != nil {
+ c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+ return
+ }
- prompt = sb.String()
+ prompt = b.String()
}
- slog.Debug("generate handler", "prompt", prompt)
+ slog.Debug("generate request", "prompt", prompt, "images", images)
ch := make(chan any)
- var generated strings.Builder
go func() {
defer close(ch)
-
- fn := func(r llm.CompletionResponse) {
- // Build up the full response
- if _, err := generated.WriteString(r.Content); err != nil {
- ch <- gin.H{"error": err.Error()}
- return
- }
-
- resp := api.GenerateResponse{
+ if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
+ Prompt: prompt,
+ Images: images,
+ Format: req.Format,
+ Options: opts,
+ }, func(r llm.CompletionResponse) {
+ ch <- api.GenerateResponse{
Model: req.Model,
CreatedAt: time.Now().UTC(),
- Done: r.Done,
Response: r.Content,
+ Done: r.Done,
DoneReason: r.DoneReason,
Metrics: api.Metrics{
PromptEvalCount: r.PromptEvalCount,
@@ -232,77 +211,35 @@ func (s *Server) GenerateHandler(c *gin.Context) {
EvalDuration: r.EvalDuration,
},
}
-
- if r.Done {
- resp.TotalDuration = time.Since(checkpointStart)
- resp.LoadDuration = checkpointLoaded.Sub(checkpointStart)
-
- if !req.Raw {
- p, err := Prompt(tmpl, req.System, req.Prompt, generated.String(), false)
- if err != nil {
- c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
- return
- }
-
- // TODO (jmorganca): encode() should not strip special tokens
- tokens, err := runner.llama.Tokenize(c.Request.Context(), p)
- if err != nil {
- ch <- gin.H{"error": err.Error()}
- return
- }
-
- resp.Context = append(req.Context, tokens...)
- }
- }
-
- ch <- resp
- }
-
- var images []llm.ImageData
- for i := range req.Images {
- images = append(images, llm.ImageData{
- ID: i,
- Data: req.Images[i],
- })
- }
-
- // Start prediction
- req := llm.CompletionRequest{
- Prompt: prompt,
- Format: req.Format,
- Images: images,
- Options: opts,
- }
- if err := runner.llama.Completion(c.Request.Context(), req, fn); err != nil {
+ }); err != nil {
ch <- gin.H{"error": err.Error()}
}
}()
if req.Stream != nil && !*req.Stream {
- // Accumulate responses into the final response
- var final api.GenerateResponse
+ var r api.GenerateResponse
var sb strings.Builder
- for resp := range ch {
- switch r := resp.(type) {
+ for rr := range ch {
+ switch t := rr.(type) {
case api.GenerateResponse:
- sb.WriteString(r.Response)
- final = r
+ sb.WriteString(t.Response)
+ r = t
case gin.H:
- if errorMsg, ok := r["error"].(string); ok {
- c.JSON(http.StatusInternalServerError, gin.H{"error": errorMsg})
- return
- } else {
- c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error format in response"})
- return
+ msg, ok := t["error"].(string)
+ if !ok {
+ msg = "unexpected error format in response"
}
+
+ c.JSON(http.StatusInternalServerError, gin.H{"error": msg})
+ return
default:
- c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error"})
+ c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected response"})
return
}
}
- final.Response = sb.String()
- c.JSON(http.StatusOK, final)
+ r.Response = sb.String()
+ c.JSON(http.StatusOK, r)
return
}
@@ -311,44 +248,17 @@ func (s *Server) GenerateHandler(c *gin.Context) {
func (s *Server) EmbeddingsHandler(c *gin.Context) {
var req api.EmbeddingRequest
- err := c.ShouldBindJSON(&req)
- switch {
- case errors.Is(err, io.EOF):
+ if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
- case err != nil:
+ } else if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
- if req.Model == "" {
- c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
- return
- }
-
- model, err := GetModel(req.Model)
+ r, _, _, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
if err != nil {
- var pErr *fs.PathError
- if errors.As(err, &pErr) {
- c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
- return
- }
- c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
- return
- }
-
- opts, err := modelOptions(model, req.Options)
- if err != nil {
- c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
- return
- }
-
- rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, req.KeepAlive)
- var runner *runnerRef
- select {
- case runner = <-rCh:
- case err = <-eCh:
- handleErrorResponse(c, err)
+ handleScheduleError(c, req.Model, err)
return
}
@@ -358,17 +268,14 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
return
}
- embedding, err := runner.llama.Embedding(c.Request.Context(), req.Prompt)
+ embedding, err := r.Embedding(c.Request.Context(), req.Prompt)
if err != nil {
slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
return
}
- resp := api.EmbeddingResponse{
- Embedding: embedding,
- }
- c.JSON(http.StatusOK, resp)
+ c.JSON(http.StatusOK, api.EmbeddingResponse{Embedding: embedding})
}
func (s *Server) PullModelHandler(c *gin.Context) {
@@ -649,9 +556,9 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
}
}
- msgs := make([]api.Message, 0)
- for _, msg := range m.Messages {
- msgs = append(msgs, api.Message{Role: msg.Role, Content: msg.Content})
+ msgs := make([]api.Message, len(m.Messages))
+ for i, msg := range m.Messages {
+ msgs[i] = api.Message{Role: msg.Role, Content: msg.Content}
}
n := model.ParseName(req.Model)
@@ -1214,132 +1121,55 @@ func (s *Server) ProcessHandler(c *gin.Context) {
c.JSON(http.StatusOK, api.ProcessResponse{Models: models})
}
-// ChatPrompt builds up a prompt from a series of messages for the currently `loaded` model
-func chatPrompt(ctx context.Context, runner *runnerRef, template *template.Template, messages []api.Message, numCtx int) (string, error) {
- encode := func(s string) ([]int, error) {
- return runner.llama.Tokenize(ctx, s)
- }
-
- prompt, err := ChatPrompt(template, messages, numCtx, encode)
- if err != nil {
- return "", err
- }
-
- return prompt, nil
-}
-
func (s *Server) ChatHandler(c *gin.Context) {
- checkpointStart := time.Now()
-
var req api.ChatRequest
- err := c.ShouldBindJSON(&req)
- switch {
- case errors.Is(err, io.EOF):
+ if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
- case err != nil:
+ } else if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
- // validate the request
- switch {
- case req.Model == "":
- c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
+ caps := []Capability{CapabilityCompletion}
+ r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
+ if errors.Is(err, errCapabilityCompletion) {
+ c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)})
return
- case len(req.Format) > 0 && req.Format != "json":
- c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be json"})
+ } else if err != nil {
+ handleScheduleError(c, req.Model, err)
return
}
- model, err := GetModel(req.Model)
- if err != nil {
- var pErr *fs.PathError
- if errors.As(err, &pErr) {
- c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
- return
- }
- c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
- return
- }
-
- if !model.Has(CapabilityCompletion) {
- c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%s does not support chat", req.Model)})
- return
- }
-
- opts, err := modelOptions(model, req.Options)
- if err != nil {
- c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
- return
- }
-
- rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, req.KeepAlive)
- var runner *runnerRef
- select {
- case runner = <-rCh:
- case err = <-eCh:
- handleErrorResponse(c, err)
- return
- }
-
- checkpointLoaded := time.Now()
-
- // if the first message is not a system message, then add the model's default system message
- if len(req.Messages) > 0 && req.Messages[0].Role != "system" {
- req.Messages = append([]api.Message{
- {
- Role: "system",
- Content: model.System,
- },
- }, req.Messages...)
- }
-
- prompt, err := chatPrompt(c.Request.Context(), runner, model.Template, req.Messages, opts.NumCtx)
- if err != nil {
- c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
- return
- }
-
- // an empty request loads the model
- if len(req.Messages) == 0 || prompt == "" {
- resp := api.ChatResponse{
- CreatedAt: time.Now().UTC(),
+ if len(req.Messages) == 0 {
+ c.JSON(http.StatusOK, api.ChatResponse{
Model: req.Model,
+ CreatedAt: time.Now().UTC(),
+ Message: api.Message{Role: "assistant"},
Done: true,
DoneReason: "load",
- Message: api.Message{Role: "assistant"},
- }
- c.JSON(http.StatusOK, resp)
+ })
return
}
- // only send images that are in the prompt
- var i int
- var images []llm.ImageData
- for _, m := range req.Messages {
- for _, img := range m.Images {
- if !isSupportedImageType(img) {
- c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "unsupported image format"})
- return
- }
-
- if strings.Contains(prompt, fmt.Sprintf("[img-%d]", i)) {
- images = append(images, llm.ImageData{Data: img, ID: i})
- }
- i += 1
- }
+ prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, req.Messages)
+ if err != nil {
+ c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+ return
}
- slog.Debug("chat handler", "prompt", prompt, "images", len(images))
+ slog.Debug("chat request", "images", len(images), "prompt", prompt)
ch := make(chan any)
-
go func() {
defer close(ch)
-
- fn := func(r llm.CompletionResponse) {
- resp := api.ChatResponse{
+ if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
+ Prompt: prompt,
+ Images: images,
+ Format: req.Format,
+ Options: opts,
+ }, func(r llm.CompletionResponse) {
+ ch <- api.ChatResponse{
Model: req.Model,
CreatedAt: time.Now().UTC(),
Message: api.Message{Role: "assistant", Content: r.Content},
@@ -1352,64 +1182,52 @@ func (s *Server) ChatHandler(c *gin.Context) {
EvalDuration: r.EvalDuration,
},
}
-
- if r.Done {
- resp.TotalDuration = time.Since(checkpointStart)
- resp.LoadDuration = checkpointLoaded.Sub(checkpointStart)
- }
-
- ch <- resp
- }
-
- if err := runner.llama.Completion(c.Request.Context(), llm.CompletionRequest{
- Prompt: prompt,
- Format: req.Format,
- Images: images,
- Options: opts,
- }, fn); err != nil {
+ }); err != nil {
ch <- gin.H{"error": err.Error()}
}
}()
if req.Stream != nil && !*req.Stream {
- // Accumulate responses into the final response
- var final api.ChatResponse
+ var r api.ChatResponse
var sb strings.Builder
- for resp := range ch {
- switch r := resp.(type) {
+ for rr := range ch {
+ switch t := rr.(type) {
case api.ChatResponse:
- sb.WriteString(r.Message.Content)
- final = r
+ sb.WriteString(t.Message.Content)
+ r = t
case gin.H:
- if errorMsg, ok := r["error"].(string); ok {
- c.JSON(http.StatusInternalServerError, gin.H{"error": errorMsg})
- return
- } else {
- c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error format in response"})
- return
+ msg, ok := t["error"].(string)
+ if !ok {
+ msg = "unexpected error format in response"
}
+
+ c.JSON(http.StatusInternalServerError, gin.H{"error": msg})
+ return
default:
- c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error"})
+ c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected response"})
return
}
}
- final.Message = api.Message{Role: "assistant", Content: sb.String()}
- c.JSON(http.StatusOK, final)
+ r.Message.Content = sb.String()
+ c.JSON(http.StatusOK, r)
return
}
streamResponse(c, ch)
}
-func handleErrorResponse(c *gin.Context, err error) {
- if errors.Is(err, context.Canceled) {
+func handleScheduleError(c *gin.Context, name string, err error) {
+ switch {
+ case errors.Is(err, errRequired):
+ c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
+ case errors.Is(err, context.Canceled):
c.JSON(499, gin.H{"error": "request canceled"})
- return
- }
- if errors.Is(err, ErrMaxQueue) {
+ case errors.Is(err, ErrMaxQueue):
c.JSON(http.StatusServiceUnavailable, gin.H{"error": err.Error()})
- return
+ case errors.Is(err, os.ErrNotExist):
+ c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model %q not found, try pulling it first", name)})
+ default:
+ c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
- c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
diff --git a/template/template.go b/template/template.go
index d15f7156..b133b97e 100644
--- a/template/template.go
+++ b/template/template.go
@@ -5,6 +5,7 @@ import (
"embed"
"encoding/json"
"errors"
+ "fmt"
"io"
"math"
"slices"
@@ -14,6 +15,7 @@ import (
"text/template/parse"
"github.com/agnivade/levenshtein"
+ "github.com/ollama/ollama/api"
"golang.org/x/exp/maps"
)
@@ -74,30 +76,59 @@ func Named(s string) (*named, error) {
return nil, errors.New("no matching template found")
}
+var DefaultTemplate, _ = Parse("{{ .Prompt }}")
+
type Template struct {
*template.Template
raw string
}
+// response is a template node that can be added to templates that don't already have one
+var response = parse.ActionNode{
+ NodeType: parse.NodeAction,
+ Pipe: &parse.PipeNode{
+ NodeType: parse.NodePipe,
+ Cmds: []*parse.CommandNode{
+ {
+ NodeType: parse.NodeCommand,
+ Args: []parse.Node{
+ &parse.FieldNode{
+ NodeType: parse.NodeField,
+ Ident: []string{"Response"},
+ },
+ },
+ },
+ },
+ },
+}
+
+func Parse(s string) (*Template, error) {
+ tmpl := template.New("").Option("missingkey=zero")
+
+ tmpl, err := tmpl.Parse(s)
+ if err != nil {
+ return nil, err
+ }
+
+ t := Template{Template: tmpl, raw: s}
+ if vars := t.Vars(); !slices.Contains(vars, "messages") && !slices.Contains(vars, "response") {
+ // touch up the template and append {{ .Response }}
+ tmpl.Tree.Root.Nodes = append(tmpl.Tree.Root.Nodes, &response)
+ }
+
+ return &t, nil
+}
+
func (t *Template) String() string {
return t.raw
}
-var DefaultTemplate, _ = Parse("{{ .Prompt }}")
-
-func Parse(s string) (*Template, error) {
- t, err := template.New("").Option("missingkey=zero").Parse(s)
- if err != nil {
- return nil, err
- }
-
- return &Template{Template: t, raw: s}, nil
-}
-
func (t *Template) Vars() []string {
var vars []string
- for _, n := range t.Tree.Root.Nodes {
- vars = append(vars, parseNode(n)...)
+ for _, tt := range t.Templates() {
+ for _, n := range tt.Root.Nodes {
+ vars = append(vars, parseNode(n)...)
+ }
}
set := make(map[string]struct{})
@@ -110,6 +141,103 @@ func (t *Template) Vars() []string {
return vars
}
+type Values struct {
+ Messages []api.Message
+}
+
+func (t *Template) Execute(w io.Writer, v Values) error {
+ system, collated := collate(v.Messages)
+ if slices.Contains(t.Vars(), "messages") {
+ return t.Template.Execute(w, map[string]any{
+ "System": system,
+ "Messages": collated,
+ })
+ }
+
+ var b bytes.Buffer
+ var prompt, response string
+ for i, m := range collated {
+ if m.Role == "user" {
+ prompt = m.Content
+ } else {
+ response = m.Content
+ }
+
+ if i != len(collated)-1 && prompt != "" && response != "" {
+ if err := t.Template.Execute(&b, map[string]any{
+ "System": "",
+ "Prompt": prompt,
+ "Response": response,
+ }); err != nil {
+ return err
+ }
+
+ prompt = ""
+ response = ""
+ }
+ }
+
+ var cut bool
+ tree := t.Template.Copy()
+ // for the last message, cut everything after "{{ .Response }}"
+ tree.Root.Nodes = slices.DeleteFunc(tree.Root.Nodes, func(n parse.Node) bool {
+ if slices.Contains(parseNode(n), "Response") {
+ cut = true
+ }
+
+ return cut
+ })
+
+ if err := template.Must(template.New("").AddParseTree("", tree)).Execute(&b, map[string]any{
+ "System": system,
+ "Prompt": prompt,
+ }); err != nil {
+ return err
+ }
+
+ _, err := io.Copy(w, &b)
+ return err
+}
+
+type messages []*api.Message
+
+// collate messages based on role. consecutive messages of the same role are merged
+// into a single message. collate also pulls out and merges messages with Role == "system"
+// which are templated separately. As a side effect, it mangles message content adding image
+// tags ([img-%d]) as needed
+func collate(msgs []api.Message) (system string, collated messages) {
+ var n int
+ for i := range msgs {
+ msg := msgs[i]
+ if msg.Role == "system" {
+ if system != "" {
+ system += "\n\n"
+ }
+
+ system += msg.Content
+ continue
+ }
+
+ for range msg.Images {
+ imageTag := fmt.Sprintf("[img-%d]", n)
+ if !strings.Contains(msg.Content, "[img]") {
+ msg.Content = strings.TrimSpace("[img] " + msg.Content)
+ }
+
+ msg.Content = strings.Replace(msg.Content, "[img]", imageTag, 1)
+ n++
+ }
+
+ if len(collated) > 0 && collated[len(collated)-1].Role == msg.Role {
+ collated[len(collated)-1].Content += "\n\n" + msg.Content
+ } else {
+ collated = append(collated, &msg)
+ }
+ }
+
+ return
+}
+
func parseNode(n parse.Node) []string {
switch n := n.(type) {
case *parse.ActionNode:
@@ -152,6 +280,8 @@ func parseNode(n parse.Node) []string {
return names
case *parse.FieldNode:
return n.Ident
+ case *parse.TemplateNode:
+ return parseNode(n.Pipe)
}
return nil
diff --git a/template/template_test.go b/template/template_test.go
index eda4634f..ac16bd60 100644
--- a/template/template_test.go
+++ b/template/template_test.go
@@ -11,6 +11,7 @@ import (
"testing"
"text/template"
+ "github.com/ollama/ollama/api"
"github.com/ollama/ollama/llm"
)
@@ -64,13 +65,12 @@ func TestParse(t *testing.T) {
template string
vars []string
}{
- {"{{ .Prompt }}", []string{"prompt"}},
- {"{{ .System }} {{ .Prompt }}", []string{"prompt", "system"}},
+ {"{{ .Prompt }}", []string{"prompt", "response"}},
+ {"{{ .System }} {{ .Prompt }}", []string{"prompt", "response", "system"}},
{"{{ .System }} {{ .Prompt }} {{ .Response }}", []string{"prompt", "response", "system"}},
- {"{{ with .Tools }}{{ . }}{{ end }} {{ .System }} {{ .Prompt }}", []string{"prompt", "system", "tools"}},
+ {"{{ with .Tools }}{{ . }}{{ end }} {{ .System }} {{ .Prompt }}", []string{"prompt", "response", "system", "tools"}},
{"{{ range .Messages }}{{ .Role }} {{ .Content }}{{ end }}", []string{"content", "messages", "role"}},
{"{{ range .Messages }}{{ if eq .Role \"system\" }}SYSTEM: {{ .Content }}{{ else if eq .Role \"user\" }}USER: {{ .Content }}{{ else if eq .Role \"assistant\" }}ASSISTANT: {{ .Content }}{{ end }}{{ end }}", []string{"content", "messages", "role"}},
- {"{{ .Prompt }} {{ .Suffix }}", []string{"prompt", "suffix"}},
}
for _, tt := range cases {
@@ -87,3 +87,159 @@ func TestParse(t *testing.T) {
})
}
}
+
+func TestExecuteWithMessages(t *testing.T) {
+ type template struct {
+ name string
+ template string
+ }
+ cases := []struct {
+ name string
+ templates []template
+ values Values
+ expected string
+ }{
+ {
+ "mistral",
+ []template{
+ {"no response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] `},
+ {"response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`},
+ {"messages", `{{- range $index, $_ := .Messages }}
+{{- if eq .Role "user" }}[INST] {{ if and (eq (len (slice $.Messages $index)) 1) $.System }}{{ $.System }}{{ "\n\n" }}
+{{- end }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}
+{{- end }}
+{{- end }}`},
+ },
+ Values{
+ Messages: []api.Message{
+ {Role: "user", Content: "Hello friend!"},
+ {Role: "assistant", Content: "Hello human!"},
+ {Role: "user", Content: "What is your name?"},
+ },
+ },
+ `[INST] Hello friend![/INST] Hello human![INST] What is your name?[/INST] `,
+ },
+ {
+ "mistral system",
+ []template{
+ {"no response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] `},
+ {"response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`},
+ {"messages", `
+{{- range $index, $_ := .Messages }}
+{{- if eq .Role "user" }}[INST] {{ if and (eq (len (slice $.Messages $index)) 1) $.System }}{{ $.System }}{{ "\n\n" }}
+{{- end }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}
+{{- end }}
+{{- end }}`},
+ },
+ Values{
+ Messages: []api.Message{
+ {Role: "system", Content: "You are a helpful assistant!"},
+ {Role: "user", Content: "Hello friend!"},
+ {Role: "assistant", Content: "Hello human!"},
+ {Role: "user", Content: "What is your name?"},
+ },
+ },
+ `[INST] Hello friend![/INST] Hello human![INST] You are a helpful assistant!
+
+What is your name?[/INST] `,
+ },
+ {
+ "chatml",
+ []template{
+ // this does not have a "no response" test because it's impossible to render the same output
+ {"response", `{{ if .System }}<|im_start|>system
+{{ .System }}<|im_end|>
+{{ end }}{{ if .Prompt }}<|im_start|>user
+{{ .Prompt }}<|im_end|>
+{{ end }}<|im_start|>assistant
+{{ .Response }}<|im_end|>
+`},
+ {"messages", `
+{{- range $index, $_ := .Messages }}
+{{- if and (eq .Role "user") (eq (len (slice $.Messages $index)) 1) $.System }}<|im_start|>system
+{{ $.System }}<|im_end|>{{ "\n" }}
+{{- end }}<|im_start|>{{ .Role }}
+{{ .Content }}<|im_end|>{{ "\n" }}
+{{- end }}<|im_start|>assistant
+`},
+ },
+ Values{
+ Messages: []api.Message{
+ {Role: "system", Content: "You are a helpful assistant!"},
+ {Role: "user", Content: "Hello friend!"},
+ {Role: "assistant", Content: "Hello human!"},
+ {Role: "user", Content: "What is your name?"},
+ },
+ },
+ `<|im_start|>user
+Hello friend!<|im_end|>
+<|im_start|>assistant
+Hello human!<|im_end|>
+<|im_start|>system
+You are a helpful assistant!<|im_end|>
+<|im_start|>user
+What is your name?<|im_end|>
+<|im_start|>assistant
+`,
+ },
+ {
+ "moondream",
+ []template{
+ // this does not have a "no response" test because it's impossible to render the same output
+ {"response", `{{ if .Prompt }}Question: {{ .Prompt }}
+
+{{ end }}Answer: {{ .Response }}
+
+`},
+ {"messages", `
+{{- range .Messages }}
+{{- if eq .Role "user" }}Question: {{ .Content }}{{ "\n\n" }}
+{{- else if eq .Role "assistant" }}Answer: {{ .Content }}{{ "\n\n" }}
+{{- end }}
+{{- end }}Answer: `},
+ },
+ Values{
+ Messages: []api.Message{
+ {Role: "user", Content: "What's in this image?", Images: []api.ImageData{[]byte("")}},
+ {Role: "assistant", Content: "It's a hot dog."},
+ {Role: "user", Content: "What's in _this_ image?"},
+ {Role: "user", Images: []api.ImageData{[]byte("")}},
+ {Role: "user", Content: "Is it a hot dog?"},
+ },
+ },
+ `Question: [img-0] What's in this image?
+
+Answer: It's a hot dog.
+
+Question: What's in _this_ image?
+
+[img-1]
+
+Is it a hot dog?
+
+Answer: `,
+ },
+ }
+
+ for _, tt := range cases {
+ t.Run(tt.name, func(t *testing.T) {
+ for _, ttt := range tt.templates {
+ t.Run(ttt.name, func(t *testing.T) {
+ tmpl, err := Parse(ttt.template)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ var b bytes.Buffer
+ if err := tmpl.Execute(&b, tt.values); err != nil {
+ t.Fatal(err)
+ }
+
+ if b.String() != tt.expected {
+ t.Errorf("expected\n%s,\ngot\n%s", tt.expected, b.String())
+ }
+ })
+ }
+ })
+ }
+}