ollama/server/prompt.go

75 lines
2 KiB
Go
Raw Normal View History

package server
import (
2024-06-17 17:38:55 +00:00
"bytes"
"context"
"log/slog"
"github.com/ollama/ollama/api"
2024-06-17 17:38:55 +00:00
"github.com/ollama/ollama/llm"
2024-06-10 21:54:42 +00:00
"github.com/ollama/ollama/template"
)
2024-06-20 18:00:08 +00:00
type tokenizeFunc func(context.Context, string) ([]int, error)
// chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn.
// chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the
// latest message and 2) system messages
func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message) (prompt string, images []llm.ImageData, _ error) {
2024-06-17 17:38:55 +00:00
var system []api.Message
2024-06-20 18:00:08 +00:00
// always include the last message
2024-06-17 17:38:55 +00:00
n := len(msgs) - 1
2024-06-20 18:00:08 +00:00
// in reverse, find all messages that fit into context window
2024-06-17 17:38:55 +00:00
for i := n - 1; i >= 0; i-- {
system = make([]api.Message, 0)
for j := range i {
if msgs[j].Role == "system" {
system = append(system, msgs[j])
}
}
2024-06-17 17:38:55 +00:00
var b bytes.Buffer
2024-06-20 18:00:08 +00:00
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...)}); err != nil {
2024-06-17 17:38:55 +00:00
return "", nil, err
}
2024-06-20 18:00:08 +00:00
s, err := tokenize(ctx, b.String())
if err != nil {
2024-06-17 17:38:55 +00:00
return "", nil, err
}
2024-06-17 17:38:55 +00:00
c := len(s)
2024-06-20 18:00:08 +00:00
if m.ProjectorPaths != nil {
2024-06-17 17:38:55 +00:00
for _, m := range msgs[i:] {
2024-06-20 18:00:08 +00:00
// images are represented as 768 sized embeddings
// TODO: get embedding length from project metadata
2024-06-17 17:38:55 +00:00
c += 768 * len(m.Images)
}
}
2024-06-20 18:00:08 +00:00
if c > opts.NumCtx {
2024-06-17 17:38:55 +00:00
slog.Debug("truncating input messages which exceed context length", "truncated", len(msgs[i:]))
break
2024-06-17 17:38:55 +00:00
} else {
n = i
}
2024-06-17 17:38:55 +00:00
}
2024-06-20 18:00:08 +00:00
// truncate any messages that do not fit into the context window
2024-06-17 17:38:55 +00:00
var b bytes.Buffer
2024-06-20 18:00:08 +00:00
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[n:]...)}); err != nil {
2024-06-17 17:38:55 +00:00
return "", nil, err
}
2024-06-17 17:38:55 +00:00
for _, m := range msgs[n:] {
for _, i := range m.Images {
images = append(images, llm.ImageData{
ID: len(images),
Data: i,
})
}
}
2024-06-17 17:38:55 +00:00
return b.String(), images, nil
}