trim chat prompt based on llm context size (#1963)

This commit is contained in:
Bruce MacDonald 2024-01-30 15:59:29 -05:00 committed by GitHub
parent 509e2dec8a
commit 0632dff3f8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 440 additions and 57 deletions

View file

@ -146,62 +146,59 @@ func (m *Model) PostResponseTemplate(p PromptVars) (string, error) {
return Prompt(post, p)
}
func (m *Model) ChatPrompt(msgs []api.Message) (string, []api.ImageData, error) {
type ChatHistory struct {
Prompts []PromptVars
CurrentImages []api.ImageData
LastSystem string
}
// ChatPrompts returns a list of formatted chat prompts from a list of messages
func (m *Model) ChatPrompts(msgs []api.Message) (*ChatHistory, error) {
// build the prompt from the list of messages
var prompt strings.Builder
var currentImages []api.ImageData
var lastSystem string
currentVars := PromptVars{
First: true,
System: m.System,
}
writePrompt := func() error {
p, err := Prompt(m.Template, currentVars)
if err != nil {
return err
}
prompt.WriteString(p)
currentVars = PromptVars{}
return nil
}
prompts := []PromptVars{}
for _, msg := range msgs {
switch strings.ToLower(msg.Role) {
case "system":
if currentVars.System != "" {
if err := writePrompt(); err != nil {
return "", nil, err
}
prompts = append(prompts, currentVars)
currentVars = PromptVars{}
}
currentVars.System = msg.Content
lastSystem = msg.Content
case "user":
if currentVars.Prompt != "" {
if err := writePrompt(); err != nil {
return "", nil, err
}
prompts = append(prompts, currentVars)
currentVars = PromptVars{}
}
currentVars.Prompt = msg.Content
currentImages = msg.Images
case "assistant":
currentVars.Response = msg.Content
if err := writePrompt(); err != nil {
return "", nil, err
}
prompts = append(prompts, currentVars)
currentVars = PromptVars{}
default:
return "", nil, fmt.Errorf("invalid role: %s, role must be one of [system, user, assistant]", msg.Role)
return nil, fmt.Errorf("invalid role: %s, role must be one of [system, user, assistant]", msg.Role)
}
}
// Append the last set of vars if they are non-empty
if currentVars.Prompt != "" || currentVars.System != "" {
p, err := m.PreResponsePrompt(currentVars)
if err != nil {
return "", nil, fmt.Errorf("pre-response template: %w", err)
}
prompt.WriteString(p)
prompts = append(prompts, currentVars)
}
return prompt.String(), currentImages, nil
return &ChatHistory{
Prompts: prompts,
CurrentImages: currentImages,
LastSystem: lastSystem,
}, nil
}
type ManifestV2 struct {

View file

@ -1,6 +1,7 @@
package server
import (
"bytes"
"strings"
"testing"
@ -233,12 +234,32 @@ func TestModel_PreResponsePrompt_PostResponsePrompt(t *testing.T) {
}
}
func chatHistoryEqual(a, b ChatHistory) bool {
if len(a.Prompts) != len(b.Prompts) {
return false
}
if len(a.CurrentImages) != len(b.CurrentImages) {
return false
}
for i, v := range a.Prompts {
if v != b.Prompts[i] {
return false
}
}
for i, v := range a.CurrentImages {
if !bytes.Equal(v, b.CurrentImages[i]) {
return false
}
}
return a.LastSystem == b.LastSystem
}
func TestChat(t *testing.T) {
tests := []struct {
name string
template string
msgs []api.Message
want string
want ChatHistory
wantErr string
}{
{
@ -254,30 +275,16 @@ func TestChat(t *testing.T) {
Content: "What are the potion ingredients?",
},
},
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]",
},
want: ChatHistory{
Prompts: []PromptVars{
{
name: "First Message",
template: "[INST] {{if .First}}Hello!{{end}} {{ .System }} {{ .Prompt }} [/INST]",
msgs: []api.Message{
{
Role: "system",
Content: "You are a Wizard.",
},
{
Role: "user",
Content: "What are the potion ingredients?",
},
{
Role: "assistant",
Content: "eye of newt",
},
{
Role: "user",
Content: "Anything else?",
System: "You are a Wizard.",
Prompt: "What are the potion ingredients?",
First: true,
},
},
want: "[INST] Hello! You are a Wizard. What are the potion ingredients? [/INST]eye of newt[INST] Anything else? [/INST]",
LastSystem: "You are a Wizard.",
},
},
{
name: "Message History",
@ -300,7 +307,20 @@ func TestChat(t *testing.T) {
Content: "Anything else?",
},
},
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]sugar[INST] Anything else? [/INST]",
want: ChatHistory{
Prompts: []PromptVars{
{
System: "You are a Wizard.",
Prompt: "What are the potion ingredients?",
Response: "sugar",
First: true,
},
{
Prompt: "Anything else?",
},
},
LastSystem: "You are a Wizard.",
},
},
{
name: "Assistant Only",
@ -311,7 +331,14 @@ func TestChat(t *testing.T) {
Content: "everything nice",
},
},
want: "[INST] [/INST]everything nice",
want: ChatHistory{
Prompts: []PromptVars{
{
Response: "everything nice",
First: true,
},
},
},
},
{
name: "Invalid Role",
@ -330,7 +357,7 @@ func TestChat(t *testing.T) {
Template: tt.template,
}
t.Run(tt.name, func(t *testing.T) {
got, _, err := m.ChatPrompt(tt.msgs)
got, err := m.ChatPrompts(tt.msgs)
if tt.wantErr != "" {
if err == nil {
t.Errorf("ChatPrompt() expected error, got nil")
@ -338,9 +365,10 @@ func TestChat(t *testing.T) {
if !strings.Contains(err.Error(), tt.wantErr) {
t.Errorf("ChatPrompt() error = %v, wantErr %v", err, tt.wantErr)
}
return
}
if got != tt.want {
t.Errorf("ChatPrompt() got = %v, want %v", got, tt.want)
if !chatHistoryEqual(*got, tt.want) {
t.Errorf("ChatPrompt() got = %#v, want %#v", got, tt.want)
}
})
}

View file

@ -1121,11 +1121,16 @@ func ChatHandler(c *gin.Context) {
checkpointLoaded := time.Now()
prompt, images, err := model.ChatPrompt(req.Messages)
chat, err := model.ChatPrompts(req.Messages)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
prompt, err := trimmedPrompt(c.Request.Context(), chat, model)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
slog.Debug(fmt.Sprintf("prompt: %s", prompt))
@ -1164,7 +1169,7 @@ func ChatHandler(c *gin.Context) {
predictReq := llm.PredictOpts{
Prompt: prompt,
Format: req.Format,
Images: images,
Images: chat.CurrentImages,
Options: opts,
}
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
@ -1202,3 +1207,101 @@ func ChatHandler(c *gin.Context) {
streamResponse(c, ch)
}
// promptInfo stores the variables used to template a prompt, and the token length of the resulting template for some model
type promptInfo struct {
vars PromptVars
tokenLen int
}
// trimmedPrompt builds a prompt to send to a running model. It ensures the prompt fits within the max context length,
// while preserving the most recent system message.
func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string, error) {
if len(chat.Prompts) == 0 {
return "", nil
}
var promptsToAdd []promptInfo
var totalTokenLength int
var systemPromptIncluded bool
// reverse iterate through the prompts to build the prompt string in a way that fits the max context length
for i := len(chat.Prompts) - 1; i >= 0; i-- {
promptText, err := promptString(model, chat.Prompts[i], i == len(chat.Prompts)-1)
if err != nil {
return "", err
}
encodedTokens, err := loaded.runner.Encode(ctx, promptText)
if err != nil {
return "", err
}
if totalTokenLength+len(encodedTokens) > loaded.NumCtx && i != len(chat.Prompts)-1 {
break // reached max context length, stop adding more prompts
}
totalTokenLength += len(encodedTokens)
systemPromptIncluded = systemPromptIncluded || chat.Prompts[i].System != ""
promptsToAdd = append(promptsToAdd, promptInfo{vars: chat.Prompts[i], tokenLen: len(encodedTokens)})
}
// ensure the system prompt is included, if not already
if chat.LastSystem != "" && !systemPromptIncluded {
var err error
promptsToAdd, err = includeSystemPrompt(ctx, chat.LastSystem, totalTokenLength, promptsToAdd)
if err != nil {
return "", err
}
}
promptsToAdd[len(promptsToAdd)-1].vars.First = true
// construct the final prompt string from the prompts which fit within the context window
var result string
for i, prompt := range promptsToAdd {
promptText, err := promptString(model, prompt.vars, i == 0)
if err != nil {
return "", err
}
result = promptText + result
}
return result, nil
}
// promptString applies the model template to the prompt
func promptString(model *Model, vars PromptVars, isMostRecent bool) (string, error) {
if isMostRecent {
p, err := model.PreResponsePrompt(vars)
if err != nil {
return "", fmt.Errorf("pre-response template: %w", err)
}
return p, nil
}
p, err := Prompt(model.Template, vars)
if err != nil {
return "", err
}
return p, nil
}
// includeSystemPrompt adjusts the prompts to include the system prompt.
func includeSystemPrompt(ctx context.Context, systemPrompt string, totalTokenLength int, promptsToAdd []promptInfo) ([]promptInfo, error) {
systemTokens, err := loaded.runner.Encode(ctx, systemPrompt)
if err != nil {
return nil, err
}
for i := len(promptsToAdd) - 1; i >= 0; i-- {
if totalTokenLength+len(systemTokens) <= loaded.NumCtx {
promptsToAdd[i].vars.System = systemPrompt
return promptsToAdd[:i+1], nil
}
totalTokenLength -= promptsToAdd[i].tokenLen
}
// if got here, system did not fit anywhere, so return the most recent prompt with the system message set
recent := promptsToAdd[len(promptsToAdd)-1]
recent.vars.System = systemPrompt
return []promptInfo{recent}, nil
}

View file

@ -16,6 +16,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/jmorganca/ollama/api"
"github.com/jmorganca/ollama/llm"
"github.com/jmorganca/ollama/parser"
"github.com/jmorganca/ollama/version"
)
@ -239,3 +240,257 @@ func Test_Routes(t *testing.T) {
}
}
func Test_ChatPrompt(t *testing.T) {
tests := []struct {
name string
template string
chat *ChatHistory
numCtx int
runner MockLLM
want string
wantErr string
}{
{
name: "Single Message",
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
chat: &ChatHistory{
Prompts: []PromptVars{
{
System: "You are a Wizard.",
Prompt: "What are the potion ingredients?",
First: true,
},
},
LastSystem: "You are a Wizard.",
},
numCtx: 1,
runner: MockLLM{
encoding: []int{1}, // fit the ctxLen
},
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]",
},
{
name: "First Message",
template: "[INST] {{if .First}}Hello!{{end}} {{ .System }} {{ .Prompt }} [/INST]",
chat: &ChatHistory{
Prompts: []PromptVars{
{
System: "You are a Wizard.",
Prompt: "What are the potion ingredients?",
Response: "eye of newt",
First: true,
},
{
Prompt: "Anything else?",
},
},
LastSystem: "You are a Wizard.",
},
numCtx: 2,
runner: MockLLM{
encoding: []int{1}, // fit the ctxLen
},
want: "[INST] Hello! You are a Wizard. What are the potion ingredients? [/INST]eye of newt[INST] Anything else? [/INST]",
},
{
name: "Message History",
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
chat: &ChatHistory{
Prompts: []PromptVars{
{
System: "You are a Wizard.",
Prompt: "What are the potion ingredients?",
Response: "sugar",
First: true,
},
{
Prompt: "Anything else?",
},
},
LastSystem: "You are a Wizard.",
},
numCtx: 4,
runner: MockLLM{
encoding: []int{1}, // fit the ctxLen, 1 for each message
},
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]sugar[INST] Anything else? [/INST]",
},
{
name: "Assistant Only",
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
chat: &ChatHistory{
Prompts: []PromptVars{
{
Response: "everything nice",
First: true,
},
},
},
numCtx: 1,
runner: MockLLM{
encoding: []int{1},
},
want: "[INST] [/INST]everything nice",
},
{
name: "Message History Truncated, No System",
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
chat: &ChatHistory{
Prompts: []PromptVars{
{
Prompt: "What are the potion ingredients?",
Response: "sugar",
First: true,
},
{
Prompt: "Anything else?",
Response: "spice",
},
{
Prompt: "... and?",
},
},
},
numCtx: 2, // only 1 message from history and most recent message
runner: MockLLM{
encoding: []int{1},
},
want: "[INST] Anything else? [/INST]spice[INST] ... and? [/INST]",
},
{
name: "System is Preserved when Truncated",
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
chat: &ChatHistory{
Prompts: []PromptVars{
{
Prompt: "What are the magic words?",
Response: "abracadabra",
},
{
Prompt: "What is the spell for invisibility?",
},
},
LastSystem: "You are a wizard.",
},
numCtx: 2,
runner: MockLLM{
encoding: []int{1},
},
want: "[INST] You are a wizard. What is the spell for invisibility? [/INST]",
},
{
name: "System is Preserved when Length Exceeded",
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
chat: &ChatHistory{
Prompts: []PromptVars{
{
Prompt: "What are the magic words?",
Response: "abracadabra",
},
{
Prompt: "What is the spell for invisibility?",
},
},
LastSystem: "You are a wizard.",
},
numCtx: 1,
runner: MockLLM{
encoding: []int{1},
},
want: "[INST] You are a wizard. What is the spell for invisibility? [/INST]",
},
{
name: "First is Preserved when Truncated",
template: "[INST] {{ if .First }}{{ .System }} {{ end }}{{ .Prompt }} [/INST]",
chat: &ChatHistory{
Prompts: []PromptVars{
// first message omitted for test
{
Prompt: "Do you have a magic hat?",
Response: "Of course.",
},
{
Prompt: "What is the spell for invisibility?",
},
},
LastSystem: "You are a wizard.",
},
numCtx: 3, // two most recent messages and room for system message
runner: MockLLM{
encoding: []int{1},
},
want: "[INST] You are a wizard. Do you have a magic hat? [/INST]Of course.[INST] What is the spell for invisibility? [/INST]",
},
{
name: "Most recent message is returned when longer than ctxLen",
template: "[INST] {{ .Prompt }} [/INST]",
chat: &ChatHistory{
Prompts: []PromptVars{
{
Prompt: "What is the spell for invisibility?",
First: true,
},
},
},
numCtx: 1, // two most recent messages
runner: MockLLM{
encoding: []int{1, 2},
},
want: "[INST] What is the spell for invisibility? [/INST]",
},
}
for _, testCase := range tests {
tt := testCase
m := &Model{
Template: tt.template,
}
t.Run(tt.name, func(t *testing.T) {
loaded.runner = &tt.runner
loaded.Options = &api.Options{
Runner: api.Runner{
NumCtx: tt.numCtx,
},
}
got, err := trimmedPrompt(context.Background(), tt.chat, m)
if tt.wantErr != "" {
if err == nil {
t.Errorf("ChatPrompt() expected error, got nil")
}
if !strings.Contains(err.Error(), tt.wantErr) {
t.Errorf("ChatPrompt() error = %v, wantErr %v", err, tt.wantErr)
}
}
if got != tt.want {
t.Errorf("ChatPrompt() got = %v, want %v", got, tt.want)
}
})
}
}
type MockLLM struct {
encoding []int
}
func (llm *MockLLM) Predict(ctx context.Context, pred llm.PredictOpts, fn func(llm.PredictResult)) error {
return nil
}
func (llm *MockLLM) Encode(ctx context.Context, prompt string) ([]int, error) {
return llm.encoding, nil
}
func (llm *MockLLM) Decode(ctx context.Context, tokens []int) (string, error) {
return "", nil
}
func (llm *MockLLM) Embedding(ctx context.Context, input string) ([]float64, error) {
return []float64{}, nil
}
func (llm *MockLLM) Close() {
// do nothing
}