use llm.ImageData for chat

This commit is contained in:
Michael Yang 2024-01-31 19:18:25 -08:00
parent f11bf0740b
commit d046bee790
2 changed files with 10 additions and 25 deletions

View file

@ -58,17 +58,12 @@ type Message struct {
Content string `json:"content"`
}
type ImageData struct {
Rank int
api.ImageData
}
type PromptVars struct {
System string
Prompt string
Response string
First bool
Images []ImageData
Images []llm.ImageData
}
// extractParts extracts the parts of the template before and after the {{.Response}} node.
@ -167,7 +162,7 @@ func (m *Model) ChatPrompts(msgs []api.Message) (*ChatHistory, error) {
}
prompts := []PromptVars{}
var images []ImageData
var images []llm.ImageData
for _, msg := range msgs {
switch strings.ToLower(msg.Role) {
@ -188,9 +183,9 @@ func (m *Model) ChatPrompts(msgs []api.Message) (*ChatHistory, error) {
currentVars.Prompt = msg.Content
for i := range msg.Images {
currentVars.Prompt += fmt.Sprintf(" [img-%d]", len(images)+i)
currentVars.Images = append(currentVars.Images, ImageData{
Rank: len(images) + i,
ImageData: msg.Images[i],
currentVars.Images = append(currentVars.Images, llm.ImageData{
ID: i,
Data: msg.Images[i],
})
}

View file

@ -1191,19 +1191,11 @@ func ChatHandler(c *gin.Context) {
ch <- resp
}
var imageData []llm.ImageData
for k, v := range images {
imageData = append(imageData, llm.ImageData{
ID: k,
Data: v,
})
}
// Start prediction
predictReq := llm.PredictOpts{
Prompt: prompt,
Format: req.Format,
Images: imageData,
Images: images,
Options: opts,
}
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
@ -1250,7 +1242,7 @@ type promptInfo struct {
// trimmedPrompt builds a prompt to send to a running model. It ensures the prompt fits within the max context length,
// while preserving the most recent system message.
func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string, map[int]api.ImageData, error) {
func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string, []llm.ImageData, error) {
if len(chat.Prompts) == 0 {
return "", nil, nil
}
@ -1259,8 +1251,7 @@ func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string
var totalTokenLength int
var systemPromptIncluded bool
images := make(map[int]api.ImageData)
var images []llm.ImageData
// reverse iterate through the prompts to build the prompt string in a way that fits the max context length
for i := len(chat.Prompts) - 1; i >= 0; i-- {
promptText, err := promptString(model, chat.Prompts[i], i == len(chat.Prompts)-1)
@ -1281,9 +1272,7 @@ func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string
systemPromptIncluded = systemPromptIncluded || chat.Prompts[i].System != ""
promptsToAdd = append(promptsToAdd, promptInfo{vars: chat.Prompts[i], tokenLen: len(encodedTokens)})
for _, image := range chat.Prompts[i].Images {
images[image.Rank] = image.ImageData
}
images = append(images, chat.Prompts[i].Images...)
}
// ensure the system prompt is included, if not already
@ -1306,6 +1295,7 @@ func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string
}
result = promptText + result
}
return result, images, nil
}