diff --git a/server/images.go b/server/images.go index 7f5a9211..39617c66 100644 --- a/server/images.go +++ b/server/images.go @@ -146,62 +146,59 @@ func (m *Model) PostResponseTemplate(p PromptVars) (string, error) { return Prompt(post, p) } -func (m *Model) ChatPrompt(msgs []api.Message) (string, []api.ImageData, error) { +type ChatHistory struct { + Prompts []PromptVars + CurrentImages []api.ImageData + LastSystem string +} + +// ChatPrompts returns a list of formatted chat prompts from a list of messages +func (m *Model) ChatPrompts(msgs []api.Message) (*ChatHistory, error) { // build the prompt from the list of messages - var prompt strings.Builder var currentImages []api.ImageData + var lastSystem string currentVars := PromptVars{ First: true, System: m.System, } - writePrompt := func() error { - p, err := Prompt(m.Template, currentVars) - if err != nil { - return err - } - prompt.WriteString(p) - currentVars = PromptVars{} - return nil - } + prompts := []PromptVars{} for _, msg := range msgs { switch strings.ToLower(msg.Role) { case "system": if currentVars.System != "" { - if err := writePrompt(); err != nil { - return "", nil, err - } + prompts = append(prompts, currentVars) + currentVars = PromptVars{} } currentVars.System = msg.Content + lastSystem = msg.Content case "user": if currentVars.Prompt != "" { - if err := writePrompt(); err != nil { - return "", nil, err - } + prompts = append(prompts, currentVars) + currentVars = PromptVars{} } currentVars.Prompt = msg.Content currentImages = msg.Images case "assistant": currentVars.Response = msg.Content - if err := writePrompt(); err != nil { - return "", nil, err - } + prompts = append(prompts, currentVars) + currentVars = PromptVars{} default: - return "", nil, fmt.Errorf("invalid role: %s, role must be one of [system, user, assistant]", msg.Role) + return nil, fmt.Errorf("invalid role: %s, role must be one of [system, user, assistant]", msg.Role) } } // Append the last set of vars if they are non-empty if currentVars.Prompt != "" || currentVars.System != "" { - p, err := m.PreResponsePrompt(currentVars) - if err != nil { - return "", nil, fmt.Errorf("pre-response template: %w", err) - } - prompt.WriteString(p) + prompts = append(prompts, currentVars) } - return prompt.String(), currentImages, nil + return &ChatHistory{ + Prompts: prompts, + CurrentImages: currentImages, + LastSystem: lastSystem, + }, nil } type ManifestV2 struct { diff --git a/server/images_test.go b/server/images_test.go index 55408e7a..08e39998 100644 --- a/server/images_test.go +++ b/server/images_test.go @@ -1,6 +1,7 @@ package server import ( + "bytes" "strings" "testing" @@ -233,12 +234,32 @@ func TestModel_PreResponsePrompt_PostResponsePrompt(t *testing.T) { } } +func chatHistoryEqual(a, b ChatHistory) bool { + if len(a.Prompts) != len(b.Prompts) { + return false + } + if len(a.CurrentImages) != len(b.CurrentImages) { + return false + } + for i, v := range a.Prompts { + if v != b.Prompts[i] { + return false + } + } + for i, v := range a.CurrentImages { + if !bytes.Equal(v, b.CurrentImages[i]) { + return false + } + } + return a.LastSystem == b.LastSystem +} + func TestChat(t *testing.T) { tests := []struct { name string template string msgs []api.Message - want string + want ChatHistory wantErr string }{ { @@ -254,30 +275,16 @@ func TestChat(t *testing.T) { Content: "What are the potion ingredients?", }, }, - want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]", - }, - { - name: "First Message", - template: "[INST] {{if .First}}Hello!{{end}} {{ .System }} {{ .Prompt }} [/INST]", - msgs: []api.Message{ - { - Role: "system", - Content: "You are a Wizard.", - }, - { - Role: "user", - Content: "What are the potion ingredients?", - }, - { - Role: "assistant", - Content: "eye of newt", - }, - { - Role: "user", - Content: "Anything else?", + want: ChatHistory{ + Prompts: []PromptVars{ + { + System: "You are a Wizard.", + Prompt: "What are the potion ingredients?", + First: true, + }, }, + LastSystem: "You are a Wizard.", }, - want: "[INST] Hello! You are a Wizard. What are the potion ingredients? [/INST]eye of newt[INST] Anything else? [/INST]", }, { name: "Message History", @@ -300,7 +307,20 @@ func TestChat(t *testing.T) { Content: "Anything else?", }, }, - want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]sugar[INST] Anything else? [/INST]", + want: ChatHistory{ + Prompts: []PromptVars{ + { + System: "You are a Wizard.", + Prompt: "What are the potion ingredients?", + Response: "sugar", + First: true, + }, + { + Prompt: "Anything else?", + }, + }, + LastSystem: "You are a Wizard.", + }, }, { name: "Assistant Only", @@ -311,7 +331,14 @@ func TestChat(t *testing.T) { Content: "everything nice", }, }, - want: "[INST] [/INST]everything nice", + want: ChatHistory{ + Prompts: []PromptVars{ + { + Response: "everything nice", + First: true, + }, + }, + }, }, { name: "Invalid Role", @@ -330,7 +357,7 @@ func TestChat(t *testing.T) { Template: tt.template, } t.Run(tt.name, func(t *testing.T) { - got, _, err := m.ChatPrompt(tt.msgs) + got, err := m.ChatPrompts(tt.msgs) if tt.wantErr != "" { if err == nil { t.Errorf("ChatPrompt() expected error, got nil") @@ -338,9 +365,10 @@ func TestChat(t *testing.T) { if !strings.Contains(err.Error(), tt.wantErr) { t.Errorf("ChatPrompt() error = %v, wantErr %v", err, tt.wantErr) } + return } - if got != tt.want { - t.Errorf("ChatPrompt() got = %v, want %v", got, tt.want) + if !chatHistoryEqual(*got, tt.want) { + t.Errorf("ChatPrompt() got = %#v, want %#v", got, tt.want) } }) } diff --git a/server/routes.go b/server/routes.go index b7c3d496..797c67a2 100644 --- a/server/routes.go +++ b/server/routes.go @@ -1121,11 +1121,16 @@ func ChatHandler(c *gin.Context) { checkpointLoaded := time.Now() - prompt, images, err := model.ChatPrompt(req.Messages) + chat, err := model.ChatPrompts(req.Messages) if err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } + prompt, err := trimmedPrompt(c.Request.Context(), chat, model) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } slog.Debug(fmt.Sprintf("prompt: %s", prompt)) @@ -1164,7 +1169,7 @@ func ChatHandler(c *gin.Context) { predictReq := llm.PredictOpts{ Prompt: prompt, Format: req.Format, - Images: images, + Images: chat.CurrentImages, Options: opts, } if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil { @@ -1202,3 +1207,101 @@ func ChatHandler(c *gin.Context) { streamResponse(c, ch) } + +// promptInfo stores the variables used to template a prompt, and the token length of the resulting template for some model +type promptInfo struct { + vars PromptVars + tokenLen int +} + +// trimmedPrompt builds a prompt to send to a running model. It ensures the prompt fits within the max context length, +// while preserving the most recent system message. +func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string, error) { + if len(chat.Prompts) == 0 { + return "", nil + } + + var promptsToAdd []promptInfo + var totalTokenLength int + var systemPromptIncluded bool + + // reverse iterate through the prompts to build the prompt string in a way that fits the max context length + for i := len(chat.Prompts) - 1; i >= 0; i-- { + promptText, err := promptString(model, chat.Prompts[i], i == len(chat.Prompts)-1) + if err != nil { + return "", err + } + + encodedTokens, err := loaded.runner.Encode(ctx, promptText) + if err != nil { + return "", err + } + + if totalTokenLength+len(encodedTokens) > loaded.NumCtx && i != len(chat.Prompts)-1 { + break // reached max context length, stop adding more prompts + } + + totalTokenLength += len(encodedTokens) + systemPromptIncluded = systemPromptIncluded || chat.Prompts[i].System != "" + promptsToAdd = append(promptsToAdd, promptInfo{vars: chat.Prompts[i], tokenLen: len(encodedTokens)}) + } + + // ensure the system prompt is included, if not already + if chat.LastSystem != "" && !systemPromptIncluded { + var err error + promptsToAdd, err = includeSystemPrompt(ctx, chat.LastSystem, totalTokenLength, promptsToAdd) + if err != nil { + return "", err + } + } + + promptsToAdd[len(promptsToAdd)-1].vars.First = true + + // construct the final prompt string from the prompts which fit within the context window + var result string + for i, prompt := range promptsToAdd { + promptText, err := promptString(model, prompt.vars, i == 0) + if err != nil { + return "", err + } + result = promptText + result + } + return result, nil +} + +// promptString applies the model template to the prompt +func promptString(model *Model, vars PromptVars, isMostRecent bool) (string, error) { + if isMostRecent { + p, err := model.PreResponsePrompt(vars) + if err != nil { + return "", fmt.Errorf("pre-response template: %w", err) + } + return p, nil + } + p, err := Prompt(model.Template, vars) + if err != nil { + return "", err + } + return p, nil +} + +// includeSystemPrompt adjusts the prompts to include the system prompt. +func includeSystemPrompt(ctx context.Context, systemPrompt string, totalTokenLength int, promptsToAdd []promptInfo) ([]promptInfo, error) { + systemTokens, err := loaded.runner.Encode(ctx, systemPrompt) + if err != nil { + return nil, err + } + + for i := len(promptsToAdd) - 1; i >= 0; i-- { + if totalTokenLength+len(systemTokens) <= loaded.NumCtx { + promptsToAdd[i].vars.System = systemPrompt + return promptsToAdd[:i+1], nil + } + totalTokenLength -= promptsToAdd[i].tokenLen + } + + // if got here, system did not fit anywhere, so return the most recent prompt with the system message set + recent := promptsToAdd[len(promptsToAdd)-1] + recent.vars.System = systemPrompt + return []promptInfo{recent}, nil +} diff --git a/server/routes_test.go b/server/routes_test.go index b2d93958..9c53dc20 100644 --- a/server/routes_test.go +++ b/server/routes_test.go @@ -16,6 +16,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/jmorganca/ollama/api" + "github.com/jmorganca/ollama/llm" "github.com/jmorganca/ollama/parser" "github.com/jmorganca/ollama/version" ) @@ -239,3 +240,257 @@ func Test_Routes(t *testing.T) { } } + +func Test_ChatPrompt(t *testing.T) { + tests := []struct { + name string + template string + chat *ChatHistory + numCtx int + runner MockLLM + want string + wantErr string + }{ + { + name: "Single Message", + template: "[INST] {{ .System }} {{ .Prompt }} [/INST]", + chat: &ChatHistory{ + Prompts: []PromptVars{ + { + System: "You are a Wizard.", + Prompt: "What are the potion ingredients?", + First: true, + }, + }, + LastSystem: "You are a Wizard.", + }, + numCtx: 1, + runner: MockLLM{ + encoding: []int{1}, // fit the ctxLen + }, + want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]", + }, + { + name: "First Message", + template: "[INST] {{if .First}}Hello!{{end}} {{ .System }} {{ .Prompt }} [/INST]", + chat: &ChatHistory{ + Prompts: []PromptVars{ + { + System: "You are a Wizard.", + Prompt: "What are the potion ingredients?", + Response: "eye of newt", + First: true, + }, + { + Prompt: "Anything else?", + }, + }, + LastSystem: "You are a Wizard.", + }, + numCtx: 2, + runner: MockLLM{ + encoding: []int{1}, // fit the ctxLen + }, + want: "[INST] Hello! You are a Wizard. What are the potion ingredients? [/INST]eye of newt[INST] Anything else? [/INST]", + }, + { + name: "Message History", + template: "[INST] {{ .System }} {{ .Prompt }} [/INST]", + chat: &ChatHistory{ + Prompts: []PromptVars{ + { + System: "You are a Wizard.", + Prompt: "What are the potion ingredients?", + Response: "sugar", + First: true, + }, + { + Prompt: "Anything else?", + }, + }, + LastSystem: "You are a Wizard.", + }, + numCtx: 4, + runner: MockLLM{ + encoding: []int{1}, // fit the ctxLen, 1 for each message + }, + want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]sugar[INST] Anything else? [/INST]", + }, + { + name: "Assistant Only", + template: "[INST] {{ .System }} {{ .Prompt }} [/INST]", + chat: &ChatHistory{ + Prompts: []PromptVars{ + { + Response: "everything nice", + First: true, + }, + }, + }, + numCtx: 1, + runner: MockLLM{ + encoding: []int{1}, + }, + want: "[INST] [/INST]everything nice", + }, + { + name: "Message History Truncated, No System", + template: "[INST] {{ .System }} {{ .Prompt }} [/INST]", + chat: &ChatHistory{ + Prompts: []PromptVars{ + { + Prompt: "What are the potion ingredients?", + Response: "sugar", + First: true, + }, + { + Prompt: "Anything else?", + Response: "spice", + }, + { + Prompt: "... and?", + }, + }, + }, + numCtx: 2, // only 1 message from history and most recent message + runner: MockLLM{ + encoding: []int{1}, + }, + want: "[INST] Anything else? [/INST]spice[INST] ... and? [/INST]", + }, + { + name: "System is Preserved when Truncated", + template: "[INST] {{ .System }} {{ .Prompt }} [/INST]", + chat: &ChatHistory{ + Prompts: []PromptVars{ + { + Prompt: "What are the magic words?", + Response: "abracadabra", + }, + { + Prompt: "What is the spell for invisibility?", + }, + }, + LastSystem: "You are a wizard.", + }, + numCtx: 2, + runner: MockLLM{ + encoding: []int{1}, + }, + want: "[INST] You are a wizard. What is the spell for invisibility? [/INST]", + }, + { + name: "System is Preserved when Length Exceeded", + template: "[INST] {{ .System }} {{ .Prompt }} [/INST]", + chat: &ChatHistory{ + Prompts: []PromptVars{ + { + Prompt: "What are the magic words?", + Response: "abracadabra", + }, + { + Prompt: "What is the spell for invisibility?", + }, + }, + LastSystem: "You are a wizard.", + }, + numCtx: 1, + runner: MockLLM{ + encoding: []int{1}, + }, + want: "[INST] You are a wizard. What is the spell for invisibility? [/INST]", + }, + { + name: "First is Preserved when Truncated", + template: "[INST] {{ if .First }}{{ .System }} {{ end }}{{ .Prompt }} [/INST]", + + chat: &ChatHistory{ + Prompts: []PromptVars{ + // first message omitted for test + { + Prompt: "Do you have a magic hat?", + Response: "Of course.", + }, + { + Prompt: "What is the spell for invisibility?", + }, + }, + LastSystem: "You are a wizard.", + }, + numCtx: 3, // two most recent messages and room for system message + runner: MockLLM{ + encoding: []int{1}, + }, + want: "[INST] You are a wizard. Do you have a magic hat? [/INST]Of course.[INST] What is the spell for invisibility? [/INST]", + }, + { + name: "Most recent message is returned when longer than ctxLen", + template: "[INST] {{ .Prompt }} [/INST]", + + chat: &ChatHistory{ + Prompts: []PromptVars{ + { + Prompt: "What is the spell for invisibility?", + First: true, + }, + }, + }, + numCtx: 1, // two most recent messages + runner: MockLLM{ + encoding: []int{1, 2}, + }, + want: "[INST] What is the spell for invisibility? [/INST]", + }, + } + + for _, testCase := range tests { + tt := testCase + m := &Model{ + Template: tt.template, + } + t.Run(tt.name, func(t *testing.T) { + loaded.runner = &tt.runner + loaded.Options = &api.Options{ + Runner: api.Runner{ + NumCtx: tt.numCtx, + }, + } + got, err := trimmedPrompt(context.Background(), tt.chat, m) + if tt.wantErr != "" { + if err == nil { + t.Errorf("ChatPrompt() expected error, got nil") + } + if !strings.Contains(err.Error(), tt.wantErr) { + t.Errorf("ChatPrompt() error = %v, wantErr %v", err, tt.wantErr) + } + } + if got != tt.want { + t.Errorf("ChatPrompt() got = %v, want %v", got, tt.want) + } + }) + } +} + +type MockLLM struct { + encoding []int +} + +func (llm *MockLLM) Predict(ctx context.Context, pred llm.PredictOpts, fn func(llm.PredictResult)) error { + return nil +} + +func (llm *MockLLM) Encode(ctx context.Context, prompt string) ([]int, error) { + return llm.encoding, nil +} + +func (llm *MockLLM) Decode(ctx context.Context, tokens []int) (string, error) { + return "", nil +} + +func (llm *MockLLM) Embedding(ctx context.Context, input string) ([]float64, error) { + return []float64{}, nil +} + +func (llm *MockLLM) Close() { + // do nothing +}