Merge pull request #2296 from ollama/mxyng/img-tags
append image tags to user content
This commit is contained in:
commit
bfbf2f7cf7
6 changed files with 89 additions and 36 deletions
|
@ -161,13 +161,10 @@ func newDynExtServer(library, model string, adapters, projectors []string, opts
|
||||||
func (llm *dynExtServer) Predict(ctx context.Context, predict PredictOpts, fn func(PredictResult)) error {
|
func (llm *dynExtServer) Predict(ctx context.Context, predict PredictOpts, fn func(PredictResult)) error {
|
||||||
resp := newExtServerResp(128)
|
resp := newExtServerResp(128)
|
||||||
defer freeExtServerResp(resp)
|
defer freeExtServerResp(resp)
|
||||||
var imageData []ImageData
|
|
||||||
if len(predict.Images) > 0 {
|
if len(predict.Images) > 0 {
|
||||||
for cnt, i := range predict.Images {
|
slog.Info(fmt.Sprintf("loaded %d images", len(predict.Images)))
|
||||||
imageData = append(imageData, ImageData{Data: i, ID: cnt})
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
slog.Info(fmt.Sprintf("loaded %d images", len(imageData)))
|
|
||||||
|
|
||||||
request := map[string]any{
|
request := map[string]any{
|
||||||
"prompt": predict.Prompt,
|
"prompt": predict.Prompt,
|
||||||
|
@ -189,7 +186,7 @@ func (llm *dynExtServer) Predict(ctx context.Context, predict PredictOpts, fn fu
|
||||||
"penalize_nl": predict.Options.PenalizeNewline,
|
"penalize_nl": predict.Options.PenalizeNewline,
|
||||||
"seed": predict.Options.Seed,
|
"seed": predict.Options.Seed,
|
||||||
"stop": predict.Options.Stop,
|
"stop": predict.Options.Stop,
|
||||||
"image_data": imageData,
|
"image_data": predict.Images,
|
||||||
"cache_prompt": true,
|
"cache_prompt": true,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -62,7 +62,7 @@ const maxRetries = 3
|
||||||
type PredictOpts struct {
|
type PredictOpts struct {
|
||||||
Prompt string
|
Prompt string
|
||||||
Format string
|
Format string
|
||||||
Images []api.ImageData
|
Images []ImageData
|
||||||
Options api.Options
|
Options api.Options
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -63,6 +63,7 @@ type PromptVars struct {
|
||||||
Prompt string
|
Prompt string
|
||||||
Response string
|
Response string
|
||||||
First bool
|
First bool
|
||||||
|
Images []llm.ImageData
|
||||||
}
|
}
|
||||||
|
|
||||||
// extractParts extracts the parts of the template before and after the {{.Response}} node.
|
// extractParts extracts the parts of the template before and after the {{.Response}} node.
|
||||||
|
@ -147,15 +148,13 @@ func (m *Model) PostResponseTemplate(p PromptVars) (string, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
type ChatHistory struct {
|
type ChatHistory struct {
|
||||||
Prompts []PromptVars
|
Prompts []PromptVars
|
||||||
CurrentImages []api.ImageData
|
LastSystem string
|
||||||
LastSystem string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ChatPrompts returns a list of formatted chat prompts from a list of messages
|
// ChatPrompts returns a list of formatted chat prompts from a list of messages
|
||||||
func (m *Model) ChatPrompts(msgs []api.Message) (*ChatHistory, error) {
|
func (m *Model) ChatPrompts(msgs []api.Message) (*ChatHistory, error) {
|
||||||
// build the prompt from the list of messages
|
// build the prompt from the list of messages
|
||||||
var currentImages []api.ImageData
|
|
||||||
lastSystem := m.System
|
lastSystem := m.System
|
||||||
currentVars := PromptVars{
|
currentVars := PromptVars{
|
||||||
First: true,
|
First: true,
|
||||||
|
@ -163,6 +162,7 @@ func (m *Model) ChatPrompts(msgs []api.Message) (*ChatHistory, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
prompts := []PromptVars{}
|
prompts := []PromptVars{}
|
||||||
|
var images []llm.ImageData
|
||||||
|
|
||||||
for _, msg := range msgs {
|
for _, msg := range msgs {
|
||||||
switch strings.ToLower(msg.Role) {
|
switch strings.ToLower(msg.Role) {
|
||||||
|
@ -179,8 +179,18 @@ func (m *Model) ChatPrompts(msgs []api.Message) (*ChatHistory, error) {
|
||||||
prompts = append(prompts, currentVars)
|
prompts = append(prompts, currentVars)
|
||||||
currentVars = PromptVars{}
|
currentVars = PromptVars{}
|
||||||
}
|
}
|
||||||
|
|
||||||
currentVars.Prompt = msg.Content
|
currentVars.Prompt = msg.Content
|
||||||
currentImages = msg.Images
|
for i := range msg.Images {
|
||||||
|
id := len(images) + i
|
||||||
|
currentVars.Prompt += fmt.Sprintf(" [img-%d]", id)
|
||||||
|
currentVars.Images = append(currentVars.Images, llm.ImageData{
|
||||||
|
ID: id,
|
||||||
|
Data: msg.Images[i],
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
images = append(images, currentVars.Images...)
|
||||||
case "assistant":
|
case "assistant":
|
||||||
currentVars.Response = msg.Content
|
currentVars.Response = msg.Content
|
||||||
prompts = append(prompts, currentVars)
|
prompts = append(prompts, currentVars)
|
||||||
|
@ -196,9 +206,8 @@ func (m *Model) ChatPrompts(msgs []api.Message) (*ChatHistory, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
return &ChatHistory{
|
return &ChatHistory{
|
||||||
Prompts: prompts,
|
Prompts: prompts,
|
||||||
CurrentImages: currentImages,
|
LastSystem: lastSystem,
|
||||||
LastSystem: lastSystem,
|
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -238,18 +238,37 @@ func chatHistoryEqual(a, b ChatHistory) bool {
|
||||||
if len(a.Prompts) != len(b.Prompts) {
|
if len(a.Prompts) != len(b.Prompts) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
if len(a.CurrentImages) != len(b.CurrentImages) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
for i, v := range a.Prompts {
|
for i, v := range a.Prompts {
|
||||||
if v != b.Prompts[i] {
|
|
||||||
|
if v.First != b.Prompts[i].First {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
}
|
|
||||||
for i, v := range a.CurrentImages {
|
if v.Response != b.Prompts[i].Response {
|
||||||
if !bytes.Equal(v, b.CurrentImages[i]) {
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if v.Prompt != b.Prompts[i].Prompt {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if v.System != b.Prompts[i].System {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(v.Images) != len(b.Prompts[i].Images) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
for j, img := range v.Images {
|
||||||
|
if img.ID != b.Prompts[i].Images[j].ID {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if !bytes.Equal(img.Data, b.Prompts[i].Images[j].Data) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return a.LastSystem == b.LastSystem
|
return a.LastSystem == b.LastSystem
|
||||||
}
|
}
|
||||||
|
|
|
@ -244,6 +244,10 @@ func GenerateHandler(c *gin.Context) {
|
||||||
promptVars.System = model.System
|
promptVars.System = model.System
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for i := range req.Images {
|
||||||
|
promptVars.Prompt += fmt.Sprintf(" [img-%d]", i)
|
||||||
|
}
|
||||||
|
|
||||||
p, err := model.PreResponsePrompt(promptVars)
|
p, err := model.PreResponsePrompt(promptVars)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
@ -308,11 +312,19 @@ func GenerateHandler(c *gin.Context) {
|
||||||
ch <- resp
|
ch <- resp
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var images []llm.ImageData
|
||||||
|
for i := range req.Images {
|
||||||
|
images = append(images, llm.ImageData{
|
||||||
|
ID: i,
|
||||||
|
Data: req.Images[i],
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// Start prediction
|
// Start prediction
|
||||||
predictReq := llm.PredictOpts{
|
predictReq := llm.PredictOpts{
|
||||||
Prompt: prompt,
|
Prompt: prompt,
|
||||||
Format: req.Format,
|
Format: req.Format,
|
||||||
Images: req.Images,
|
Images: images,
|
||||||
Options: opts,
|
Options: opts,
|
||||||
}
|
}
|
||||||
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
|
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
|
||||||
|
@ -1139,7 +1151,8 @@ func ChatHandler(c *gin.Context) {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
prompt, err := trimmedPrompt(c.Request.Context(), chat, model)
|
|
||||||
|
prompt, images, err := trimmedPrompt(c.Request.Context(), chat, model)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
|
@ -1182,7 +1195,7 @@ func ChatHandler(c *gin.Context) {
|
||||||
predictReq := llm.PredictOpts{
|
predictReq := llm.PredictOpts{
|
||||||
Prompt: prompt,
|
Prompt: prompt,
|
||||||
Format: req.Format,
|
Format: req.Format,
|
||||||
Images: chat.CurrentImages,
|
Images: images,
|
||||||
Options: opts,
|
Options: opts,
|
||||||
}
|
}
|
||||||
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
|
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
|
||||||
|
@ -1229,34 +1242,47 @@ type promptInfo struct {
|
||||||
|
|
||||||
// trimmedPrompt builds a prompt to send to a running model. It ensures the prompt fits within the max context length,
|
// trimmedPrompt builds a prompt to send to a running model. It ensures the prompt fits within the max context length,
|
||||||
// while preserving the most recent system message.
|
// while preserving the most recent system message.
|
||||||
func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string, error) {
|
func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string, []llm.ImageData, error) {
|
||||||
if len(chat.Prompts) == 0 {
|
if len(chat.Prompts) == 0 {
|
||||||
return "", nil
|
return "", nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var promptsToAdd []promptInfo
|
var promptsToAdd []promptInfo
|
||||||
var totalTokenLength int
|
var totalTokenLength int
|
||||||
var systemPromptIncluded bool
|
var systemPromptIncluded bool
|
||||||
|
|
||||||
|
var images []llm.ImageData
|
||||||
// reverse iterate through the prompts to build the prompt string in a way that fits the max context length
|
// reverse iterate through the prompts to build the prompt string in a way that fits the max context length
|
||||||
for i := len(chat.Prompts) - 1; i >= 0; i-- {
|
for i := len(chat.Prompts) - 1; i >= 0; i-- {
|
||||||
promptText, err := promptString(model, chat.Prompts[i], i == len(chat.Prompts)-1)
|
prompt := chat.Prompts[i]
|
||||||
|
promptText, err := promptString(model, prompt, i == len(chat.Prompts)-1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
encodedTokens, err := loaded.runner.Encode(ctx, promptText)
|
encodedTokens, err := loaded.runner.Encode(ctx, promptText)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if totalTokenLength+len(encodedTokens) > loaded.NumCtx && i != len(chat.Prompts)-1 {
|
if totalTokenLength+len(encodedTokens) > loaded.NumCtx && i != len(chat.Prompts)-1 {
|
||||||
break // reached max context length, stop adding more prompts
|
break // reached max context length, stop adding more prompts
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for j := range prompt.Images {
|
||||||
|
if totalTokenLength+768 > loaded.NumCtx {
|
||||||
|
// this decreases the token length but overestimating is fine
|
||||||
|
prompt.Prompt = strings.ReplaceAll(prompt.Prompt, fmt.Sprintf(" [img-%d]", prompt.Images[j].ID), "")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
totalTokenLength += 768
|
||||||
|
images = append(images, prompt.Images[j])
|
||||||
|
}
|
||||||
|
|
||||||
totalTokenLength += len(encodedTokens)
|
totalTokenLength += len(encodedTokens)
|
||||||
systemPromptIncluded = systemPromptIncluded || chat.Prompts[i].System != ""
|
systemPromptIncluded = systemPromptIncluded || prompt.System != ""
|
||||||
promptsToAdd = append(promptsToAdd, promptInfo{vars: chat.Prompts[i], tokenLen: len(encodedTokens)})
|
promptsToAdd = append(promptsToAdd, promptInfo{vars: prompt, tokenLen: len(encodedTokens)})
|
||||||
}
|
}
|
||||||
|
|
||||||
// ensure the system prompt is included, if not already
|
// ensure the system prompt is included, if not already
|
||||||
|
@ -1264,7 +1290,7 @@ func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string
|
||||||
var err error
|
var err error
|
||||||
promptsToAdd, err = includeSystemPrompt(ctx, chat.LastSystem, totalTokenLength, promptsToAdd)
|
promptsToAdd, err = includeSystemPrompt(ctx, chat.LastSystem, totalTokenLength, promptsToAdd)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1275,11 +1301,12 @@ func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string
|
||||||
for i, prompt := range promptsToAdd {
|
for i, prompt := range promptsToAdd {
|
||||||
promptText, err := promptString(model, prompt.vars, i == 0)
|
promptText, err := promptString(model, prompt.vars, i == 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", nil, err
|
||||||
}
|
}
|
||||||
result = promptText + result
|
result = promptText + result
|
||||||
}
|
}
|
||||||
return result, nil
|
|
||||||
|
return result, images, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// promptString applies the model template to the prompt
|
// promptString applies the model template to the prompt
|
||||||
|
|
|
@ -455,7 +455,8 @@ func Test_ChatPrompt(t *testing.T) {
|
||||||
NumCtx: tt.numCtx,
|
NumCtx: tt.numCtx,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
got, err := trimmedPrompt(context.Background(), tt.chat, m)
|
// TODO: add tests for trimming images
|
||||||
|
got, _, err := trimmedPrompt(context.Background(), tt.chat, m)
|
||||||
if tt.wantErr != "" {
|
if tt.wantErr != "" {
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Errorf("ChatPrompt() expected error, got nil")
|
t.Errorf("ChatPrompt() expected error, got nil")
|
||||||
|
|
Loading…
Add table
Reference in a new issue