diff --git a/server/images.go b/server/images.go index 72fdef71..a68a6699 100644 --- a/server/images.go +++ b/server/images.go @@ -58,17 +58,12 @@ type Message struct { Content string `json:"content"` } -type ImageData struct { - Rank int - api.ImageData -} - type PromptVars struct { System string Prompt string Response string First bool - Images []ImageData + Images []llm.ImageData } // extractParts extracts the parts of the template before and after the {{.Response}} node. @@ -167,7 +162,7 @@ func (m *Model) ChatPrompts(msgs []api.Message) (*ChatHistory, error) { } prompts := []PromptVars{} - var images []ImageData + var images []llm.ImageData for _, msg := range msgs { switch strings.ToLower(msg.Role) { @@ -188,9 +183,9 @@ func (m *Model) ChatPrompts(msgs []api.Message) (*ChatHistory, error) { currentVars.Prompt = msg.Content for i := range msg.Images { currentVars.Prompt += fmt.Sprintf(" [img-%d]", len(images)+i) - currentVars.Images = append(currentVars.Images, ImageData{ - Rank: len(images) + i, - ImageData: msg.Images[i], + currentVars.Images = append(currentVars.Images, llm.ImageData{ + ID: i, + Data: msg.Images[i], }) } diff --git a/server/routes.go b/server/routes.go index 6d155aee..f29d9b2b 100644 --- a/server/routes.go +++ b/server/routes.go @@ -1191,19 +1191,11 @@ func ChatHandler(c *gin.Context) { ch <- resp } - var imageData []llm.ImageData - for k, v := range images { - imageData = append(imageData, llm.ImageData{ - ID: k, - Data: v, - }) - } - // Start prediction predictReq := llm.PredictOpts{ Prompt: prompt, Format: req.Format, - Images: imageData, + Images: images, Options: opts, } if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil { @@ -1250,7 +1242,7 @@ type promptInfo struct { // trimmedPrompt builds a prompt to send to a running model. It ensures the prompt fits within the max context length, // while preserving the most recent system message. -func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string, map[int]api.ImageData, error) { +func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string, []llm.ImageData, error) { if len(chat.Prompts) == 0 { return "", nil, nil } @@ -1259,8 +1251,7 @@ func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string var totalTokenLength int var systemPromptIncluded bool - images := make(map[int]api.ImageData) - + var images []llm.ImageData // reverse iterate through the prompts to build the prompt string in a way that fits the max context length for i := len(chat.Prompts) - 1; i >= 0; i-- { promptText, err := promptString(model, chat.Prompts[i], i == len(chat.Prompts)-1) @@ -1281,9 +1272,7 @@ func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string systemPromptIncluded = systemPromptIncluded || chat.Prompts[i].System != "" promptsToAdd = append(promptsToAdd, promptInfo{vars: chat.Prompts[i], tokenLen: len(encodedTokens)}) - for _, image := range chat.Prompts[i].Images { - images[image.Rank] = image.ImageData - } + images = append(images, chat.Prompts[i].Images...) } // ensure the system prompt is included, if not already @@ -1306,6 +1295,7 @@ func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string } result = promptText + result } + return result, images, nil }