From 48a273f80ba1f0fc7a5ed8881c0dc14fc664ea4e Mon Sep 17 00:00:00 2001 From: Jeffrey Morgan Date: Mon, 12 Feb 2024 15:06:57 -0800 Subject: [PATCH] Fix issues with templating prompt in chat mode (#2460) --- docs/modelfile.md | 34 ++-- server/images.go | 157 --------------- server/images_test.go | 442 ------------------------------------------ server/prompt.go | 224 +++++++++++++++++++++ server/prompt_test.go | 234 ++++++++++++++++++++++ server/routes.go | 229 +++++++--------------- server/routes_test.go | 231 ---------------------- 7 files changed, 538 insertions(+), 1013 deletions(-) delete mode 100644 server/images_test.go create mode 100644 server/prompt.go create mode 100644 server/prompt_test.go diff --git a/docs/modelfile.md b/docs/modelfile.md index b92af782..1d0030f4 100644 --- a/docs/modelfile.md +++ b/docs/modelfile.md @@ -86,7 +86,7 @@ There are two ways to view `Modelfile`s underlying the models in [ollama.com/lib # FROM llama2:13b FROM /root/.ollama/models/blobs/sha256:123abc - TEMPLATE """[INST] {{ if and .First .System }}<>{{ .System }}<> + TEMPLATE """[INST] {{ if .System }}<>{{ .System }}<> {{ end }}{{ .Prompt }} [/INST] """ SYSTEM """""" @@ -154,31 +154,23 @@ PARAMETER ### TEMPLATE -`TEMPLATE` of the full prompt template to be passed into the model. It may include (optionally) a system message and a user's prompt. This is used to create a full custom prompt, and syntax may be model specific. You can usually find the template for a given model in the readme for that model. +`TEMPLATE` of the full prompt template to be passed into the model. It may include (optionally) a system message, a user's message and the response from the model. Note: syntax may be model specific. Templates use Go [template syntax](https://pkg.go.dev/text/template). #### Template Variables -| Variable | Description | -| ----------------- | ------------------------------------------------------------------------------------------------------------- | -| `{{ .System }}` | The system message used to specify custom behavior, this must also be set in the Modelfile as an instruction. | -| `{{ .Prompt }}` | The incoming prompt, this is not specified in the model file and will be set based on input. | -| `{{ .Response }}` | The response from the LLM, if not specified response is appended to the end of the template. | -| `{{ .First }}` | A boolean value used to render specific template information for the first generation of a session. | +| Variable | Description | +| ----------------- | --------------------------------------------------------------------------------------------- | +| `{{ .System }}` | The system message used to specify custom behavior. | +| `{{ .Prompt }}` | The user prompt message. | +| `{{ .Response }}` | The response from the model. When generating a response, text after this variable is omitted. | -```modelfile -TEMPLATE """ -{{- if .First }} -### System: -{{ .System }} -{{- end }} - -### User: -{{ .Prompt }} - -### Response: +``` +TEMPLATE """{{ if .System }}<|im_start|>system +{{ .System }}<|im_end|> +{{ end }}{{ if .Prompt }}<|im_start|>user +{{ .Prompt }}<|im_end|> +{{ end }}<|im_start|>assistant """ - -SYSTEM """""" ``` ### SYSTEM diff --git a/server/images.go b/server/images.go index fb1c48e1..55b68456 100644 --- a/server/images.go +++ b/server/images.go @@ -19,7 +19,6 @@ import ( "strconv" "strings" "text/template" - "text/template/parse" "golang.org/x/exp/slices" @@ -58,162 +57,6 @@ type Message struct { Content string `json:"content"` } -type PromptVars struct { - System string - Prompt string - Response string - First bool - Images []llm.ImageData -} - -// extractParts extracts the parts of the template before and after the {{.Response}} node. -func extractParts(tmplStr string) (pre string, post string, err error) { - tmpl, err := template.New("").Parse(tmplStr) - if err != nil { - return "", "", err - } - - var foundResponse bool - - for _, node := range tmpl.Tree.Root.Nodes { - if node.Type() == parse.NodeAction && node.String() == "{{.Response}}" { - foundResponse = true - } - if !foundResponse { - pre += node.String() - } else { - post += node.String() - } - } - - return pre, post, nil -} - -func Prompt(promptTemplate string, p PromptVars) (string, error) { - var prompt strings.Builder - // Use the "missingkey=zero" option to handle missing variables without panicking - tmpl, err := template.New("").Option("missingkey=zero").Parse(promptTemplate) - if err != nil { - return "", err - } - - vars := map[string]any{ - "System": p.System, - "Prompt": p.Prompt, - "Response": p.Response, - "First": p.First, - } - - var sb strings.Builder - if err := tmpl.Execute(&sb, vars); err != nil { - return "", err - } - prompt.WriteString(sb.String()) - - if !strings.Contains(prompt.String(), p.Response) { - // if the response is not in the prompt template, append it to the end - prompt.WriteString(p.Response) - } - - return prompt.String(), nil -} - -// PreResponsePrompt returns the prompt before the response tag -func (m *Model) PreResponsePrompt(p PromptVars) (string, error) { - pre, _, err := extractParts(m.Template) - if err != nil { - return "", err - } - - return Prompt(pre, p) -} - -// PostResponseTemplate returns the template after the response tag -func (m *Model) PostResponseTemplate(p PromptVars) (string, error) { - if p.System == "" { - // use the default system prompt for this model if one is not specified - p.System = m.System - } - _, post, err := extractParts(m.Template) - if err != nil { - return "", err - } - - if post == "" { - // if there is no post-response template, return the provided response - return p.Response, nil - } - - return Prompt(post, p) -} - -type ChatHistory struct { - Prompts []PromptVars - LastSystem string -} - -// ChatPrompts returns a list of formatted chat prompts from a list of messages -func (m *Model) ChatPrompts(msgs []api.Message) (*ChatHistory, error) { - // build the prompt from the list of messages - lastSystem := m.System - currentVars := PromptVars{ - First: true, - System: m.System, - } - - prompts := []PromptVars{} - var images []llm.ImageData - - for _, msg := range msgs { - switch strings.ToLower(msg.Role) { - case "system": - // if this is the first message it overrides the system prompt in the modelfile - if !currentVars.First && currentVars.System != "" { - prompts = append(prompts, currentVars) - currentVars = PromptVars{} - } - currentVars.System = msg.Content - lastSystem = msg.Content - case "user": - if currentVars.Prompt != "" { - prompts = append(prompts, currentVars) - currentVars = PromptVars{} - } - - currentVars.Prompt = msg.Content - - if len(m.ProjectorPaths) > 0 { - for i := range msg.Images { - id := len(images) + i - currentVars.Prompt += fmt.Sprintf(" [img-%d]", id) - currentVars.Images = append(currentVars.Images, llm.ImageData{ - ID: id, - Data: msg.Images[i], - }) - } - - images = append(images, currentVars.Images...) - } - case "assistant": - currentVars.Response = msg.Content - prompts = append(prompts, currentVars) - currentVars = PromptVars{} - default: - return nil, fmt.Errorf("invalid role: %s, role must be one of [system, user, assistant]", msg.Role) - } - } - - // Append the last set of vars if they are non-empty - if currentVars.Prompt != "" || currentVars.System != "" { - prompts = append(prompts, currentVars) - } - - return &ChatHistory{ - Prompts: prompts, - LastSystem: lastSystem, - }, nil -} - type ManifestV2 struct { SchemaVersion int `json:"schemaVersion"` MediaType string `json:"mediaType"` diff --git a/server/images_test.go b/server/images_test.go deleted file mode 100644 index 4c2a7cac..00000000 --- a/server/images_test.go +++ /dev/null @@ -1,442 +0,0 @@ -package server - -import ( - "bytes" - "strings" - "testing" - - "github.com/jmorganca/ollama/api" -) - -func TestPrompt(t *testing.T) { - tests := []struct { - name string - template string - vars PromptVars - want string - wantErr bool - }{ - { - name: "System Prompt", - template: "[INST] {{ .System }} {{ .Prompt }} [/INST]", - vars: PromptVars{ - System: "You are a Wizard.", - Prompt: "What are the potion ingredients?", - }, - want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]", - }, - { - name: "System Prompt with Response", - template: "[INST] {{ .System }} {{ .Prompt }} [/INST] {{ .Response }}", - vars: PromptVars{ - System: "You are a Wizard.", - Prompt: "What are the potion ingredients?", - Response: "I don't know.", - }, - want: "[INST] You are a Wizard. What are the potion ingredients? [/INST] I don't know.", - }, - { - name: "Conditional Logic Nodes", - template: "[INST] {{if .First}}Hello!{{end}} {{ .System }} {{ .Prompt }} [/INST] {{ .Response }}", - vars: PromptVars{ - First: true, - System: "You are a Wizard.", - Prompt: "What are the potion ingredients?", - Response: "I don't know.", - }, - want: "[INST] Hello! You are a Wizard. What are the potion ingredients? [/INST] I don't know.", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := Prompt(tt.template, tt.vars) - if (err != nil) != tt.wantErr { - t.Errorf("Prompt() error = %v, wantErr %v", err, tt.wantErr) - return - } - if got != tt.want { - t.Errorf("Prompt() got = %v, want %v", got, tt.want) - } - }) - } -} - -func TestModel_PreResponsePrompt(t *testing.T) { - tests := []struct { - name string - template string - vars PromptVars - want string - wantErr bool - }{ - { - name: "No Response in Template", - template: "[INST] {{ .System }} {{ .Prompt }} [/INST]", - vars: PromptVars{ - System: "You are a Wizard.", - Prompt: "What are the potion ingredients?", - }, - want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]", - }, - { - name: "Response in Template", - template: "[INST] {{ .System }} {{ .Prompt }} [/INST] {{ .Response }}", - vars: PromptVars{ - System: "You are a Wizard.", - Prompt: "What are the potion ingredients?", - }, - want: "[INST] You are a Wizard. What are the potion ingredients? [/INST] ", - }, - { - name: "Response in Template with Trailing Formatting", - template: "<|im_start|>user\n{{ .Prompt }}<|im_end|><|im_start|>assistant\n{{ .Response }}<|im_end|>", - vars: PromptVars{ - Prompt: "What are the potion ingredients?", - }, - want: "<|im_start|>user\nWhat are the potion ingredients?<|im_end|><|im_start|>assistant\n", - }, - { - name: "Response in Template with Alternative Formatting", - template: "<|im_start|>user\n{{.Prompt}}<|im_end|><|im_start|>assistant\n{{.Response}}<|im_end|>", - vars: PromptVars{ - Prompt: "What are the potion ingredients?", - }, - want: "<|im_start|>user\nWhat are the potion ingredients?<|im_end|><|im_start|>assistant\n", - }, - } - - for _, tt := range tests { - m := Model{Template: tt.template} - t.Run(tt.name, func(t *testing.T) { - got, err := m.PreResponsePrompt(tt.vars) - if (err != nil) != tt.wantErr { - t.Errorf("PreResponsePrompt() error = %v, wantErr %v", err, tt.wantErr) - return - } - if got != tt.want { - t.Errorf("PreResponsePrompt() got = %v, want %v", got, tt.want) - } - }) - } -} - -func TestModel_PostResponsePrompt(t *testing.T) { - tests := []struct { - name string - template string - vars PromptVars - want string - wantErr bool - }{ - { - name: "No Response in Template", - template: "[INST] {{ .System }} {{ .Prompt }} [/INST]", - vars: PromptVars{ - Response: "I don't know.", - }, - want: "I don't know.", - }, - { - name: "Response in Template", - template: "[INST] {{ .System }} {{ .Prompt }} [/INST] {{ .Response }}", - vars: PromptVars{ - Response: "I don't know.", - }, - want: "I don't know.", - }, - { - name: "Response in Template with Trailing Formatting", - template: "<|im_start|>user\n{{ .Prompt }}<|im_end|><|im_start|>assistant\n{{ .Response }}<|im_end|>", - vars: PromptVars{ - Response: "I don't know.", - }, - want: "I don't know.<|im_end|>", - }, - { - name: "Response in Template with Alternative Formatting", - template: "<|im_start|>user\n{{.Prompt}}<|im_end|><|im_start|>assistant\n{{.Response}}<|im_end|>", - vars: PromptVars{ - Response: "I don't know.", - }, - want: "I don't know.<|im_end|>", - }, - } - - for _, tt := range tests { - m := Model{Template: tt.template} - t.Run(tt.name, func(t *testing.T) { - got, err := m.PostResponseTemplate(tt.vars) - if (err != nil) != tt.wantErr { - t.Errorf("PostResponseTemplate() error = %v, wantErr %v", err, tt.wantErr) - return - } - if got != tt.want { - t.Errorf("PostResponseTemplate() got = %v, want %v", got, tt.want) - } - }) - } -} - -func TestModel_PreResponsePrompt_PostResponsePrompt(t *testing.T) { - tests := []struct { - name string - template string - preVars PromptVars - postVars PromptVars - want string - wantErr bool - }{ - { - name: "Response in Template", - template: "<|im_start|>user\n{{.Prompt}}<|im_end|><|im_start|>assistant\n{{.Response}}<|im_end|>", - preVars: PromptVars{ - Prompt: "What are the potion ingredients?", - }, - postVars: PromptVars{ - Prompt: "What are the potion ingredients?", - Response: "Sugar.", - }, - want: "<|im_start|>user\nWhat are the potion ingredients?<|im_end|><|im_start|>assistant\nSugar.<|im_end|>", - }, - { - name: "No Response in Template", - template: "<|im_start|>user\n{{.Prompt}}<|im_end|><|im_start|>assistant\n", - preVars: PromptVars{ - Prompt: "What are the potion ingredients?", - }, - postVars: PromptVars{ - Prompt: "What are the potion ingredients?", - Response: "Spice.", - }, - want: "<|im_start|>user\nWhat are the potion ingredients?<|im_end|><|im_start|>assistant\nSpice.", - }, - } - - for _, tt := range tests { - m := Model{Template: tt.template} - t.Run(tt.name, func(t *testing.T) { - pre, err := m.PreResponsePrompt(tt.preVars) - if (err != nil) != tt.wantErr { - t.Errorf("PreResponsePrompt() error = %v, wantErr %v", err, tt.wantErr) - return - } - post, err := m.PostResponseTemplate(tt.postVars) - if err != nil { - t.Errorf("PostResponseTemplate() error = %v, wantErr %v", err, tt.wantErr) - return - } - result := pre + post - if result != tt.want { - t.Errorf("Prompt() got = %v, want %v", result, tt.want) - } - }) - } -} - -func chatHistoryEqual(a, b ChatHistory) bool { - if len(a.Prompts) != len(b.Prompts) { - return false - } - for i, v := range a.Prompts { - - if v.First != b.Prompts[i].First { - return false - } - - if v.Response != b.Prompts[i].Response { - return false - } - - if v.Prompt != b.Prompts[i].Prompt { - return false - } - - if v.System != b.Prompts[i].System { - return false - } - - if len(v.Images) != len(b.Prompts[i].Images) { - return false - } - - for j, img := range v.Images { - if img.ID != b.Prompts[i].Images[j].ID { - return false - } - - if !bytes.Equal(img.Data, b.Prompts[i].Images[j].Data) { - return false - } - } - } - return a.LastSystem == b.LastSystem -} - -func TestChat(t *testing.T) { - tests := []struct { - name string - model Model - msgs []api.Message - want ChatHistory - wantErr string - }{ - { - name: "Single Message", - model: Model{ - Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]", - }, - msgs: []api.Message{ - { - Role: "system", - Content: "You are a Wizard.", - }, - { - Role: "user", - Content: "What are the potion ingredients?", - }, - }, - want: ChatHistory{ - Prompts: []PromptVars{ - { - System: "You are a Wizard.", - Prompt: "What are the potion ingredients?", - First: true, - }, - }, - LastSystem: "You are a Wizard.", - }, - }, - { - name: "Message History", - model: Model{ - Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]", - }, - msgs: []api.Message{ - { - Role: "system", - Content: "You are a Wizard.", - }, - { - Role: "user", - Content: "What are the potion ingredients?", - }, - { - Role: "assistant", - Content: "sugar", - }, - { - Role: "user", - Content: "Anything else?", - }, - }, - want: ChatHistory{ - Prompts: []PromptVars{ - { - System: "You are a Wizard.", - Prompt: "What are the potion ingredients?", - Response: "sugar", - First: true, - }, - { - Prompt: "Anything else?", - }, - }, - LastSystem: "You are a Wizard.", - }, - }, - { - name: "Assistant Only", - model: Model{ - Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]", - }, - msgs: []api.Message{ - { - Role: "assistant", - Content: "everything nice", - }, - }, - want: ChatHistory{ - Prompts: []PromptVars{ - { - Response: "everything nice", - First: true, - }, - }, - }, - }, - { - name: "Last system message is preserved from modelfile", - model: Model{ - Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]", - System: "You are Mojo Jojo.", - }, - msgs: []api.Message{ - { - Role: "user", - Content: "hi", - }, - }, - want: ChatHistory{ - Prompts: []PromptVars{ - { - System: "You are Mojo Jojo.", - Prompt: "hi", - First: true, - }, - }, - LastSystem: "You are Mojo Jojo.", - }, - }, - { - name: "Last system message is preserved from messages", - model: Model{ - Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]", - System: "You are Mojo Jojo.", - }, - msgs: []api.Message{ - { - Role: "system", - Content: "You are Professor Utonium.", - }, - }, - want: ChatHistory{ - Prompts: []PromptVars{ - { - System: "You are Professor Utonium.", - First: true, - }, - }, - LastSystem: "You are Professor Utonium.", - }, - }, - { - name: "Invalid Role", - msgs: []api.Message{ - { - Role: "not-a-role", - Content: "howdy", - }, - }, - wantErr: "invalid role: not-a-role", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := tt.model.ChatPrompts(tt.msgs) - if tt.wantErr != "" { - if err == nil { - t.Errorf("ChatPrompt() expected error, got nil") - } - if !strings.Contains(err.Error(), tt.wantErr) { - t.Errorf("ChatPrompt() error = %v, wantErr %v", err, tt.wantErr) - } - return - } - if !chatHistoryEqual(*got, tt.want) { - t.Errorf("ChatPrompt() got = %#v, want %#v", got, tt.want) - } - }) - } -} diff --git a/server/prompt.go b/server/prompt.go new file mode 100644 index 00000000..c83075d9 --- /dev/null +++ b/server/prompt.go @@ -0,0 +1,224 @@ +package server + +import ( + "fmt" + "log/slog" + "strings" + "text/template" + "text/template/parse" + + "github.com/jmorganca/ollama/api" +) + +// isResponseNode checks if the node contains .Response +func isResponseNode(node *parse.ActionNode) bool { + for _, cmd := range node.Pipe.Cmds { + for _, arg := range cmd.Args { + if fieldNode, ok := arg.(*parse.FieldNode); ok && len(fieldNode.Ident) > 0 { + if fieldNode.Ident[0] == "Response" { + return true + } + } + } + } + return false +} + +// formatTemplateForResponse formats the template AST to: +// 1. remove all nodes after the first .Response (if generate=true) +// 2. add a .Response node to the end if it doesn't exist +// TODO(jmorganca): this should recursively cut the template before the first .Response +func formatTemplateForResponse(tmpl *template.Template, generate bool) { + var found bool + for i, node := range tmpl.Tree.Root.Nodes { + if actionNode, ok := node.(*parse.ActionNode); ok { + if isResponseNode(actionNode) { + found = true + if generate { + tmpl.Tree.Root.Nodes = tmpl.Tree.Root.Nodes[:i+1] + break + } + } + } + } + + if !found { + // add the response node if it doesn't exist + responseFieldNode := &parse.FieldNode{NodeType: parse.NodeField, Ident: []string{"Response"}} + responsePipeNode := &parse.PipeNode{NodeType: parse.NodePipe, Cmds: []*parse.CommandNode{{NodeType: parse.NodeCommand, Args: []parse.Node{responseFieldNode}}}} + responseActionNode := &parse.ActionNode{NodeType: parse.NodeAction, Pipe: responsePipeNode} + tmpl.Tree.Root.Nodes = append(tmpl.Tree.Root.Nodes, responseActionNode) + } +} + +// Prompt renders a prompt from a template. If generate is set to true, +// the response and parts of the template following it are not rendered +func Prompt(tmpl, system, prompt, response string, generate bool) (string, error) { + parsed, err := template.New("").Option("missingkey=zero").Parse(tmpl) + if err != nil { + return "", err + } + + formatTemplateForResponse(parsed, generate) + + vars := map[string]any{ + "System": system, + "Prompt": prompt, + "Response": response, + } + + var sb strings.Builder + if err := parsed.Execute(&sb, vars); err != nil { + return "", err + } + + return sb.String(), nil +} + +func countTokens(tmpl string, system string, prompt string, response string, encode func(string) ([]int, error)) (int, error) { + rendered, err := Prompt(tmpl, system, prompt, response, false) + if err != nil { + return 0, err + } + + tokens, err := encode(rendered) + if err != nil { + slog.Error("failed to encode prompt", "err", err) + return 0, err + } + + return len(tokens), err +} + +// ChatPrompt builds up a prompt from a series of messages, truncating based on context window size +func ChatPrompt(tmpl string, system string, messages []api.Message, window int, encode func(string) ([]int, error)) (string, error) { + type prompt struct { + System string + Prompt string + Response string + + images []int + tokens int + } + + var p prompt + + // Set the first system prompt to the model's system prompt + if system != "" { + p.System = system + } + + // iterate through messages to build up {system,user,response} prompts + var imgId int + var prompts []prompt + for _, msg := range messages { + switch strings.ToLower(msg.Role) { + case "system": + if p.System != "" || p.Prompt != "" || p.Response != "" { + prompts = append(prompts, p) + p = prompt{} + } + + p.System = msg.Content + case "user": + if p.Prompt != "" || p.Response != "" { + prompts = append(prompts, p) + p = prompt{} + } + + p.Prompt = msg.Content + + for range msg.Images { + p.Prompt += fmt.Sprintf(" [img-%d]", imgId) + p.images = append(p.images, imgId) + imgId += 1 + } + case "assistant": + if p.Response != "" { + prompts = append(prompts, p) + p = prompt{} + } + + p.Response = msg.Content + default: + return "", fmt.Errorf("invalid role: %s, role must be one of [system, user, assistant]", msg.Role) + } + } + + // add final prompt + if p.System != "" || p.Prompt != "" || p.Response != "" { + prompts = append(prompts, p) + } + + // calculate token lengths for each prompt, estimating 768 tokens per images + for i, p := range prompts { + tokens, err := countTokens(tmpl, p.System, p.Prompt, p.Response, encode) + if err != nil { + return "", err + } + + prompts[i].tokens = tokens + len(prompts[i].images)*768 + } + + // truncate images and prompts starting from the beginning of the list + // until either one prompt remains or the total tokens fits the context window + // TODO (jmorganca): this doesn't account for the context window room required for the response + for { + var required int + for _, p := range prompts { + required += p.tokens + } + + required += 1 // for bos token + + if required <= window { + slog.Debug("prompt now fits in context window", "required", required, "window", window) + break + } + + prompt := &prompts[0] + + if len(prompt.images) > 1 { + img := prompt.images[0] + slog.Debug("prompt longer than context window, removing image", "id", img, "required", required, "window", window) + prompt.images = prompt.images[1:] + prompt.Prompt = strings.Replace(prompt.Prompt, fmt.Sprintf(" [img-%d]", img), "", 1) + prompt.tokens -= 768 + continue + } + + if len(prompts) > 1 { + slog.Debug("required tokens longer than context window, removing first prompt", "prompt", prompts[0].tokens, "required", required, "window", window) + system := prompt.System + prompts = prompts[1:] + + if system != "" && prompts[0].System == "" { + prompts[0].System = system + + tokens, err := countTokens(tmpl, prompts[0].System, prompts[0].Prompt, prompts[0].Response, encode) + if err != nil { + return "", err + } + + prompts[0].tokens = tokens + len(prompts[0].images)*768 + } + + continue + } + + // stop truncating if there's only one prompt left + break + } + + var sb strings.Builder + for i, p := range prompts { + // last prompt should leave the response unrendered (for completion) + rendered, err := Prompt(tmpl, p.System, p.Prompt, p.Response, i == len(prompts)-1) + if err != nil { + return "", err + } + sb.WriteString(rendered) + } + + return sb.String(), nil +} diff --git a/server/prompt_test.go b/server/prompt_test.go new file mode 100644 index 00000000..0ac8e314 --- /dev/null +++ b/server/prompt_test.go @@ -0,0 +1,234 @@ +package server + +import ( + "strings" + "testing" + + "github.com/jmorganca/ollama/api" +) + +func TestPrompt(t *testing.T) { + tests := []struct { + name string + template string + system string + prompt string + response string + generate bool + want string + }{ + { + name: "simple prompt", + template: "[INST] {{ .System }} {{ .Prompt }} [/INST]", + system: "You are a Wizard.", + prompt: "What are the potion ingredients?", + want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]", + }, + { + name: "implicit response", + template: "[INST] {{ .System }} {{ .Prompt }} [/INST]", + system: "You are a Wizard.", + prompt: "What are the potion ingredients?", + response: "I don't know.", + want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]I don't know.", + }, + { + name: "response", + template: "[INST] {{ .System }} {{ .Prompt }} [/INST] {{ .Response }}", + system: "You are a Wizard.", + prompt: "What are the potion ingredients?", + response: "I don't know.", + want: "[INST] You are a Wizard. What are the potion ingredients? [/INST] I don't know.", + }, + { + name: "cut", + template: "{{ .System }}{{ .Prompt }}{{ .Response }}", + system: "You are a Wizard.", + prompt: "What are the potion ingredients?", + response: "I don't know.", + generate: true, + want: "You are a Wizard.What are the potion ingredients?I don't know.", + }, + { + name: "nocut", + template: "{{ .System }}{{ .Prompt }}{{ .Response }}", + system: "You are a Wizard.", + prompt: "What are the potion ingredients?", + response: "I don't know.", + want: "You are a Wizard.What are the potion ingredients?I don't know.", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got, err := Prompt(tc.template, tc.system, tc.prompt, tc.response, tc.generate) + if err != nil { + t.Errorf("error = %v", err) + } + + if got != tc.want { + t.Errorf("got = %v, want %v", got, tc.want) + } + }) + } +} + +func TestChatPrompt(t *testing.T) { + tests := []struct { + name string + template string + system string + messages []api.Message + window int + want string + }{ + { + name: "simple prompt", + template: "[INST] {{ .Prompt }} [/INST]", + messages: []api.Message{ + {Role: "user", Content: "Hello"}, + }, + window: 1024, + want: "[INST] Hello [/INST]", + }, + { + name: "with default system message", + system: "You are a Wizard.", + template: "[INST] {{ if .System }}<>{{ .System }}<> {{ end }}{{ .Prompt }} [/INST]", + messages: []api.Message{ + {Role: "user", Content: "Hello"}, + }, + window: 1024, + want: "[INST] <>You are a Wizard.<> Hello [/INST]", + }, + { + name: "with system message", + template: "[INST] {{ if .System }}<>{{ .System }}<> {{ end }}{{ .Prompt }} [/INST]", + messages: []api.Message{ + {Role: "system", Content: "You are a Wizard."}, + {Role: "user", Content: "Hello"}, + }, + window: 1024, + want: "[INST] <>You are a Wizard.<> Hello [/INST]", + }, + { + name: "with response", + template: "[INST] {{ if .System }}<>{{ .System }}<> {{ end }}{{ .Prompt }} [/INST] {{ .Response }}", + messages: []api.Message{ + {Role: "system", Content: "You are a Wizard."}, + {Role: "user", Content: "Hello"}, + {Role: "assistant", Content: "I am?"}, + }, + window: 1024, + want: "[INST] <>You are a Wizard.<> Hello [/INST] I am?", + }, + { + name: "with implicit response", + template: "[INST] {{ if .System }}<>{{ .System }}<> {{ end }}{{ .Prompt }} [/INST]", + messages: []api.Message{ + {Role: "system", Content: "You are a Wizard."}, + {Role: "user", Content: "Hello"}, + {Role: "assistant", Content: "I am?"}, + }, + window: 1024, + want: "[INST] <>You are a Wizard.<> Hello [/INST]I am?", + }, + { + name: "with conversation", + template: "[INST] {{ if .System }}<>{{ .System }}<> {{ end }}{{ .Prompt }} [/INST] {{ .Response }} ", + messages: []api.Message{ + {Role: "system", Content: "You are a Wizard."}, + {Role: "user", Content: "What are the potion ingredients?"}, + {Role: "assistant", Content: "sugar"}, + {Role: "user", Content: "Anything else?"}, + }, + window: 1024, + want: "[INST] <>You are a Wizard.<> What are the potion ingredients? [/INST] sugar [INST] Anything else? [/INST] ", + }, + { + name: "with truncation", + template: "{{ .System }} {{ .Prompt }} {{ .Response }} ", + messages: []api.Message{ + {Role: "system", Content: "You are a Wizard."}, + {Role: "user", Content: "Hello"}, + {Role: "assistant", Content: "I am?"}, + {Role: "user", Content: "Why is the sky blue?"}, + {Role: "assistant", Content: "The sky is blue from rayleigh scattering"}, + }, + window: 10, + want: "You are a Wizard. Why is the sky blue? The sky is blue from rayleigh scattering", + }, + { + name: "images", + template: "{{ .System }} {{ .Prompt }}", + messages: []api.Message{ + {Role: "system", Content: "You are a Wizard."}, + {Role: "user", Content: "Hello", Images: []api.ImageData{[]byte("base64")}}, + }, + window: 1024, + want: "You are a Wizard. Hello [img-0]", + }, + { + name: "images truncated", + template: "{{ .System }} {{ .Prompt }}", + messages: []api.Message{ + {Role: "system", Content: "You are a Wizard."}, + {Role: "user", Content: "Hello", Images: []api.ImageData{[]byte("img1"), []byte("img2")}}, + }, + window: 1024, + want: "You are a Wizard. Hello [img-1]", + }, + { + name: "empty list", + template: "{{ .System }} {{ .Prompt }}", + messages: []api.Message{}, + window: 1024, + want: "", + }, + { + name: "empty list default system", + system: "You are a Wizard.", + template: "{{ .System }} {{ .Prompt }}", + messages: []api.Message{}, + window: 1024, + want: "You are a Wizard. ", + }, + { + name: "empty user message", + system: "You are a Wizard.", + template: "{{ .System }} {{ .Prompt }}", + messages: []api.Message{ + {Role: "user", Content: ""}, + }, + window: 1024, + want: "You are a Wizard. ", + }, + { + name: "empty prompt", + template: "[INST] {{ if .System }}<>{{ .System }}<> {{ end }}{{ .Prompt }} [/INST] {{ .Response }} ", + messages: []api.Message{ + {Role: "user", Content: ""}, + }, + window: 1024, + want: "", + }, + } + + encode := func(s string) ([]int, error) { + words := strings.Fields(s) + return make([]int, len(words)), nil + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got, err := ChatPrompt(tc.template, tc.system, tc.messages, tc.window, encode) + if err != nil { + t.Errorf("error = %v", err) + } + + if got != tc.want { + t.Errorf("got = %v, want %v", got, tc.want) + } + }) + } +} diff --git a/server/routes.go b/server/routes.go index 9abaea42..bd943ee1 100644 --- a/server/routes.go +++ b/server/routes.go @@ -214,6 +214,8 @@ func GenerateHandler(c *gin.Context) { } // an empty request loads the model + // note: for a short while template was used in lieu + // of `raw` mode so we need to check for it too if req.Prompt == "" && req.Template == "" && req.System == "" { c.JSON(http.StatusOK, api.GenerateResponse{ CreatedAt: time.Now().UTC(), @@ -226,50 +228,48 @@ func GenerateHandler(c *gin.Context) { checkpointLoaded := time.Now() var prompt string - var promptVars PromptVars switch { case req.Raw: prompt = req.Prompt case req.Prompt != "": - if req.Template != "" { - // override the default model template - model.Template = req.Template + if req.Template == "" { + req.Template = model.Template } - var rebuild strings.Builder + if req.System == "" { + req.System = model.System + } + + slog.Debug("generate handler", "prompt", req.Prompt) + slog.Debug("generate handler", "template", req.Template) + slog.Debug("generate handler", "system", req.System) + + var sb strings.Builder if req.Context != nil { - // TODO: context is deprecated, at some point the context logic within this conditional should be removed - prevCtx, err := loaded.runner.Decode(c.Request.Context(), req.Context) + prev, err := loaded.runner.Decode(c.Request.Context(), req.Context) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } - // Remove leading spaces from prevCtx if present - prevCtx = strings.TrimPrefix(prevCtx, " ") - rebuild.WriteString(prevCtx) - } - promptVars = PromptVars{ - System: req.System, - Prompt: req.Prompt, - First: len(req.Context) == 0, - } - - if promptVars.System == "" { - promptVars.System = model.System + sb.WriteString(prev) } + // write image tags + // TODO: limit the number of images to fit in the context similar to the chat endpoint for i := range req.Images { - promptVars.Prompt += fmt.Sprintf(" [img-%d]", i) + req.Prompt += fmt.Sprintf(" [img-%d]", i) } - p, err := model.PreResponsePrompt(promptVars) + p, err := Prompt(req.Template, req.System, req.Prompt, "", true) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } - rebuild.WriteString(p) - prompt = rebuild.String() + + sb.WriteString(p) + + prompt = sb.String() } slog.Debug("generate handler", "prompt", prompt) @@ -308,19 +308,20 @@ func GenerateHandler(c *gin.Context) { resp.LoadDuration = checkpointLoaded.Sub(checkpointStart) if !req.Raw { - // append the generated text to the history and template it if needed - promptVars.Response = generated.String() - result, err := model.PostResponseTemplate(promptVars) + p, err := Prompt(req.Template, req.System, req.Prompt, generated.String(), false) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + // TODO (jmorganca): encode() should not strip special tokens + tokens, err := loaded.runner.Encode(c.Request.Context(), p) if err != nil { ch <- gin.H{"error": err.Error()} return } - embd, err := loaded.runner.Encode(c.Request.Context(), prompt+result) - if err != nil { - ch <- gin.H{"error": err.Error()} - return - } - resp.Context = embd + + resp.Context = append(req.Context, tokens...) } } @@ -1090,6 +1091,20 @@ func streamResponse(c *gin.Context, ch chan any) { }) } +// ChatPrompt builds up a prompt from a series of messages for the currently `loaded` model +func chatPrompt(ctx context.Context, messages []api.Message) (string, error) { + encode := func(s string) ([]int, error) { + return loaded.runner.Encode(ctx, s) + } + + prompt, err := ChatPrompt(loaded.Model.Template, loaded.Model.System, messages, loaded.Options.NumCtx, encode) + if err != nil { + return "", err + } + + return prompt, nil +} + func ChatHandler(c *gin.Context) { loaded.mu.Lock() defer loaded.mu.Unlock() @@ -1117,15 +1132,6 @@ func ChatHandler(c *gin.Context) { return } - for _, msg := range req.Messages { - for _, img := range msg.Images { - if !isSupportedImageType(img) { - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "unsupported image format"}) - return - } - } - } - model, err := GetModel(req.Model) if err != nil { var pErr *fs.PathError @@ -1161,20 +1167,14 @@ func ChatHandler(c *gin.Context) { checkpointLoaded := time.Now() - chat, err := model.ChatPrompts(req.Messages) + prompt, err := chatPrompt(c.Request.Context(), req.Messages) if err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } - prompt, images, err := trimmedPrompt(c.Request.Context(), chat, model) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - // an empty request loads the model - if len(prompt) == 0 { + if len(req.Messages) == 0 || prompt == "" { resp := api.ChatResponse{ CreatedAt: time.Now().UTC(), Model: req.Model, @@ -1185,7 +1185,24 @@ func ChatHandler(c *gin.Context) { return } - slog.Debug("chat handler", "prompt", prompt) + // only send images that are in the prompt + var i int + var images []llm.ImageData + for _, m := range req.Messages { + for _, img := range m.Images { + if !isSupportedImageType(img) { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "unsupported image format"}) + return + } + + if strings.Contains(prompt, fmt.Sprintf("[img-%d]", i)) { + images = append(images, llm.ImageData{Data: img, ID: i}) + } + i += 1 + } + } + + slog.Debug("chat handler", "prompt", prompt, "images", len(images)) ch := make(chan any) @@ -1260,115 +1277,3 @@ func ChatHandler(c *gin.Context) { streamResponse(c, ch) } - -// promptInfo stores the variables used to template a prompt, and the token length of the resulting template for some model -type promptInfo struct { - vars PromptVars - tokenLen int -} - -// trimmedPrompt builds a prompt to send to a running model. It ensures the prompt fits within the max context length, -// while preserving the most recent system message. -func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string, []llm.ImageData, error) { - if len(chat.Prompts) == 0 { - return "", nil, nil - } - - var promptsToAdd []promptInfo - var totalTokenLength int - var systemPromptIncluded bool - - var images []llm.ImageData - // reverse iterate through the prompts to build the prompt string in a way that fits the max context length - for i := len(chat.Prompts) - 1; i >= 0; i-- { - prompt := chat.Prompts[i] - promptText, err := promptString(model, prompt, i == len(chat.Prompts)-1) - if err != nil { - return "", nil, err - } - - encodedTokens, err := loaded.runner.Encode(ctx, promptText) - if err != nil { - return "", nil, err - } - - if totalTokenLength+len(encodedTokens) > loaded.NumCtx && i != len(chat.Prompts)-1 { - break // reached max context length, stop adding more prompts - } - - for j := range prompt.Images { - if totalTokenLength+768 > loaded.NumCtx { - // this decreases the token length but overestimating is fine - prompt.Prompt = strings.ReplaceAll(prompt.Prompt, fmt.Sprintf(" [img-%d]", prompt.Images[j].ID), "") - continue - } - - totalTokenLength += 768 - images = append(images, prompt.Images[j]) - } - - totalTokenLength += len(encodedTokens) - systemPromptIncluded = systemPromptIncluded || prompt.System != "" - promptsToAdd = append(promptsToAdd, promptInfo{vars: prompt, tokenLen: len(encodedTokens)}) - } - - // ensure the system prompt is included, if not already - if chat.LastSystem != "" && !systemPromptIncluded { - var err error - promptsToAdd, err = includeSystemPrompt(ctx, chat.LastSystem, totalTokenLength, promptsToAdd) - if err != nil { - return "", nil, err - } - } - - promptsToAdd[len(promptsToAdd)-1].vars.First = true - - // construct the final prompt string from the prompts which fit within the context window - var result string - for i, prompt := range promptsToAdd { - promptText, err := promptString(model, prompt.vars, i == 0) - if err != nil { - return "", nil, err - } - result = promptText + result - } - - return result, images, nil -} - -// promptString applies the model template to the prompt -func promptString(model *Model, vars PromptVars, isMostRecent bool) (string, error) { - if isMostRecent { - p, err := model.PreResponsePrompt(vars) - if err != nil { - return "", fmt.Errorf("pre-response template: %w", err) - } - return p, nil - } - p, err := Prompt(model.Template, vars) - if err != nil { - return "", err - } - return p, nil -} - -// includeSystemPrompt adjusts the prompts to include the system prompt. -func includeSystemPrompt(ctx context.Context, systemPrompt string, totalTokenLength int, promptsToAdd []promptInfo) ([]promptInfo, error) { - systemTokens, err := loaded.runner.Encode(ctx, systemPrompt) - if err != nil { - return nil, err - } - - for i := len(promptsToAdd) - 1; i >= 0; i-- { - if totalTokenLength+len(systemTokens) <= loaded.NumCtx { - promptsToAdd[i].vars.System = systemPrompt - return promptsToAdd[:i+1], nil - } - totalTokenLength -= promptsToAdd[i].tokenLen - } - - // if got here, system did not fit anywhere, so return the most recent prompt with the system message set - recent := promptsToAdd[len(promptsToAdd)-1] - recent.vars.System = systemPrompt - return []promptInfo{recent}, nil -} diff --git a/server/routes_test.go b/server/routes_test.go index 2a0308b8..9cf96f10 100644 --- a/server/routes_test.go +++ b/server/routes_test.go @@ -241,237 +241,6 @@ func Test_Routes(t *testing.T) { } } -func Test_ChatPrompt(t *testing.T) { - tests := []struct { - name string - template string - chat *ChatHistory - numCtx int - runner MockLLM - want string - wantErr string - }{ - { - name: "Single Message", - template: "[INST] {{ .System }} {{ .Prompt }} [/INST]", - chat: &ChatHistory{ - Prompts: []PromptVars{ - { - System: "You are a Wizard.", - Prompt: "What are the potion ingredients?", - First: true, - }, - }, - LastSystem: "You are a Wizard.", - }, - numCtx: 1, - runner: MockLLM{ - encoding: []int{1}, // fit the ctxLen - }, - want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]", - }, - { - name: "First Message", - template: "[INST] {{if .First}}Hello!{{end}} {{ .System }} {{ .Prompt }} [/INST]", - chat: &ChatHistory{ - Prompts: []PromptVars{ - { - System: "You are a Wizard.", - Prompt: "What are the potion ingredients?", - Response: "eye of newt", - First: true, - }, - { - Prompt: "Anything else?", - }, - }, - LastSystem: "You are a Wizard.", - }, - numCtx: 2, - runner: MockLLM{ - encoding: []int{1}, // fit the ctxLen - }, - want: "[INST] Hello! You are a Wizard. What are the potion ingredients? [/INST]eye of newt[INST] Anything else? [/INST]", - }, - { - name: "Message History", - template: "[INST] {{ .System }} {{ .Prompt }} [/INST]", - chat: &ChatHistory{ - Prompts: []PromptVars{ - { - System: "You are a Wizard.", - Prompt: "What are the potion ingredients?", - Response: "sugar", - First: true, - }, - { - Prompt: "Anything else?", - }, - }, - LastSystem: "You are a Wizard.", - }, - numCtx: 4, - runner: MockLLM{ - encoding: []int{1}, // fit the ctxLen, 1 for each message - }, - want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]sugar[INST] Anything else? [/INST]", - }, - { - name: "Assistant Only", - template: "[INST] {{ .System }} {{ .Prompt }} [/INST]", - chat: &ChatHistory{ - Prompts: []PromptVars{ - { - Response: "everything nice", - First: true, - }, - }, - }, - numCtx: 1, - runner: MockLLM{ - encoding: []int{1}, - }, - want: "[INST] [/INST]everything nice", - }, - { - name: "Message History Truncated, No System", - template: "[INST] {{ .System }} {{ .Prompt }} [/INST]", - chat: &ChatHistory{ - Prompts: []PromptVars{ - { - Prompt: "What are the potion ingredients?", - Response: "sugar", - First: true, - }, - { - Prompt: "Anything else?", - Response: "spice", - }, - { - Prompt: "... and?", - }, - }, - }, - numCtx: 2, // only 1 message from history and most recent message - runner: MockLLM{ - encoding: []int{1}, - }, - want: "[INST] Anything else? [/INST]spice[INST] ... and? [/INST]", - }, - { - name: "System is Preserved when Truncated", - template: "[INST] {{ .System }} {{ .Prompt }} [/INST]", - chat: &ChatHistory{ - Prompts: []PromptVars{ - { - Prompt: "What are the magic words?", - Response: "abracadabra", - }, - { - Prompt: "What is the spell for invisibility?", - }, - }, - LastSystem: "You are a wizard.", - }, - numCtx: 2, - runner: MockLLM{ - encoding: []int{1}, - }, - want: "[INST] You are a wizard. What is the spell for invisibility? [/INST]", - }, - { - name: "System is Preserved when Length Exceeded", - template: "[INST] {{ .System }} {{ .Prompt }} [/INST]", - chat: &ChatHistory{ - Prompts: []PromptVars{ - { - Prompt: "What are the magic words?", - Response: "abracadabra", - }, - { - Prompt: "What is the spell for invisibility?", - }, - }, - LastSystem: "You are a wizard.", - }, - numCtx: 1, - runner: MockLLM{ - encoding: []int{1}, - }, - want: "[INST] You are a wizard. What is the spell for invisibility? [/INST]", - }, - { - name: "First is Preserved when Truncated", - template: "[INST] {{ if .First }}{{ .System }} {{ end }}{{ .Prompt }} [/INST]", - - chat: &ChatHistory{ - Prompts: []PromptVars{ - // first message omitted for test - { - Prompt: "Do you have a magic hat?", - Response: "Of course.", - }, - { - Prompt: "What is the spell for invisibility?", - }, - }, - LastSystem: "You are a wizard.", - }, - numCtx: 3, // two most recent messages and room for system message - runner: MockLLM{ - encoding: []int{1}, - }, - want: "[INST] You are a wizard. Do you have a magic hat? [/INST]Of course.[INST] What is the spell for invisibility? [/INST]", - }, - { - name: "Most recent message is returned when longer than ctxLen", - template: "[INST] {{ .Prompt }} [/INST]", - - chat: &ChatHistory{ - Prompts: []PromptVars{ - { - Prompt: "What is the spell for invisibility?", - First: true, - }, - }, - }, - numCtx: 1, // two most recent messages - runner: MockLLM{ - encoding: []int{1, 2}, - }, - want: "[INST] What is the spell for invisibility? [/INST]", - }, - } - - for _, testCase := range tests { - tt := testCase - m := &Model{ - Template: tt.template, - } - t.Run(tt.name, func(t *testing.T) { - loaded.runner = &tt.runner - loaded.Options = &api.Options{ - Runner: api.Runner{ - NumCtx: tt.numCtx, - }, - } - // TODO: add tests for trimming images - got, _, err := trimmedPrompt(context.Background(), tt.chat, m) - if tt.wantErr != "" { - if err == nil { - t.Errorf("ChatPrompt() expected error, got nil") - } - if !strings.Contains(err.Error(), tt.wantErr) { - t.Errorf("ChatPrompt() error = %v, wantErr %v", err, tt.wantErr) - } - } - if got != tt.want { - t.Errorf("ChatPrompt() got = %v, want %v", got, tt.want) - } - }) - } -} - type MockLLM struct { encoding []int }