Fix issues with templating prompt in chat mode (#2460)
This commit is contained in:
parent
939c60473f
commit
48a273f80b
7 changed files with 538 additions and 1013 deletions
|
@ -86,7 +86,7 @@ There are two ways to view `Modelfile`s underlying the models in [ollama.com/lib
|
||||||
# FROM llama2:13b
|
# FROM llama2:13b
|
||||||
|
|
||||||
FROM /root/.ollama/models/blobs/sha256:123abc
|
FROM /root/.ollama/models/blobs/sha256:123abc
|
||||||
TEMPLATE """[INST] {{ if and .First .System }}<<SYS>>{{ .System }}<</SYS>>
|
TEMPLATE """[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>>
|
||||||
|
|
||||||
{{ end }}{{ .Prompt }} [/INST] """
|
{{ end }}{{ .Prompt }} [/INST] """
|
||||||
SYSTEM """"""
|
SYSTEM """"""
|
||||||
|
@ -154,31 +154,23 @@ PARAMETER <parameter> <parametervalue>
|
||||||
|
|
||||||
### TEMPLATE
|
### TEMPLATE
|
||||||
|
|
||||||
`TEMPLATE` of the full prompt template to be passed into the model. It may include (optionally) a system message and a user's prompt. This is used to create a full custom prompt, and syntax may be model specific. You can usually find the template for a given model in the readme for that model.
|
`TEMPLATE` of the full prompt template to be passed into the model. It may include (optionally) a system message, a user's message and the response from the model. Note: syntax may be model specific. Templates use Go [template syntax](https://pkg.go.dev/text/template).
|
||||||
|
|
||||||
#### Template Variables
|
#### Template Variables
|
||||||
|
|
||||||
| Variable | Description |
|
| Variable | Description |
|
||||||
| ----------------- | ------------------------------------------------------------------------------------------------------------- |
|
| ----------------- | --------------------------------------------------------------------------------------------- |
|
||||||
| `{{ .System }}` | The system message used to specify custom behavior, this must also be set in the Modelfile as an instruction. |
|
| `{{ .System }}` | The system message used to specify custom behavior. |
|
||||||
| `{{ .Prompt }}` | The incoming prompt, this is not specified in the model file and will be set based on input. |
|
| `{{ .Prompt }}` | The user prompt message. |
|
||||||
| `{{ .Response }}` | The response from the LLM, if not specified response is appended to the end of the template. |
|
| `{{ .Response }}` | The response from the model. When generating a response, text after this variable is omitted. |
|
||||||
| `{{ .First }}` | A boolean value used to render specific template information for the first generation of a session. |
|
|
||||||
|
|
||||||
```modelfile
|
```
|
||||||
TEMPLATE """
|
TEMPLATE """{{ if .System }}<|im_start|>system
|
||||||
{{- if .First }}
|
{{ .System }}<|im_end|>
|
||||||
### System:
|
{{ end }}{{ if .Prompt }}<|im_start|>user
|
||||||
{{ .System }}
|
{{ .Prompt }}<|im_end|>
|
||||||
{{- end }}
|
{{ end }}<|im_start|>assistant
|
||||||
|
|
||||||
### User:
|
|
||||||
{{ .Prompt }}
|
|
||||||
|
|
||||||
### Response:
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
SYSTEM """<system message>"""
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### SYSTEM
|
### SYSTEM
|
||||||
|
|
157
server/images.go
157
server/images.go
|
@ -19,7 +19,6 @@ import (
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"text/template"
|
"text/template"
|
||||||
"text/template/parse"
|
|
||||||
|
|
||||||
"golang.org/x/exp/slices"
|
"golang.org/x/exp/slices"
|
||||||
|
|
||||||
|
@ -58,162 +57,6 @@ type Message struct {
|
||||||
Content string `json:"content"`
|
Content string `json:"content"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type PromptVars struct {
|
|
||||||
System string
|
|
||||||
Prompt string
|
|
||||||
Response string
|
|
||||||
First bool
|
|
||||||
Images []llm.ImageData
|
|
||||||
}
|
|
||||||
|
|
||||||
// extractParts extracts the parts of the template before and after the {{.Response}} node.
|
|
||||||
func extractParts(tmplStr string) (pre string, post string, err error) {
|
|
||||||
tmpl, err := template.New("").Parse(tmplStr)
|
|
||||||
if err != nil {
|
|
||||||
return "", "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
var foundResponse bool
|
|
||||||
|
|
||||||
for _, node := range tmpl.Tree.Root.Nodes {
|
|
||||||
if node.Type() == parse.NodeAction && node.String() == "{{.Response}}" {
|
|
||||||
foundResponse = true
|
|
||||||
}
|
|
||||||
if !foundResponse {
|
|
||||||
pre += node.String()
|
|
||||||
} else {
|
|
||||||
post += node.String()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return pre, post, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func Prompt(promptTemplate string, p PromptVars) (string, error) {
|
|
||||||
var prompt strings.Builder
|
|
||||||
// Use the "missingkey=zero" option to handle missing variables without panicking
|
|
||||||
tmpl, err := template.New("").Option("missingkey=zero").Parse(promptTemplate)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
vars := map[string]any{
|
|
||||||
"System": p.System,
|
|
||||||
"Prompt": p.Prompt,
|
|
||||||
"Response": p.Response,
|
|
||||||
"First": p.First,
|
|
||||||
}
|
|
||||||
|
|
||||||
var sb strings.Builder
|
|
||||||
if err := tmpl.Execute(&sb, vars); err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
prompt.WriteString(sb.String())
|
|
||||||
|
|
||||||
if !strings.Contains(prompt.String(), p.Response) {
|
|
||||||
// if the response is not in the prompt template, append it to the end
|
|
||||||
prompt.WriteString(p.Response)
|
|
||||||
}
|
|
||||||
|
|
||||||
return prompt.String(), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// PreResponsePrompt returns the prompt before the response tag
|
|
||||||
func (m *Model) PreResponsePrompt(p PromptVars) (string, error) {
|
|
||||||
pre, _, err := extractParts(m.Template)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
return Prompt(pre, p)
|
|
||||||
}
|
|
||||||
|
|
||||||
// PostResponseTemplate returns the template after the response tag
|
|
||||||
func (m *Model) PostResponseTemplate(p PromptVars) (string, error) {
|
|
||||||
if p.System == "" {
|
|
||||||
// use the default system prompt for this model if one is not specified
|
|
||||||
p.System = m.System
|
|
||||||
}
|
|
||||||
_, post, err := extractParts(m.Template)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
if post == "" {
|
|
||||||
// if there is no post-response template, return the provided response
|
|
||||||
return p.Response, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return Prompt(post, p)
|
|
||||||
}
|
|
||||||
|
|
||||||
type ChatHistory struct {
|
|
||||||
Prompts []PromptVars
|
|
||||||
LastSystem string
|
|
||||||
}
|
|
||||||
|
|
||||||
// ChatPrompts returns a list of formatted chat prompts from a list of messages
|
|
||||||
func (m *Model) ChatPrompts(msgs []api.Message) (*ChatHistory, error) {
|
|
||||||
// build the prompt from the list of messages
|
|
||||||
lastSystem := m.System
|
|
||||||
currentVars := PromptVars{
|
|
||||||
First: true,
|
|
||||||
System: m.System,
|
|
||||||
}
|
|
||||||
|
|
||||||
prompts := []PromptVars{}
|
|
||||||
var images []llm.ImageData
|
|
||||||
|
|
||||||
for _, msg := range msgs {
|
|
||||||
switch strings.ToLower(msg.Role) {
|
|
||||||
case "system":
|
|
||||||
// if this is the first message it overrides the system prompt in the modelfile
|
|
||||||
if !currentVars.First && currentVars.System != "" {
|
|
||||||
prompts = append(prompts, currentVars)
|
|
||||||
currentVars = PromptVars{}
|
|
||||||
}
|
|
||||||
currentVars.System = msg.Content
|
|
||||||
lastSystem = msg.Content
|
|
||||||
case "user":
|
|
||||||
if currentVars.Prompt != "" {
|
|
||||||
prompts = append(prompts, currentVars)
|
|
||||||
currentVars = PromptVars{}
|
|
||||||
}
|
|
||||||
|
|
||||||
currentVars.Prompt = msg.Content
|
|
||||||
|
|
||||||
if len(m.ProjectorPaths) > 0 {
|
|
||||||
for i := range msg.Images {
|
|
||||||
id := len(images) + i
|
|
||||||
currentVars.Prompt += fmt.Sprintf(" [img-%d]", id)
|
|
||||||
currentVars.Images = append(currentVars.Images, llm.ImageData{
|
|
||||||
ID: id,
|
|
||||||
Data: msg.Images[i],
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
images = append(images, currentVars.Images...)
|
|
||||||
}
|
|
||||||
case "assistant":
|
|
||||||
currentVars.Response = msg.Content
|
|
||||||
prompts = append(prompts, currentVars)
|
|
||||||
currentVars = PromptVars{}
|
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("invalid role: %s, role must be one of [system, user, assistant]", msg.Role)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Append the last set of vars if they are non-empty
|
|
||||||
if currentVars.Prompt != "" || currentVars.System != "" {
|
|
||||||
prompts = append(prompts, currentVars)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &ChatHistory{
|
|
||||||
Prompts: prompts,
|
|
||||||
LastSystem: lastSystem,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type ManifestV2 struct {
|
type ManifestV2 struct {
|
||||||
SchemaVersion int `json:"schemaVersion"`
|
SchemaVersion int `json:"schemaVersion"`
|
||||||
MediaType string `json:"mediaType"`
|
MediaType string `json:"mediaType"`
|
||||||
|
|
|
@ -1,442 +0,0 @@
|
||||||
package server
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/jmorganca/ollama/api"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestPrompt(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
template string
|
|
||||||
vars PromptVars
|
|
||||||
want string
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "System Prompt",
|
|
||||||
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
|
||||||
vars: PromptVars{
|
|
||||||
System: "You are a Wizard.",
|
|
||||||
Prompt: "What are the potion ingredients?",
|
|
||||||
},
|
|
||||||
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "System Prompt with Response",
|
|
||||||
template: "[INST] {{ .System }} {{ .Prompt }} [/INST] {{ .Response }}",
|
|
||||||
vars: PromptVars{
|
|
||||||
System: "You are a Wizard.",
|
|
||||||
Prompt: "What are the potion ingredients?",
|
|
||||||
Response: "I don't know.",
|
|
||||||
},
|
|
||||||
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST] I don't know.",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Conditional Logic Nodes",
|
|
||||||
template: "[INST] {{if .First}}Hello!{{end}} {{ .System }} {{ .Prompt }} [/INST] {{ .Response }}",
|
|
||||||
vars: PromptVars{
|
|
||||||
First: true,
|
|
||||||
System: "You are a Wizard.",
|
|
||||||
Prompt: "What are the potion ingredients?",
|
|
||||||
Response: "I don't know.",
|
|
||||||
},
|
|
||||||
want: "[INST] Hello! You are a Wizard. What are the potion ingredients? [/INST] I don't know.",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
got, err := Prompt(tt.template, tt.vars)
|
|
||||||
if (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf("Prompt() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if got != tt.want {
|
|
||||||
t.Errorf("Prompt() got = %v, want %v", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestModel_PreResponsePrompt(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
template string
|
|
||||||
vars PromptVars
|
|
||||||
want string
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "No Response in Template",
|
|
||||||
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
|
||||||
vars: PromptVars{
|
|
||||||
System: "You are a Wizard.",
|
|
||||||
Prompt: "What are the potion ingredients?",
|
|
||||||
},
|
|
||||||
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Response in Template",
|
|
||||||
template: "[INST] {{ .System }} {{ .Prompt }} [/INST] {{ .Response }}",
|
|
||||||
vars: PromptVars{
|
|
||||||
System: "You are a Wizard.",
|
|
||||||
Prompt: "What are the potion ingredients?",
|
|
||||||
},
|
|
||||||
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST] ",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Response in Template with Trailing Formatting",
|
|
||||||
template: "<|im_start|>user\n{{ .Prompt }}<|im_end|><|im_start|>assistant\n{{ .Response }}<|im_end|>",
|
|
||||||
vars: PromptVars{
|
|
||||||
Prompt: "What are the potion ingredients?",
|
|
||||||
},
|
|
||||||
want: "<|im_start|>user\nWhat are the potion ingredients?<|im_end|><|im_start|>assistant\n",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Response in Template with Alternative Formatting",
|
|
||||||
template: "<|im_start|>user\n{{.Prompt}}<|im_end|><|im_start|>assistant\n{{.Response}}<|im_end|>",
|
|
||||||
vars: PromptVars{
|
|
||||||
Prompt: "What are the potion ingredients?",
|
|
||||||
},
|
|
||||||
want: "<|im_start|>user\nWhat are the potion ingredients?<|im_end|><|im_start|>assistant\n",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
m := Model{Template: tt.template}
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
got, err := m.PreResponsePrompt(tt.vars)
|
|
||||||
if (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf("PreResponsePrompt() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if got != tt.want {
|
|
||||||
t.Errorf("PreResponsePrompt() got = %v, want %v", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestModel_PostResponsePrompt(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
template string
|
|
||||||
vars PromptVars
|
|
||||||
want string
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "No Response in Template",
|
|
||||||
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
|
||||||
vars: PromptVars{
|
|
||||||
Response: "I don't know.",
|
|
||||||
},
|
|
||||||
want: "I don't know.",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Response in Template",
|
|
||||||
template: "[INST] {{ .System }} {{ .Prompt }} [/INST] {{ .Response }}",
|
|
||||||
vars: PromptVars{
|
|
||||||
Response: "I don't know.",
|
|
||||||
},
|
|
||||||
want: "I don't know.",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Response in Template with Trailing Formatting",
|
|
||||||
template: "<|im_start|>user\n{{ .Prompt }}<|im_end|><|im_start|>assistant\n{{ .Response }}<|im_end|>",
|
|
||||||
vars: PromptVars{
|
|
||||||
Response: "I don't know.",
|
|
||||||
},
|
|
||||||
want: "I don't know.<|im_end|>",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Response in Template with Alternative Formatting",
|
|
||||||
template: "<|im_start|>user\n{{.Prompt}}<|im_end|><|im_start|>assistant\n{{.Response}}<|im_end|>",
|
|
||||||
vars: PromptVars{
|
|
||||||
Response: "I don't know.",
|
|
||||||
},
|
|
||||||
want: "I don't know.<|im_end|>",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
m := Model{Template: tt.template}
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
got, err := m.PostResponseTemplate(tt.vars)
|
|
||||||
if (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf("PostResponseTemplate() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if got != tt.want {
|
|
||||||
t.Errorf("PostResponseTemplate() got = %v, want %v", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestModel_PreResponsePrompt_PostResponsePrompt(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
template string
|
|
||||||
preVars PromptVars
|
|
||||||
postVars PromptVars
|
|
||||||
want string
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "Response in Template",
|
|
||||||
template: "<|im_start|>user\n{{.Prompt}}<|im_end|><|im_start|>assistant\n{{.Response}}<|im_end|>",
|
|
||||||
preVars: PromptVars{
|
|
||||||
Prompt: "What are the potion ingredients?",
|
|
||||||
},
|
|
||||||
postVars: PromptVars{
|
|
||||||
Prompt: "What are the potion ingredients?",
|
|
||||||
Response: "Sugar.",
|
|
||||||
},
|
|
||||||
want: "<|im_start|>user\nWhat are the potion ingredients?<|im_end|><|im_start|>assistant\nSugar.<|im_end|>",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "No Response in Template",
|
|
||||||
template: "<|im_start|>user\n{{.Prompt}}<|im_end|><|im_start|>assistant\n",
|
|
||||||
preVars: PromptVars{
|
|
||||||
Prompt: "What are the potion ingredients?",
|
|
||||||
},
|
|
||||||
postVars: PromptVars{
|
|
||||||
Prompt: "What are the potion ingredients?",
|
|
||||||
Response: "Spice.",
|
|
||||||
},
|
|
||||||
want: "<|im_start|>user\nWhat are the potion ingredients?<|im_end|><|im_start|>assistant\nSpice.",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
m := Model{Template: tt.template}
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
pre, err := m.PreResponsePrompt(tt.preVars)
|
|
||||||
if (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf("PreResponsePrompt() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
post, err := m.PostResponseTemplate(tt.postVars)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("PostResponseTemplate() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
result := pre + post
|
|
||||||
if result != tt.want {
|
|
||||||
t.Errorf("Prompt() got = %v, want %v", result, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func chatHistoryEqual(a, b ChatHistory) bool {
|
|
||||||
if len(a.Prompts) != len(b.Prompts) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
for i, v := range a.Prompts {
|
|
||||||
|
|
||||||
if v.First != b.Prompts[i].First {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
if v.Response != b.Prompts[i].Response {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
if v.Prompt != b.Prompts[i].Prompt {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
if v.System != b.Prompts[i].System {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(v.Images) != len(b.Prompts[i].Images) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
for j, img := range v.Images {
|
|
||||||
if img.ID != b.Prompts[i].Images[j].ID {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
if !bytes.Equal(img.Data, b.Prompts[i].Images[j].Data) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return a.LastSystem == b.LastSystem
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestChat(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
model Model
|
|
||||||
msgs []api.Message
|
|
||||||
want ChatHistory
|
|
||||||
wantErr string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "Single Message",
|
|
||||||
model: Model{
|
|
||||||
Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
|
||||||
},
|
|
||||||
msgs: []api.Message{
|
|
||||||
{
|
|
||||||
Role: "system",
|
|
||||||
Content: "You are a Wizard.",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Role: "user",
|
|
||||||
Content: "What are the potion ingredients?",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
want: ChatHistory{
|
|
||||||
Prompts: []PromptVars{
|
|
||||||
{
|
|
||||||
System: "You are a Wizard.",
|
|
||||||
Prompt: "What are the potion ingredients?",
|
|
||||||
First: true,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
LastSystem: "You are a Wizard.",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Message History",
|
|
||||||
model: Model{
|
|
||||||
Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
|
||||||
},
|
|
||||||
msgs: []api.Message{
|
|
||||||
{
|
|
||||||
Role: "system",
|
|
||||||
Content: "You are a Wizard.",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Role: "user",
|
|
||||||
Content: "What are the potion ingredients?",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Role: "assistant",
|
|
||||||
Content: "sugar",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Role: "user",
|
|
||||||
Content: "Anything else?",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
want: ChatHistory{
|
|
||||||
Prompts: []PromptVars{
|
|
||||||
{
|
|
||||||
System: "You are a Wizard.",
|
|
||||||
Prompt: "What are the potion ingredients?",
|
|
||||||
Response: "sugar",
|
|
||||||
First: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Prompt: "Anything else?",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
LastSystem: "You are a Wizard.",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Assistant Only",
|
|
||||||
model: Model{
|
|
||||||
Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
|
||||||
},
|
|
||||||
msgs: []api.Message{
|
|
||||||
{
|
|
||||||
Role: "assistant",
|
|
||||||
Content: "everything nice",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
want: ChatHistory{
|
|
||||||
Prompts: []PromptVars{
|
|
||||||
{
|
|
||||||
Response: "everything nice",
|
|
||||||
First: true,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Last system message is preserved from modelfile",
|
|
||||||
model: Model{
|
|
||||||
Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
|
||||||
System: "You are Mojo Jojo.",
|
|
||||||
},
|
|
||||||
msgs: []api.Message{
|
|
||||||
{
|
|
||||||
Role: "user",
|
|
||||||
Content: "hi",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
want: ChatHistory{
|
|
||||||
Prompts: []PromptVars{
|
|
||||||
{
|
|
||||||
System: "You are Mojo Jojo.",
|
|
||||||
Prompt: "hi",
|
|
||||||
First: true,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
LastSystem: "You are Mojo Jojo.",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Last system message is preserved from messages",
|
|
||||||
model: Model{
|
|
||||||
Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
|
||||||
System: "You are Mojo Jojo.",
|
|
||||||
},
|
|
||||||
msgs: []api.Message{
|
|
||||||
{
|
|
||||||
Role: "system",
|
|
||||||
Content: "You are Professor Utonium.",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
want: ChatHistory{
|
|
||||||
Prompts: []PromptVars{
|
|
||||||
{
|
|
||||||
System: "You are Professor Utonium.",
|
|
||||||
First: true,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
LastSystem: "You are Professor Utonium.",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Invalid Role",
|
|
||||||
msgs: []api.Message{
|
|
||||||
{
|
|
||||||
Role: "not-a-role",
|
|
||||||
Content: "howdy",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
wantErr: "invalid role: not-a-role",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
got, err := tt.model.ChatPrompts(tt.msgs)
|
|
||||||
if tt.wantErr != "" {
|
|
||||||
if err == nil {
|
|
||||||
t.Errorf("ChatPrompt() expected error, got nil")
|
|
||||||
}
|
|
||||||
if !strings.Contains(err.Error(), tt.wantErr) {
|
|
||||||
t.Errorf("ChatPrompt() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if !chatHistoryEqual(*got, tt.want) {
|
|
||||||
t.Errorf("ChatPrompt() got = %#v, want %#v", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
224
server/prompt.go
Normal file
224
server/prompt.go
Normal file
|
@ -0,0 +1,224 @@
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
|
"strings"
|
||||||
|
"text/template"
|
||||||
|
"text/template/parse"
|
||||||
|
|
||||||
|
"github.com/jmorganca/ollama/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
// isResponseNode checks if the node contains .Response
|
||||||
|
func isResponseNode(node *parse.ActionNode) bool {
|
||||||
|
for _, cmd := range node.Pipe.Cmds {
|
||||||
|
for _, arg := range cmd.Args {
|
||||||
|
if fieldNode, ok := arg.(*parse.FieldNode); ok && len(fieldNode.Ident) > 0 {
|
||||||
|
if fieldNode.Ident[0] == "Response" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// formatTemplateForResponse formats the template AST to:
|
||||||
|
// 1. remove all nodes after the first .Response (if generate=true)
|
||||||
|
// 2. add a .Response node to the end if it doesn't exist
|
||||||
|
// TODO(jmorganca): this should recursively cut the template before the first .Response
|
||||||
|
func formatTemplateForResponse(tmpl *template.Template, generate bool) {
|
||||||
|
var found bool
|
||||||
|
for i, node := range tmpl.Tree.Root.Nodes {
|
||||||
|
if actionNode, ok := node.(*parse.ActionNode); ok {
|
||||||
|
if isResponseNode(actionNode) {
|
||||||
|
found = true
|
||||||
|
if generate {
|
||||||
|
tmpl.Tree.Root.Nodes = tmpl.Tree.Root.Nodes[:i+1]
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !found {
|
||||||
|
// add the response node if it doesn't exist
|
||||||
|
responseFieldNode := &parse.FieldNode{NodeType: parse.NodeField, Ident: []string{"Response"}}
|
||||||
|
responsePipeNode := &parse.PipeNode{NodeType: parse.NodePipe, Cmds: []*parse.CommandNode{{NodeType: parse.NodeCommand, Args: []parse.Node{responseFieldNode}}}}
|
||||||
|
responseActionNode := &parse.ActionNode{NodeType: parse.NodeAction, Pipe: responsePipeNode}
|
||||||
|
tmpl.Tree.Root.Nodes = append(tmpl.Tree.Root.Nodes, responseActionNode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prompt renders a prompt from a template. If generate is set to true,
|
||||||
|
// the response and parts of the template following it are not rendered
|
||||||
|
func Prompt(tmpl, system, prompt, response string, generate bool) (string, error) {
|
||||||
|
parsed, err := template.New("").Option("missingkey=zero").Parse(tmpl)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
formatTemplateForResponse(parsed, generate)
|
||||||
|
|
||||||
|
vars := map[string]any{
|
||||||
|
"System": system,
|
||||||
|
"Prompt": prompt,
|
||||||
|
"Response": response,
|
||||||
|
}
|
||||||
|
|
||||||
|
var sb strings.Builder
|
||||||
|
if err := parsed.Execute(&sb, vars); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return sb.String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func countTokens(tmpl string, system string, prompt string, response string, encode func(string) ([]int, error)) (int, error) {
|
||||||
|
rendered, err := Prompt(tmpl, system, prompt, response, false)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
tokens, err := encode(rendered)
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("failed to encode prompt", "err", err)
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return len(tokens), err
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChatPrompt builds up a prompt from a series of messages, truncating based on context window size
|
||||||
|
func ChatPrompt(tmpl string, system string, messages []api.Message, window int, encode func(string) ([]int, error)) (string, error) {
|
||||||
|
type prompt struct {
|
||||||
|
System string
|
||||||
|
Prompt string
|
||||||
|
Response string
|
||||||
|
|
||||||
|
images []int
|
||||||
|
tokens int
|
||||||
|
}
|
||||||
|
|
||||||
|
var p prompt
|
||||||
|
|
||||||
|
// Set the first system prompt to the model's system prompt
|
||||||
|
if system != "" {
|
||||||
|
p.System = system
|
||||||
|
}
|
||||||
|
|
||||||
|
// iterate through messages to build up {system,user,response} prompts
|
||||||
|
var imgId int
|
||||||
|
var prompts []prompt
|
||||||
|
for _, msg := range messages {
|
||||||
|
switch strings.ToLower(msg.Role) {
|
||||||
|
case "system":
|
||||||
|
if p.System != "" || p.Prompt != "" || p.Response != "" {
|
||||||
|
prompts = append(prompts, p)
|
||||||
|
p = prompt{}
|
||||||
|
}
|
||||||
|
|
||||||
|
p.System = msg.Content
|
||||||
|
case "user":
|
||||||
|
if p.Prompt != "" || p.Response != "" {
|
||||||
|
prompts = append(prompts, p)
|
||||||
|
p = prompt{}
|
||||||
|
}
|
||||||
|
|
||||||
|
p.Prompt = msg.Content
|
||||||
|
|
||||||
|
for range msg.Images {
|
||||||
|
p.Prompt += fmt.Sprintf(" [img-%d]", imgId)
|
||||||
|
p.images = append(p.images, imgId)
|
||||||
|
imgId += 1
|
||||||
|
}
|
||||||
|
case "assistant":
|
||||||
|
if p.Response != "" {
|
||||||
|
prompts = append(prompts, p)
|
||||||
|
p = prompt{}
|
||||||
|
}
|
||||||
|
|
||||||
|
p.Response = msg.Content
|
||||||
|
default:
|
||||||
|
return "", fmt.Errorf("invalid role: %s, role must be one of [system, user, assistant]", msg.Role)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// add final prompt
|
||||||
|
if p.System != "" || p.Prompt != "" || p.Response != "" {
|
||||||
|
prompts = append(prompts, p)
|
||||||
|
}
|
||||||
|
|
||||||
|
// calculate token lengths for each prompt, estimating 768 tokens per images
|
||||||
|
for i, p := range prompts {
|
||||||
|
tokens, err := countTokens(tmpl, p.System, p.Prompt, p.Response, encode)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
prompts[i].tokens = tokens + len(prompts[i].images)*768
|
||||||
|
}
|
||||||
|
|
||||||
|
// truncate images and prompts starting from the beginning of the list
|
||||||
|
// until either one prompt remains or the total tokens fits the context window
|
||||||
|
// TODO (jmorganca): this doesn't account for the context window room required for the response
|
||||||
|
for {
|
||||||
|
var required int
|
||||||
|
for _, p := range prompts {
|
||||||
|
required += p.tokens
|
||||||
|
}
|
||||||
|
|
||||||
|
required += 1 // for bos token
|
||||||
|
|
||||||
|
if required <= window {
|
||||||
|
slog.Debug("prompt now fits in context window", "required", required, "window", window)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
prompt := &prompts[0]
|
||||||
|
|
||||||
|
if len(prompt.images) > 1 {
|
||||||
|
img := prompt.images[0]
|
||||||
|
slog.Debug("prompt longer than context window, removing image", "id", img, "required", required, "window", window)
|
||||||
|
prompt.images = prompt.images[1:]
|
||||||
|
prompt.Prompt = strings.Replace(prompt.Prompt, fmt.Sprintf(" [img-%d]", img), "", 1)
|
||||||
|
prompt.tokens -= 768
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(prompts) > 1 {
|
||||||
|
slog.Debug("required tokens longer than context window, removing first prompt", "prompt", prompts[0].tokens, "required", required, "window", window)
|
||||||
|
system := prompt.System
|
||||||
|
prompts = prompts[1:]
|
||||||
|
|
||||||
|
if system != "" && prompts[0].System == "" {
|
||||||
|
prompts[0].System = system
|
||||||
|
|
||||||
|
tokens, err := countTokens(tmpl, prompts[0].System, prompts[0].Prompt, prompts[0].Response, encode)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
prompts[0].tokens = tokens + len(prompts[0].images)*768
|
||||||
|
}
|
||||||
|
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// stop truncating if there's only one prompt left
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
var sb strings.Builder
|
||||||
|
for i, p := range prompts {
|
||||||
|
// last prompt should leave the response unrendered (for completion)
|
||||||
|
rendered, err := Prompt(tmpl, p.System, p.Prompt, p.Response, i == len(prompts)-1)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
sb.WriteString(rendered)
|
||||||
|
}
|
||||||
|
|
||||||
|
return sb.String(), nil
|
||||||
|
}
|
234
server/prompt_test.go
Normal file
234
server/prompt_test.go
Normal file
|
@ -0,0 +1,234 @@
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/jmorganca/ollama/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPrompt(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
template string
|
||||||
|
system string
|
||||||
|
prompt string
|
||||||
|
response string
|
||||||
|
generate bool
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "simple prompt",
|
||||||
|
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
||||||
|
system: "You are a Wizard.",
|
||||||
|
prompt: "What are the potion ingredients?",
|
||||||
|
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "implicit response",
|
||||||
|
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
||||||
|
system: "You are a Wizard.",
|
||||||
|
prompt: "What are the potion ingredients?",
|
||||||
|
response: "I don't know.",
|
||||||
|
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]I don't know.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "response",
|
||||||
|
template: "[INST] {{ .System }} {{ .Prompt }} [/INST] {{ .Response }}",
|
||||||
|
system: "You are a Wizard.",
|
||||||
|
prompt: "What are the potion ingredients?",
|
||||||
|
response: "I don't know.",
|
||||||
|
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST] I don't know.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "cut",
|
||||||
|
template: "<system>{{ .System }}</system><user>{{ .Prompt }}</user><assistant>{{ .Response }}</assistant>",
|
||||||
|
system: "You are a Wizard.",
|
||||||
|
prompt: "What are the potion ingredients?",
|
||||||
|
response: "I don't know.",
|
||||||
|
generate: true,
|
||||||
|
want: "<system>You are a Wizard.</system><user>What are the potion ingredients?</user><assistant>I don't know.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nocut",
|
||||||
|
template: "<system>{{ .System }}</system><user>{{ .Prompt }}</user><assistant>{{ .Response }}</assistant>",
|
||||||
|
system: "You are a Wizard.",
|
||||||
|
prompt: "What are the potion ingredients?",
|
||||||
|
response: "I don't know.",
|
||||||
|
want: "<system>You are a Wizard.</system><user>What are the potion ingredients?</user><assistant>I don't know.</assistant>",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
got, err := Prompt(tc.template, tc.system, tc.prompt, tc.response, tc.generate)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got != tc.want {
|
||||||
|
t.Errorf("got = %v, want %v", got, tc.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChatPrompt(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
template string
|
||||||
|
system string
|
||||||
|
messages []api.Message
|
||||||
|
window int
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "simple prompt",
|
||||||
|
template: "[INST] {{ .Prompt }} [/INST]",
|
||||||
|
messages: []api.Message{
|
||||||
|
{Role: "user", Content: "Hello"},
|
||||||
|
},
|
||||||
|
window: 1024,
|
||||||
|
want: "[INST] Hello [/INST]",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "with default system message",
|
||||||
|
system: "You are a Wizard.",
|
||||||
|
template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST]",
|
||||||
|
messages: []api.Message{
|
||||||
|
{Role: "user", Content: "Hello"},
|
||||||
|
},
|
||||||
|
window: 1024,
|
||||||
|
want: "[INST] <<SYS>>You are a Wizard.<</SYS>> Hello [/INST]",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "with system message",
|
||||||
|
template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST]",
|
||||||
|
messages: []api.Message{
|
||||||
|
{Role: "system", Content: "You are a Wizard."},
|
||||||
|
{Role: "user", Content: "Hello"},
|
||||||
|
},
|
||||||
|
window: 1024,
|
||||||
|
want: "[INST] <<SYS>>You are a Wizard.<</SYS>> Hello [/INST]",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "with response",
|
||||||
|
template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST] {{ .Response }}",
|
||||||
|
messages: []api.Message{
|
||||||
|
{Role: "system", Content: "You are a Wizard."},
|
||||||
|
{Role: "user", Content: "Hello"},
|
||||||
|
{Role: "assistant", Content: "I am?"},
|
||||||
|
},
|
||||||
|
window: 1024,
|
||||||
|
want: "[INST] <<SYS>>You are a Wizard.<</SYS>> Hello [/INST] I am?",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "with implicit response",
|
||||||
|
template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST]",
|
||||||
|
messages: []api.Message{
|
||||||
|
{Role: "system", Content: "You are a Wizard."},
|
||||||
|
{Role: "user", Content: "Hello"},
|
||||||
|
{Role: "assistant", Content: "I am?"},
|
||||||
|
},
|
||||||
|
window: 1024,
|
||||||
|
want: "[INST] <<SYS>>You are a Wizard.<</SYS>> Hello [/INST]I am?",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "with conversation",
|
||||||
|
template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST] {{ .Response }} ",
|
||||||
|
messages: []api.Message{
|
||||||
|
{Role: "system", Content: "You are a Wizard."},
|
||||||
|
{Role: "user", Content: "What are the potion ingredients?"},
|
||||||
|
{Role: "assistant", Content: "sugar"},
|
||||||
|
{Role: "user", Content: "Anything else?"},
|
||||||
|
},
|
||||||
|
window: 1024,
|
||||||
|
want: "[INST] <<SYS>>You are a Wizard.<</SYS>> What are the potion ingredients? [/INST] sugar [INST] Anything else? [/INST] ",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "with truncation",
|
||||||
|
template: "{{ .System }} {{ .Prompt }} {{ .Response }} ",
|
||||||
|
messages: []api.Message{
|
||||||
|
{Role: "system", Content: "You are a Wizard."},
|
||||||
|
{Role: "user", Content: "Hello"},
|
||||||
|
{Role: "assistant", Content: "I am?"},
|
||||||
|
{Role: "user", Content: "Why is the sky blue?"},
|
||||||
|
{Role: "assistant", Content: "The sky is blue from rayleigh scattering"},
|
||||||
|
},
|
||||||
|
window: 10,
|
||||||
|
want: "You are a Wizard. Why is the sky blue? The sky is blue from rayleigh scattering",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "images",
|
||||||
|
template: "{{ .System }} {{ .Prompt }}",
|
||||||
|
messages: []api.Message{
|
||||||
|
{Role: "system", Content: "You are a Wizard."},
|
||||||
|
{Role: "user", Content: "Hello", Images: []api.ImageData{[]byte("base64")}},
|
||||||
|
},
|
||||||
|
window: 1024,
|
||||||
|
want: "You are a Wizard. Hello [img-0]",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "images truncated",
|
||||||
|
template: "{{ .System }} {{ .Prompt }}",
|
||||||
|
messages: []api.Message{
|
||||||
|
{Role: "system", Content: "You are a Wizard."},
|
||||||
|
{Role: "user", Content: "Hello", Images: []api.ImageData{[]byte("img1"), []byte("img2")}},
|
||||||
|
},
|
||||||
|
window: 1024,
|
||||||
|
want: "You are a Wizard. Hello [img-1]",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty list",
|
||||||
|
template: "{{ .System }} {{ .Prompt }}",
|
||||||
|
messages: []api.Message{},
|
||||||
|
window: 1024,
|
||||||
|
want: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty list default system",
|
||||||
|
system: "You are a Wizard.",
|
||||||
|
template: "{{ .System }} {{ .Prompt }}",
|
||||||
|
messages: []api.Message{},
|
||||||
|
window: 1024,
|
||||||
|
want: "You are a Wizard. ",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty user message",
|
||||||
|
system: "You are a Wizard.",
|
||||||
|
template: "{{ .System }} {{ .Prompt }}",
|
||||||
|
messages: []api.Message{
|
||||||
|
{Role: "user", Content: ""},
|
||||||
|
},
|
||||||
|
window: 1024,
|
||||||
|
want: "You are a Wizard. ",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty prompt",
|
||||||
|
template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST] {{ .Response }} ",
|
||||||
|
messages: []api.Message{
|
||||||
|
{Role: "user", Content: ""},
|
||||||
|
},
|
||||||
|
window: 1024,
|
||||||
|
want: "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
encode := func(s string) ([]int, error) {
|
||||||
|
words := strings.Fields(s)
|
||||||
|
return make([]int, len(words)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
got, err := ChatPrompt(tc.template, tc.system, tc.messages, tc.window, encode)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got != tc.want {
|
||||||
|
t.Errorf("got = %v, want %v", got, tc.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
229
server/routes.go
229
server/routes.go
|
@ -214,6 +214,8 @@ func GenerateHandler(c *gin.Context) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// an empty request loads the model
|
// an empty request loads the model
|
||||||
|
// note: for a short while template was used in lieu
|
||||||
|
// of `raw` mode so we need to check for it too
|
||||||
if req.Prompt == "" && req.Template == "" && req.System == "" {
|
if req.Prompt == "" && req.Template == "" && req.System == "" {
|
||||||
c.JSON(http.StatusOK, api.GenerateResponse{
|
c.JSON(http.StatusOK, api.GenerateResponse{
|
||||||
CreatedAt: time.Now().UTC(),
|
CreatedAt: time.Now().UTC(),
|
||||||
|
@ -226,50 +228,48 @@ func GenerateHandler(c *gin.Context) {
|
||||||
checkpointLoaded := time.Now()
|
checkpointLoaded := time.Now()
|
||||||
|
|
||||||
var prompt string
|
var prompt string
|
||||||
var promptVars PromptVars
|
|
||||||
switch {
|
switch {
|
||||||
case req.Raw:
|
case req.Raw:
|
||||||
prompt = req.Prompt
|
prompt = req.Prompt
|
||||||
case req.Prompt != "":
|
case req.Prompt != "":
|
||||||
if req.Template != "" {
|
if req.Template == "" {
|
||||||
// override the default model template
|
req.Template = model.Template
|
||||||
model.Template = req.Template
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var rebuild strings.Builder
|
if req.System == "" {
|
||||||
|
req.System = model.System
|
||||||
|
}
|
||||||
|
|
||||||
|
slog.Debug("generate handler", "prompt", req.Prompt)
|
||||||
|
slog.Debug("generate handler", "template", req.Template)
|
||||||
|
slog.Debug("generate handler", "system", req.System)
|
||||||
|
|
||||||
|
var sb strings.Builder
|
||||||
if req.Context != nil {
|
if req.Context != nil {
|
||||||
// TODO: context is deprecated, at some point the context logic within this conditional should be removed
|
prev, err := loaded.runner.Decode(c.Request.Context(), req.Context)
|
||||||
prevCtx, err := loaded.runner.Decode(c.Request.Context(), req.Context)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remove leading spaces from prevCtx if present
|
sb.WriteString(prev)
|
||||||
prevCtx = strings.TrimPrefix(prevCtx, " ")
|
|
||||||
rebuild.WriteString(prevCtx)
|
|
||||||
}
|
|
||||||
promptVars = PromptVars{
|
|
||||||
System: req.System,
|
|
||||||
Prompt: req.Prompt,
|
|
||||||
First: len(req.Context) == 0,
|
|
||||||
}
|
|
||||||
|
|
||||||
if promptVars.System == "" {
|
|
||||||
promptVars.System = model.System
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// write image tags
|
||||||
|
// TODO: limit the number of images to fit in the context similar to the chat endpoint
|
||||||
for i := range req.Images {
|
for i := range req.Images {
|
||||||
promptVars.Prompt += fmt.Sprintf(" [img-%d]", i)
|
req.Prompt += fmt.Sprintf(" [img-%d]", i)
|
||||||
}
|
}
|
||||||
|
|
||||||
p, err := model.PreResponsePrompt(promptVars)
|
p, err := Prompt(req.Template, req.System, req.Prompt, "", true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
rebuild.WriteString(p)
|
|
||||||
prompt = rebuild.String()
|
sb.WriteString(p)
|
||||||
|
|
||||||
|
prompt = sb.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
slog.Debug("generate handler", "prompt", prompt)
|
slog.Debug("generate handler", "prompt", prompt)
|
||||||
|
@ -308,19 +308,20 @@ func GenerateHandler(c *gin.Context) {
|
||||||
resp.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
resp.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
||||||
|
|
||||||
if !req.Raw {
|
if !req.Raw {
|
||||||
// append the generated text to the history and template it if needed
|
p, err := Prompt(req.Template, req.System, req.Prompt, generated.String(), false)
|
||||||
promptVars.Response = generated.String()
|
if err != nil {
|
||||||
result, err := model.PostResponseTemplate(promptVars)
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO (jmorganca): encode() should not strip special tokens
|
||||||
|
tokens, err := loaded.runner.Encode(c.Request.Context(), p)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ch <- gin.H{"error": err.Error()}
|
ch <- gin.H{"error": err.Error()}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
embd, err := loaded.runner.Encode(c.Request.Context(), prompt+result)
|
|
||||||
if err != nil {
|
resp.Context = append(req.Context, tokens...)
|
||||||
ch <- gin.H{"error": err.Error()}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
resp.Context = embd
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1090,6 +1091,20 @@ func streamResponse(c *gin.Context, ch chan any) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ChatPrompt builds up a prompt from a series of messages for the currently `loaded` model
|
||||||
|
func chatPrompt(ctx context.Context, messages []api.Message) (string, error) {
|
||||||
|
encode := func(s string) ([]int, error) {
|
||||||
|
return loaded.runner.Encode(ctx, s)
|
||||||
|
}
|
||||||
|
|
||||||
|
prompt, err := ChatPrompt(loaded.Model.Template, loaded.Model.System, messages, loaded.Options.NumCtx, encode)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return prompt, nil
|
||||||
|
}
|
||||||
|
|
||||||
func ChatHandler(c *gin.Context) {
|
func ChatHandler(c *gin.Context) {
|
||||||
loaded.mu.Lock()
|
loaded.mu.Lock()
|
||||||
defer loaded.mu.Unlock()
|
defer loaded.mu.Unlock()
|
||||||
|
@ -1117,15 +1132,6 @@ func ChatHandler(c *gin.Context) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, msg := range req.Messages {
|
|
||||||
for _, img := range msg.Images {
|
|
||||||
if !isSupportedImageType(img) {
|
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "unsupported image format"})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
model, err := GetModel(req.Model)
|
model, err := GetModel(req.Model)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
var pErr *fs.PathError
|
var pErr *fs.PathError
|
||||||
|
@ -1161,20 +1167,14 @@ func ChatHandler(c *gin.Context) {
|
||||||
|
|
||||||
checkpointLoaded := time.Now()
|
checkpointLoaded := time.Now()
|
||||||
|
|
||||||
chat, err := model.ChatPrompts(req.Messages)
|
prompt, err := chatPrompt(c.Request.Context(), req.Messages)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
prompt, images, err := trimmedPrompt(c.Request.Context(), chat, model)
|
|
||||||
if err != nil {
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// an empty request loads the model
|
// an empty request loads the model
|
||||||
if len(prompt) == 0 {
|
if len(req.Messages) == 0 || prompt == "" {
|
||||||
resp := api.ChatResponse{
|
resp := api.ChatResponse{
|
||||||
CreatedAt: time.Now().UTC(),
|
CreatedAt: time.Now().UTC(),
|
||||||
Model: req.Model,
|
Model: req.Model,
|
||||||
|
@ -1185,7 +1185,24 @@ func ChatHandler(c *gin.Context) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
slog.Debug("chat handler", "prompt", prompt)
|
// only send images that are in the prompt
|
||||||
|
var i int
|
||||||
|
var images []llm.ImageData
|
||||||
|
for _, m := range req.Messages {
|
||||||
|
for _, img := range m.Images {
|
||||||
|
if !isSupportedImageType(img) {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "unsupported image format"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.Contains(prompt, fmt.Sprintf("[img-%d]", i)) {
|
||||||
|
images = append(images, llm.ImageData{Data: img, ID: i})
|
||||||
|
}
|
||||||
|
i += 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
slog.Debug("chat handler", "prompt", prompt, "images", len(images))
|
||||||
|
|
||||||
ch := make(chan any)
|
ch := make(chan any)
|
||||||
|
|
||||||
|
@ -1260,115 +1277,3 @@ func ChatHandler(c *gin.Context) {
|
||||||
|
|
||||||
streamResponse(c, ch)
|
streamResponse(c, ch)
|
||||||
}
|
}
|
||||||
|
|
||||||
// promptInfo stores the variables used to template a prompt, and the token length of the resulting template for some model
|
|
||||||
type promptInfo struct {
|
|
||||||
vars PromptVars
|
|
||||||
tokenLen int
|
|
||||||
}
|
|
||||||
|
|
||||||
// trimmedPrompt builds a prompt to send to a running model. It ensures the prompt fits within the max context length,
|
|
||||||
// while preserving the most recent system message.
|
|
||||||
func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string, []llm.ImageData, error) {
|
|
||||||
if len(chat.Prompts) == 0 {
|
|
||||||
return "", nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var promptsToAdd []promptInfo
|
|
||||||
var totalTokenLength int
|
|
||||||
var systemPromptIncluded bool
|
|
||||||
|
|
||||||
var images []llm.ImageData
|
|
||||||
// reverse iterate through the prompts to build the prompt string in a way that fits the max context length
|
|
||||||
for i := len(chat.Prompts) - 1; i >= 0; i-- {
|
|
||||||
prompt := chat.Prompts[i]
|
|
||||||
promptText, err := promptString(model, prompt, i == len(chat.Prompts)-1)
|
|
||||||
if err != nil {
|
|
||||||
return "", nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
encodedTokens, err := loaded.runner.Encode(ctx, promptText)
|
|
||||||
if err != nil {
|
|
||||||
return "", nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if totalTokenLength+len(encodedTokens) > loaded.NumCtx && i != len(chat.Prompts)-1 {
|
|
||||||
break // reached max context length, stop adding more prompts
|
|
||||||
}
|
|
||||||
|
|
||||||
for j := range prompt.Images {
|
|
||||||
if totalTokenLength+768 > loaded.NumCtx {
|
|
||||||
// this decreases the token length but overestimating is fine
|
|
||||||
prompt.Prompt = strings.ReplaceAll(prompt.Prompt, fmt.Sprintf(" [img-%d]", prompt.Images[j].ID), "")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
totalTokenLength += 768
|
|
||||||
images = append(images, prompt.Images[j])
|
|
||||||
}
|
|
||||||
|
|
||||||
totalTokenLength += len(encodedTokens)
|
|
||||||
systemPromptIncluded = systemPromptIncluded || prompt.System != ""
|
|
||||||
promptsToAdd = append(promptsToAdd, promptInfo{vars: prompt, tokenLen: len(encodedTokens)})
|
|
||||||
}
|
|
||||||
|
|
||||||
// ensure the system prompt is included, if not already
|
|
||||||
if chat.LastSystem != "" && !systemPromptIncluded {
|
|
||||||
var err error
|
|
||||||
promptsToAdd, err = includeSystemPrompt(ctx, chat.LastSystem, totalTokenLength, promptsToAdd)
|
|
||||||
if err != nil {
|
|
||||||
return "", nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
promptsToAdd[len(promptsToAdd)-1].vars.First = true
|
|
||||||
|
|
||||||
// construct the final prompt string from the prompts which fit within the context window
|
|
||||||
var result string
|
|
||||||
for i, prompt := range promptsToAdd {
|
|
||||||
promptText, err := promptString(model, prompt.vars, i == 0)
|
|
||||||
if err != nil {
|
|
||||||
return "", nil, err
|
|
||||||
}
|
|
||||||
result = promptText + result
|
|
||||||
}
|
|
||||||
|
|
||||||
return result, images, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// promptString applies the model template to the prompt
|
|
||||||
func promptString(model *Model, vars PromptVars, isMostRecent bool) (string, error) {
|
|
||||||
if isMostRecent {
|
|
||||||
p, err := model.PreResponsePrompt(vars)
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("pre-response template: %w", err)
|
|
||||||
}
|
|
||||||
return p, nil
|
|
||||||
}
|
|
||||||
p, err := Prompt(model.Template, vars)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
return p, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// includeSystemPrompt adjusts the prompts to include the system prompt.
|
|
||||||
func includeSystemPrompt(ctx context.Context, systemPrompt string, totalTokenLength int, promptsToAdd []promptInfo) ([]promptInfo, error) {
|
|
||||||
systemTokens, err := loaded.runner.Encode(ctx, systemPrompt)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := len(promptsToAdd) - 1; i >= 0; i-- {
|
|
||||||
if totalTokenLength+len(systemTokens) <= loaded.NumCtx {
|
|
||||||
promptsToAdd[i].vars.System = systemPrompt
|
|
||||||
return promptsToAdd[:i+1], nil
|
|
||||||
}
|
|
||||||
totalTokenLength -= promptsToAdd[i].tokenLen
|
|
||||||
}
|
|
||||||
|
|
||||||
// if got here, system did not fit anywhere, so return the most recent prompt with the system message set
|
|
||||||
recent := promptsToAdd[len(promptsToAdd)-1]
|
|
||||||
recent.vars.System = systemPrompt
|
|
||||||
return []promptInfo{recent}, nil
|
|
||||||
}
|
|
||||||
|
|
|
@ -241,237 +241,6 @@ func Test_Routes(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_ChatPrompt(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
template string
|
|
||||||
chat *ChatHistory
|
|
||||||
numCtx int
|
|
||||||
runner MockLLM
|
|
||||||
want string
|
|
||||||
wantErr string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "Single Message",
|
|
||||||
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
|
||||||
chat: &ChatHistory{
|
|
||||||
Prompts: []PromptVars{
|
|
||||||
{
|
|
||||||
System: "You are a Wizard.",
|
|
||||||
Prompt: "What are the potion ingredients?",
|
|
||||||
First: true,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
LastSystem: "You are a Wizard.",
|
|
||||||
},
|
|
||||||
numCtx: 1,
|
|
||||||
runner: MockLLM{
|
|
||||||
encoding: []int{1}, // fit the ctxLen
|
|
||||||
},
|
|
||||||
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "First Message",
|
|
||||||
template: "[INST] {{if .First}}Hello!{{end}} {{ .System }} {{ .Prompt }} [/INST]",
|
|
||||||
chat: &ChatHistory{
|
|
||||||
Prompts: []PromptVars{
|
|
||||||
{
|
|
||||||
System: "You are a Wizard.",
|
|
||||||
Prompt: "What are the potion ingredients?",
|
|
||||||
Response: "eye of newt",
|
|
||||||
First: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Prompt: "Anything else?",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
LastSystem: "You are a Wizard.",
|
|
||||||
},
|
|
||||||
numCtx: 2,
|
|
||||||
runner: MockLLM{
|
|
||||||
encoding: []int{1}, // fit the ctxLen
|
|
||||||
},
|
|
||||||
want: "[INST] Hello! You are a Wizard. What are the potion ingredients? [/INST]eye of newt[INST] Anything else? [/INST]",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Message History",
|
|
||||||
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
|
||||||
chat: &ChatHistory{
|
|
||||||
Prompts: []PromptVars{
|
|
||||||
{
|
|
||||||
System: "You are a Wizard.",
|
|
||||||
Prompt: "What are the potion ingredients?",
|
|
||||||
Response: "sugar",
|
|
||||||
First: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Prompt: "Anything else?",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
LastSystem: "You are a Wizard.",
|
|
||||||
},
|
|
||||||
numCtx: 4,
|
|
||||||
runner: MockLLM{
|
|
||||||
encoding: []int{1}, // fit the ctxLen, 1 for each message
|
|
||||||
},
|
|
||||||
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]sugar[INST] Anything else? [/INST]",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Assistant Only",
|
|
||||||
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
|
||||||
chat: &ChatHistory{
|
|
||||||
Prompts: []PromptVars{
|
|
||||||
{
|
|
||||||
Response: "everything nice",
|
|
||||||
First: true,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
numCtx: 1,
|
|
||||||
runner: MockLLM{
|
|
||||||
encoding: []int{1},
|
|
||||||
},
|
|
||||||
want: "[INST] [/INST]everything nice",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Message History Truncated, No System",
|
|
||||||
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
|
||||||
chat: &ChatHistory{
|
|
||||||
Prompts: []PromptVars{
|
|
||||||
{
|
|
||||||
Prompt: "What are the potion ingredients?",
|
|
||||||
Response: "sugar",
|
|
||||||
First: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Prompt: "Anything else?",
|
|
||||||
Response: "spice",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Prompt: "... and?",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
numCtx: 2, // only 1 message from history and most recent message
|
|
||||||
runner: MockLLM{
|
|
||||||
encoding: []int{1},
|
|
||||||
},
|
|
||||||
want: "[INST] Anything else? [/INST]spice[INST] ... and? [/INST]",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "System is Preserved when Truncated",
|
|
||||||
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
|
||||||
chat: &ChatHistory{
|
|
||||||
Prompts: []PromptVars{
|
|
||||||
{
|
|
||||||
Prompt: "What are the magic words?",
|
|
||||||
Response: "abracadabra",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Prompt: "What is the spell for invisibility?",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
LastSystem: "You are a wizard.",
|
|
||||||
},
|
|
||||||
numCtx: 2,
|
|
||||||
runner: MockLLM{
|
|
||||||
encoding: []int{1},
|
|
||||||
},
|
|
||||||
want: "[INST] You are a wizard. What is the spell for invisibility? [/INST]",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "System is Preserved when Length Exceeded",
|
|
||||||
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
|
||||||
chat: &ChatHistory{
|
|
||||||
Prompts: []PromptVars{
|
|
||||||
{
|
|
||||||
Prompt: "What are the magic words?",
|
|
||||||
Response: "abracadabra",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Prompt: "What is the spell for invisibility?",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
LastSystem: "You are a wizard.",
|
|
||||||
},
|
|
||||||
numCtx: 1,
|
|
||||||
runner: MockLLM{
|
|
||||||
encoding: []int{1},
|
|
||||||
},
|
|
||||||
want: "[INST] You are a wizard. What is the spell for invisibility? [/INST]",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "First is Preserved when Truncated",
|
|
||||||
template: "[INST] {{ if .First }}{{ .System }} {{ end }}{{ .Prompt }} [/INST]",
|
|
||||||
|
|
||||||
chat: &ChatHistory{
|
|
||||||
Prompts: []PromptVars{
|
|
||||||
// first message omitted for test
|
|
||||||
{
|
|
||||||
Prompt: "Do you have a magic hat?",
|
|
||||||
Response: "Of course.",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Prompt: "What is the spell for invisibility?",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
LastSystem: "You are a wizard.",
|
|
||||||
},
|
|
||||||
numCtx: 3, // two most recent messages and room for system message
|
|
||||||
runner: MockLLM{
|
|
||||||
encoding: []int{1},
|
|
||||||
},
|
|
||||||
want: "[INST] You are a wizard. Do you have a magic hat? [/INST]Of course.[INST] What is the spell for invisibility? [/INST]",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Most recent message is returned when longer than ctxLen",
|
|
||||||
template: "[INST] {{ .Prompt }} [/INST]",
|
|
||||||
|
|
||||||
chat: &ChatHistory{
|
|
||||||
Prompts: []PromptVars{
|
|
||||||
{
|
|
||||||
Prompt: "What is the spell for invisibility?",
|
|
||||||
First: true,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
numCtx: 1, // two most recent messages
|
|
||||||
runner: MockLLM{
|
|
||||||
encoding: []int{1, 2},
|
|
||||||
},
|
|
||||||
want: "[INST] What is the spell for invisibility? [/INST]",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, testCase := range tests {
|
|
||||||
tt := testCase
|
|
||||||
m := &Model{
|
|
||||||
Template: tt.template,
|
|
||||||
}
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
loaded.runner = &tt.runner
|
|
||||||
loaded.Options = &api.Options{
|
|
||||||
Runner: api.Runner{
|
|
||||||
NumCtx: tt.numCtx,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
// TODO: add tests for trimming images
|
|
||||||
got, _, err := trimmedPrompt(context.Background(), tt.chat, m)
|
|
||||||
if tt.wantErr != "" {
|
|
||||||
if err == nil {
|
|
||||||
t.Errorf("ChatPrompt() expected error, got nil")
|
|
||||||
}
|
|
||||||
if !strings.Contains(err.Error(), tt.wantErr) {
|
|
||||||
t.Errorf("ChatPrompt() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if got != tt.want {
|
|
||||||
t.Errorf("ChatPrompt() got = %v, want %v", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type MockLLM struct {
|
type MockLLM struct {
|
||||||
encoding []int
|
encoding []int
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue