2024-02-12 15:06:57 -08:00
|
|
|
package server
|
|
|
|
|
|
|
|
import (
|
2024-06-17 10:38:55 -07:00
|
|
|
"bytes"
|
|
|
|
"context"
|
2024-02-12 15:06:57 -08:00
|
|
|
"log/slog"
|
2024-06-17 10:38:55 -07:00
|
|
|
"slices"
|
2024-02-12 15:06:57 -08:00
|
|
|
|
2024-03-26 13:04:17 -07:00
|
|
|
"github.com/ollama/ollama/api"
|
2024-06-17 10:38:55 -07:00
|
|
|
"github.com/ollama/ollama/llm"
|
2024-06-10 14:54:42 -07:00
|
|
|
"github.com/ollama/ollama/template"
|
2024-02-12 15:06:57 -08:00
|
|
|
)
|
|
|
|
|
2024-06-20 11:00:08 -07:00
|
|
|
type tokenizeFunc func(context.Context, string) ([]int, error)
|
|
|
|
|
|
|
|
// chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn.
|
|
|
|
// chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the
|
|
|
|
// latest message and 2) system messages
|
|
|
|
func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message) (prompt string, images []llm.ImageData, _ error) {
|
|
|
|
// pull out any system messages which should always be included in the prompt
|
2024-06-17 10:38:55 -07:00
|
|
|
var system []api.Message
|
|
|
|
msgs = slices.DeleteFunc(msgs, func(m api.Message) bool {
|
|
|
|
if m.Role == "system" {
|
|
|
|
system = append(system, m)
|
|
|
|
return true
|
2024-02-12 15:06:57 -08:00
|
|
|
}
|
|
|
|
|
2024-06-17 10:38:55 -07:00
|
|
|
return false
|
|
|
|
})
|
2024-02-12 15:06:57 -08:00
|
|
|
|
2024-06-20 11:00:08 -07:00
|
|
|
if len(system) == 0 && m.System != "" {
|
2024-06-17 10:38:55 -07:00
|
|
|
// add model system prompt since it wasn't provided
|
2024-06-20 11:00:08 -07:00
|
|
|
system = append(system, api.Message{Role: "system", Content: m.System})
|
2024-02-12 15:06:57 -08:00
|
|
|
}
|
|
|
|
|
2024-06-20 11:00:08 -07:00
|
|
|
// always include the last message
|
2024-06-17 10:38:55 -07:00
|
|
|
n := len(msgs) - 1
|
2024-06-20 11:00:08 -07:00
|
|
|
// in reverse, find all messages that fit into context window
|
2024-06-17 10:38:55 -07:00
|
|
|
for i := n - 1; i >= 0; i-- {
|
|
|
|
var b bytes.Buffer
|
2024-06-20 11:00:08 -07:00
|
|
|
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...)}); err != nil {
|
2024-06-17 10:38:55 -07:00
|
|
|
return "", nil, err
|
2024-02-12 15:06:57 -08:00
|
|
|
}
|
|
|
|
|
2024-06-20 11:00:08 -07:00
|
|
|
s, err := tokenize(ctx, b.String())
|
2024-02-12 15:06:57 -08:00
|
|
|
if err != nil {
|
2024-06-17 10:38:55 -07:00
|
|
|
return "", nil, err
|
2024-02-12 15:06:57 -08:00
|
|
|
}
|
|
|
|
|
2024-06-17 10:38:55 -07:00
|
|
|
c := len(s)
|
2024-06-20 11:00:08 -07:00
|
|
|
if m.ProjectorPaths != nil {
|
2024-06-17 10:38:55 -07:00
|
|
|
for _, m := range msgs[i:] {
|
2024-06-20 11:00:08 -07:00
|
|
|
// images are represented as 768 sized embeddings
|
|
|
|
// TODO: get embedding length from project metadata
|
2024-06-17 10:38:55 -07:00
|
|
|
c += 768 * len(m.Images)
|
|
|
|
}
|
2024-02-12 15:06:57 -08:00
|
|
|
}
|
|
|
|
|
2024-06-20 11:00:08 -07:00
|
|
|
if c > opts.NumCtx {
|
2024-06-17 10:38:55 -07:00
|
|
|
slog.Debug("truncating input messages which exceed context length", "truncated", len(msgs[i:]))
|
2024-02-12 15:06:57 -08:00
|
|
|
break
|
2024-06-17 10:38:55 -07:00
|
|
|
} else {
|
|
|
|
n = i
|
2024-02-12 15:06:57 -08:00
|
|
|
}
|
2024-06-17 10:38:55 -07:00
|
|
|
}
|
2024-02-12 15:06:57 -08:00
|
|
|
|
2024-06-20 11:00:08 -07:00
|
|
|
// truncate any messages that do not fit into the context window
|
2024-06-17 10:38:55 -07:00
|
|
|
var b bytes.Buffer
|
2024-06-20 11:00:08 -07:00
|
|
|
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[n:]...)}); err != nil {
|
2024-06-17 10:38:55 -07:00
|
|
|
return "", nil, err
|
2024-02-12 15:06:57 -08:00
|
|
|
}
|
|
|
|
|
2024-06-17 10:38:55 -07:00
|
|
|
for _, m := range msgs[n:] {
|
|
|
|
for _, i := range m.Images {
|
|
|
|
images = append(images, llm.ImageData{
|
|
|
|
ID: len(images),
|
|
|
|
Data: i,
|
|
|
|
})
|
2024-02-12 15:06:57 -08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2024-06-17 10:38:55 -07:00
|
|
|
return b.String(), images, nil
|
2024-02-12 15:06:57 -08:00
|
|
|
}
|