trim images
This commit is contained in:
parent
b4e11be8ef
commit
8450bf66e6
3 changed files with 41 additions and 20 deletions
|
@ -62,7 +62,7 @@ const maxRetries = 3
|
|||
type PredictOpts struct {
|
||||
Prompt string
|
||||
Format string
|
||||
Images []api.ImageData
|
||||
Images map[int]api.ImageData
|
||||
Options api.Options
|
||||
}
|
||||
|
||||
|
|
|
@ -58,11 +58,17 @@ type Message struct {
|
|||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type ImageData struct {
|
||||
Rank int
|
||||
api.ImageData
|
||||
}
|
||||
|
||||
type PromptVars struct {
|
||||
System string
|
||||
Prompt string
|
||||
Response string
|
||||
First bool
|
||||
Images []ImageData
|
||||
}
|
||||
|
||||
// extractParts extracts the parts of the template before and after the {{.Response}} node.
|
||||
|
@ -147,15 +153,13 @@ func (m *Model) PostResponseTemplate(p PromptVars) (string, error) {
|
|||
}
|
||||
|
||||
type ChatHistory struct {
|
||||
Prompts []PromptVars
|
||||
CurrentImages []api.ImageData
|
||||
LastSystem string
|
||||
Prompts []PromptVars
|
||||
LastSystem string
|
||||
}
|
||||
|
||||
// ChatPrompts returns a list of formatted chat prompts from a list of messages
|
||||
func (m *Model) ChatPrompts(msgs []api.Message) (*ChatHistory, error) {
|
||||
// build the prompt from the list of messages
|
||||
var currentImages []api.ImageData
|
||||
lastSystem := m.System
|
||||
currentVars := PromptVars{
|
||||
First: true,
|
||||
|
@ -163,6 +167,7 @@ func (m *Model) ChatPrompts(msgs []api.Message) (*ChatHistory, error) {
|
|||
}
|
||||
|
||||
prompts := []PromptVars{}
|
||||
var images []ImageData
|
||||
|
||||
for _, msg := range msgs {
|
||||
switch strings.ToLower(msg.Role) {
|
||||
|
@ -182,10 +187,15 @@ func (m *Model) ChatPrompts(msgs []api.Message) (*ChatHistory, error) {
|
|||
|
||||
currentVars.Prompt = msg.Content
|
||||
for i := range msg.Images {
|
||||
currentVars.Prompt += fmt.Sprintf(" [img-%d]", len(currentImages)+i)
|
||||
currentVars.Prompt += fmt.Sprintf(" [img-%d]", len(images)+i)
|
||||
currentVars.Images = append(currentVars.Images, ImageData{
|
||||
Rank: len(images) + i,
|
||||
ImageData: msg.Images[i],
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
currentImages = append(currentImages, msg.Images...)
|
||||
images = append(images, currentVars.Images...)
|
||||
case "assistant":
|
||||
currentVars.Response = msg.Content
|
||||
prompts = append(prompts, currentVars)
|
||||
|
@ -201,9 +211,8 @@ func (m *Model) ChatPrompts(msgs []api.Message) (*ChatHistory, error) {
|
|||
}
|
||||
|
||||
return &ChatHistory{
|
||||
Prompts: prompts,
|
||||
CurrentImages: currentImages,
|
||||
LastSystem: lastSystem,
|
||||
Prompts: prompts,
|
||||
LastSystem: lastSystem,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -312,11 +312,16 @@ func GenerateHandler(c *gin.Context) {
|
|||
ch <- resp
|
||||
}
|
||||
|
||||
images := make(map[int]api.ImageData)
|
||||
for i := range req.Images {
|
||||
images[i] = req.Images[i]
|
||||
}
|
||||
|
||||
// Start prediction
|
||||
predictReq := llm.PredictOpts{
|
||||
Prompt: prompt,
|
||||
Format: req.Format,
|
||||
Images: req.Images,
|
||||
Images: images,
|
||||
Options: opts,
|
||||
}
|
||||
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
|
||||
|
@ -1143,7 +1148,8 @@ func ChatHandler(c *gin.Context) {
|
|||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
prompt, err := trimmedPrompt(c.Request.Context(), chat, model)
|
||||
|
||||
prompt, images, err := trimmedPrompt(c.Request.Context(), chat, model)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
|
@ -1186,7 +1192,7 @@ func ChatHandler(c *gin.Context) {
|
|||
predictReq := llm.PredictOpts{
|
||||
Prompt: prompt,
|
||||
Format: req.Format,
|
||||
Images: chat.CurrentImages,
|
||||
Images: images,
|
||||
Options: opts,
|
||||
}
|
||||
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
|
||||
|
@ -1233,25 +1239,27 @@ type promptInfo struct {
|
|||
|
||||
// trimmedPrompt builds a prompt to send to a running model. It ensures the prompt fits within the max context length,
|
||||
// while preserving the most recent system message.
|
||||
func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string, error) {
|
||||
func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string, map[int]api.ImageData, error) {
|
||||
if len(chat.Prompts) == 0 {
|
||||
return "", nil
|
||||
return "", nil, nil
|
||||
}
|
||||
|
||||
var promptsToAdd []promptInfo
|
||||
var totalTokenLength int
|
||||
var systemPromptIncluded bool
|
||||
|
||||
images := make(map[int]api.ImageData)
|
||||
|
||||
// reverse iterate through the prompts to build the prompt string in a way that fits the max context length
|
||||
for i := len(chat.Prompts) - 1; i >= 0; i-- {
|
||||
promptText, err := promptString(model, chat.Prompts[i], i == len(chat.Prompts)-1)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
encodedTokens, err := loaded.runner.Encode(ctx, promptText)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
if totalTokenLength+len(encodedTokens) > loaded.NumCtx && i != len(chat.Prompts)-1 {
|
||||
|
@ -1261,6 +1269,10 @@ func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string
|
|||
totalTokenLength += len(encodedTokens)
|
||||
systemPromptIncluded = systemPromptIncluded || chat.Prompts[i].System != ""
|
||||
promptsToAdd = append(promptsToAdd, promptInfo{vars: chat.Prompts[i], tokenLen: len(encodedTokens)})
|
||||
|
||||
for _, image := range chat.Prompts[i].Images {
|
||||
images[image.Rank] = image.ImageData
|
||||
}
|
||||
}
|
||||
|
||||
// ensure the system prompt is included, if not already
|
||||
|
@ -1268,7 +1280,7 @@ func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string
|
|||
var err error
|
||||
promptsToAdd, err = includeSystemPrompt(ctx, chat.LastSystem, totalTokenLength, promptsToAdd)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return "", nil, err
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1279,11 +1291,11 @@ func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string
|
|||
for i, prompt := range promptsToAdd {
|
||||
promptText, err := promptString(model, prompt.vars, i == 0)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return "", nil, err
|
||||
}
|
||||
result = promptText + result
|
||||
}
|
||||
return result, nil
|
||||
return result, images, nil
|
||||
}
|
||||
|
||||
// promptString applies the model template to the prompt
|
||||
|
|
Loading…
Add table
Reference in a new issue