trim chat prompt based on llm context size (#1963)
This commit is contained in:
parent
509e2dec8a
commit
0632dff3f8
4 changed files with 440 additions and 57 deletions
|
@ -146,62 +146,59 @@ func (m *Model) PostResponseTemplate(p PromptVars) (string, error) {
|
||||||
return Prompt(post, p)
|
return Prompt(post, p)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) ChatPrompt(msgs []api.Message) (string, []api.ImageData, error) {
|
type ChatHistory struct {
|
||||||
|
Prompts []PromptVars
|
||||||
|
CurrentImages []api.ImageData
|
||||||
|
LastSystem string
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChatPrompts returns a list of formatted chat prompts from a list of messages
|
||||||
|
func (m *Model) ChatPrompts(msgs []api.Message) (*ChatHistory, error) {
|
||||||
// build the prompt from the list of messages
|
// build the prompt from the list of messages
|
||||||
var prompt strings.Builder
|
|
||||||
var currentImages []api.ImageData
|
var currentImages []api.ImageData
|
||||||
|
var lastSystem string
|
||||||
currentVars := PromptVars{
|
currentVars := PromptVars{
|
||||||
First: true,
|
First: true,
|
||||||
System: m.System,
|
System: m.System,
|
||||||
}
|
}
|
||||||
|
|
||||||
writePrompt := func() error {
|
prompts := []PromptVars{}
|
||||||
p, err := Prompt(m.Template, currentVars)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
prompt.WriteString(p)
|
|
||||||
currentVars = PromptVars{}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, msg := range msgs {
|
for _, msg := range msgs {
|
||||||
switch strings.ToLower(msg.Role) {
|
switch strings.ToLower(msg.Role) {
|
||||||
case "system":
|
case "system":
|
||||||
if currentVars.System != "" {
|
if currentVars.System != "" {
|
||||||
if err := writePrompt(); err != nil {
|
prompts = append(prompts, currentVars)
|
||||||
return "", nil, err
|
currentVars = PromptVars{}
|
||||||
}
|
|
||||||
}
|
}
|
||||||
currentVars.System = msg.Content
|
currentVars.System = msg.Content
|
||||||
|
lastSystem = msg.Content
|
||||||
case "user":
|
case "user":
|
||||||
if currentVars.Prompt != "" {
|
if currentVars.Prompt != "" {
|
||||||
if err := writePrompt(); err != nil {
|
prompts = append(prompts, currentVars)
|
||||||
return "", nil, err
|
currentVars = PromptVars{}
|
||||||
}
|
|
||||||
}
|
}
|
||||||
currentVars.Prompt = msg.Content
|
currentVars.Prompt = msg.Content
|
||||||
currentImages = msg.Images
|
currentImages = msg.Images
|
||||||
case "assistant":
|
case "assistant":
|
||||||
currentVars.Response = msg.Content
|
currentVars.Response = msg.Content
|
||||||
if err := writePrompt(); err != nil {
|
prompts = append(prompts, currentVars)
|
||||||
return "", nil, err
|
currentVars = PromptVars{}
|
||||||
}
|
|
||||||
default:
|
default:
|
||||||
return "", nil, fmt.Errorf("invalid role: %s, role must be one of [system, user, assistant]", msg.Role)
|
return nil, fmt.Errorf("invalid role: %s, role must be one of [system, user, assistant]", msg.Role)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Append the last set of vars if they are non-empty
|
// Append the last set of vars if they are non-empty
|
||||||
if currentVars.Prompt != "" || currentVars.System != "" {
|
if currentVars.Prompt != "" || currentVars.System != "" {
|
||||||
p, err := m.PreResponsePrompt(currentVars)
|
prompts = append(prompts, currentVars)
|
||||||
if err != nil {
|
|
||||||
return "", nil, fmt.Errorf("pre-response template: %w", err)
|
|
||||||
}
|
|
||||||
prompt.WriteString(p)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return prompt.String(), currentImages, nil
|
return &ChatHistory{
|
||||||
|
Prompts: prompts,
|
||||||
|
CurrentImages: currentImages,
|
||||||
|
LastSystem: lastSystem,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type ManifestV2 struct {
|
type ManifestV2 struct {
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
@ -233,12 +234,32 @@ func TestModel_PreResponsePrompt_PostResponsePrompt(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func chatHistoryEqual(a, b ChatHistory) bool {
|
||||||
|
if len(a.Prompts) != len(b.Prompts) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if len(a.CurrentImages) != len(b.CurrentImages) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for i, v := range a.Prompts {
|
||||||
|
if v != b.Prompts[i] {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for i, v := range a.CurrentImages {
|
||||||
|
if !bytes.Equal(v, b.CurrentImages[i]) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return a.LastSystem == b.LastSystem
|
||||||
|
}
|
||||||
|
|
||||||
func TestChat(t *testing.T) {
|
func TestChat(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
template string
|
template string
|
||||||
msgs []api.Message
|
msgs []api.Message
|
||||||
want string
|
want ChatHistory
|
||||||
wantErr string
|
wantErr string
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
|
@ -254,30 +275,16 @@ func TestChat(t *testing.T) {
|
||||||
Content: "What are the potion ingredients?",
|
Content: "What are the potion ingredients?",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]",
|
want: ChatHistory{
|
||||||
},
|
Prompts: []PromptVars{
|
||||||
{
|
{
|
||||||
name: "First Message",
|
System: "You are a Wizard.",
|
||||||
template: "[INST] {{if .First}}Hello!{{end}} {{ .System }} {{ .Prompt }} [/INST]",
|
Prompt: "What are the potion ingredients?",
|
||||||
msgs: []api.Message{
|
First: true,
|
||||||
{
|
},
|
||||||
Role: "system",
|
|
||||||
Content: "You are a Wizard.",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Role: "user",
|
|
||||||
Content: "What are the potion ingredients?",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Role: "assistant",
|
|
||||||
Content: "eye of newt",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Role: "user",
|
|
||||||
Content: "Anything else?",
|
|
||||||
},
|
},
|
||||||
|
LastSystem: "You are a Wizard.",
|
||||||
},
|
},
|
||||||
want: "[INST] Hello! You are a Wizard. What are the potion ingredients? [/INST]eye of newt[INST] Anything else? [/INST]",
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Message History",
|
name: "Message History",
|
||||||
|
@ -300,7 +307,20 @@ func TestChat(t *testing.T) {
|
||||||
Content: "Anything else?",
|
Content: "Anything else?",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]sugar[INST] Anything else? [/INST]",
|
want: ChatHistory{
|
||||||
|
Prompts: []PromptVars{
|
||||||
|
{
|
||||||
|
System: "You are a Wizard.",
|
||||||
|
Prompt: "What are the potion ingredients?",
|
||||||
|
Response: "sugar",
|
||||||
|
First: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Prompt: "Anything else?",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
LastSystem: "You are a Wizard.",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Assistant Only",
|
name: "Assistant Only",
|
||||||
|
@ -311,7 +331,14 @@ func TestChat(t *testing.T) {
|
||||||
Content: "everything nice",
|
Content: "everything nice",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
want: "[INST] [/INST]everything nice",
|
want: ChatHistory{
|
||||||
|
Prompts: []PromptVars{
|
||||||
|
{
|
||||||
|
Response: "everything nice",
|
||||||
|
First: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Invalid Role",
|
name: "Invalid Role",
|
||||||
|
@ -330,7 +357,7 @@ func TestChat(t *testing.T) {
|
||||||
Template: tt.template,
|
Template: tt.template,
|
||||||
}
|
}
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
got, _, err := m.ChatPrompt(tt.msgs)
|
got, err := m.ChatPrompts(tt.msgs)
|
||||||
if tt.wantErr != "" {
|
if tt.wantErr != "" {
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Errorf("ChatPrompt() expected error, got nil")
|
t.Errorf("ChatPrompt() expected error, got nil")
|
||||||
|
@ -338,9 +365,10 @@ func TestChat(t *testing.T) {
|
||||||
if !strings.Contains(err.Error(), tt.wantErr) {
|
if !strings.Contains(err.Error(), tt.wantErr) {
|
||||||
t.Errorf("ChatPrompt() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("ChatPrompt() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
}
|
}
|
||||||
|
return
|
||||||
}
|
}
|
||||||
if got != tt.want {
|
if !chatHistoryEqual(*got, tt.want) {
|
||||||
t.Errorf("ChatPrompt() got = %v, want %v", got, tt.want)
|
t.Errorf("ChatPrompt() got = %#v, want %#v", got, tt.want)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
107
server/routes.go
107
server/routes.go
|
@ -1121,11 +1121,16 @@ func ChatHandler(c *gin.Context) {
|
||||||
|
|
||||||
checkpointLoaded := time.Now()
|
checkpointLoaded := time.Now()
|
||||||
|
|
||||||
prompt, images, err := model.ChatPrompt(req.Messages)
|
chat, err := model.ChatPrompts(req.Messages)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
prompt, err := trimmedPrompt(c.Request.Context(), chat, model)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
slog.Debug(fmt.Sprintf("prompt: %s", prompt))
|
slog.Debug(fmt.Sprintf("prompt: %s", prompt))
|
||||||
|
|
||||||
|
@ -1164,7 +1169,7 @@ func ChatHandler(c *gin.Context) {
|
||||||
predictReq := llm.PredictOpts{
|
predictReq := llm.PredictOpts{
|
||||||
Prompt: prompt,
|
Prompt: prompt,
|
||||||
Format: req.Format,
|
Format: req.Format,
|
||||||
Images: images,
|
Images: chat.CurrentImages,
|
||||||
Options: opts,
|
Options: opts,
|
||||||
}
|
}
|
||||||
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
|
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
|
||||||
|
@ -1202,3 +1207,101 @@ func ChatHandler(c *gin.Context) {
|
||||||
|
|
||||||
streamResponse(c, ch)
|
streamResponse(c, ch)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// promptInfo stores the variables used to template a prompt, and the token length of the resulting template for some model
|
||||||
|
type promptInfo struct {
|
||||||
|
vars PromptVars
|
||||||
|
tokenLen int
|
||||||
|
}
|
||||||
|
|
||||||
|
// trimmedPrompt builds a prompt to send to a running model. It ensures the prompt fits within the max context length,
|
||||||
|
// while preserving the most recent system message.
|
||||||
|
func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string, error) {
|
||||||
|
if len(chat.Prompts) == 0 {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var promptsToAdd []promptInfo
|
||||||
|
var totalTokenLength int
|
||||||
|
var systemPromptIncluded bool
|
||||||
|
|
||||||
|
// reverse iterate through the prompts to build the prompt string in a way that fits the max context length
|
||||||
|
for i := len(chat.Prompts) - 1; i >= 0; i-- {
|
||||||
|
promptText, err := promptString(model, chat.Prompts[i], i == len(chat.Prompts)-1)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
encodedTokens, err := loaded.runner.Encode(ctx, promptText)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
if totalTokenLength+len(encodedTokens) > loaded.NumCtx && i != len(chat.Prompts)-1 {
|
||||||
|
break // reached max context length, stop adding more prompts
|
||||||
|
}
|
||||||
|
|
||||||
|
totalTokenLength += len(encodedTokens)
|
||||||
|
systemPromptIncluded = systemPromptIncluded || chat.Prompts[i].System != ""
|
||||||
|
promptsToAdd = append(promptsToAdd, promptInfo{vars: chat.Prompts[i], tokenLen: len(encodedTokens)})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ensure the system prompt is included, if not already
|
||||||
|
if chat.LastSystem != "" && !systemPromptIncluded {
|
||||||
|
var err error
|
||||||
|
promptsToAdd, err = includeSystemPrompt(ctx, chat.LastSystem, totalTokenLength, promptsToAdd)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
promptsToAdd[len(promptsToAdd)-1].vars.First = true
|
||||||
|
|
||||||
|
// construct the final prompt string from the prompts which fit within the context window
|
||||||
|
var result string
|
||||||
|
for i, prompt := range promptsToAdd {
|
||||||
|
promptText, err := promptString(model, prompt.vars, i == 0)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
result = promptText + result
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// promptString applies the model template to the prompt
|
||||||
|
func promptString(model *Model, vars PromptVars, isMostRecent bool) (string, error) {
|
||||||
|
if isMostRecent {
|
||||||
|
p, err := model.PreResponsePrompt(vars)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("pre-response template: %w", err)
|
||||||
|
}
|
||||||
|
return p, nil
|
||||||
|
}
|
||||||
|
p, err := Prompt(model.Template, vars)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return p, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// includeSystemPrompt adjusts the prompts to include the system prompt.
|
||||||
|
func includeSystemPrompt(ctx context.Context, systemPrompt string, totalTokenLength int, promptsToAdd []promptInfo) ([]promptInfo, error) {
|
||||||
|
systemTokens, err := loaded.runner.Encode(ctx, systemPrompt)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := len(promptsToAdd) - 1; i >= 0; i-- {
|
||||||
|
if totalTokenLength+len(systemTokens) <= loaded.NumCtx {
|
||||||
|
promptsToAdd[i].vars.System = systemPrompt
|
||||||
|
return promptsToAdd[:i+1], nil
|
||||||
|
}
|
||||||
|
totalTokenLength -= promptsToAdd[i].tokenLen
|
||||||
|
}
|
||||||
|
|
||||||
|
// if got here, system did not fit anywhere, so return the most recent prompt with the system message set
|
||||||
|
recent := promptsToAdd[len(promptsToAdd)-1]
|
||||||
|
recent.vars.System = systemPrompt
|
||||||
|
return []promptInfo{recent}, nil
|
||||||
|
}
|
||||||
|
|
|
@ -16,6 +16,7 @@ import (
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
"github.com/jmorganca/ollama/api"
|
"github.com/jmorganca/ollama/api"
|
||||||
|
"github.com/jmorganca/ollama/llm"
|
||||||
"github.com/jmorganca/ollama/parser"
|
"github.com/jmorganca/ollama/parser"
|
||||||
"github.com/jmorganca/ollama/version"
|
"github.com/jmorganca/ollama/version"
|
||||||
)
|
)
|
||||||
|
@ -239,3 +240,257 @@ func Test_Routes(t *testing.T) {
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Test_ChatPrompt(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
template string
|
||||||
|
chat *ChatHistory
|
||||||
|
numCtx int
|
||||||
|
runner MockLLM
|
||||||
|
want string
|
||||||
|
wantErr string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Single Message",
|
||||||
|
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
||||||
|
chat: &ChatHistory{
|
||||||
|
Prompts: []PromptVars{
|
||||||
|
{
|
||||||
|
System: "You are a Wizard.",
|
||||||
|
Prompt: "What are the potion ingredients?",
|
||||||
|
First: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
LastSystem: "You are a Wizard.",
|
||||||
|
},
|
||||||
|
numCtx: 1,
|
||||||
|
runner: MockLLM{
|
||||||
|
encoding: []int{1}, // fit the ctxLen
|
||||||
|
},
|
||||||
|
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "First Message",
|
||||||
|
template: "[INST] {{if .First}}Hello!{{end}} {{ .System }} {{ .Prompt }} [/INST]",
|
||||||
|
chat: &ChatHistory{
|
||||||
|
Prompts: []PromptVars{
|
||||||
|
{
|
||||||
|
System: "You are a Wizard.",
|
||||||
|
Prompt: "What are the potion ingredients?",
|
||||||
|
Response: "eye of newt",
|
||||||
|
First: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Prompt: "Anything else?",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
LastSystem: "You are a Wizard.",
|
||||||
|
},
|
||||||
|
numCtx: 2,
|
||||||
|
runner: MockLLM{
|
||||||
|
encoding: []int{1}, // fit the ctxLen
|
||||||
|
},
|
||||||
|
want: "[INST] Hello! You are a Wizard. What are the potion ingredients? [/INST]eye of newt[INST] Anything else? [/INST]",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Message History",
|
||||||
|
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
||||||
|
chat: &ChatHistory{
|
||||||
|
Prompts: []PromptVars{
|
||||||
|
{
|
||||||
|
System: "You are a Wizard.",
|
||||||
|
Prompt: "What are the potion ingredients?",
|
||||||
|
Response: "sugar",
|
||||||
|
First: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Prompt: "Anything else?",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
LastSystem: "You are a Wizard.",
|
||||||
|
},
|
||||||
|
numCtx: 4,
|
||||||
|
runner: MockLLM{
|
||||||
|
encoding: []int{1}, // fit the ctxLen, 1 for each message
|
||||||
|
},
|
||||||
|
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]sugar[INST] Anything else? [/INST]",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Assistant Only",
|
||||||
|
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
||||||
|
chat: &ChatHistory{
|
||||||
|
Prompts: []PromptVars{
|
||||||
|
{
|
||||||
|
Response: "everything nice",
|
||||||
|
First: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
numCtx: 1,
|
||||||
|
runner: MockLLM{
|
||||||
|
encoding: []int{1},
|
||||||
|
},
|
||||||
|
want: "[INST] [/INST]everything nice",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Message History Truncated, No System",
|
||||||
|
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
||||||
|
chat: &ChatHistory{
|
||||||
|
Prompts: []PromptVars{
|
||||||
|
{
|
||||||
|
Prompt: "What are the potion ingredients?",
|
||||||
|
Response: "sugar",
|
||||||
|
First: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Prompt: "Anything else?",
|
||||||
|
Response: "spice",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Prompt: "... and?",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
numCtx: 2, // only 1 message from history and most recent message
|
||||||
|
runner: MockLLM{
|
||||||
|
encoding: []int{1},
|
||||||
|
},
|
||||||
|
want: "[INST] Anything else? [/INST]spice[INST] ... and? [/INST]",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "System is Preserved when Truncated",
|
||||||
|
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
||||||
|
chat: &ChatHistory{
|
||||||
|
Prompts: []PromptVars{
|
||||||
|
{
|
||||||
|
Prompt: "What are the magic words?",
|
||||||
|
Response: "abracadabra",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Prompt: "What is the spell for invisibility?",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
LastSystem: "You are a wizard.",
|
||||||
|
},
|
||||||
|
numCtx: 2,
|
||||||
|
runner: MockLLM{
|
||||||
|
encoding: []int{1},
|
||||||
|
},
|
||||||
|
want: "[INST] You are a wizard. What is the spell for invisibility? [/INST]",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "System is Preserved when Length Exceeded",
|
||||||
|
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
||||||
|
chat: &ChatHistory{
|
||||||
|
Prompts: []PromptVars{
|
||||||
|
{
|
||||||
|
Prompt: "What are the magic words?",
|
||||||
|
Response: "abracadabra",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Prompt: "What is the spell for invisibility?",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
LastSystem: "You are a wizard.",
|
||||||
|
},
|
||||||
|
numCtx: 1,
|
||||||
|
runner: MockLLM{
|
||||||
|
encoding: []int{1},
|
||||||
|
},
|
||||||
|
want: "[INST] You are a wizard. What is the spell for invisibility? [/INST]",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "First is Preserved when Truncated",
|
||||||
|
template: "[INST] {{ if .First }}{{ .System }} {{ end }}{{ .Prompt }} [/INST]",
|
||||||
|
|
||||||
|
chat: &ChatHistory{
|
||||||
|
Prompts: []PromptVars{
|
||||||
|
// first message omitted for test
|
||||||
|
{
|
||||||
|
Prompt: "Do you have a magic hat?",
|
||||||
|
Response: "Of course.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Prompt: "What is the spell for invisibility?",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
LastSystem: "You are a wizard.",
|
||||||
|
},
|
||||||
|
numCtx: 3, // two most recent messages and room for system message
|
||||||
|
runner: MockLLM{
|
||||||
|
encoding: []int{1},
|
||||||
|
},
|
||||||
|
want: "[INST] You are a wizard. Do you have a magic hat? [/INST]Of course.[INST] What is the spell for invisibility? [/INST]",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Most recent message is returned when longer than ctxLen",
|
||||||
|
template: "[INST] {{ .Prompt }} [/INST]",
|
||||||
|
|
||||||
|
chat: &ChatHistory{
|
||||||
|
Prompts: []PromptVars{
|
||||||
|
{
|
||||||
|
Prompt: "What is the spell for invisibility?",
|
||||||
|
First: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
numCtx: 1, // two most recent messages
|
||||||
|
runner: MockLLM{
|
||||||
|
encoding: []int{1, 2},
|
||||||
|
},
|
||||||
|
want: "[INST] What is the spell for invisibility? [/INST]",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, testCase := range tests {
|
||||||
|
tt := testCase
|
||||||
|
m := &Model{
|
||||||
|
Template: tt.template,
|
||||||
|
}
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
loaded.runner = &tt.runner
|
||||||
|
loaded.Options = &api.Options{
|
||||||
|
Runner: api.Runner{
|
||||||
|
NumCtx: tt.numCtx,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
got, err := trimmedPrompt(context.Background(), tt.chat, m)
|
||||||
|
if tt.wantErr != "" {
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("ChatPrompt() expected error, got nil")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), tt.wantErr) {
|
||||||
|
t.Errorf("ChatPrompt() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("ChatPrompt() got = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type MockLLM struct {
|
||||||
|
encoding []int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (llm *MockLLM) Predict(ctx context.Context, pred llm.PredictOpts, fn func(llm.PredictResult)) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (llm *MockLLM) Encode(ctx context.Context, prompt string) ([]int, error) {
|
||||||
|
return llm.encoding, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (llm *MockLLM) Decode(ctx context.Context, tokens []int) (string, error) {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (llm *MockLLM) Embedding(ctx context.Context, input string) ([]float64, error) {
|
||||||
|
return []float64{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (llm *MockLLM) Close() {
|
||||||
|
// do nothing
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in a new issue