diff --git a/llm/llama.go b/llm/llama.go index 80b4f75b..a53df7c8 100644 --- a/llm/llama.go +++ b/llm/llama.go @@ -62,7 +62,7 @@ const maxRetries = 3 type PredictOpts struct { Prompt string Format string - Images []api.ImageData + Images map[int]api.ImageData Options api.Options } diff --git a/server/images.go b/server/images.go index 26b59e0d..72fdef71 100644 --- a/server/images.go +++ b/server/images.go @@ -58,11 +58,17 @@ type Message struct { Content string `json:"content"` } +type ImageData struct { + Rank int + api.ImageData +} + type PromptVars struct { System string Prompt string Response string First bool + Images []ImageData } // extractParts extracts the parts of the template before and after the {{.Response}} node. @@ -147,15 +153,13 @@ func (m *Model) PostResponseTemplate(p PromptVars) (string, error) { } type ChatHistory struct { - Prompts []PromptVars - CurrentImages []api.ImageData - LastSystem string + Prompts []PromptVars + LastSystem string } // ChatPrompts returns a list of formatted chat prompts from a list of messages func (m *Model) ChatPrompts(msgs []api.Message) (*ChatHistory, error) { // build the prompt from the list of messages - var currentImages []api.ImageData lastSystem := m.System currentVars := PromptVars{ First: true, @@ -163,6 +167,7 @@ func (m *Model) ChatPrompts(msgs []api.Message) (*ChatHistory, error) { } prompts := []PromptVars{} + var images []ImageData for _, msg := range msgs { switch strings.ToLower(msg.Role) { @@ -182,10 +187,15 @@ func (m *Model) ChatPrompts(msgs []api.Message) (*ChatHistory, error) { currentVars.Prompt = msg.Content for i := range msg.Images { - currentVars.Prompt += fmt.Sprintf(" [img-%d]", len(currentImages)+i) + currentVars.Prompt += fmt.Sprintf(" [img-%d]", len(images)+i) + currentVars.Images = append(currentVars.Images, ImageData{ + Rank: len(images) + i, + ImageData: msg.Images[i], + }) + } - currentImages = append(currentImages, msg.Images...) + images = append(images, currentVars.Images...) case "assistant": currentVars.Response = msg.Content prompts = append(prompts, currentVars) @@ -201,9 +211,8 @@ func (m *Model) ChatPrompts(msgs []api.Message) (*ChatHistory, error) { } return &ChatHistory{ - Prompts: prompts, - CurrentImages: currentImages, - LastSystem: lastSystem, + Prompts: prompts, + LastSystem: lastSystem, }, nil } diff --git a/server/routes.go b/server/routes.go index 4dc1be5b..453a7b09 100644 --- a/server/routes.go +++ b/server/routes.go @@ -312,11 +312,16 @@ func GenerateHandler(c *gin.Context) { ch <- resp } + images := make(map[int]api.ImageData) + for i := range req.Images { + images[i] = req.Images[i] + } + // Start prediction predictReq := llm.PredictOpts{ Prompt: prompt, Format: req.Format, - Images: req.Images, + Images: images, Options: opts, } if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil { @@ -1143,7 +1148,8 @@ func ChatHandler(c *gin.Context) { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } - prompt, err := trimmedPrompt(c.Request.Context(), chat, model) + + prompt, images, err := trimmedPrompt(c.Request.Context(), chat, model) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return @@ -1186,7 +1192,7 @@ func ChatHandler(c *gin.Context) { predictReq := llm.PredictOpts{ Prompt: prompt, Format: req.Format, - Images: chat.CurrentImages, + Images: images, Options: opts, } if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil { @@ -1233,25 +1239,27 @@ type promptInfo struct { // trimmedPrompt builds a prompt to send to a running model. It ensures the prompt fits within the max context length, // while preserving the most recent system message. -func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string, error) { +func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string, map[int]api.ImageData, error) { if len(chat.Prompts) == 0 { - return "", nil + return "", nil, nil } var promptsToAdd []promptInfo var totalTokenLength int var systemPromptIncluded bool + images := make(map[int]api.ImageData) + // reverse iterate through the prompts to build the prompt string in a way that fits the max context length for i := len(chat.Prompts) - 1; i >= 0; i-- { promptText, err := promptString(model, chat.Prompts[i], i == len(chat.Prompts)-1) if err != nil { - return "", err + return "", nil, err } encodedTokens, err := loaded.runner.Encode(ctx, promptText) if err != nil { - return "", err + return "", nil, err } if totalTokenLength+len(encodedTokens) > loaded.NumCtx && i != len(chat.Prompts)-1 { @@ -1261,6 +1269,10 @@ func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string totalTokenLength += len(encodedTokens) systemPromptIncluded = systemPromptIncluded || chat.Prompts[i].System != "" promptsToAdd = append(promptsToAdd, promptInfo{vars: chat.Prompts[i], tokenLen: len(encodedTokens)}) + + for _, image := range chat.Prompts[i].Images { + images[image.Rank] = image.ImageData + } } // ensure the system prompt is included, if not already @@ -1268,7 +1280,7 @@ func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string var err error promptsToAdd, err = includeSystemPrompt(ctx, chat.LastSystem, totalTokenLength, promptsToAdd) if err != nil { - return "", err + return "", nil, err } } @@ -1279,11 +1291,11 @@ func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string for i, prompt := range promptsToAdd { promptText, err := promptString(model, prompt.vars, i == 0) if err != nil { - return "", err + return "", nil, err } result = promptText + result } - return result, nil + return result, images, nil } // promptString applies the model template to the prompt