trim images

This commit is contained in:
Michael Yang 2024-01-31 17:39:38 -08:00
parent b4e11be8ef
commit 8450bf66e6
3 changed files with 41 additions and 20 deletions

View file

@ -62,7 +62,7 @@ const maxRetries = 3
type PredictOpts struct { type PredictOpts struct {
Prompt string Prompt string
Format string Format string
Images []api.ImageData Images map[int]api.ImageData
Options api.Options Options api.Options
} }

View file

@ -58,11 +58,17 @@ type Message struct {
Content string `json:"content"` Content string `json:"content"`
} }
type ImageData struct {
Rank int
api.ImageData
}
type PromptVars struct { type PromptVars struct {
System string System string
Prompt string Prompt string
Response string Response string
First bool First bool
Images []ImageData
} }
// extractParts extracts the parts of the template before and after the {{.Response}} node. // extractParts extracts the parts of the template before and after the {{.Response}} node.
@ -148,14 +154,12 @@ func (m *Model) PostResponseTemplate(p PromptVars) (string, error) {
type ChatHistory struct { type ChatHistory struct {
Prompts []PromptVars Prompts []PromptVars
CurrentImages []api.ImageData
LastSystem string LastSystem string
} }
// ChatPrompts returns a list of formatted chat prompts from a list of messages // ChatPrompts returns a list of formatted chat prompts from a list of messages
func (m *Model) ChatPrompts(msgs []api.Message) (*ChatHistory, error) { func (m *Model) ChatPrompts(msgs []api.Message) (*ChatHistory, error) {
// build the prompt from the list of messages // build the prompt from the list of messages
var currentImages []api.ImageData
lastSystem := m.System lastSystem := m.System
currentVars := PromptVars{ currentVars := PromptVars{
First: true, First: true,
@ -163,6 +167,7 @@ func (m *Model) ChatPrompts(msgs []api.Message) (*ChatHistory, error) {
} }
prompts := []PromptVars{} prompts := []PromptVars{}
var images []ImageData
for _, msg := range msgs { for _, msg := range msgs {
switch strings.ToLower(msg.Role) { switch strings.ToLower(msg.Role) {
@ -182,10 +187,15 @@ func (m *Model) ChatPrompts(msgs []api.Message) (*ChatHistory, error) {
currentVars.Prompt = msg.Content currentVars.Prompt = msg.Content
for i := range msg.Images { for i := range msg.Images {
currentVars.Prompt += fmt.Sprintf(" [img-%d]", len(currentImages)+i) currentVars.Prompt += fmt.Sprintf(" [img-%d]", len(images)+i)
currentVars.Images = append(currentVars.Images, ImageData{
Rank: len(images) + i,
ImageData: msg.Images[i],
})
} }
currentImages = append(currentImages, msg.Images...) images = append(images, currentVars.Images...)
case "assistant": case "assistant":
currentVars.Response = msg.Content currentVars.Response = msg.Content
prompts = append(prompts, currentVars) prompts = append(prompts, currentVars)
@ -202,7 +212,6 @@ func (m *Model) ChatPrompts(msgs []api.Message) (*ChatHistory, error) {
return &ChatHistory{ return &ChatHistory{
Prompts: prompts, Prompts: prompts,
CurrentImages: currentImages,
LastSystem: lastSystem, LastSystem: lastSystem,
}, nil }, nil
} }

View file

@ -312,11 +312,16 @@ func GenerateHandler(c *gin.Context) {
ch <- resp ch <- resp
} }
images := make(map[int]api.ImageData)
for i := range req.Images {
images[i] = req.Images[i]
}
// Start prediction // Start prediction
predictReq := llm.PredictOpts{ predictReq := llm.PredictOpts{
Prompt: prompt, Prompt: prompt,
Format: req.Format, Format: req.Format,
Images: req.Images, Images: images,
Options: opts, Options: opts,
} }
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil { if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
@ -1143,7 +1148,8 @@ func ChatHandler(c *gin.Context) {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return return
} }
prompt, err := trimmedPrompt(c.Request.Context(), chat, model)
prompt, images, err := trimmedPrompt(c.Request.Context(), chat, model)
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
@ -1186,7 +1192,7 @@ func ChatHandler(c *gin.Context) {
predictReq := llm.PredictOpts{ predictReq := llm.PredictOpts{
Prompt: prompt, Prompt: prompt,
Format: req.Format, Format: req.Format,
Images: chat.CurrentImages, Images: images,
Options: opts, Options: opts,
} }
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil { if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
@ -1233,25 +1239,27 @@ type promptInfo struct {
// trimmedPrompt builds a prompt to send to a running model. It ensures the prompt fits within the max context length, // trimmedPrompt builds a prompt to send to a running model. It ensures the prompt fits within the max context length,
// while preserving the most recent system message. // while preserving the most recent system message.
func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string, error) { func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string, map[int]api.ImageData, error) {
if len(chat.Prompts) == 0 { if len(chat.Prompts) == 0 {
return "", nil return "", nil, nil
} }
var promptsToAdd []promptInfo var promptsToAdd []promptInfo
var totalTokenLength int var totalTokenLength int
var systemPromptIncluded bool var systemPromptIncluded bool
images := make(map[int]api.ImageData)
// reverse iterate through the prompts to build the prompt string in a way that fits the max context length // reverse iterate through the prompts to build the prompt string in a way that fits the max context length
for i := len(chat.Prompts) - 1; i >= 0; i-- { for i := len(chat.Prompts) - 1; i >= 0; i-- {
promptText, err := promptString(model, chat.Prompts[i], i == len(chat.Prompts)-1) promptText, err := promptString(model, chat.Prompts[i], i == len(chat.Prompts)-1)
if err != nil { if err != nil {
return "", err return "", nil, err
} }
encodedTokens, err := loaded.runner.Encode(ctx, promptText) encodedTokens, err := loaded.runner.Encode(ctx, promptText)
if err != nil { if err != nil {
return "", err return "", nil, err
} }
if totalTokenLength+len(encodedTokens) > loaded.NumCtx && i != len(chat.Prompts)-1 { if totalTokenLength+len(encodedTokens) > loaded.NumCtx && i != len(chat.Prompts)-1 {
@ -1261,6 +1269,10 @@ func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string
totalTokenLength += len(encodedTokens) totalTokenLength += len(encodedTokens)
systemPromptIncluded = systemPromptIncluded || chat.Prompts[i].System != "" systemPromptIncluded = systemPromptIncluded || chat.Prompts[i].System != ""
promptsToAdd = append(promptsToAdd, promptInfo{vars: chat.Prompts[i], tokenLen: len(encodedTokens)}) promptsToAdd = append(promptsToAdd, promptInfo{vars: chat.Prompts[i], tokenLen: len(encodedTokens)})
for _, image := range chat.Prompts[i].Images {
images[image.Rank] = image.ImageData
}
} }
// ensure the system prompt is included, if not already // ensure the system prompt is included, if not already
@ -1268,7 +1280,7 @@ func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string
var err error var err error
promptsToAdd, err = includeSystemPrompt(ctx, chat.LastSystem, totalTokenLength, promptsToAdd) promptsToAdd, err = includeSystemPrompt(ctx, chat.LastSystem, totalTokenLength, promptsToAdd)
if err != nil { if err != nil {
return "", err return "", nil, err
} }
} }
@ -1279,11 +1291,11 @@ func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string
for i, prompt := range promptsToAdd { for i, prompt := range promptsToAdd {
promptText, err := promptString(model, prompt.vars, i == 0) promptText, err := promptString(model, prompt.vars, i == 0)
if err != nil { if err != nil {
return "", err return "", nil, err
} }
result = promptText + result result = promptText + result
} }
return result, nil return result, images, nil
} }
// promptString applies the model template to the prompt // promptString applies the model template to the prompt