ollama/server/prompt.go
Patrick Devine c7cb0f0602
image processing for llama3.2 (#6963)
Co-authored-by: jmorganca <jmorganca@gmail.com>
Co-authored-by: Michael Yang <mxyng@pm.me>
Co-authored-by: Jesse Gross <jesse@ollama.com>
2024-10-18 16:12:35 -07:00

142 lines
3.5 KiB
Go

package server
import (
"bytes"
"context"
"encoding/binary"
"errors"
"fmt"
"log/slog"
"strings"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/server/imageproc"
"github.com/ollama/ollama/template"
)
type tokenizeFunc func(context.Context, string) ([]int, error)
var errTooManyImages = errors.New("vision model only supports a single image per message")
// chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn.
// chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the
// latest message and 2) system messages
func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool) (prompt string, images []llm.ImageData, _ error) {
var system []api.Message
isMllama := checkMllamaModelFamily(m)
n := len(msgs) - 1
// in reverse, find all messages that fit into context window
for i := n; i >= 0; i-- {
if isMllama && len(msgs[i].Images) > 1 {
return "", nil, errTooManyImages
}
// always include the last message
if i == n {
continue
}
system = make([]api.Message, 0)
for j := range i {
if msgs[j].Role == "system" {
system = append(system, msgs[j])
}
}
var b bytes.Buffer
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...), Tools: tools}); err != nil {
return "", nil, err
}
s, err := tokenize(ctx, b.String())
if err != nil {
return "", nil, err
}
ctxLen := len(s)
if m.ProjectorPaths != nil {
for _, m := range msgs[i:] {
// images are represented as 768 sized embeddings
// TODO: get embedding length from project metadata
ctxLen += 768 * len(m.Images)
}
}
if ctxLen > opts.NumCtx {
slog.Debug("truncating input messages which exceed context length", "truncated", len(msgs[i:]))
break
} else {
n = i
}
}
currMsgIdx := n
if isMllama {
lastMsgIdx := len(msgs) - 1
for i := lastMsgIdx; i >= currMsgIdx; i-- {
if len(msgs[i].Images) > 0 {
data, aspectRatioID, err := imageproc.Preprocess(msgs[i].Images[0])
if err != nil {
return "", nil, err
}
buf := new(bytes.Buffer)
err = binary.Write(buf, binary.LittleEndian, data)
if err != nil {
return "", nil, err
}
imgData := llm.ImageData{
Data: buf.Bytes(),
AspectRatioID: aspectRatioID,
}
msgs[i].Content = strings.TrimSpace("<|image|>" + msgs[i].Content)
images = append(images, imgData)
break
}
}
} else {
for cnt, msg := range msgs[currMsgIdx:] {
prefix := ""
prompt := msg.Content
for _, i := range msg.Images {
imgData := llm.ImageData{
ID: len(images),
Data: i,
}
imgTag := fmt.Sprintf("[img-%d]", imgData.ID)
if !strings.Contains(prompt, "[img]") {
prefix += imgTag
} else {
prompt = strings.Replace(prompt, "[img]", imgTag, 1)
}
images = append(images, imgData)
}
msgs[currMsgIdx+cnt].Content = strings.TrimSpace(prefix + " " + prompt)
}
}
// truncate any messages that do not fit into the context window
var b bytes.Buffer
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[currMsgIdx:]...), Tools: tools}); err != nil {
return "", nil, err
}
return b.String(), images, nil
}
func checkMllamaModelFamily(m *Model) bool {
for _, arch := range m.Config.ModelFamilies {
if arch == "mllama" {
return true
}
}
return false
}