From b4e11be8ef3c4afc82a7357d51f93b336c1866a1 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Wed, 31 Jan 2024 16:31:29 -0800 Subject: [PATCH 1/8] append image tags to user content --- server/images.go | 7 ++++++- server/routes.go | 4 ++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/server/images.go b/server/images.go index 503dd8e2..26b59e0d 100644 --- a/server/images.go +++ b/server/images.go @@ -179,8 +179,13 @@ func (m *Model) ChatPrompts(msgs []api.Message) (*ChatHistory, error) { prompts = append(prompts, currentVars) currentVars = PromptVars{} } + currentVars.Prompt = msg.Content - currentImages = msg.Images + for i := range msg.Images { + currentVars.Prompt += fmt.Sprintf(" [img-%d]", len(currentImages)+i) + } + + currentImages = append(currentImages, msg.Images...) case "assistant": currentVars.Response = msg.Content prompts = append(prompts, currentVars) diff --git a/server/routes.go b/server/routes.go index 01a898a8..4dc1be5b 100644 --- a/server/routes.go +++ b/server/routes.go @@ -244,6 +244,10 @@ func GenerateHandler(c *gin.Context) { promptVars.System = model.System } + for i := range req.Images { + promptVars.Prompt += fmt.Sprintf(" [img-%d]", i) + } + p, err := model.PreResponsePrompt(promptVars) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) From 8450bf66e60ab563552d31c0c69039cc12fe4603 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Wed, 31 Jan 2024 17:39:38 -0800 Subject: [PATCH 2/8] trim images --- llm/llama.go | 2 +- server/images.go | 27 ++++++++++++++++++--------- server/routes.go | 32 ++++++++++++++++++++++---------- 3 files changed, 41 insertions(+), 20 deletions(-) diff --git a/llm/llama.go b/llm/llama.go index 80b4f75b..a53df7c8 100644 --- a/llm/llama.go +++ b/llm/llama.go @@ -62,7 +62,7 @@ const maxRetries = 3 type PredictOpts struct { Prompt string Format string - Images []api.ImageData + Images map[int]api.ImageData Options api.Options } diff --git a/server/images.go b/server/images.go index 26b59e0d..72fdef71 100644 --- a/server/images.go +++ b/server/images.go @@ -58,11 +58,17 @@ type Message struct { Content string `json:"content"` } +type ImageData struct { + Rank int + api.ImageData +} + type PromptVars struct { System string Prompt string Response string First bool + Images []ImageData } // extractParts extracts the parts of the template before and after the {{.Response}} node. @@ -147,15 +153,13 @@ func (m *Model) PostResponseTemplate(p PromptVars) (string, error) { } type ChatHistory struct { - Prompts []PromptVars - CurrentImages []api.ImageData - LastSystem string + Prompts []PromptVars + LastSystem string } // ChatPrompts returns a list of formatted chat prompts from a list of messages func (m *Model) ChatPrompts(msgs []api.Message) (*ChatHistory, error) { // build the prompt from the list of messages - var currentImages []api.ImageData lastSystem := m.System currentVars := PromptVars{ First: true, @@ -163,6 +167,7 @@ func (m *Model) ChatPrompts(msgs []api.Message) (*ChatHistory, error) { } prompts := []PromptVars{} + var images []ImageData for _, msg := range msgs { switch strings.ToLower(msg.Role) { @@ -182,10 +187,15 @@ func (m *Model) ChatPrompts(msgs []api.Message) (*ChatHistory, error) { currentVars.Prompt = msg.Content for i := range msg.Images { - currentVars.Prompt += fmt.Sprintf(" [img-%d]", len(currentImages)+i) + currentVars.Prompt += fmt.Sprintf(" [img-%d]", len(images)+i) + currentVars.Images = append(currentVars.Images, ImageData{ + Rank: len(images) + i, + ImageData: msg.Images[i], + }) + } - currentImages = append(currentImages, msg.Images...) + images = append(images, currentVars.Images...) case "assistant": currentVars.Response = msg.Content prompts = append(prompts, currentVars) @@ -201,9 +211,8 @@ func (m *Model) ChatPrompts(msgs []api.Message) (*ChatHistory, error) { } return &ChatHistory{ - Prompts: prompts, - CurrentImages: currentImages, - LastSystem: lastSystem, + Prompts: prompts, + LastSystem: lastSystem, }, nil } diff --git a/server/routes.go b/server/routes.go index 4dc1be5b..453a7b09 100644 --- a/server/routes.go +++ b/server/routes.go @@ -312,11 +312,16 @@ func GenerateHandler(c *gin.Context) { ch <- resp } + images := make(map[int]api.ImageData) + for i := range req.Images { + images[i] = req.Images[i] + } + // Start prediction predictReq := llm.PredictOpts{ Prompt: prompt, Format: req.Format, - Images: req.Images, + Images: images, Options: opts, } if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil { @@ -1143,7 +1148,8 @@ func ChatHandler(c *gin.Context) { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } - prompt, err := trimmedPrompt(c.Request.Context(), chat, model) + + prompt, images, err := trimmedPrompt(c.Request.Context(), chat, model) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return @@ -1186,7 +1192,7 @@ func ChatHandler(c *gin.Context) { predictReq := llm.PredictOpts{ Prompt: prompt, Format: req.Format, - Images: chat.CurrentImages, + Images: images, Options: opts, } if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil { @@ -1233,25 +1239,27 @@ type promptInfo struct { // trimmedPrompt builds a prompt to send to a running model. It ensures the prompt fits within the max context length, // while preserving the most recent system message. -func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string, error) { +func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string, map[int]api.ImageData, error) { if len(chat.Prompts) == 0 { - return "", nil + return "", nil, nil } var promptsToAdd []promptInfo var totalTokenLength int var systemPromptIncluded bool + images := make(map[int]api.ImageData) + // reverse iterate through the prompts to build the prompt string in a way that fits the max context length for i := len(chat.Prompts) - 1; i >= 0; i-- { promptText, err := promptString(model, chat.Prompts[i], i == len(chat.Prompts)-1) if err != nil { - return "", err + return "", nil, err } encodedTokens, err := loaded.runner.Encode(ctx, promptText) if err != nil { - return "", err + return "", nil, err } if totalTokenLength+len(encodedTokens) > loaded.NumCtx && i != len(chat.Prompts)-1 { @@ -1261,6 +1269,10 @@ func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string totalTokenLength += len(encodedTokens) systemPromptIncluded = systemPromptIncluded || chat.Prompts[i].System != "" promptsToAdd = append(promptsToAdd, promptInfo{vars: chat.Prompts[i], tokenLen: len(encodedTokens)}) + + for _, image := range chat.Prompts[i].Images { + images[image.Rank] = image.ImageData + } } // ensure the system prompt is included, if not already @@ -1268,7 +1280,7 @@ func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string var err error promptsToAdd, err = includeSystemPrompt(ctx, chat.LastSystem, totalTokenLength, promptsToAdd) if err != nil { - return "", err + return "", nil, err } } @@ -1279,11 +1291,11 @@ func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string for i, prompt := range promptsToAdd { promptText, err := promptString(model, prompt.vars, i == 0) if err != nil { - return "", err + return "", nil, err } result = promptText + result } - return result, nil + return result, images, nil } // promptString applies the model template to the prompt From f11bf0740bfc4a9653a4c59bf3cb9a00361654b1 Mon Sep 17 00:00:00 2001 From: Jeffrey Morgan Date: Wed, 31 Jan 2024 18:56:12 -0800 Subject: [PATCH 3/8] use `llm.ImageData` --- llm/dyn_ext_server.go | 9 +++------ llm/llama.go | 2 +- server/routes.go | 17 ++++++++++++++--- 3 files changed, 18 insertions(+), 10 deletions(-) diff --git a/llm/dyn_ext_server.go b/llm/dyn_ext_server.go index 782fd382..f7e19a7b 100644 --- a/llm/dyn_ext_server.go +++ b/llm/dyn_ext_server.go @@ -161,13 +161,10 @@ func newDynExtServer(library, model string, adapters, projectors []string, opts func (llm *dynExtServer) Predict(ctx context.Context, predict PredictOpts, fn func(PredictResult)) error { resp := newExtServerResp(128) defer freeExtServerResp(resp) - var imageData []ImageData + if len(predict.Images) > 0 { - for cnt, i := range predict.Images { - imageData = append(imageData, ImageData{Data: i, ID: cnt}) - } + slog.Info(fmt.Sprintf("loaded %d images", len(predict.Images))) } - slog.Info(fmt.Sprintf("loaded %d images", len(imageData))) request := map[string]any{ "prompt": predict.Prompt, @@ -189,7 +186,7 @@ func (llm *dynExtServer) Predict(ctx context.Context, predict PredictOpts, fn fu "penalize_nl": predict.Options.PenalizeNewline, "seed": predict.Options.Seed, "stop": predict.Options.Stop, - "image_data": imageData, + "image_data": predict.Images, "cache_prompt": true, } diff --git a/llm/llama.go b/llm/llama.go index a53df7c8..a5d2036a 100644 --- a/llm/llama.go +++ b/llm/llama.go @@ -62,7 +62,7 @@ const maxRetries = 3 type PredictOpts struct { Prompt string Format string - Images map[int]api.ImageData + Images []ImageData Options api.Options } diff --git a/server/routes.go b/server/routes.go index 453a7b09..6d155aee 100644 --- a/server/routes.go +++ b/server/routes.go @@ -312,9 +312,12 @@ func GenerateHandler(c *gin.Context) { ch <- resp } - images := make(map[int]api.ImageData) + var images []llm.ImageData for i := range req.Images { - images[i] = req.Images[i] + images = append(images, llm.ImageData{ + ID: i, + Data: req.Images[i], + }) } // Start prediction @@ -1188,11 +1191,19 @@ func ChatHandler(c *gin.Context) { ch <- resp } + var imageData []llm.ImageData + for k, v := range images { + imageData = append(imageData, llm.ImageData{ + ID: k, + Data: v, + }) + } + // Start prediction predictReq := llm.PredictOpts{ Prompt: prompt, Format: req.Format, - Images: images, + Images: imageData, Options: opts, } if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil { From d046bee790fd8549d324d2558693722a21b897e8 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Wed, 31 Jan 2024 19:18:25 -0800 Subject: [PATCH 4/8] use llm.ImageData for chat --- server/images.go | 15 +++++---------- server/routes.go | 20 +++++--------------- 2 files changed, 10 insertions(+), 25 deletions(-) diff --git a/server/images.go b/server/images.go index 72fdef71..a68a6699 100644 --- a/server/images.go +++ b/server/images.go @@ -58,17 +58,12 @@ type Message struct { Content string `json:"content"` } -type ImageData struct { - Rank int - api.ImageData -} - type PromptVars struct { System string Prompt string Response string First bool - Images []ImageData + Images []llm.ImageData } // extractParts extracts the parts of the template before and after the {{.Response}} node. @@ -167,7 +162,7 @@ func (m *Model) ChatPrompts(msgs []api.Message) (*ChatHistory, error) { } prompts := []PromptVars{} - var images []ImageData + var images []llm.ImageData for _, msg := range msgs { switch strings.ToLower(msg.Role) { @@ -188,9 +183,9 @@ func (m *Model) ChatPrompts(msgs []api.Message) (*ChatHistory, error) { currentVars.Prompt = msg.Content for i := range msg.Images { currentVars.Prompt += fmt.Sprintf(" [img-%d]", len(images)+i) - currentVars.Images = append(currentVars.Images, ImageData{ - Rank: len(images) + i, - ImageData: msg.Images[i], + currentVars.Images = append(currentVars.Images, llm.ImageData{ + ID: i, + Data: msg.Images[i], }) } diff --git a/server/routes.go b/server/routes.go index 6d155aee..f29d9b2b 100644 --- a/server/routes.go +++ b/server/routes.go @@ -1191,19 +1191,11 @@ func ChatHandler(c *gin.Context) { ch <- resp } - var imageData []llm.ImageData - for k, v := range images { - imageData = append(imageData, llm.ImageData{ - ID: k, - Data: v, - }) - } - // Start prediction predictReq := llm.PredictOpts{ Prompt: prompt, Format: req.Format, - Images: imageData, + Images: images, Options: opts, } if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil { @@ -1250,7 +1242,7 @@ type promptInfo struct { // trimmedPrompt builds a prompt to send to a running model. It ensures the prompt fits within the max context length, // while preserving the most recent system message. -func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string, map[int]api.ImageData, error) { +func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string, []llm.ImageData, error) { if len(chat.Prompts) == 0 { return "", nil, nil } @@ -1259,8 +1251,7 @@ func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string var totalTokenLength int var systemPromptIncluded bool - images := make(map[int]api.ImageData) - + var images []llm.ImageData // reverse iterate through the prompts to build the prompt string in a way that fits the max context length for i := len(chat.Prompts) - 1; i >= 0; i-- { promptText, err := promptString(model, chat.Prompts[i], i == len(chat.Prompts)-1) @@ -1281,9 +1272,7 @@ func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string systemPromptIncluded = systemPromptIncluded || chat.Prompts[i].System != "" promptsToAdd = append(promptsToAdd, promptInfo{vars: chat.Prompts[i], tokenLen: len(encodedTokens)}) - for _, image := range chat.Prompts[i].Images { - images[image.Rank] = image.ImageData - } + images = append(images, chat.Prompts[i].Images...) } // ensure the system prompt is included, if not already @@ -1306,6 +1295,7 @@ func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string } result = promptText + result } + return result, images, nil } From fb5698801426f045e46dfc228f1adca70ed79bbc Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Thu, 1 Feb 2024 09:50:48 -0800 Subject: [PATCH 5/8] account for image projection in token count --- server/routes.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/server/routes.go b/server/routes.go index f29d9b2b..d2c7323e 100644 --- a/server/routes.go +++ b/server/routes.go @@ -1273,6 +1273,10 @@ func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string promptsToAdd = append(promptsToAdd, promptInfo{vars: chat.Prompts[i], tokenLen: len(encodedTokens)}) images = append(images, chat.Prompts[i].Images...) + + // clip has a projection dimension of 768 + // TODO: use kv['clip.vision.projection_dim'] from projection instead + totalTokenLength += 768 * len(chat.Prompts[i].Images) } // ensure the system prompt is included, if not already From d125510b4b7fef09b8a5795f30692da354a0d9cd Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Thu, 1 Feb 2024 11:21:17 -0800 Subject: [PATCH 6/8] remove image tags --- server/images.go | 3 +-- server/routes.go | 24 +++++++++++++++--------- 2 files changed, 16 insertions(+), 11 deletions(-) diff --git a/server/images.go b/server/images.go index a68a6699..68dae0fe 100644 --- a/server/images.go +++ b/server/images.go @@ -184,10 +184,9 @@ func (m *Model) ChatPrompts(msgs []api.Message) (*ChatHistory, error) { for i := range msg.Images { currentVars.Prompt += fmt.Sprintf(" [img-%d]", len(images)+i) currentVars.Images = append(currentVars.Images, llm.ImageData{ - ID: i, + ID: len(images) + i, Data: msg.Images[i], }) - } images = append(images, currentVars.Images...) diff --git a/server/routes.go b/server/routes.go index d2c7323e..503f0fa1 100644 --- a/server/routes.go +++ b/server/routes.go @@ -1254,7 +1254,8 @@ func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string var images []llm.ImageData // reverse iterate through the prompts to build the prompt string in a way that fits the max context length for i := len(chat.Prompts) - 1; i >= 0; i-- { - promptText, err := promptString(model, chat.Prompts[i], i == len(chat.Prompts)-1) + prompt := chat.Prompts[i] + promptText, err := promptString(model, prompt, i == len(chat.Prompts)-1) if err != nil { return "", nil, err } @@ -1268,15 +1269,20 @@ func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string break // reached max context length, stop adding more prompts } + for j := range prompt.Images { + if totalTokenLength+768 > loaded.NumCtx { + // this decreases the token length but overestimating is fine + prompt.Prompt = strings.ReplaceAll(prompt.Prompt, fmt.Sprintf(" [img-%d]", prompt.Images[j].ID), "") + continue + } + + totalTokenLength += 768 + images = append(images, prompt.Images[j]) + } + totalTokenLength += len(encodedTokens) - systemPromptIncluded = systemPromptIncluded || chat.Prompts[i].System != "" - promptsToAdd = append(promptsToAdd, promptInfo{vars: chat.Prompts[i], tokenLen: len(encodedTokens)}) - - images = append(images, chat.Prompts[i].Images...) - - // clip has a projection dimension of 768 - // TODO: use kv['clip.vision.projection_dim'] from projection instead - totalTokenLength += 768 * len(chat.Prompts[i].Images) + systemPromptIncluded = systemPromptIncluded || prompt.System != "" + promptsToAdd = append(promptsToAdd, promptInfo{vars: prompt, tokenLen: len(encodedTokens)}) } // ensure the system prompt is included, if not already From e49dc9f3d882ca5a4d56f9b4dea1987c39ab8aef Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Thu, 1 Feb 2024 11:48:11 -0800 Subject: [PATCH 7/8] fix tests --- server/images_test.go | 33 ++++++++++++++++++++++++++------- server/routes_test.go | 3 ++- 2 files changed, 28 insertions(+), 8 deletions(-) diff --git a/server/images_test.go b/server/images_test.go index 0f63a19b..4c2a7cac 100644 --- a/server/images_test.go +++ b/server/images_test.go @@ -238,18 +238,37 @@ func chatHistoryEqual(a, b ChatHistory) bool { if len(a.Prompts) != len(b.Prompts) { return false } - if len(a.CurrentImages) != len(b.CurrentImages) { - return false - } for i, v := range a.Prompts { - if v != b.Prompts[i] { + + if v.First != b.Prompts[i].First { return false } - } - for i, v := range a.CurrentImages { - if !bytes.Equal(v, b.CurrentImages[i]) { + + if v.Response != b.Prompts[i].Response { return false } + + if v.Prompt != b.Prompts[i].Prompt { + return false + } + + if v.System != b.Prompts[i].System { + return false + } + + if len(v.Images) != len(b.Prompts[i].Images) { + return false + } + + for j, img := range v.Images { + if img.ID != b.Prompts[i].Images[j].ID { + return false + } + + if !bytes.Equal(img.Data, b.Prompts[i].Images[j].Data) { + return false + } + } } return a.LastSystem == b.LastSystem } diff --git a/server/routes_test.go b/server/routes_test.go index 9c53dc20..2a0308b8 100644 --- a/server/routes_test.go +++ b/server/routes_test.go @@ -455,7 +455,8 @@ func Test_ChatPrompt(t *testing.T) { NumCtx: tt.numCtx, }, } - got, err := trimmedPrompt(context.Background(), tt.chat, m) + // TODO: add tests for trimming images + got, _, err := trimmedPrompt(context.Background(), tt.chat, m) if tt.wantErr != "" { if err == nil { t.Errorf("ChatPrompt() expected error, got nil") From f3761405c88d36becaba7589362aa976a39aa59c Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Thu, 1 Feb 2024 11:52:42 -0800 Subject: [PATCH 8/8] use image id --- server/images.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/server/images.go b/server/images.go index 68dae0fe..6f59d72d 100644 --- a/server/images.go +++ b/server/images.go @@ -182,9 +182,10 @@ func (m *Model) ChatPrompts(msgs []api.Message) (*ChatHistory, error) { currentVars.Prompt = msg.Content for i := range msg.Images { - currentVars.Prompt += fmt.Sprintf(" [img-%d]", len(images)+i) + id := len(images) + i + currentVars.Prompt += fmt.Sprintf(" [img-%d]", id) currentVars.Images = append(currentVars.Images, llm.ImageData{ - ID: len(images) + i, + ID: id, Data: msg.Images[i], }) }