trim chat prompt based on llm context size (#1963)
This commit is contained in:
parent
509e2dec8a
commit
0632dff3f8
4 changed files with 440 additions and 57 deletions
|
@ -146,62 +146,59 @@ func (m *Model) PostResponseTemplate(p PromptVars) (string, error) {
|
|||
return Prompt(post, p)
|
||||
}
|
||||
|
||||
func (m *Model) ChatPrompt(msgs []api.Message) (string, []api.ImageData, error) {
|
||||
type ChatHistory struct {
|
||||
Prompts []PromptVars
|
||||
CurrentImages []api.ImageData
|
||||
LastSystem string
|
||||
}
|
||||
|
||||
// ChatPrompts returns a list of formatted chat prompts from a list of messages
|
||||
func (m *Model) ChatPrompts(msgs []api.Message) (*ChatHistory, error) {
|
||||
// build the prompt from the list of messages
|
||||
var prompt strings.Builder
|
||||
var currentImages []api.ImageData
|
||||
var lastSystem string
|
||||
currentVars := PromptVars{
|
||||
First: true,
|
||||
System: m.System,
|
||||
}
|
||||
|
||||
writePrompt := func() error {
|
||||
p, err := Prompt(m.Template, currentVars)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
prompt.WriteString(p)
|
||||
currentVars = PromptVars{}
|
||||
return nil
|
||||
}
|
||||
prompts := []PromptVars{}
|
||||
|
||||
for _, msg := range msgs {
|
||||
switch strings.ToLower(msg.Role) {
|
||||
case "system":
|
||||
if currentVars.System != "" {
|
||||
if err := writePrompt(); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
prompts = append(prompts, currentVars)
|
||||
currentVars = PromptVars{}
|
||||
}
|
||||
currentVars.System = msg.Content
|
||||
lastSystem = msg.Content
|
||||
case "user":
|
||||
if currentVars.Prompt != "" {
|
||||
if err := writePrompt(); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
prompts = append(prompts, currentVars)
|
||||
currentVars = PromptVars{}
|
||||
}
|
||||
currentVars.Prompt = msg.Content
|
||||
currentImages = msg.Images
|
||||
case "assistant":
|
||||
currentVars.Response = msg.Content
|
||||
if err := writePrompt(); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
prompts = append(prompts, currentVars)
|
||||
currentVars = PromptVars{}
|
||||
default:
|
||||
return "", nil, fmt.Errorf("invalid role: %s, role must be one of [system, user, assistant]", msg.Role)
|
||||
return nil, fmt.Errorf("invalid role: %s, role must be one of [system, user, assistant]", msg.Role)
|
||||
}
|
||||
}
|
||||
|
||||
// Append the last set of vars if they are non-empty
|
||||
if currentVars.Prompt != "" || currentVars.System != "" {
|
||||
p, err := m.PreResponsePrompt(currentVars)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("pre-response template: %w", err)
|
||||
}
|
||||
prompt.WriteString(p)
|
||||
prompts = append(prompts, currentVars)
|
||||
}
|
||||
|
||||
return prompt.String(), currentImages, nil
|
||||
return &ChatHistory{
|
||||
Prompts: prompts,
|
||||
CurrentImages: currentImages,
|
||||
LastSystem: lastSystem,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type ManifestV2 struct {
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
|
@ -233,12 +234,32 @@ func TestModel_PreResponsePrompt_PostResponsePrompt(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func chatHistoryEqual(a, b ChatHistory) bool {
|
||||
if len(a.Prompts) != len(b.Prompts) {
|
||||
return false
|
||||
}
|
||||
if len(a.CurrentImages) != len(b.CurrentImages) {
|
||||
return false
|
||||
}
|
||||
for i, v := range a.Prompts {
|
||||
if v != b.Prompts[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
for i, v := range a.CurrentImages {
|
||||
if !bytes.Equal(v, b.CurrentImages[i]) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return a.LastSystem == b.LastSystem
|
||||
}
|
||||
|
||||
func TestChat(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
template string
|
||||
msgs []api.Message
|
||||
want string
|
||||
want ChatHistory
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
|
@ -254,30 +275,16 @@ func TestChat(t *testing.T) {
|
|||
Content: "What are the potion ingredients?",
|
||||
},
|
||||
},
|
||||
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]",
|
||||
},
|
||||
{
|
||||
name: "First Message",
|
||||
template: "[INST] {{if .First}}Hello!{{end}} {{ .System }} {{ .Prompt }} [/INST]",
|
||||
msgs: []api.Message{
|
||||
{
|
||||
Role: "system",
|
||||
Content: "You are a Wizard.",
|
||||
},
|
||||
{
|
||||
Role: "user",
|
||||
Content: "What are the potion ingredients?",
|
||||
},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "eye of newt",
|
||||
},
|
||||
{
|
||||
Role: "user",
|
||||
Content: "Anything else?",
|
||||
want: ChatHistory{
|
||||
Prompts: []PromptVars{
|
||||
{
|
||||
System: "You are a Wizard.",
|
||||
Prompt: "What are the potion ingredients?",
|
||||
First: true,
|
||||
},
|
||||
},
|
||||
LastSystem: "You are a Wizard.",
|
||||
},
|
||||
want: "[INST] Hello! You are a Wizard. What are the potion ingredients? [/INST]eye of newt[INST] Anything else? [/INST]",
|
||||
},
|
||||
{
|
||||
name: "Message History",
|
||||
|
@ -300,7 +307,20 @@ func TestChat(t *testing.T) {
|
|||
Content: "Anything else?",
|
||||
},
|
||||
},
|
||||
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]sugar[INST] Anything else? [/INST]",
|
||||
want: ChatHistory{
|
||||
Prompts: []PromptVars{
|
||||
{
|
||||
System: "You are a Wizard.",
|
||||
Prompt: "What are the potion ingredients?",
|
||||
Response: "sugar",
|
||||
First: true,
|
||||
},
|
||||
{
|
||||
Prompt: "Anything else?",
|
||||
},
|
||||
},
|
||||
LastSystem: "You are a Wizard.",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Assistant Only",
|
||||
|
@ -311,7 +331,14 @@ func TestChat(t *testing.T) {
|
|||
Content: "everything nice",
|
||||
},
|
||||
},
|
||||
want: "[INST] [/INST]everything nice",
|
||||
want: ChatHistory{
|
||||
Prompts: []PromptVars{
|
||||
{
|
||||
Response: "everything nice",
|
||||
First: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Invalid Role",
|
||||
|
@ -330,7 +357,7 @@ func TestChat(t *testing.T) {
|
|||
Template: tt.template,
|
||||
}
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, _, err := m.ChatPrompt(tt.msgs)
|
||||
got, err := m.ChatPrompts(tt.msgs)
|
||||
if tt.wantErr != "" {
|
||||
if err == nil {
|
||||
t.Errorf("ChatPrompt() expected error, got nil")
|
||||
|
@ -338,9 +365,10 @@ func TestChat(t *testing.T) {
|
|||
if !strings.Contains(err.Error(), tt.wantErr) {
|
||||
t.Errorf("ChatPrompt() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
return
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Errorf("ChatPrompt() got = %v, want %v", got, tt.want)
|
||||
if !chatHistoryEqual(*got, tt.want) {
|
||||
t.Errorf("ChatPrompt() got = %#v, want %#v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
107
server/routes.go
107
server/routes.go
|
@ -1121,11 +1121,16 @@ func ChatHandler(c *gin.Context) {
|
|||
|
||||
checkpointLoaded := time.Now()
|
||||
|
||||
prompt, images, err := model.ChatPrompt(req.Messages)
|
||||
chat, err := model.ChatPrompts(req.Messages)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
prompt, err := trimmedPrompt(c.Request.Context(), chat, model)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
slog.Debug(fmt.Sprintf("prompt: %s", prompt))
|
||||
|
||||
|
@ -1164,7 +1169,7 @@ func ChatHandler(c *gin.Context) {
|
|||
predictReq := llm.PredictOpts{
|
||||
Prompt: prompt,
|
||||
Format: req.Format,
|
||||
Images: images,
|
||||
Images: chat.CurrentImages,
|
||||
Options: opts,
|
||||
}
|
||||
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
|
||||
|
@ -1202,3 +1207,101 @@ func ChatHandler(c *gin.Context) {
|
|||
|
||||
streamResponse(c, ch)
|
||||
}
|
||||
|
||||
// promptInfo stores the variables used to template a prompt, and the token length of the resulting template for some model
|
||||
type promptInfo struct {
|
||||
vars PromptVars
|
||||
tokenLen int
|
||||
}
|
||||
|
||||
// trimmedPrompt builds a prompt to send to a running model. It ensures the prompt fits within the max context length,
|
||||
// while preserving the most recent system message.
|
||||
func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string, error) {
|
||||
if len(chat.Prompts) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
var promptsToAdd []promptInfo
|
||||
var totalTokenLength int
|
||||
var systemPromptIncluded bool
|
||||
|
||||
// reverse iterate through the prompts to build the prompt string in a way that fits the max context length
|
||||
for i := len(chat.Prompts) - 1; i >= 0; i-- {
|
||||
promptText, err := promptString(model, chat.Prompts[i], i == len(chat.Prompts)-1)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
encodedTokens, err := loaded.runner.Encode(ctx, promptText)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if totalTokenLength+len(encodedTokens) > loaded.NumCtx && i != len(chat.Prompts)-1 {
|
||||
break // reached max context length, stop adding more prompts
|
||||
}
|
||||
|
||||
totalTokenLength += len(encodedTokens)
|
||||
systemPromptIncluded = systemPromptIncluded || chat.Prompts[i].System != ""
|
||||
promptsToAdd = append(promptsToAdd, promptInfo{vars: chat.Prompts[i], tokenLen: len(encodedTokens)})
|
||||
}
|
||||
|
||||
// ensure the system prompt is included, if not already
|
||||
if chat.LastSystem != "" && !systemPromptIncluded {
|
||||
var err error
|
||||
promptsToAdd, err = includeSystemPrompt(ctx, chat.LastSystem, totalTokenLength, promptsToAdd)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
promptsToAdd[len(promptsToAdd)-1].vars.First = true
|
||||
|
||||
// construct the final prompt string from the prompts which fit within the context window
|
||||
var result string
|
||||
for i, prompt := range promptsToAdd {
|
||||
promptText, err := promptString(model, prompt.vars, i == 0)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
result = promptText + result
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// promptString applies the model template to the prompt
|
||||
func promptString(model *Model, vars PromptVars, isMostRecent bool) (string, error) {
|
||||
if isMostRecent {
|
||||
p, err := model.PreResponsePrompt(vars)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("pre-response template: %w", err)
|
||||
}
|
||||
return p, nil
|
||||
}
|
||||
p, err := Prompt(model.Template, vars)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return p, nil
|
||||
}
|
||||
|
||||
// includeSystemPrompt adjusts the prompts to include the system prompt.
|
||||
func includeSystemPrompt(ctx context.Context, systemPrompt string, totalTokenLength int, promptsToAdd []promptInfo) ([]promptInfo, error) {
|
||||
systemTokens, err := loaded.runner.Encode(ctx, systemPrompt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for i := len(promptsToAdd) - 1; i >= 0; i-- {
|
||||
if totalTokenLength+len(systemTokens) <= loaded.NumCtx {
|
||||
promptsToAdd[i].vars.System = systemPrompt
|
||||
return promptsToAdd[:i+1], nil
|
||||
}
|
||||
totalTokenLength -= promptsToAdd[i].tokenLen
|
||||
}
|
||||
|
||||
// if got here, system did not fit anywhere, so return the most recent prompt with the system message set
|
||||
recent := promptsToAdd[len(promptsToAdd)-1]
|
||||
recent.vars.System = systemPrompt
|
||||
return []promptInfo{recent}, nil
|
||||
}
|
||||
|
|
|
@ -16,6 +16,7 @@ import (
|
|||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/jmorganca/ollama/api"
|
||||
"github.com/jmorganca/ollama/llm"
|
||||
"github.com/jmorganca/ollama/parser"
|
||||
"github.com/jmorganca/ollama/version"
|
||||
)
|
||||
|
@ -239,3 +240,257 @@ func Test_Routes(t *testing.T) {
|
|||
|
||||
}
|
||||
}
|
||||
|
||||
func Test_ChatPrompt(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
template string
|
||||
chat *ChatHistory
|
||||
numCtx int
|
||||
runner MockLLM
|
||||
want string
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "Single Message",
|
||||
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
||||
chat: &ChatHistory{
|
||||
Prompts: []PromptVars{
|
||||
{
|
||||
System: "You are a Wizard.",
|
||||
Prompt: "What are the potion ingredients?",
|
||||
First: true,
|
||||
},
|
||||
},
|
||||
LastSystem: "You are a Wizard.",
|
||||
},
|
||||
numCtx: 1,
|
||||
runner: MockLLM{
|
||||
encoding: []int{1}, // fit the ctxLen
|
||||
},
|
||||
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]",
|
||||
},
|
||||
{
|
||||
name: "First Message",
|
||||
template: "[INST] {{if .First}}Hello!{{end}} {{ .System }} {{ .Prompt }} [/INST]",
|
||||
chat: &ChatHistory{
|
||||
Prompts: []PromptVars{
|
||||
{
|
||||
System: "You are a Wizard.",
|
||||
Prompt: "What are the potion ingredients?",
|
||||
Response: "eye of newt",
|
||||
First: true,
|
||||
},
|
||||
{
|
||||
Prompt: "Anything else?",
|
||||
},
|
||||
},
|
||||
LastSystem: "You are a Wizard.",
|
||||
},
|
||||
numCtx: 2,
|
||||
runner: MockLLM{
|
||||
encoding: []int{1}, // fit the ctxLen
|
||||
},
|
||||
want: "[INST] Hello! You are a Wizard. What are the potion ingredients? [/INST]eye of newt[INST] Anything else? [/INST]",
|
||||
},
|
||||
{
|
||||
name: "Message History",
|
||||
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
||||
chat: &ChatHistory{
|
||||
Prompts: []PromptVars{
|
||||
{
|
||||
System: "You are a Wizard.",
|
||||
Prompt: "What are the potion ingredients?",
|
||||
Response: "sugar",
|
||||
First: true,
|
||||
},
|
||||
{
|
||||
Prompt: "Anything else?",
|
||||
},
|
||||
},
|
||||
LastSystem: "You are a Wizard.",
|
||||
},
|
||||
numCtx: 4,
|
||||
runner: MockLLM{
|
||||
encoding: []int{1}, // fit the ctxLen, 1 for each message
|
||||
},
|
||||
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]sugar[INST] Anything else? [/INST]",
|
||||
},
|
||||
{
|
||||
name: "Assistant Only",
|
||||
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
||||
chat: &ChatHistory{
|
||||
Prompts: []PromptVars{
|
||||
{
|
||||
Response: "everything nice",
|
||||
First: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
numCtx: 1,
|
||||
runner: MockLLM{
|
||||
encoding: []int{1},
|
||||
},
|
||||
want: "[INST] [/INST]everything nice",
|
||||
},
|
||||
{
|
||||
name: "Message History Truncated, No System",
|
||||
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
||||
chat: &ChatHistory{
|
||||
Prompts: []PromptVars{
|
||||
{
|
||||
Prompt: "What are the potion ingredients?",
|
||||
Response: "sugar",
|
||||
First: true,
|
||||
},
|
||||
{
|
||||
Prompt: "Anything else?",
|
||||
Response: "spice",
|
||||
},
|
||||
{
|
||||
Prompt: "... and?",
|
||||
},
|
||||
},
|
||||
},
|
||||
numCtx: 2, // only 1 message from history and most recent message
|
||||
runner: MockLLM{
|
||||
encoding: []int{1},
|
||||
},
|
||||
want: "[INST] Anything else? [/INST]spice[INST] ... and? [/INST]",
|
||||
},
|
||||
{
|
||||
name: "System is Preserved when Truncated",
|
||||
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
||||
chat: &ChatHistory{
|
||||
Prompts: []PromptVars{
|
||||
{
|
||||
Prompt: "What are the magic words?",
|
||||
Response: "abracadabra",
|
||||
},
|
||||
{
|
||||
Prompt: "What is the spell for invisibility?",
|
||||
},
|
||||
},
|
||||
LastSystem: "You are a wizard.",
|
||||
},
|
||||
numCtx: 2,
|
||||
runner: MockLLM{
|
||||
encoding: []int{1},
|
||||
},
|
||||
want: "[INST] You are a wizard. What is the spell for invisibility? [/INST]",
|
||||
},
|
||||
{
|
||||
name: "System is Preserved when Length Exceeded",
|
||||
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
||||
chat: &ChatHistory{
|
||||
Prompts: []PromptVars{
|
||||
{
|
||||
Prompt: "What are the magic words?",
|
||||
Response: "abracadabra",
|
||||
},
|
||||
{
|
||||
Prompt: "What is the spell for invisibility?",
|
||||
},
|
||||
},
|
||||
LastSystem: "You are a wizard.",
|
||||
},
|
||||
numCtx: 1,
|
||||
runner: MockLLM{
|
||||
encoding: []int{1},
|
||||
},
|
||||
want: "[INST] You are a wizard. What is the spell for invisibility? [/INST]",
|
||||
},
|
||||
{
|
||||
name: "First is Preserved when Truncated",
|
||||
template: "[INST] {{ if .First }}{{ .System }} {{ end }}{{ .Prompt }} [/INST]",
|
||||
|
||||
chat: &ChatHistory{
|
||||
Prompts: []PromptVars{
|
||||
// first message omitted for test
|
||||
{
|
||||
Prompt: "Do you have a magic hat?",
|
||||
Response: "Of course.",
|
||||
},
|
||||
{
|
||||
Prompt: "What is the spell for invisibility?",
|
||||
},
|
||||
},
|
||||
LastSystem: "You are a wizard.",
|
||||
},
|
||||
numCtx: 3, // two most recent messages and room for system message
|
||||
runner: MockLLM{
|
||||
encoding: []int{1},
|
||||
},
|
||||
want: "[INST] You are a wizard. Do you have a magic hat? [/INST]Of course.[INST] What is the spell for invisibility? [/INST]",
|
||||
},
|
||||
{
|
||||
name: "Most recent message is returned when longer than ctxLen",
|
||||
template: "[INST] {{ .Prompt }} [/INST]",
|
||||
|
||||
chat: &ChatHistory{
|
||||
Prompts: []PromptVars{
|
||||
{
|
||||
Prompt: "What is the spell for invisibility?",
|
||||
First: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
numCtx: 1, // two most recent messages
|
||||
runner: MockLLM{
|
||||
encoding: []int{1, 2},
|
||||
},
|
||||
want: "[INST] What is the spell for invisibility? [/INST]",
|
||||
},
|
||||
}
|
||||
|
||||
for _, testCase := range tests {
|
||||
tt := testCase
|
||||
m := &Model{
|
||||
Template: tt.template,
|
||||
}
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
loaded.runner = &tt.runner
|
||||
loaded.Options = &api.Options{
|
||||
Runner: api.Runner{
|
||||
NumCtx: tt.numCtx,
|
||||
},
|
||||
}
|
||||
got, err := trimmedPrompt(context.Background(), tt.chat, m)
|
||||
if tt.wantErr != "" {
|
||||
if err == nil {
|
||||
t.Errorf("ChatPrompt() expected error, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), tt.wantErr) {
|
||||
t.Errorf("ChatPrompt() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Errorf("ChatPrompt() got = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type MockLLM struct {
|
||||
encoding []int
|
||||
}
|
||||
|
||||
func (llm *MockLLM) Predict(ctx context.Context, pred llm.PredictOpts, fn func(llm.PredictResult)) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (llm *MockLLM) Encode(ctx context.Context, prompt string) ([]int, error) {
|
||||
return llm.encoding, nil
|
||||
}
|
||||
|
||||
func (llm *MockLLM) Decode(ctx context.Context, tokens []int) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (llm *MockLLM) Embedding(ctx context.Context, input string) ([]float64, error) {
|
||||
return []float64{}, nil
|
||||
}
|
||||
|
||||
func (llm *MockLLM) Close() {
|
||||
// do nothing
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue