Fix issues with templating prompt in chat mode (#2460)

This commit is contained in:
Jeffrey Morgan 2024-02-12 15:06:57 -08:00 committed by GitHub
parent 939c60473f
commit 48a273f80b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 538 additions and 1013 deletions

View file

@ -86,7 +86,7 @@ There are two ways to view `Modelfile`s underlying the models in [ollama.com/lib
# FROM llama2:13b
FROM /root/.ollama/models/blobs/sha256:123abc
TEMPLATE """[INST] {{ if and .First .System }}<<SYS>>{{ .System }}<</SYS>>
TEMPLATE """[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>>
{{ end }}{{ .Prompt }} [/INST] """
SYSTEM """"""
@ -154,31 +154,23 @@ PARAMETER <parameter> <parametervalue>
### TEMPLATE
`TEMPLATE` of the full prompt template to be passed into the model. It may include (optionally) a system message and a user's prompt. This is used to create a full custom prompt, and syntax may be model specific. You can usually find the template for a given model in the readme for that model.
`TEMPLATE` of the full prompt template to be passed into the model. It may include (optionally) a system message, a user's message and the response from the model. Note: syntax may be model specific. Templates use Go [template syntax](https://pkg.go.dev/text/template).
#### Template Variables
| Variable | Description |
| ----------------- | ------------------------------------------------------------------------------------------------------------- |
| `{{ .System }}` | The system message used to specify custom behavior, this must also be set in the Modelfile as an instruction. |
| `{{ .Prompt }}` | The incoming prompt, this is not specified in the model file and will be set based on input. |
| `{{ .Response }}` | The response from the LLM, if not specified response is appended to the end of the template. |
| `{{ .First }}` | A boolean value used to render specific template information for the first generation of a session. |
| ----------------- | --------------------------------------------------------------------------------------------- |
| `{{ .System }}` | The system message used to specify custom behavior. |
| `{{ .Prompt }}` | The user prompt message. |
| `{{ .Response }}` | The response from the model. When generating a response, text after this variable is omitted. |
```modelfile
TEMPLATE """
{{- if .First }}
### System:
{{ .System }}
{{- end }}
### User:
{{ .Prompt }}
### Response:
```
TEMPLATE """{{ if .System }}<|im_start|>system
{{ .System }}<|im_end|>
{{ end }}{{ if .Prompt }}<|im_start|>user
{{ .Prompt }}<|im_end|>
{{ end }}<|im_start|>assistant
"""
SYSTEM """<system message>"""
```
### SYSTEM

View file

@ -19,7 +19,6 @@ import (
"strconv"
"strings"
"text/template"
"text/template/parse"
"golang.org/x/exp/slices"
@ -58,162 +57,6 @@ type Message struct {
Content string `json:"content"`
}
type PromptVars struct {
System string
Prompt string
Response string
First bool
Images []llm.ImageData
}
// extractParts extracts the parts of the template before and after the {{.Response}} node.
func extractParts(tmplStr string) (pre string, post string, err error) {
tmpl, err := template.New("").Parse(tmplStr)
if err != nil {
return "", "", err
}
var foundResponse bool
for _, node := range tmpl.Tree.Root.Nodes {
if node.Type() == parse.NodeAction && node.String() == "{{.Response}}" {
foundResponse = true
}
if !foundResponse {
pre += node.String()
} else {
post += node.String()
}
}
return pre, post, nil
}
func Prompt(promptTemplate string, p PromptVars) (string, error) {
var prompt strings.Builder
// Use the "missingkey=zero" option to handle missing variables without panicking
tmpl, err := template.New("").Option("missingkey=zero").Parse(promptTemplate)
if err != nil {
return "", err
}
vars := map[string]any{
"System": p.System,
"Prompt": p.Prompt,
"Response": p.Response,
"First": p.First,
}
var sb strings.Builder
if err := tmpl.Execute(&sb, vars); err != nil {
return "", err
}
prompt.WriteString(sb.String())
if !strings.Contains(prompt.String(), p.Response) {
// if the response is not in the prompt template, append it to the end
prompt.WriteString(p.Response)
}
return prompt.String(), nil
}
// PreResponsePrompt returns the prompt before the response tag
func (m *Model) PreResponsePrompt(p PromptVars) (string, error) {
pre, _, err := extractParts(m.Template)
if err != nil {
return "", err
}
return Prompt(pre, p)
}
// PostResponseTemplate returns the template after the response tag
func (m *Model) PostResponseTemplate(p PromptVars) (string, error) {
if p.System == "" {
// use the default system prompt for this model if one is not specified
p.System = m.System
}
_, post, err := extractParts(m.Template)
if err != nil {
return "", err
}
if post == "" {
// if there is no post-response template, return the provided response
return p.Response, nil
}
return Prompt(post, p)
}
type ChatHistory struct {
Prompts []PromptVars
LastSystem string
}
// ChatPrompts returns a list of formatted chat prompts from a list of messages
func (m *Model) ChatPrompts(msgs []api.Message) (*ChatHistory, error) {
// build the prompt from the list of messages
lastSystem := m.System
currentVars := PromptVars{
First: true,
System: m.System,
}
prompts := []PromptVars{}
var images []llm.ImageData
for _, msg := range msgs {
switch strings.ToLower(msg.Role) {
case "system":
// if this is the first message it overrides the system prompt in the modelfile
if !currentVars.First && currentVars.System != "" {
prompts = append(prompts, currentVars)
currentVars = PromptVars{}
}
currentVars.System = msg.Content
lastSystem = msg.Content
case "user":
if currentVars.Prompt != "" {
prompts = append(prompts, currentVars)
currentVars = PromptVars{}
}
currentVars.Prompt = msg.Content
if len(m.ProjectorPaths) > 0 {
for i := range msg.Images {
id := len(images) + i
currentVars.Prompt += fmt.Sprintf(" [img-%d]", id)
currentVars.Images = append(currentVars.Images, llm.ImageData{
ID: id,
Data: msg.Images[i],
})
}
images = append(images, currentVars.Images...)
}
case "assistant":
currentVars.Response = msg.Content
prompts = append(prompts, currentVars)
currentVars = PromptVars{}
default:
return nil, fmt.Errorf("invalid role: %s, role must be one of [system, user, assistant]", msg.Role)
}
}
// Append the last set of vars if they are non-empty
if currentVars.Prompt != "" || currentVars.System != "" {
prompts = append(prompts, currentVars)
}
return &ChatHistory{
Prompts: prompts,
LastSystem: lastSystem,
}, nil
}
type ManifestV2 struct {
SchemaVersion int `json:"schemaVersion"`
MediaType string `json:"mediaType"`

View file

@ -1,442 +0,0 @@
package server
import (
"bytes"
"strings"
"testing"
"github.com/jmorganca/ollama/api"
)
func TestPrompt(t *testing.T) {
tests := []struct {
name string
template string
vars PromptVars
want string
wantErr bool
}{
{
name: "System Prompt",
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
vars: PromptVars{
System: "You are a Wizard.",
Prompt: "What are the potion ingredients?",
},
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]",
},
{
name: "System Prompt with Response",
template: "[INST] {{ .System }} {{ .Prompt }} [/INST] {{ .Response }}",
vars: PromptVars{
System: "You are a Wizard.",
Prompt: "What are the potion ingredients?",
Response: "I don't know.",
},
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST] I don't know.",
},
{
name: "Conditional Logic Nodes",
template: "[INST] {{if .First}}Hello!{{end}} {{ .System }} {{ .Prompt }} [/INST] {{ .Response }}",
vars: PromptVars{
First: true,
System: "You are a Wizard.",
Prompt: "What are the potion ingredients?",
Response: "I don't know.",
},
want: "[INST] Hello! You are a Wizard. What are the potion ingredients? [/INST] I don't know.",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := Prompt(tt.template, tt.vars)
if (err != nil) != tt.wantErr {
t.Errorf("Prompt() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("Prompt() got = %v, want %v", got, tt.want)
}
})
}
}
func TestModel_PreResponsePrompt(t *testing.T) {
tests := []struct {
name string
template string
vars PromptVars
want string
wantErr bool
}{
{
name: "No Response in Template",
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
vars: PromptVars{
System: "You are a Wizard.",
Prompt: "What are the potion ingredients?",
},
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]",
},
{
name: "Response in Template",
template: "[INST] {{ .System }} {{ .Prompt }} [/INST] {{ .Response }}",
vars: PromptVars{
System: "You are a Wizard.",
Prompt: "What are the potion ingredients?",
},
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST] ",
},
{
name: "Response in Template with Trailing Formatting",
template: "<|im_start|>user\n{{ .Prompt }}<|im_end|><|im_start|>assistant\n{{ .Response }}<|im_end|>",
vars: PromptVars{
Prompt: "What are the potion ingredients?",
},
want: "<|im_start|>user\nWhat are the potion ingredients?<|im_end|><|im_start|>assistant\n",
},
{
name: "Response in Template with Alternative Formatting",
template: "<|im_start|>user\n{{.Prompt}}<|im_end|><|im_start|>assistant\n{{.Response}}<|im_end|>",
vars: PromptVars{
Prompt: "What are the potion ingredients?",
},
want: "<|im_start|>user\nWhat are the potion ingredients?<|im_end|><|im_start|>assistant\n",
},
}
for _, tt := range tests {
m := Model{Template: tt.template}
t.Run(tt.name, func(t *testing.T) {
got, err := m.PreResponsePrompt(tt.vars)
if (err != nil) != tt.wantErr {
t.Errorf("PreResponsePrompt() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("PreResponsePrompt() got = %v, want %v", got, tt.want)
}
})
}
}
func TestModel_PostResponsePrompt(t *testing.T) {
tests := []struct {
name string
template string
vars PromptVars
want string
wantErr bool
}{
{
name: "No Response in Template",
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
vars: PromptVars{
Response: "I don't know.",
},
want: "I don't know.",
},
{
name: "Response in Template",
template: "[INST] {{ .System }} {{ .Prompt }} [/INST] {{ .Response }}",
vars: PromptVars{
Response: "I don't know.",
},
want: "I don't know.",
},
{
name: "Response in Template with Trailing Formatting",
template: "<|im_start|>user\n{{ .Prompt }}<|im_end|><|im_start|>assistant\n{{ .Response }}<|im_end|>",
vars: PromptVars{
Response: "I don't know.",
},
want: "I don't know.<|im_end|>",
},
{
name: "Response in Template with Alternative Formatting",
template: "<|im_start|>user\n{{.Prompt}}<|im_end|><|im_start|>assistant\n{{.Response}}<|im_end|>",
vars: PromptVars{
Response: "I don't know.",
},
want: "I don't know.<|im_end|>",
},
}
for _, tt := range tests {
m := Model{Template: tt.template}
t.Run(tt.name, func(t *testing.T) {
got, err := m.PostResponseTemplate(tt.vars)
if (err != nil) != tt.wantErr {
t.Errorf("PostResponseTemplate() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("PostResponseTemplate() got = %v, want %v", got, tt.want)
}
})
}
}
func TestModel_PreResponsePrompt_PostResponsePrompt(t *testing.T) {
tests := []struct {
name string
template string
preVars PromptVars
postVars PromptVars
want string
wantErr bool
}{
{
name: "Response in Template",
template: "<|im_start|>user\n{{.Prompt}}<|im_end|><|im_start|>assistant\n{{.Response}}<|im_end|>",
preVars: PromptVars{
Prompt: "What are the potion ingredients?",
},
postVars: PromptVars{
Prompt: "What are the potion ingredients?",
Response: "Sugar.",
},
want: "<|im_start|>user\nWhat are the potion ingredients?<|im_end|><|im_start|>assistant\nSugar.<|im_end|>",
},
{
name: "No Response in Template",
template: "<|im_start|>user\n{{.Prompt}}<|im_end|><|im_start|>assistant\n",
preVars: PromptVars{
Prompt: "What are the potion ingredients?",
},
postVars: PromptVars{
Prompt: "What are the potion ingredients?",
Response: "Spice.",
},
want: "<|im_start|>user\nWhat are the potion ingredients?<|im_end|><|im_start|>assistant\nSpice.",
},
}
for _, tt := range tests {
m := Model{Template: tt.template}
t.Run(tt.name, func(t *testing.T) {
pre, err := m.PreResponsePrompt(tt.preVars)
if (err != nil) != tt.wantErr {
t.Errorf("PreResponsePrompt() error = %v, wantErr %v", err, tt.wantErr)
return
}
post, err := m.PostResponseTemplate(tt.postVars)
if err != nil {
t.Errorf("PostResponseTemplate() error = %v, wantErr %v", err, tt.wantErr)
return
}
result := pre + post
if result != tt.want {
t.Errorf("Prompt() got = %v, want %v", result, tt.want)
}
})
}
}
func chatHistoryEqual(a, b ChatHistory) bool {
if len(a.Prompts) != len(b.Prompts) {
return false
}
for i, v := range a.Prompts {
if v.First != b.Prompts[i].First {
return false
}
if v.Response != b.Prompts[i].Response {
return false
}
if v.Prompt != b.Prompts[i].Prompt {
return false
}
if v.System != b.Prompts[i].System {
return false
}
if len(v.Images) != len(b.Prompts[i].Images) {
return false
}
for j, img := range v.Images {
if img.ID != b.Prompts[i].Images[j].ID {
return false
}
if !bytes.Equal(img.Data, b.Prompts[i].Images[j].Data) {
return false
}
}
}
return a.LastSystem == b.LastSystem
}
func TestChat(t *testing.T) {
tests := []struct {
name string
model Model
msgs []api.Message
want ChatHistory
wantErr string
}{
{
name: "Single Message",
model: Model{
Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
},
msgs: []api.Message{
{
Role: "system",
Content: "You are a Wizard.",
},
{
Role: "user",
Content: "What are the potion ingredients?",
},
},
want: ChatHistory{
Prompts: []PromptVars{
{
System: "You are a Wizard.",
Prompt: "What are the potion ingredients?",
First: true,
},
},
LastSystem: "You are a Wizard.",
},
},
{
name: "Message History",
model: Model{
Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
},
msgs: []api.Message{
{
Role: "system",
Content: "You are a Wizard.",
},
{
Role: "user",
Content: "What are the potion ingredients?",
},
{
Role: "assistant",
Content: "sugar",
},
{
Role: "user",
Content: "Anything else?",
},
},
want: ChatHistory{
Prompts: []PromptVars{
{
System: "You are a Wizard.",
Prompt: "What are the potion ingredients?",
Response: "sugar",
First: true,
},
{
Prompt: "Anything else?",
},
},
LastSystem: "You are a Wizard.",
},
},
{
name: "Assistant Only",
model: Model{
Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
},
msgs: []api.Message{
{
Role: "assistant",
Content: "everything nice",
},
},
want: ChatHistory{
Prompts: []PromptVars{
{
Response: "everything nice",
First: true,
},
},
},
},
{
name: "Last system message is preserved from modelfile",
model: Model{
Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
System: "You are Mojo Jojo.",
},
msgs: []api.Message{
{
Role: "user",
Content: "hi",
},
},
want: ChatHistory{
Prompts: []PromptVars{
{
System: "You are Mojo Jojo.",
Prompt: "hi",
First: true,
},
},
LastSystem: "You are Mojo Jojo.",
},
},
{
name: "Last system message is preserved from messages",
model: Model{
Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
System: "You are Mojo Jojo.",
},
msgs: []api.Message{
{
Role: "system",
Content: "You are Professor Utonium.",
},
},
want: ChatHistory{
Prompts: []PromptVars{
{
System: "You are Professor Utonium.",
First: true,
},
},
LastSystem: "You are Professor Utonium.",
},
},
{
name: "Invalid Role",
msgs: []api.Message{
{
Role: "not-a-role",
Content: "howdy",
},
},
wantErr: "invalid role: not-a-role",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := tt.model.ChatPrompts(tt.msgs)
if tt.wantErr != "" {
if err == nil {
t.Errorf("ChatPrompt() expected error, got nil")
}
if !strings.Contains(err.Error(), tt.wantErr) {
t.Errorf("ChatPrompt() error = %v, wantErr %v", err, tt.wantErr)
}
return
}
if !chatHistoryEqual(*got, tt.want) {
t.Errorf("ChatPrompt() got = %#v, want %#v", got, tt.want)
}
})
}
}

224
server/prompt.go Normal file
View file

@ -0,0 +1,224 @@
package server
import (
"fmt"
"log/slog"
"strings"
"text/template"
"text/template/parse"
"github.com/jmorganca/ollama/api"
)
// isResponseNode checks if the node contains .Response
func isResponseNode(node *parse.ActionNode) bool {
for _, cmd := range node.Pipe.Cmds {
for _, arg := range cmd.Args {
if fieldNode, ok := arg.(*parse.FieldNode); ok && len(fieldNode.Ident) > 0 {
if fieldNode.Ident[0] == "Response" {
return true
}
}
}
}
return false
}
// formatTemplateForResponse formats the template AST to:
// 1. remove all nodes after the first .Response (if generate=true)
// 2. add a .Response node to the end if it doesn't exist
// TODO(jmorganca): this should recursively cut the template before the first .Response
func formatTemplateForResponse(tmpl *template.Template, generate bool) {
var found bool
for i, node := range tmpl.Tree.Root.Nodes {
if actionNode, ok := node.(*parse.ActionNode); ok {
if isResponseNode(actionNode) {
found = true
if generate {
tmpl.Tree.Root.Nodes = tmpl.Tree.Root.Nodes[:i+1]
break
}
}
}
}
if !found {
// add the response node if it doesn't exist
responseFieldNode := &parse.FieldNode{NodeType: parse.NodeField, Ident: []string{"Response"}}
responsePipeNode := &parse.PipeNode{NodeType: parse.NodePipe, Cmds: []*parse.CommandNode{{NodeType: parse.NodeCommand, Args: []parse.Node{responseFieldNode}}}}
responseActionNode := &parse.ActionNode{NodeType: parse.NodeAction, Pipe: responsePipeNode}
tmpl.Tree.Root.Nodes = append(tmpl.Tree.Root.Nodes, responseActionNode)
}
}
// Prompt renders a prompt from a template. If generate is set to true,
// the response and parts of the template following it are not rendered
func Prompt(tmpl, system, prompt, response string, generate bool) (string, error) {
parsed, err := template.New("").Option("missingkey=zero").Parse(tmpl)
if err != nil {
return "", err
}
formatTemplateForResponse(parsed, generate)
vars := map[string]any{
"System": system,
"Prompt": prompt,
"Response": response,
}
var sb strings.Builder
if err := parsed.Execute(&sb, vars); err != nil {
return "", err
}
return sb.String(), nil
}
func countTokens(tmpl string, system string, prompt string, response string, encode func(string) ([]int, error)) (int, error) {
rendered, err := Prompt(tmpl, system, prompt, response, false)
if err != nil {
return 0, err
}
tokens, err := encode(rendered)
if err != nil {
slog.Error("failed to encode prompt", "err", err)
return 0, err
}
return len(tokens), err
}
// ChatPrompt builds up a prompt from a series of messages, truncating based on context window size
func ChatPrompt(tmpl string, system string, messages []api.Message, window int, encode func(string) ([]int, error)) (string, error) {
type prompt struct {
System string
Prompt string
Response string
images []int
tokens int
}
var p prompt
// Set the first system prompt to the model's system prompt
if system != "" {
p.System = system
}
// iterate through messages to build up {system,user,response} prompts
var imgId int
var prompts []prompt
for _, msg := range messages {
switch strings.ToLower(msg.Role) {
case "system":
if p.System != "" || p.Prompt != "" || p.Response != "" {
prompts = append(prompts, p)
p = prompt{}
}
p.System = msg.Content
case "user":
if p.Prompt != "" || p.Response != "" {
prompts = append(prompts, p)
p = prompt{}
}
p.Prompt = msg.Content
for range msg.Images {
p.Prompt += fmt.Sprintf(" [img-%d]", imgId)
p.images = append(p.images, imgId)
imgId += 1
}
case "assistant":
if p.Response != "" {
prompts = append(prompts, p)
p = prompt{}
}
p.Response = msg.Content
default:
return "", fmt.Errorf("invalid role: %s, role must be one of [system, user, assistant]", msg.Role)
}
}
// add final prompt
if p.System != "" || p.Prompt != "" || p.Response != "" {
prompts = append(prompts, p)
}
// calculate token lengths for each prompt, estimating 768 tokens per images
for i, p := range prompts {
tokens, err := countTokens(tmpl, p.System, p.Prompt, p.Response, encode)
if err != nil {
return "", err
}
prompts[i].tokens = tokens + len(prompts[i].images)*768
}
// truncate images and prompts starting from the beginning of the list
// until either one prompt remains or the total tokens fits the context window
// TODO (jmorganca): this doesn't account for the context window room required for the response
for {
var required int
for _, p := range prompts {
required += p.tokens
}
required += 1 // for bos token
if required <= window {
slog.Debug("prompt now fits in context window", "required", required, "window", window)
break
}
prompt := &prompts[0]
if len(prompt.images) > 1 {
img := prompt.images[0]
slog.Debug("prompt longer than context window, removing image", "id", img, "required", required, "window", window)
prompt.images = prompt.images[1:]
prompt.Prompt = strings.Replace(prompt.Prompt, fmt.Sprintf(" [img-%d]", img), "", 1)
prompt.tokens -= 768
continue
}
if len(prompts) > 1 {
slog.Debug("required tokens longer than context window, removing first prompt", "prompt", prompts[0].tokens, "required", required, "window", window)
system := prompt.System
prompts = prompts[1:]
if system != "" && prompts[0].System == "" {
prompts[0].System = system
tokens, err := countTokens(tmpl, prompts[0].System, prompts[0].Prompt, prompts[0].Response, encode)
if err != nil {
return "", err
}
prompts[0].tokens = tokens + len(prompts[0].images)*768
}
continue
}
// stop truncating if there's only one prompt left
break
}
var sb strings.Builder
for i, p := range prompts {
// last prompt should leave the response unrendered (for completion)
rendered, err := Prompt(tmpl, p.System, p.Prompt, p.Response, i == len(prompts)-1)
if err != nil {
return "", err
}
sb.WriteString(rendered)
}
return sb.String(), nil
}

234
server/prompt_test.go Normal file
View file

@ -0,0 +1,234 @@
package server
import (
"strings"
"testing"
"github.com/jmorganca/ollama/api"
)
func TestPrompt(t *testing.T) {
tests := []struct {
name string
template string
system string
prompt string
response string
generate bool
want string
}{
{
name: "simple prompt",
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
system: "You are a Wizard.",
prompt: "What are the potion ingredients?",
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]",
},
{
name: "implicit response",
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
system: "You are a Wizard.",
prompt: "What are the potion ingredients?",
response: "I don't know.",
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]I don't know.",
},
{
name: "response",
template: "[INST] {{ .System }} {{ .Prompt }} [/INST] {{ .Response }}",
system: "You are a Wizard.",
prompt: "What are the potion ingredients?",
response: "I don't know.",
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST] I don't know.",
},
{
name: "cut",
template: "<system>{{ .System }}</system><user>{{ .Prompt }}</user><assistant>{{ .Response }}</assistant>",
system: "You are a Wizard.",
prompt: "What are the potion ingredients?",
response: "I don't know.",
generate: true,
want: "<system>You are a Wizard.</system><user>What are the potion ingredients?</user><assistant>I don't know.",
},
{
name: "nocut",
template: "<system>{{ .System }}</system><user>{{ .Prompt }}</user><assistant>{{ .Response }}</assistant>",
system: "You are a Wizard.",
prompt: "What are the potion ingredients?",
response: "I don't know.",
want: "<system>You are a Wizard.</system><user>What are the potion ingredients?</user><assistant>I don't know.</assistant>",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got, err := Prompt(tc.template, tc.system, tc.prompt, tc.response, tc.generate)
if err != nil {
t.Errorf("error = %v", err)
}
if got != tc.want {
t.Errorf("got = %v, want %v", got, tc.want)
}
})
}
}
func TestChatPrompt(t *testing.T) {
tests := []struct {
name string
template string
system string
messages []api.Message
window int
want string
}{
{
name: "simple prompt",
template: "[INST] {{ .Prompt }} [/INST]",
messages: []api.Message{
{Role: "user", Content: "Hello"},
},
window: 1024,
want: "[INST] Hello [/INST]",
},
{
name: "with default system message",
system: "You are a Wizard.",
template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST]",
messages: []api.Message{
{Role: "user", Content: "Hello"},
},
window: 1024,
want: "[INST] <<SYS>>You are a Wizard.<</SYS>> Hello [/INST]",
},
{
name: "with system message",
template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST]",
messages: []api.Message{
{Role: "system", Content: "You are a Wizard."},
{Role: "user", Content: "Hello"},
},
window: 1024,
want: "[INST] <<SYS>>You are a Wizard.<</SYS>> Hello [/INST]",
},
{
name: "with response",
template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST] {{ .Response }}",
messages: []api.Message{
{Role: "system", Content: "You are a Wizard."},
{Role: "user", Content: "Hello"},
{Role: "assistant", Content: "I am?"},
},
window: 1024,
want: "[INST] <<SYS>>You are a Wizard.<</SYS>> Hello [/INST] I am?",
},
{
name: "with implicit response",
template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST]",
messages: []api.Message{
{Role: "system", Content: "You are a Wizard."},
{Role: "user", Content: "Hello"},
{Role: "assistant", Content: "I am?"},
},
window: 1024,
want: "[INST] <<SYS>>You are a Wizard.<</SYS>> Hello [/INST]I am?",
},
{
name: "with conversation",
template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST] {{ .Response }} ",
messages: []api.Message{
{Role: "system", Content: "You are a Wizard."},
{Role: "user", Content: "What are the potion ingredients?"},
{Role: "assistant", Content: "sugar"},
{Role: "user", Content: "Anything else?"},
},
window: 1024,
want: "[INST] <<SYS>>You are a Wizard.<</SYS>> What are the potion ingredients? [/INST] sugar [INST] Anything else? [/INST] ",
},
{
name: "with truncation",
template: "{{ .System }} {{ .Prompt }} {{ .Response }} ",
messages: []api.Message{
{Role: "system", Content: "You are a Wizard."},
{Role: "user", Content: "Hello"},
{Role: "assistant", Content: "I am?"},
{Role: "user", Content: "Why is the sky blue?"},
{Role: "assistant", Content: "The sky is blue from rayleigh scattering"},
},
window: 10,
want: "You are a Wizard. Why is the sky blue? The sky is blue from rayleigh scattering",
},
{
name: "images",
template: "{{ .System }} {{ .Prompt }}",
messages: []api.Message{
{Role: "system", Content: "You are a Wizard."},
{Role: "user", Content: "Hello", Images: []api.ImageData{[]byte("base64")}},
},
window: 1024,
want: "You are a Wizard. Hello [img-0]",
},
{
name: "images truncated",
template: "{{ .System }} {{ .Prompt }}",
messages: []api.Message{
{Role: "system", Content: "You are a Wizard."},
{Role: "user", Content: "Hello", Images: []api.ImageData{[]byte("img1"), []byte("img2")}},
},
window: 1024,
want: "You are a Wizard. Hello [img-1]",
},
{
name: "empty list",
template: "{{ .System }} {{ .Prompt }}",
messages: []api.Message{},
window: 1024,
want: "",
},
{
name: "empty list default system",
system: "You are a Wizard.",
template: "{{ .System }} {{ .Prompt }}",
messages: []api.Message{},
window: 1024,
want: "You are a Wizard. ",
},
{
name: "empty user message",
system: "You are a Wizard.",
template: "{{ .System }} {{ .Prompt }}",
messages: []api.Message{
{Role: "user", Content: ""},
},
window: 1024,
want: "You are a Wizard. ",
},
{
name: "empty prompt",
template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST] {{ .Response }} ",
messages: []api.Message{
{Role: "user", Content: ""},
},
window: 1024,
want: "",
},
}
encode := func(s string) ([]int, error) {
words := strings.Fields(s)
return make([]int, len(words)), nil
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got, err := ChatPrompt(tc.template, tc.system, tc.messages, tc.window, encode)
if err != nil {
t.Errorf("error = %v", err)
}
if got != tc.want {
t.Errorf("got = %v, want %v", got, tc.want)
}
})
}
}

