From f11bf0740bfc4a9653a4c59bf3cb9a00361654b1 Mon Sep 17 00:00:00 2001 From: Jeffrey Morgan Date: Wed, 31 Jan 2024 18:56:12 -0800 Subject: [PATCH] use `llm.ImageData` --- llm/dyn_ext_server.go | 9 +++------ llm/llama.go | 2 +- server/routes.go | 17 ++++++++++++++--- 3 files changed, 18 insertions(+), 10 deletions(-) diff --git a/llm/dyn_ext_server.go b/llm/dyn_ext_server.go index 782fd382..f7e19a7b 100644 --- a/llm/dyn_ext_server.go +++ b/llm/dyn_ext_server.go @@ -161,13 +161,10 @@ func newDynExtServer(library, model string, adapters, projectors []string, opts func (llm *dynExtServer) Predict(ctx context.Context, predict PredictOpts, fn func(PredictResult)) error { resp := newExtServerResp(128) defer freeExtServerResp(resp) - var imageData []ImageData + if len(predict.Images) > 0 { - for cnt, i := range predict.Images { - imageData = append(imageData, ImageData{Data: i, ID: cnt}) - } + slog.Info(fmt.Sprintf("loaded %d images", len(predict.Images))) } - slog.Info(fmt.Sprintf("loaded %d images", len(imageData))) request := map[string]any{ "prompt": predict.Prompt, @@ -189,7 +186,7 @@ func (llm *dynExtServer) Predict(ctx context.Context, predict PredictOpts, fn fu "penalize_nl": predict.Options.PenalizeNewline, "seed": predict.Options.Seed, "stop": predict.Options.Stop, - "image_data": imageData, + "image_data": predict.Images, "cache_prompt": true, } diff --git a/llm/llama.go b/llm/llama.go index a53df7c8..a5d2036a 100644 --- a/llm/llama.go +++ b/llm/llama.go @@ -62,7 +62,7 @@ const maxRetries = 3 type PredictOpts struct { Prompt string Format string - Images map[int]api.ImageData + Images []ImageData Options api.Options } diff --git a/server/routes.go b/server/routes.go index 453a7b09..6d155aee 100644 --- a/server/routes.go +++ b/server/routes.go @@ -312,9 +312,12 @@ func GenerateHandler(c *gin.Context) { ch <- resp } - images := make(map[int]api.ImageData) + var images []llm.ImageData for i := range req.Images { - images[i] = req.Images[i] + images = append(images, llm.ImageData{ + ID: i, + Data: req.Images[i], + }) } // Start prediction @@ -1188,11 +1191,19 @@ func ChatHandler(c *gin.Context) { ch <- resp } + var imageData []llm.ImageData + for k, v := range images { + imageData = append(imageData, llm.ImageData{ + ID: k, + Data: v, + }) + } + // Start prediction predictReq := llm.PredictOpts{ Prompt: prompt, Format: req.Format, - Images: images, + Images: imageData, Options: opts, } if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {