use llm.ImageData for chat
This commit is contained in:
parent
f11bf0740b
commit
d046bee790
2 changed files with 10 additions and 25 deletions
|
@ -58,17 +58,12 @@ type Message struct {
|
|||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type ImageData struct {
|
||||
Rank int
|
||||
api.ImageData
|
||||
}
|
||||
|
||||
type PromptVars struct {
|
||||
System string
|
||||
Prompt string
|
||||
Response string
|
||||
First bool
|
||||
Images []ImageData
|
||||
Images []llm.ImageData
|
||||
}
|
||||
|
||||
// extractParts extracts the parts of the template before and after the {{.Response}} node.
|
||||
|
@ -167,7 +162,7 @@ func (m *Model) ChatPrompts(msgs []api.Message) (*ChatHistory, error) {
|
|||
}
|
||||
|
||||
prompts := []PromptVars{}
|
||||
var images []ImageData
|
||||
var images []llm.ImageData
|
||||
|
||||
for _, msg := range msgs {
|
||||
switch strings.ToLower(msg.Role) {
|
||||
|
@ -188,9 +183,9 @@ func (m *Model) ChatPrompts(msgs []api.Message) (*ChatHistory, error) {
|
|||
currentVars.Prompt = msg.Content
|
||||
for i := range msg.Images {
|
||||
currentVars.Prompt += fmt.Sprintf(" [img-%d]", len(images)+i)
|
||||
currentVars.Images = append(currentVars.Images, ImageData{
|
||||
Rank: len(images) + i,
|
||||
ImageData: msg.Images[i],
|
||||
currentVars.Images = append(currentVars.Images, llm.ImageData{
|
||||
ID: i,
|
||||
Data: msg.Images[i],
|
||||
})
|
||||
|
||||
}
|
||||
|
|
|
@ -1191,19 +1191,11 @@ func ChatHandler(c *gin.Context) {
|
|||
ch <- resp
|
||||
}
|
||||
|
||||
var imageData []llm.ImageData
|
||||
for k, v := range images {
|
||||
imageData = append(imageData, llm.ImageData{
|
||||
ID: k,
|
||||
Data: v,
|
||||
})
|
||||
}
|
||||
|
||||
// Start prediction
|
||||
predictReq := llm.PredictOpts{
|
||||
Prompt: prompt,
|
||||
Format: req.Format,
|
||||
Images: imageData,
|
||||
Images: images,
|
||||
Options: opts,
|
||||
}
|
||||
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
|
||||
|
@ -1250,7 +1242,7 @@ type promptInfo struct {
|
|||
|
||||
// trimmedPrompt builds a prompt to send to a running model. It ensures the prompt fits within the max context length,
|
||||
// while preserving the most recent system message.
|
||||
func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string, map[int]api.ImageData, error) {
|
||||
func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string, []llm.ImageData, error) {
|
||||
if len(chat.Prompts) == 0 {
|
||||
return "", nil, nil
|
||||
}
|
||||
|
@ -1259,8 +1251,7 @@ func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string
|
|||
var totalTokenLength int
|
||||
var systemPromptIncluded bool
|
||||
|
||||
images := make(map[int]api.ImageData)
|
||||
|
||||
var images []llm.ImageData
|
||||
// reverse iterate through the prompts to build the prompt string in a way that fits the max context length
|
||||
for i := len(chat.Prompts) - 1; i >= 0; i-- {
|
||||
promptText, err := promptString(model, chat.Prompts[i], i == len(chat.Prompts)-1)
|
||||
|
@ -1281,9 +1272,7 @@ func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string
|
|||
systemPromptIncluded = systemPromptIncluded || chat.Prompts[i].System != ""
|
||||
promptsToAdd = append(promptsToAdd, promptInfo{vars: chat.Prompts[i], tokenLen: len(encodedTokens)})
|
||||
|
||||
for _, image := range chat.Prompts[i].Images {
|
||||
images[image.Rank] = image.ImageData
|
||||
}
|
||||
images = append(images, chat.Prompts[i].Images...)
|
||||
}
|
||||
|
||||
// ensure the system prompt is included, if not already
|
||||
|
@ -1306,6 +1295,7 @@ func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string
|
|||
}
|
||||
result = promptText + result
|
||||
}
|
||||
|
||||
return result, images, nil
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in a new issue