View file

@ -214,6 +214,8 @@ func GenerateHandler(c *gin.Context) {
}
// an empty request loads the model
// note: for a short while template was used in lieu
// of `raw` mode so we need to check for it too
if req.Prompt == "" && req.Template == "" && req.System == "" {
c.JSON(http.StatusOK, api.GenerateResponse{
CreatedAt: time.Now().UTC(),
@ -226,50 +228,48 @@ func GenerateHandler(c *gin.Context) {
checkpointLoaded := time.Now()
var prompt string
var promptVars PromptVars
switch {
case req.Raw:
prompt = req.Prompt
case req.Prompt != "":
if req.Template != "" {
// override the default model template
model.Template = req.Template
if req.Template == "" {
req.Template = model.Template
}
var rebuild strings.Builder
if req.System == "" {
req.System = model.System
}
slog.Debug("generate handler", "prompt", req.Prompt)
slog.Debug("generate handler", "template", req.Template)
slog.Debug("generate handler", "system", req.System)
var sb strings.Builder
if req.Context != nil {
// TODO: context is deprecated, at some point the context logic within this conditional should be removed
prevCtx, err := loaded.runner.Decode(c.Request.Context(), req.Context)
prev, err := loaded.runner.Decode(c.Request.Context(), req.Context)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// Remove leading spaces from prevCtx if present
prevCtx = strings.TrimPrefix(prevCtx, " ")
rebuild.WriteString(prevCtx)
}
promptVars = PromptVars{
System: req.System,
Prompt: req.Prompt,
First: len(req.Context) == 0,
}
if promptVars.System == "" {
promptVars.System = model.System
sb.WriteString(prev)
}
// write image tags
// TODO: limit the number of images to fit in the context similar to the chat endpoint
for i := range req.Images {
promptVars.Prompt += fmt.Sprintf(" [img-%d]", i)
req.Prompt += fmt.Sprintf(" [img-%d]", i)
}
p, err := model.PreResponsePrompt(promptVars)
p, err := Prompt(req.Template, req.System, req.Prompt, "", true)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
rebuild.WriteString(p)
prompt = rebuild.String()
sb.WriteString(p)
prompt = sb.String()
}
slog.Debug("generate handler", "prompt", prompt)
@ -308,19 +308,20 @@ func GenerateHandler(c *gin.Context) {
resp.LoadDuration = checkpointLoaded.Sub(checkpointStart)
if !req.Raw {
// append the generated text to the history and template it if needed
promptVars.Response = generated.String()
result, err := model.PostResponseTemplate(promptVars)
p, err := Prompt(req.Template, req.System, req.Prompt, generated.String(), false)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// TODO (jmorganca): encode() should not strip special tokens
tokens, err := loaded.runner.Encode(c.Request.Context(), p)
if err != nil {
ch <- gin.H{"error": err.Error()}
return
}
embd, err := loaded.runner.Encode(c.Request.Context(), prompt+result)
if err != nil {
ch <- gin.H{"error": err.Error()}
return
}
resp.Context = embd
resp.Context = append(req.Context, tokens...)
}
}
@ -1090,6 +1091,20 @@ func streamResponse(c *gin.Context, ch chan any) {
})
}
// ChatPrompt builds up a prompt from a series of messages for the currently `loaded` model
func chatPrompt(ctx context.Context, messages []api.Message) (string, error) {
encode := func(s string) ([]int, error) {
return loaded.runner.Encode(ctx, s)
}
prompt, err := ChatPrompt(loaded.Model.Template, loaded.Model.System, messages, loaded.Options.NumCtx, encode)
if err != nil {
return "", err
}
return prompt, nil
}
func ChatHandler(c *gin.Context) {
loaded.mu.Lock()
defer loaded.mu.Unlock()
@ -1117,15 +1132,6 @@ func ChatHandler(c *gin.Context) {
return
}
for _, msg := range req.Messages {
for _, img := range msg.Images {
if !isSupportedImageType(img) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "unsupported image format"})
return
}
}
}
model, err := GetModel(req.Model)
if err != nil {
var pErr *fs.PathError
@ -1161,20 +1167,14 @@ func ChatHandler(c *gin.Context) {
checkpointLoaded := time.Now()
chat, err := model.ChatPrompts(req.Messages)
prompt, err := chatPrompt(c.Request.Context(), req.Messages)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
prompt, images, err := trimmedPrompt(c.Request.Context(), chat, model)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// an empty request loads the model
if len(prompt) == 0 {
if len(req.Messages) == 0 || prompt == "" {
resp := api.ChatResponse{
CreatedAt: time.Now().UTC(),
Model: req.Model,
@ -1185,7 +1185,24 @@ func ChatHandler(c *gin.Context) {
return
}
slog.Debug("chat handler", "prompt", prompt)
// only send images that are in the prompt
var i int
var images []llm.ImageData
for _, m := range req.Messages {
for _, img := range m.Images {
if !isSupportedImageType(img) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "unsupported image format"})
return
}
if strings.Contains(prompt, fmt.Sprintf("[img-%d]", i)) {
images = append(images, llm.ImageData{Data: img, ID: i})
}
i += 1
}
}
slog.Debug("chat handler", "prompt", prompt, "images", len(images))
ch := make(chan any)
@ -1260,115 +1277,3 @@ func ChatHandler(c *gin.Context) {
streamResponse(c, ch)
}
// promptInfo stores the variables used to template a prompt, and the token length of the resulting template for some model
type promptInfo struct {
vars PromptVars
tokenLen int
}
// trimmedPrompt builds a prompt to send to a running model. It ensures the prompt fits within the max context length,
// while preserving the most recent system message.
func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string, []llm.ImageData, error) {
if len(chat.Prompts) == 0 {
return "", nil, nil
}
var promptsToAdd []promptInfo
var totalTokenLength int
var systemPromptIncluded bool
var images []llm.ImageData
// reverse iterate through the prompts to build the prompt string in a way that fits the max context length
for i := len(chat.Prompts) - 1; i >= 0; i-- {
prompt := chat.Prompts[i]
promptText, err := promptString(model, prompt, i == len(chat.Prompts)-1)
if err != nil {
return "", nil, err
}
encodedTokens, err := loaded.runner.Encode(ctx, promptText)
if err != nil {
return "", nil, err
}
if totalTokenLength+len(encodedTokens) > loaded.NumCtx && i != len(chat.Prompts)-1 {
break // reached max context length, stop adding more prompts
}
for j := range prompt.Images {
if totalTokenLength+768 > loaded.NumCtx {
// this decreases the token length but overestimating is fine
prompt.Prompt = strings.ReplaceAll(prompt.Prompt, fmt.Sprintf(" [img-%d]", prompt.Images[j].ID), "")
continue
}
totalTokenLength += 768
images = append(images, prompt.Images[j])
}
totalTokenLength += len(encodedTokens)
systemPromptIncluded = systemPromptIncluded || prompt.System != ""
promptsToAdd = append(promptsToAdd, promptInfo{vars: prompt, tokenLen: len(encodedTokens)})
}
// ensure the system prompt is included, if not already
if chat.LastSystem != "" && !systemPromptIncluded {
var err error
promptsToAdd, err = includeSystemPrompt(ctx, chat.LastSystem, totalTokenLength, promptsToAdd)
if err != nil {
return "", nil, err
}
}
promptsToAdd[len(promptsToAdd)-1].vars.First = true
// construct the final prompt string from the prompts which fit within the context window
var result string
for i, prompt := range promptsToAdd {
promptText, err := promptString(model, prompt.vars, i == 0)
if err != nil {
return "", nil, err
}
result = promptText + result
}
return result, images, nil
}
// promptString applies the model template to the prompt
func promptString(model *Model, vars PromptVars, isMostRecent bool) (string, error) {
if isMostRecent {
p, err := model.PreResponsePrompt(vars)
if err != nil {
return "", fmt.Errorf("pre-response template: %w", err)
}
return p, nil
}
p, err := Prompt(model.Template, vars)
if err != nil {
return "", err
}
return p, nil
}
// includeSystemPrompt adjusts the prompts to include the system prompt.
func includeSystemPrompt(ctx context.Context, systemPrompt string, totalTokenLength int, promptsToAdd []promptInfo) ([]promptInfo, error) {
systemTokens, err := loaded.runner.Encode(ctx, systemPrompt)
if err != nil {
return nil, err
}
for i := len(promptsToAdd) - 1; i >= 0; i-- {
if totalTokenLength+len(systemTokens) <= loaded.NumCtx {
promptsToAdd[i].vars.System = systemPrompt
return promptsToAdd[:i+1], nil
}
totalTokenLength -= promptsToAdd[i].tokenLen
}
// if got here, system did not fit anywhere, so return the most recent prompt with the system message set
recent := promptsToAdd[len(promptsToAdd)-1]
recent.vars.System = systemPrompt
return []promptInfo{recent}, nil
}

View file

@ -241,237 +241,6 @@ func Test_Routes(t *testing.T) {
}
}
func Test_ChatPrompt(t *testing.T) {
tests := []struct {
name string
template string
chat *ChatHistory
numCtx int
runner MockLLM
want string
wantErr string
}{
{
name: "Single Message",
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
chat: &ChatHistory{
Prompts: []PromptVars{
{
System: "You are a Wizard.",
Prompt: "What are the potion ingredients?",
First: true,
},
},
LastSystem: "You are a Wizard.",
},
numCtx: 1,
runner: MockLLM{
encoding: []int{1}, // fit the ctxLen
},
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]",
},
{
name: "First Message",
template: "[INST] {{if .First}}Hello!{{end}} {{ .System }} {{ .Prompt }} [/INST]",
chat: &ChatHistory{
Prompts: []PromptVars{
{
System: "You are a Wizard.",
Prompt: "What are the potion ingredients?",
Response: "eye of newt",
First: true,
},
{
Prompt: "Anything else?",
},
},
LastSystem: "You are a Wizard.",
},
numCtx: 2,
runner: MockLLM{
encoding: []int{1}, // fit the ctxLen
},
want: "[INST] Hello! You are a Wizard. What are the potion ingredients? [/INST]eye of newt[INST] Anything else? [/INST]",
},
{
name: "Message History",
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
chat: &ChatHistory{
Prompts: []PromptVars{
{
System: "You are a Wizard.",
Prompt: "What are the potion ingredients?",
Response: "sugar",
First: true,
},
{
Prompt: "Anything else?",
},
},
LastSystem: "You are a Wizard.",
},
numCtx: 4,
runner: MockLLM{
encoding: []int{1}, // fit the ctxLen, 1 for each message
},
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]sugar[INST] Anything else? [/INST]",
},
{
name: "Assistant Only",
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
chat: &ChatHistory{
Prompts: []PromptVars{
{
Response: "everything nice",
First: true,
},
},
},
numCtx: 1,
runner: MockLLM{
encoding: []int{1},
},
want: "[INST] [/INST]everything nice",
},
{
name: "Message History Truncated, No System",
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
chat: &ChatHistory{
Prompts: []PromptVars{
{
Prompt: "What are the potion ingredients?",
Response: "sugar",
First: true,
},
{
Prompt: "Anything else?",
Response: "spice",
},
{
Prompt: "... and?",
},
},
},
numCtx: 2, // only 1 message from history and most recent message
runner: MockLLM{
encoding: []int{1},
},
want: "[INST] Anything else? [/INST]spice[INST] ... and? [/INST]",
},
{
name: "System is Preserved when Truncated",
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
chat: &ChatHistory{
Prompts: []PromptVars{
{
Prompt: "What are the magic words?",
Response: "abracadabra",
},
{
Prompt: "What is the spell for invisibility?",
},
},
LastSystem: "You are a wizard.",
},
numCtx: 2,
runner: MockLLM{
encoding: []int{1},
},
want: "[INST] You are a wizard. What is the spell for invisibility? [/INST]",
},
{
name: "System is Preserved when Length Exceeded",
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
chat: &ChatHistory{
Prompts: []PromptVars{
{
Prompt: "What are the magic words?",
Response: "abracadabra",
},
{
Prompt: "What is the spell for invisibility?",
},
},
LastSystem: "You are a wizard.",
},
numCtx: 1,
runner: MockLLM{
encoding: []int{1},
},
want: "[INST] You are a wizard. What is the spell for invisibility? [/INST]",
},
{
name: "First is Preserved when Truncated",
template: "[INST] {{ if .First }}{{ .System }} {{ end }}{{ .Prompt }} [/INST]",
chat: &ChatHistory{
Prompts: []PromptVars{
// first message omitted for test
{
Prompt: "Do you have a magic hat?",
Response: "Of course.",
},
{
Prompt: "What is the spell for invisibility?",
},
},
LastSystem: "You are a wizard.",
},
numCtx: 3, // two most recent messages and room for system message
runner: MockLLM{
encoding: []int{1},
},
want: "[INST] You are a wizard. Do you have a magic hat? [/INST]Of course.[INST] What is the spell for invisibility? [/INST]",
},
{
name: "Most recent message is returned when longer than ctxLen",
template: "[INST] {{ .Prompt }} [/INST]",
chat: &ChatHistory{
Prompts: []PromptVars{
{
Prompt: "What is the spell for invisibility?",
First: true,
},
},
},
numCtx: 1, // two most recent messages
runner: MockLLM{
encoding: []int{1, 2},
},
want: "[INST] What is the spell for invisibility? [/INST]",
},
}
for _, testCase := range tests {
tt := testCase
m := &Model{
Template: tt.template,
}
t.Run(tt.name, func(t *testing.T) {
loaded.runner = &tt.runner
loaded.Options = &api.Options{
Runner: api.Runner{
NumCtx: tt.numCtx,
},
}
// TODO: add tests for trimming images
got, _, err := trimmedPrompt(context.Background(), tt.chat, m)
if tt.wantErr != "" {
if err == nil {
t.Errorf("ChatPrompt() expected error, got nil")
}
if !strings.Contains(err.Error(), tt.wantErr) {
t.Errorf("ChatPrompt() error = %v, wantErr %v", err, tt.wantErr)
}
}
if got != tt.want {
t.Errorf("ChatPrompt() got = %v, want %v", got, tt.want)
}
})
}
}
type MockLLM struct {
encoding []int
}