diff --git a/docs/modelfile.md b/docs/modelfile.md
index b92af782..1d0030f4 100644
--- a/docs/modelfile.md
+++ b/docs/modelfile.md
@@ -86,7 +86,7 @@ There are two ways to view `Modelfile`s underlying the models in [ollama.com/lib
# FROM llama2:13b
FROM /root/.ollama/models/blobs/sha256:123abc
- TEMPLATE """[INST] {{ if and .First .System }}<>{{ .System }}<>
+ TEMPLATE """[INST] {{ if .System }}<>{{ .System }}<>
{{ end }}{{ .Prompt }} [/INST] """
SYSTEM """"""
@@ -154,31 +154,23 @@ PARAMETER
### TEMPLATE
-`TEMPLATE` of the full prompt template to be passed into the model. It may include (optionally) a system message and a user's prompt. This is used to create a full custom prompt, and syntax may be model specific. You can usually find the template for a given model in the readme for that model.
+`TEMPLATE` of the full prompt template to be passed into the model. It may include (optionally) a system message, a user's message and the response from the model. Note: syntax may be model specific. Templates use Go [template syntax](https://pkg.go.dev/text/template).
#### Template Variables
-| Variable | Description |
-| ----------------- | ------------------------------------------------------------------------------------------------------------- |
-| `{{ .System }}` | The system message used to specify custom behavior, this must also be set in the Modelfile as an instruction. |
-| `{{ .Prompt }}` | The incoming prompt, this is not specified in the model file and will be set based on input. |
-| `{{ .Response }}` | The response from the LLM, if not specified response is appended to the end of the template. |
-| `{{ .First }}` | A boolean value used to render specific template information for the first generation of a session. |
+| Variable | Description |
+| ----------------- | --------------------------------------------------------------------------------------------- |
+| `{{ .System }}` | The system message used to specify custom behavior. |
+| `{{ .Prompt }}` | The user prompt message. |
+| `{{ .Response }}` | The response from the model. When generating a response, text after this variable is omitted. |
-```modelfile
-TEMPLATE """
-{{- if .First }}
-### System:
-{{ .System }}
-{{- end }}
-
-### User:
-{{ .Prompt }}
-
-### Response:
+```
+TEMPLATE """{{ if .System }}<|im_start|>system
+{{ .System }}<|im_end|>
+{{ end }}{{ if .Prompt }}<|im_start|>user
+{{ .Prompt }}<|im_end|>
+{{ end }}<|im_start|>assistant
"""
-
-SYSTEM """"""
```
### SYSTEM
diff --git a/server/images.go b/server/images.go
index fb1c48e1..55b68456 100644
--- a/server/images.go
+++ b/server/images.go
@@ -19,7 +19,6 @@ import (
"strconv"
"strings"
"text/template"
- "text/template/parse"
"golang.org/x/exp/slices"
@@ -58,162 +57,6 @@ type Message struct {
Content string `json:"content"`
}
-type PromptVars struct {
- System string
- Prompt string
- Response string
- First bool
- Images []llm.ImageData
-}
-
-// extractParts extracts the parts of the template before and after the {{.Response}} node.
-func extractParts(tmplStr string) (pre string, post string, err error) {
- tmpl, err := template.New("").Parse(tmplStr)
- if err != nil {
- return "", "", err
- }
-
- var foundResponse bool
-
- for _, node := range tmpl.Tree.Root.Nodes {
- if node.Type() == parse.NodeAction && node.String() == "{{.Response}}" {
- foundResponse = true
- }
- if !foundResponse {
- pre += node.String()
- } else {
- post += node.String()
- }
- }
-
- return pre, post, nil
-}
-
-func Prompt(promptTemplate string, p PromptVars) (string, error) {
- var prompt strings.Builder
- // Use the "missingkey=zero" option to handle missing variables without panicking
- tmpl, err := template.New("").Option("missingkey=zero").Parse(promptTemplate)
- if err != nil {
- return "", err
- }
-
- vars := map[string]any{
- "System": p.System,
- "Prompt": p.Prompt,
- "Response": p.Response,
- "First": p.First,
- }
-
- var sb strings.Builder
- if err := tmpl.Execute(&sb, vars); err != nil {
- return "", err
- }
- prompt.WriteString(sb.String())
-
- if !strings.Contains(prompt.String(), p.Response) {
- // if the response is not in the prompt template, append it to the end
- prompt.WriteString(p.Response)
- }
-
- return prompt.String(), nil
-}
-
-// PreResponsePrompt returns the prompt before the response tag
-func (m *Model) PreResponsePrompt(p PromptVars) (string, error) {
- pre, _, err := extractParts(m.Template)
- if err != nil {
- return "", err
- }
-
- return Prompt(pre, p)
-}
-
-// PostResponseTemplate returns the template after the response tag
-func (m *Model) PostResponseTemplate(p PromptVars) (string, error) {
- if p.System == "" {
- // use the default system prompt for this model if one is not specified
- p.System = m.System
- }
- _, post, err := extractParts(m.Template)
- if err != nil {
- return "", err
- }
-
- if post == "" {
- // if there is no post-response template, return the provided response
- return p.Response, nil
- }
-
- return Prompt(post, p)
-}
-
-type ChatHistory struct {
- Prompts []PromptVars
- LastSystem string
-}
-
-// ChatPrompts returns a list of formatted chat prompts from a list of messages
-func (m *Model) ChatPrompts(msgs []api.Message) (*ChatHistory, error) {
- // build the prompt from the list of messages
- lastSystem := m.System
- currentVars := PromptVars{
- First: true,
- System: m.System,
- }
-
- prompts := []PromptVars{}
- var images []llm.ImageData
-
- for _, msg := range msgs {
- switch strings.ToLower(msg.Role) {
- case "system":
- // if this is the first message it overrides the system prompt in the modelfile
- if !currentVars.First && currentVars.System != "" {
- prompts = append(prompts, currentVars)
- currentVars = PromptVars{}
- }
- currentVars.System = msg.Content
- lastSystem = msg.Content
- case "user":
- if currentVars.Prompt != "" {
- prompts = append(prompts, currentVars)
- currentVars = PromptVars{}
- }
-
- currentVars.Prompt = msg.Content
-
- if len(m.ProjectorPaths) > 0 {
- for i := range msg.Images {
- id := len(images) + i
- currentVars.Prompt += fmt.Sprintf(" [img-%d]", id)
- currentVars.Images = append(currentVars.Images, llm.ImageData{
- ID: id,
- Data: msg.Images[i],
- })
- }
-
- images = append(images, currentVars.Images...)
- }
- case "assistant":
- currentVars.Response = msg.Content
- prompts = append(prompts, currentVars)
- currentVars = PromptVars{}
- default:
- return nil, fmt.Errorf("invalid role: %s, role must be one of [system, user, assistant]", msg.Role)
- }
- }
-
- // Append the last set of vars if they are non-empty
- if currentVars.Prompt != "" || currentVars.System != "" {
- prompts = append(prompts, currentVars)
- }
-
- return &ChatHistory{
- Prompts: prompts,
- LastSystem: lastSystem,
- }, nil
-}
-
type ManifestV2 struct {
SchemaVersion int `json:"schemaVersion"`
MediaType string `json:"mediaType"`
diff --git a/server/images_test.go b/server/images_test.go
deleted file mode 100644
index 4c2a7cac..00000000
--- a/server/images_test.go
+++ /dev/null
@@ -1,442 +0,0 @@
-package server
-
-import (
- "bytes"
- "strings"
- "testing"
-
- "github.com/jmorganca/ollama/api"
-)
-
-func TestPrompt(t *testing.T) {
- tests := []struct {
- name string
- template string
- vars PromptVars
- want string
- wantErr bool
- }{
- {
- name: "System Prompt",
- template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
- vars: PromptVars{
- System: "You are a Wizard.",
- Prompt: "What are the potion ingredients?",
- },
- want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]",
- },
- {
- name: "System Prompt with Response",
- template: "[INST] {{ .System }} {{ .Prompt }} [/INST] {{ .Response }}",
- vars: PromptVars{
- System: "You are a Wizard.",
- Prompt: "What are the potion ingredients?",
- Response: "I don't know.",
- },
- want: "[INST] You are a Wizard. What are the potion ingredients? [/INST] I don't know.",
- },
- {
- name: "Conditional Logic Nodes",
- template: "[INST] {{if .First}}Hello!{{end}} {{ .System }} {{ .Prompt }} [/INST] {{ .Response }}",
- vars: PromptVars{
- First: true,
- System: "You are a Wizard.",
- Prompt: "What are the potion ingredients?",
- Response: "I don't know.",
- },
- want: "[INST] Hello! You are a Wizard. What are the potion ingredients? [/INST] I don't know.",
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- got, err := Prompt(tt.template, tt.vars)
- if (err != nil) != tt.wantErr {
- t.Errorf("Prompt() error = %v, wantErr %v", err, tt.wantErr)
- return
- }
- if got != tt.want {
- t.Errorf("Prompt() got = %v, want %v", got, tt.want)
- }
- })
- }
-}
-
-func TestModel_PreResponsePrompt(t *testing.T) {
- tests := []struct {
- name string
- template string
- vars PromptVars
- want string
- wantErr bool
- }{
- {
- name: "No Response in Template",
- template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
- vars: PromptVars{
- System: "You are a Wizard.",
- Prompt: "What are the potion ingredients?",
- },
- want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]",
- },
- {
- name: "Response in Template",
- template: "[INST] {{ .System }} {{ .Prompt }} [/INST] {{ .Response }}",
- vars: PromptVars{
- System: "You are a Wizard.",
- Prompt: "What are the potion ingredients?",
- },
- want: "[INST] You are a Wizard. What are the potion ingredients? [/INST] ",
- },
- {
- name: "Response in Template with Trailing Formatting",
- template: "<|im_start|>user\n{{ .Prompt }}<|im_end|><|im_start|>assistant\n{{ .Response }}<|im_end|>",
- vars: PromptVars{
- Prompt: "What are the potion ingredients?",
- },
- want: "<|im_start|>user\nWhat are the potion ingredients?<|im_end|><|im_start|>assistant\n",
- },
- {
- name: "Response in Template with Alternative Formatting",
- template: "<|im_start|>user\n{{.Prompt}}<|im_end|><|im_start|>assistant\n{{.Response}}<|im_end|>",
- vars: PromptVars{
- Prompt: "What are the potion ingredients?",
- },
- want: "<|im_start|>user\nWhat are the potion ingredients?<|im_end|><|im_start|>assistant\n",
- },
- }
-
- for _, tt := range tests {
- m := Model{Template: tt.template}
- t.Run(tt.name, func(t *testing.T) {
- got, err := m.PreResponsePrompt(tt.vars)
- if (err != nil) != tt.wantErr {
- t.Errorf("PreResponsePrompt() error = %v, wantErr %v", err, tt.wantErr)
- return
- }
- if got != tt.want {
- t.Errorf("PreResponsePrompt() got = %v, want %v", got, tt.want)
- }
- })
- }
-}
-
-func TestModel_PostResponsePrompt(t *testing.T) {
- tests := []struct {
- name string
- template string
- vars PromptVars
- want string
- wantErr bool
- }{
- {
- name: "No Response in Template",
- template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
- vars: PromptVars{
- Response: "I don't know.",
- },
- want: "I don't know.",
- },
- {
- name: "Response in Template",
- template: "[INST] {{ .System }} {{ .Prompt }} [/INST] {{ .Response }}",
- vars: PromptVars{
- Response: "I don't know.",
- },
- want: "I don't know.",
- },
- {
- name: "Response in Template with Trailing Formatting",
- template: "<|im_start|>user\n{{ .Prompt }}<|im_end|><|im_start|>assistant\n{{ .Response }}<|im_end|>",
- vars: PromptVars{
- Response: "I don't know.",
- },
- want: "I don't know.<|im_end|>",
- },
- {
- name: "Response in Template with Alternative Formatting",
- template: "<|im_start|>user\n{{.Prompt}}<|im_end|><|im_start|>assistant\n{{.Response}}<|im_end|>",
- vars: PromptVars{
- Response: "I don't know.",
- },
- want: "I don't know.<|im_end|>",
- },
- }
-
- for _, tt := range tests {
- m := Model{Template: tt.template}
- t.Run(tt.name, func(t *testing.T) {
- got, err := m.PostResponseTemplate(tt.vars)
- if (err != nil) != tt.wantErr {
- t.Errorf("PostResponseTemplate() error = %v, wantErr %v", err, tt.wantErr)
- return
- }
- if got != tt.want {
- t.Errorf("PostResponseTemplate() got = %v, want %v", got, tt.want)
- }
- })
- }
-}
-
-func TestModel_PreResponsePrompt_PostResponsePrompt(t *testing.T) {
- tests := []struct {
- name string
- template string
- preVars PromptVars
- postVars PromptVars
- want string
- wantErr bool
- }{
- {
- name: "Response in Template",
- template: "<|im_start|>user\n{{.Prompt}}<|im_end|><|im_start|>assistant\n{{.Response}}<|im_end|>",
- preVars: PromptVars{
- Prompt: "What are the potion ingredients?",
- },
- postVars: PromptVars{
- Prompt: "What are the potion ingredients?",
- Response: "Sugar.",
- },
- want: "<|im_start|>user\nWhat are the potion ingredients?<|im_end|><|im_start|>assistant\nSugar.<|im_end|>",
- },
- {
- name: "No Response in Template",
- template: "<|im_start|>user\n{{.Prompt}}<|im_end|><|im_start|>assistant\n",
- preVars: PromptVars{
- Prompt: "What are the potion ingredients?",
- },
- postVars: PromptVars{
- Prompt: "What are the potion ingredients?",
- Response: "Spice.",
- },
- want: "<|im_start|>user\nWhat are the potion ingredients?<|im_end|><|im_start|>assistant\nSpice.",
- },
- }
-
- for _, tt := range tests {
- m := Model{Template: tt.template}
- t.Run(tt.name, func(t *testing.T) {
- pre, err := m.PreResponsePrompt(tt.preVars)
- if (err != nil) != tt.wantErr {
- t.Errorf("PreResponsePrompt() error = %v, wantErr %v", err, tt.wantErr)
- return
- }
- post, err := m.PostResponseTemplate(tt.postVars)
- if err != nil {
- t.Errorf("PostResponseTemplate() error = %v, wantErr %v", err, tt.wantErr)
- return
- }
- result := pre + post
- if result != tt.want {
- t.Errorf("Prompt() got = %v, want %v", result, tt.want)
- }
- })
- }
-}
-
-func chatHistoryEqual(a, b ChatHistory) bool {
- if len(a.Prompts) != len(b.Prompts) {
- return false
- }
- for i, v := range a.Prompts {
-
- if v.First != b.Prompts[i].First {
- return false
- }
-
- if v.Response != b.Prompts[i].Response {
- return false
- }
-
- if v.Prompt != b.Prompts[i].Prompt {
- return false
- }
-
- if v.System != b.Prompts[i].System {
- return false
- }
-
- if len(v.Images) != len(b.Prompts[i].Images) {
- return false
- }
-
- for j, img := range v.Images {
- if img.ID != b.Prompts[i].Images[j].ID {
- return false
- }
-
- if !bytes.Equal(img.Data, b.Prompts[i].Images[j].Data) {
- return false
- }
- }
- }
- return a.LastSystem == b.LastSystem
-}
-
-func TestChat(t *testing.T) {
- tests := []struct {
- name string
- model Model
- msgs []api.Message
- want ChatHistory
- wantErr string
- }{
- {
- name: "Single Message",
- model: Model{
- Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
- },
- msgs: []api.Message{
- {
- Role: "system",
- Content: "You are a Wizard.",
- },
- {
- Role: "user",
- Content: "What are the potion ingredients?",
- },
- },
- want: ChatHistory{
- Prompts: []PromptVars{
- {
- System: "You are a Wizard.",
- Prompt: "What are the potion ingredients?",
- First: true,
- },
- },
- LastSystem: "You are a Wizard.",
- },
- },
- {
- name: "Message History",
- model: Model{
- Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
- },
- msgs: []api.Message{
- {
- Role: "system",
- Content: "You are a Wizard.",
- },
- {
- Role: "user",
- Content: "What are the potion ingredients?",
- },
- {
- Role: "assistant",
- Content: "sugar",
- },
- {
- Role: "user",
- Content: "Anything else?",
- },
- },
- want: ChatHistory{
- Prompts: []PromptVars{
- {
- System: "You are a Wizard.",
- Prompt: "What are the potion ingredients?",
- Response: "sugar",
- First: true,
- },
- {
- Prompt: "Anything else?",
- },
- },
- LastSystem: "You are a Wizard.",
- },
- },
- {
- name: "Assistant Only",
- model: Model{
- Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
- },
- msgs: []api.Message{
- {
- Role: "assistant",
- Content: "everything nice",
- },
- },
- want: ChatHistory{
- Prompts: []PromptVars{
- {
- Response: "everything nice",
- First: true,
- },
- },
- },
- },
- {
- name: "Last system message is preserved from modelfile",
- model: Model{
- Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
- System: "You are Mojo Jojo.",
- },
- msgs: []api.Message{
- {
- Role: "user",
- Content: "hi",
- },
- },
- want: ChatHistory{
- Prompts: []PromptVars{
- {
- System: "You are Mojo Jojo.",
- Prompt: "hi",
- First: true,
- },
- },
- LastSystem: "You are Mojo Jojo.",
- },
- },
- {
- name: "Last system message is preserved from messages",
- model: Model{
- Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
- System: "You are Mojo Jojo.",
- },
- msgs: []api.Message{
- {
- Role: "system",
- Content: "You are Professor Utonium.",
- },
- },
- want: ChatHistory{
- Prompts: []PromptVars{
- {
- System: "You are Professor Utonium.",
- First: true,
- },
- },
- LastSystem: "You are Professor Utonium.",
- },
- },
- {
- name: "Invalid Role",
- msgs: []api.Message{
- {
- Role: "not-a-role",
- Content: "howdy",
- },
- },
- wantErr: "invalid role: not-a-role",
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- got, err := tt.model.ChatPrompts(tt.msgs)
- if tt.wantErr != "" {
- if err == nil {
- t.Errorf("ChatPrompt() expected error, got nil")
- }
- if !strings.Contains(err.Error(), tt.wantErr) {
- t.Errorf("ChatPrompt() error = %v, wantErr %v", err, tt.wantErr)
- }
- return
- }
- if !chatHistoryEqual(*got, tt.want) {
- t.Errorf("ChatPrompt() got = %#v, want %#v", got, tt.want)
- }
- })
- }
-}
diff --git a/server/prompt.go b/server/prompt.go
new file mode 100644
index 00000000..c83075d9
--- /dev/null
+++ b/server/prompt.go
@@ -0,0 +1,224 @@
+package server
+
+import (
+ "fmt"
+ "log/slog"
+ "strings"
+ "text/template"
+ "text/template/parse"
+
+ "github.com/jmorganca/ollama/api"
+)
+
+// isResponseNode checks if the node contains .Response
+func isResponseNode(node *parse.ActionNode) bool {
+ for _, cmd := range node.Pipe.Cmds {
+ for _, arg := range cmd.Args {
+ if fieldNode, ok := arg.(*parse.FieldNode); ok && len(fieldNode.Ident) > 0 {
+ if fieldNode.Ident[0] == "Response" {
+ return true
+ }
+ }
+ }
+ }
+ return false
+}
+
+// formatTemplateForResponse formats the template AST to:
+// 1. remove all nodes after the first .Response (if generate=true)
+// 2. add a .Response node to the end if it doesn't exist
+// TODO(jmorganca): this should recursively cut the template before the first .Response
+func formatTemplateForResponse(tmpl *template.Template, generate bool) {
+ var found bool
+ for i, node := range tmpl.Tree.Root.Nodes {
+ if actionNode, ok := node.(*parse.ActionNode); ok {
+ if isResponseNode(actionNode) {
+ found = true
+ if generate {
+ tmpl.Tree.Root.Nodes = tmpl.Tree.Root.Nodes[:i+1]
+ break
+ }
+ }
+ }
+ }
+
+ if !found {
+ // add the response node if it doesn't exist
+ responseFieldNode := &parse.FieldNode{NodeType: parse.NodeField, Ident: []string{"Response"}}
+ responsePipeNode := &parse.PipeNode{NodeType: parse.NodePipe, Cmds: []*parse.CommandNode{{NodeType: parse.NodeCommand, Args: []parse.Node{responseFieldNode}}}}
+ responseActionNode := &parse.ActionNode{NodeType: parse.NodeAction, Pipe: responsePipeNode}
+ tmpl.Tree.Root.Nodes = append(tmpl.Tree.Root.Nodes, responseActionNode)
+ }
+}
+
+// Prompt renders a prompt from a template. If generate is set to true,
+// the response and parts of the template following it are not rendered
+func Prompt(tmpl, system, prompt, response string, generate bool) (string, error) {
+ parsed, err := template.New("").Option("missingkey=zero").Parse(tmpl)
+ if err != nil {
+ return "", err
+ }
+
+ formatTemplateForResponse(parsed, generate)
+
+ vars := map[string]any{
+ "System": system,
+ "Prompt": prompt,
+ "Response": response,
+ }
+
+ var sb strings.Builder
+ if err := parsed.Execute(&sb, vars); err != nil {
+ return "", err
+ }
+
+ return sb.String(), nil
+}
+
+func countTokens(tmpl string, system string, prompt string, response string, encode func(string) ([]int, error)) (int, error) {
+ rendered, err := Prompt(tmpl, system, prompt, response, false)
+ if err != nil {
+ return 0, err
+ }
+
+ tokens, err := encode(rendered)
+ if err != nil {
+ slog.Error("failed to encode prompt", "err", err)
+ return 0, err
+ }
+
+ return len(tokens), err
+}
+
+// ChatPrompt builds up a prompt from a series of messages, truncating based on context window size
+func ChatPrompt(tmpl string, system string, messages []api.Message, window int, encode func(string) ([]int, error)) (string, error) {
+ type prompt struct {
+ System string
+ Prompt string
+ Response string
+
+ images []int
+ tokens int
+ }
+
+ var p prompt
+
+ // Set the first system prompt to the model's system prompt
+ if system != "" {
+ p.System = system
+ }
+
+ // iterate through messages to build up {system,user,response} prompts
+ var imgId int
+ var prompts []prompt
+ for _, msg := range messages {
+ switch strings.ToLower(msg.Role) {
+ case "system":
+ if p.System != "" || p.Prompt != "" || p.Response != "" {
+ prompts = append(prompts, p)
+ p = prompt{}
+ }
+
+ p.System = msg.Content
+ case "user":
+ if p.Prompt != "" || p.Response != "" {
+ prompts = append(prompts, p)
+ p = prompt{}
+ }
+
+ p.Prompt = msg.Content
+
+ for range msg.Images {
+ p.Prompt += fmt.Sprintf(" [img-%d]", imgId)
+ p.images = append(p.images, imgId)
+ imgId += 1
+ }
+ case "assistant":
+ if p.Response != "" {
+ prompts = append(prompts, p)
+ p = prompt{}
+ }
+
+ p.Response = msg.Content
+ default:
+ return "", fmt.Errorf("invalid role: %s, role must be one of [system, user, assistant]", msg.Role)
+ }
+ }
+
+ // add final prompt
+ if p.System != "" || p.Prompt != "" || p.Response != "" {
+ prompts = append(prompts, p)
+ }
+
+ // calculate token lengths for each prompt, estimating 768 tokens per images
+ for i, p := range prompts {
+ tokens, err := countTokens(tmpl, p.System, p.Prompt, p.Response, encode)
+ if err != nil {
+ return "", err
+ }
+
+ prompts[i].tokens = tokens + len(prompts[i].images)*768
+ }
+
+ // truncate images and prompts starting from the beginning of the list
+ // until either one prompt remains or the total tokens fits the context window
+ // TODO (jmorganca): this doesn't account for the context window room required for the response
+ for {
+ var required int
+ for _, p := range prompts {
+ required += p.tokens
+ }
+
+ required += 1 // for bos token
+
+ if required <= window {
+ slog.Debug("prompt now fits in context window", "required", required, "window", window)
+ break
+ }
+
+ prompt := &prompts[0]
+
+ if len(prompt.images) > 1 {
+ img := prompt.images[0]
+ slog.Debug("prompt longer than context window, removing image", "id", img, "required", required, "window", window)
+ prompt.images = prompt.images[1:]
+ prompt.Prompt = strings.Replace(prompt.Prompt, fmt.Sprintf(" [img-%d]", img), "", 1)
+ prompt.tokens -= 768
+ continue
+ }
+
+ if len(prompts) > 1 {
+ slog.Debug("required tokens longer than context window, removing first prompt", "prompt", prompts[0].tokens, "required", required, "window", window)
+ system := prompt.System
+ prompts = prompts[1:]
+
+ if system != "" && prompts[0].System == "" {
+ prompts[0].System = system
+
+ tokens, err := countTokens(tmpl, prompts[0].System, prompts[0].Prompt, prompts[0].Response, encode)
+ if err != nil {
+ return "", err
+ }
+
+ prompts[0].tokens = tokens + len(prompts[0].images)*768
+ }
+
+ continue
+ }
+
+ // stop truncating if there's only one prompt left
+ break
+ }
+
+ var sb strings.Builder
+ for i, p := range prompts {
+ // last prompt should leave the response unrendered (for completion)
+ rendered, err := Prompt(tmpl, p.System, p.Prompt, p.Response, i == len(prompts)-1)
+ if err != nil {
+ return "", err
+ }
+ sb.WriteString(rendered)
+ }
+
+ return sb.String(), nil
+}
diff --git a/server/prompt_test.go b/server/prompt_test.go
new file mode 100644
index 00000000..0ac8e314
--- /dev/null
+++ b/server/prompt_test.go
@@ -0,0 +1,234 @@
+package server
+
+import (
+ "strings"
+ "testing"
+
+ "github.com/jmorganca/ollama/api"
+)
+
+func TestPrompt(t *testing.T) {
+ tests := []struct {
+ name string
+ template string
+ system string
+ prompt string
+ response string
+ generate bool
+ want string
+ }{
+ {
+ name: "simple prompt",
+ template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
+ system: "You are a Wizard.",
+ prompt: "What are the potion ingredients?",
+ want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]",
+ },
+ {
+ name: "implicit response",
+ template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
+ system: "You are a Wizard.",
+ prompt: "What are the potion ingredients?",
+ response: "I don't know.",
+ want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]I don't know.",
+ },
+ {
+ name: "response",
+ template: "[INST] {{ .System }} {{ .Prompt }} [/INST] {{ .Response }}",
+ system: "You are a Wizard.",
+ prompt: "What are the potion ingredients?",
+ response: "I don't know.",
+ want: "[INST] You are a Wizard. What are the potion ingredients? [/INST] I don't know.",
+ },
+ {
+ name: "cut",
+ template: "{{ .System }}{{ .Prompt }}{{ .Response }}",
+ system: "You are a Wizard.",
+ prompt: "What are the potion ingredients?",
+ response: "I don't know.",
+ generate: true,
+ want: "You are a Wizard.What are the potion ingredients?I don't know.",
+ },
+ {
+ name: "nocut",
+ template: "{{ .System }}{{ .Prompt }}{{ .Response }}",
+ system: "You are a Wizard.",
+ prompt: "What are the potion ingredients?",
+ response: "I don't know.",
+ want: "You are a Wizard.What are the potion ingredients?I don't know.",
+ },
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ got, err := Prompt(tc.template, tc.system, tc.prompt, tc.response, tc.generate)
+ if err != nil {
+ t.Errorf("error = %v", err)
+ }
+
+ if got != tc.want {
+ t.Errorf("got = %v, want %v", got, tc.want)
+ }
+ })
+ }
+}
+
+func TestChatPrompt(t *testing.T) {
+ tests := []struct {
+ name string
+ template string
+ system string
+ messages []api.Message
+ window int
+ want string
+ }{
+ {
+ name: "simple prompt",
+ template: "[INST] {{ .Prompt }} [/INST]",
+ messages: []api.Message{
+ {Role: "user", Content: "Hello"},
+ },
+ window: 1024,
+ want: "[INST] Hello [/INST]",
+ },
+ {
+ name: "with default system message",
+ system: "You are a Wizard.",
+ template: "[INST] {{ if .System }}<>{{ .System }}<> {{ end }}{{ .Prompt }} [/INST]",
+ messages: []api.Message{
+ {Role: "user", Content: "Hello"},
+ },
+ window: 1024,
+ want: "[INST] <>You are a Wizard.<> Hello [/INST]",
+ },
+ {
+ name: "with system message",
+ template: "[INST] {{ if .System }}<>{{ .System }}<> {{ end }}{{ .Prompt }} [/INST]",
+ messages: []api.Message{
+ {Role: "system", Content: "You are a Wizard."},
+ {Role: "user", Content: "Hello"},
+ },
+ window: 1024,
+ want: "[INST] <>You are a Wizard.<> Hello [/INST]",
+ },
+ {
+ name: "with response",
+ template: "[INST] {{ if .System }}<>{{ .System }}<> {{ end }}{{ .Prompt }} [/INST] {{ .Response }}",
+ messages: []api.Message{
+ {Role: "system", Content: "You are a Wizard."},
+ {Role: "user", Content: "Hello"},
+ {Role: "assistant", Content: "I am?"},
+ },
+ window: 1024,
+ want: "[INST] <>You are a Wizard.<> Hello [/INST] I am?",
+ },
+ {
+ name: "with implicit response",
+ template: "[INST] {{ if .System }}<>{{ .System }}<> {{ end }}{{ .Prompt }} [/INST]",
+ messages: []api.Message{
+ {Role: "system", Content: "You are a Wizard."},
+ {Role: "user", Content: "Hello"},
+ {Role: "assistant", Content: "I am?"},
+ },
+ window: 1024,
+ want: "[INST] <>You are a Wizard.<> Hello [/INST]I am?",
+ },
+ {
+ name: "with conversation",
+ template: "[INST] {{ if .System }}<>{{ .System }}<> {{ end }}{{ .Prompt }} [/INST] {{ .Response }} ",
+ messages: []api.Message{
+ {Role: "system", Content: "You are a Wizard."},
+ {Role: "user", Content: "What are the potion ingredients?"},
+ {Role: "assistant", Content: "sugar"},
+ {Role: "user", Content: "Anything else?"},
+ },
+ window: 1024,
+ want: "[INST] <>You are a Wizard.<> What are the potion ingredients? [/INST] sugar [INST] Anything else? [/INST] ",
+ },
+ {
+ name: "with truncation",
+ template: "{{ .System }} {{ .Prompt }} {{ .Response }} ",
+ messages: []api.Message{
+ {Role: "system", Content: "You are a Wizard."},
+ {Role: "user", Content: "Hello"},
+ {Role: "assistant", Content: "I am?"},
+ {Role: "user", Content: "Why is the sky blue?"},
+ {Role: "assistant", Content: "The sky is blue from rayleigh scattering"},
+ },
+ window: 10,
+ want: "You are a Wizard. Why is the sky blue? The sky is blue from rayleigh scattering",
+ },
+ {
+ name: "images",
+ template: "{{ .System }} {{ .Prompt }}",
+ messages: []api.Message{
+ {Role: "system", Content: "You are a Wizard."},
+ {Role: "user", Content: "Hello", Images: []api.ImageData{[]byte("base64")}},
+ },
+ window: 1024,
+ want: "You are a Wizard. Hello [img-0]",
+ },
+ {
+ name: "images truncated",
+ template: "{{ .System }} {{ .Prompt }}",
+ messages: []api.Message{
+ {Role: "system", Content: "You are a Wizard."},
+ {Role: "user", Content: "Hello", Images: []api.ImageData{[]byte("img1"), []byte("img2")}},
+ },
+ window: 1024,
+ want: "You are a Wizard. Hello [img-1]",
+ },
+ {
+ name: "empty list",
+ template: "{{ .System }} {{ .Prompt }}",
+ messages: []api.Message{},
+ window: 1024,
+ want: "",
+ },
+ {
+ name: "empty list default system",
+ system: "You are a Wizard.",
+ template: "{{ .System }} {{ .Prompt }}",
+ messages: []api.Message{},
+ window: 1024,
+ want: "You are a Wizard. ",
+ },
+ {
+ name: "empty user message",
+ system: "You are a Wizard.",
+ template: "{{ .System }} {{ .Prompt }}",
+ messages: []api.Message{
+ {Role: "user", Content: ""},
+ },
+ window: 1024,
+ want: "You are a Wizard. ",
+ },
+ {
+ name: "empty prompt",
+ template: "[INST] {{ if .System }}<>{{ .System }}<> {{ end }}{{ .Prompt }} [/INST] {{ .Response }} ",
+ messages: []api.Message{
+ {Role: "user", Content: ""},
+ },
+ window: 1024,
+ want: "",
+ },
+ }
+
+ encode := func(s string) ([]int, error) {
+ words := strings.Fields(s)
+ return make([]int, len(words)), nil
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ got, err := ChatPrompt(tc.template, tc.system, tc.messages, tc.window, encode)
+ if err != nil {
+ t.Errorf("error = %v", err)
+ }
+
+ if got != tc.want {
+ t.Errorf("got = %v, want %v", got, tc.want)
+ }
+ })
+ }
+}
diff --git a/server/routes.go b/server/routes.go
index 9abaea42..bd943ee1 100644
--- a/server/routes.go
+++ b/server/routes.go
@@ -214,6 +214,8 @@ func GenerateHandler(c *gin.Context) {
}
// an empty request loads the model
+ // note: for a short while template was used in lieu
+ // of `raw` mode so we need to check for it too
if req.Prompt == "" && req.Template == "" && req.System == "" {
c.JSON(http.StatusOK, api.GenerateResponse{
CreatedAt: time.Now().UTC(),
@@ -226,50 +228,48 @@ func GenerateHandler(c *gin.Context) {
checkpointLoaded := time.Now()
var prompt string
- var promptVars PromptVars
switch {
case req.Raw:
prompt = req.Prompt
case req.Prompt != "":
- if req.Template != "" {
- // override the default model template
- model.Template = req.Template
+ if req.Template == "" {
+ req.Template = model.Template
}
- var rebuild strings.Builder
+ if req.System == "" {
+ req.System = model.System
+ }
+
+ slog.Debug("generate handler", "prompt", req.Prompt)
+ slog.Debug("generate handler", "template", req.Template)
+ slog.Debug("generate handler", "system", req.System)
+
+ var sb strings.Builder
if req.Context != nil {
- // TODO: context is deprecated, at some point the context logic within this conditional should be removed
- prevCtx, err := loaded.runner.Decode(c.Request.Context(), req.Context)
+ prev, err := loaded.runner.Decode(c.Request.Context(), req.Context)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
- // Remove leading spaces from prevCtx if present
- prevCtx = strings.TrimPrefix(prevCtx, " ")
- rebuild.WriteString(prevCtx)
- }
- promptVars = PromptVars{
- System: req.System,
- Prompt: req.Prompt,
- First: len(req.Context) == 0,
- }
-
- if promptVars.System == "" {
- promptVars.System = model.System
+ sb.WriteString(prev)
}
+ // write image tags
+ // TODO: limit the number of images to fit in the context similar to the chat endpoint
for i := range req.Images {
- promptVars.Prompt += fmt.Sprintf(" [img-%d]", i)
+ req.Prompt += fmt.Sprintf(" [img-%d]", i)
}
- p, err := model.PreResponsePrompt(promptVars)
+ p, err := Prompt(req.Template, req.System, req.Prompt, "", true)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
- rebuild.WriteString(p)
- prompt = rebuild.String()
+
+ sb.WriteString(p)
+
+ prompt = sb.String()
}
slog.Debug("generate handler", "prompt", prompt)
@@ -308,19 +308,20 @@ func GenerateHandler(c *gin.Context) {
resp.LoadDuration = checkpointLoaded.Sub(checkpointStart)
if !req.Raw {
- // append the generated text to the history and template it if needed
- promptVars.Response = generated.String()
- result, err := model.PostResponseTemplate(promptVars)
+ p, err := Prompt(req.Template, req.System, req.Prompt, generated.String(), false)
+ if err != nil {
+ c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+ return
+ }
+
+ // TODO (jmorganca): encode() should not strip special tokens
+ tokens, err := loaded.runner.Encode(c.Request.Context(), p)
if err != nil {
ch <- gin.H{"error": err.Error()}
return
}
- embd, err := loaded.runner.Encode(c.Request.Context(), prompt+result)
- if err != nil {
- ch <- gin.H{"error": err.Error()}
- return
- }
- resp.Context = embd
+
+ resp.Context = append(req.Context, tokens...)
}
}
@@ -1090,6 +1091,20 @@ func streamResponse(c *gin.Context, ch chan any) {
})
}
+// ChatPrompt builds up a prompt from a series of messages for the currently `loaded` model
+func chatPrompt(ctx context.Context, messages []api.Message) (string, error) {
+ encode := func(s string) ([]int, error) {
+ return loaded.runner.Encode(ctx, s)
+ }
+
+ prompt, err := ChatPrompt(loaded.Model.Template, loaded.Model.System, messages, loaded.Options.NumCtx, encode)
+ if err != nil {
+ return "", err
+ }
+
+ return prompt, nil
+}
+
func ChatHandler(c *gin.Context) {
loaded.mu.Lock()
defer loaded.mu.Unlock()
@@ -1117,15 +1132,6 @@ func ChatHandler(c *gin.Context) {
return
}
- for _, msg := range req.Messages {
- for _, img := range msg.Images {
- if !isSupportedImageType(img) {
- c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "unsupported image format"})
- return
- }
- }
- }
-
model, err := GetModel(req.Model)
if err != nil {
var pErr *fs.PathError
@@ -1161,20 +1167,14 @@ func ChatHandler(c *gin.Context) {
checkpointLoaded := time.Now()
- chat, err := model.ChatPrompts(req.Messages)
+ prompt, err := chatPrompt(c.Request.Context(), req.Messages)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
- prompt, images, err := trimmedPrompt(c.Request.Context(), chat, model)
- if err != nil {
- c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
- return
- }
-
// an empty request loads the model
- if len(prompt) == 0 {
+ if len(req.Messages) == 0 || prompt == "" {
resp := api.ChatResponse{
CreatedAt: time.Now().UTC(),
Model: req.Model,
@@ -1185,7 +1185,24 @@ func ChatHandler(c *gin.Context) {
return
}
- slog.Debug("chat handler", "prompt", prompt)
+ // only send images that are in the prompt
+ var i int
+ var images []llm.ImageData
+ for _, m := range req.Messages {
+ for _, img := range m.Images {
+ if !isSupportedImageType(img) {
+ c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "unsupported image format"})
+ return
+ }
+
+ if strings.Contains(prompt, fmt.Sprintf("[img-%d]", i)) {
+ images = append(images, llm.ImageData{Data: img, ID: i})
+ }
+ i += 1
+ }
+ }
+
+ slog.Debug("chat handler", "prompt", prompt, "images", len(images))
ch := make(chan any)
@@ -1260,115 +1277,3 @@ func ChatHandler(c *gin.Context) {
streamResponse(c, ch)
}
-
-// promptInfo stores the variables used to template a prompt, and the token length of the resulting template for some model
-type promptInfo struct {
- vars PromptVars
- tokenLen int
-}
-
-// trimmedPrompt builds a prompt to send to a running model. It ensures the prompt fits within the max context length,
-// while preserving the most recent system message.
-func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string, []llm.ImageData, error) {
- if len(chat.Prompts) == 0 {
- return "", nil, nil
- }
-
- var promptsToAdd []promptInfo
- var totalTokenLength int
- var systemPromptIncluded bool
-
- var images []llm.ImageData
- // reverse iterate through the prompts to build the prompt string in a way that fits the max context length
- for i := len(chat.Prompts) - 1; i >= 0; i-- {
- prompt := chat.Prompts[i]
- promptText, err := promptString(model, prompt, i == len(chat.Prompts)-1)
- if err != nil {
- return "", nil, err
- }
-
- encodedTokens, err := loaded.runner.Encode(ctx, promptText)
- if err != nil {
- return "", nil, err
- }
-
- if totalTokenLength+len(encodedTokens) > loaded.NumCtx && i != len(chat.Prompts)-1 {
- break // reached max context length, stop adding more prompts
- }
-
- for j := range prompt.Images {
- if totalTokenLength+768 > loaded.NumCtx {
- // this decreases the token length but overestimating is fine
- prompt.Prompt = strings.ReplaceAll(prompt.Prompt, fmt.Sprintf(" [img-%d]", prompt.Images[j].ID), "")
- continue
- }
-
- totalTokenLength += 768
- images = append(images, prompt.Images[j])
- }
-
- totalTokenLength += len(encodedTokens)
- systemPromptIncluded = systemPromptIncluded || prompt.System != ""
- promptsToAdd = append(promptsToAdd, promptInfo{vars: prompt, tokenLen: len(encodedTokens)})
- }
-
- // ensure the system prompt is included, if not already
- if chat.LastSystem != "" && !systemPromptIncluded {
- var err error
- promptsToAdd, err = includeSystemPrompt(ctx, chat.LastSystem, totalTokenLength, promptsToAdd)
- if err != nil {
- return "", nil, err
- }
- }
-
- promptsToAdd[len(promptsToAdd)-1].vars.First = true
-
- // construct the final prompt string from the prompts which fit within the context window
- var result string
- for i, prompt := range promptsToAdd {
- promptText, err := promptString(model, prompt.vars, i == 0)
- if err != nil {
- return "", nil, err
- }
- result = promptText + result
- }
-
- return result, images, nil
-}
-
-// promptString applies the model template to the prompt
-func promptString(model *Model, vars PromptVars, isMostRecent bool) (string, error) {
- if isMostRecent {
- p, err := model.PreResponsePrompt(vars)
- if err != nil {
- return "", fmt.Errorf("pre-response template: %w", err)
- }
- return p, nil
- }
- p, err := Prompt(model.Template, vars)
- if err != nil {
- return "", err
- }
- return p, nil
-}
-
-// includeSystemPrompt adjusts the prompts to include the system prompt.
-func includeSystemPrompt(ctx context.Context, systemPrompt string, totalTokenLength int, promptsToAdd []promptInfo) ([]promptInfo, error) {
- systemTokens, err := loaded.runner.Encode(ctx, systemPrompt)
- if err != nil {
- return nil, err
- }
-
- for i := len(promptsToAdd) - 1; i >= 0; i-- {
- if totalTokenLength+len(systemTokens) <= loaded.NumCtx {
- promptsToAdd[i].vars.System = systemPrompt
- return promptsToAdd[:i+1], nil
- }
- totalTokenLength -= promptsToAdd[i].tokenLen
- }
-
- // if got here, system did not fit anywhere, so return the most recent prompt with the system message set
- recent := promptsToAdd[len(promptsToAdd)-1]
- recent.vars.System = systemPrompt
- return []promptInfo{recent}, nil
-}
diff --git a/server/routes_test.go b/server/routes_test.go
index 2a0308b8..9cf96f10 100644
--- a/server/routes_test.go
+++ b/server/routes_test.go
@@ -241,237 +241,6 @@ func Test_Routes(t *testing.T) {
}
}
-func Test_ChatPrompt(t *testing.T) {
- tests := []struct {
- name string
- template string
- chat *ChatHistory
- numCtx int
- runner MockLLM
- want string
- wantErr string
- }{
- {
- name: "Single Message",
- template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
- chat: &ChatHistory{
- Prompts: []PromptVars{
- {
- System: "You are a Wizard.",
- Prompt: "What are the potion ingredients?",
- First: true,
- },
- },
- LastSystem: "You are a Wizard.",
- },
- numCtx: 1,
- runner: MockLLM{
- encoding: []int{1}, // fit the ctxLen
- },
- want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]",
- },
- {
- name: "First Message",
- template: "[INST] {{if .First}}Hello!{{end}} {{ .System }} {{ .Prompt }} [/INST]",
- chat: &ChatHistory{
- Prompts: []PromptVars{
- {
- System: "You are a Wizard.",
- Prompt: "What are the potion ingredients?",
- Response: "eye of newt",
- First: true,
- },
- {
- Prompt: "Anything else?",
- },
- },
- LastSystem: "You are a Wizard.",
- },
- numCtx: 2,
- runner: MockLLM{
- encoding: []int{1}, // fit the ctxLen
- },
- want: "[INST] Hello! You are a Wizard. What are the potion ingredients? [/INST]eye of newt[INST] Anything else? [/INST]",
- },
- {
- name: "Message History",
- template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
- chat: &ChatHistory{
- Prompts: []PromptVars{
- {
- System: "You are a Wizard.",
- Prompt: "What are the potion ingredients?",
- Response: "sugar",
- First: true,
- },
- {
- Prompt: "Anything else?",
- },
- },
- LastSystem: "You are a Wizard.",
- },
- numCtx: 4,
- runner: MockLLM{
- encoding: []int{1}, // fit the ctxLen, 1 for each message
- },
- want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]sugar[INST] Anything else? [/INST]",
- },
- {
- name: "Assistant Only",
- template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
- chat: &ChatHistory{
- Prompts: []PromptVars{
- {
- Response: "everything nice",
- First: true,
- },
- },
- },
- numCtx: 1,
- runner: MockLLM{
- encoding: []int{1},
- },
- want: "[INST] [/INST]everything nice",
- },
- {
- name: "Message History Truncated, No System",
- template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
- chat: &ChatHistory{
- Prompts: []PromptVars{
- {
- Prompt: "What are the potion ingredients?",
- Response: "sugar",
- First: true,
- },
- {
- Prompt: "Anything else?",
- Response: "spice",
- },
- {
- Prompt: "... and?",
- },
- },
- },
- numCtx: 2, // only 1 message from history and most recent message
- runner: MockLLM{
- encoding: []int{1},
- },
- want: "[INST] Anything else? [/INST]spice[INST] ... and? [/INST]",
- },
- {
- name: "System is Preserved when Truncated",
- template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
- chat: &ChatHistory{
- Prompts: []PromptVars{
- {
- Prompt: "What are the magic words?",
- Response: "abracadabra",
- },
- {
- Prompt: "What is the spell for invisibility?",
- },
- },
- LastSystem: "You are a wizard.",
- },
- numCtx: 2,
- runner: MockLLM{
- encoding: []int{1},
- },
- want: "[INST] You are a wizard. What is the spell for invisibility? [/INST]",
- },
- {
- name: "System is Preserved when Length Exceeded",
- template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
- chat: &ChatHistory{
- Prompts: []PromptVars{
- {
- Prompt: "What are the magic words?",
- Response: "abracadabra",
- },
- {
- Prompt: "What is the spell for invisibility?",
- },
- },
- LastSystem: "You are a wizard.",
- },
- numCtx: 1,
- runner: MockLLM{
- encoding: []int{1},
- },
- want: "[INST] You are a wizard. What is the spell for invisibility? [/INST]",
- },
- {
- name: "First is Preserved when Truncated",
- template: "[INST] {{ if .First }}{{ .System }} {{ end }}{{ .Prompt }} [/INST]",
-
- chat: &ChatHistory{
- Prompts: []PromptVars{
- // first message omitted for test
- {
- Prompt: "Do you have a magic hat?",
- Response: "Of course.",
- },
- {
- Prompt: "What is the spell for invisibility?",
- },
- },
- LastSystem: "You are a wizard.",
- },
- numCtx: 3, // two most recent messages and room for system message
- runner: MockLLM{
- encoding: []int{1},
- },
- want: "[INST] You are a wizard. Do you have a magic hat? [/INST]Of course.[INST] What is the spell for invisibility? [/INST]",
- },
- {
- name: "Most recent message is returned when longer than ctxLen",
- template: "[INST] {{ .Prompt }} [/INST]",
-
- chat: &ChatHistory{
- Prompts: []PromptVars{
- {
- Prompt: "What is the spell for invisibility?",
- First: true,
- },
- },
- },
- numCtx: 1, // two most recent messages
- runner: MockLLM{
- encoding: []int{1, 2},
- },
- want: "[INST] What is the spell for invisibility? [/INST]",
- },
- }
-
- for _, testCase := range tests {
- tt := testCase
- m := &Model{
- Template: tt.template,
- }
- t.Run(tt.name, func(t *testing.T) {
- loaded.runner = &tt.runner
- loaded.Options = &api.Options{
- Runner: api.Runner{
- NumCtx: tt.numCtx,
- },
- }
- // TODO: add tests for trimming images
- got, _, err := trimmedPrompt(context.Background(), tt.chat, m)
- if tt.wantErr != "" {
- if err == nil {
- t.Errorf("ChatPrompt() expected error, got nil")
- }
- if !strings.Contains(err.Error(), tt.wantErr) {
- t.Errorf("ChatPrompt() error = %v, wantErr %v", err, tt.wantErr)
- }
- }
- if got != tt.want {
- t.Errorf("ChatPrompt() got = %v, want %v", got, tt.want)
- }
- })
- }
-}
-
type MockLLM struct {
encoding []int
}