use llm.ImageData for chat
This commit is contained in:
parent
f11bf0740b
commit
d046bee790
2 changed files with 10 additions and 25 deletions
|
@ -58,17 +58,12 @@ type Message struct {
|
||||||
Content string `json:"content"`
|
Content string `json:"content"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type ImageData struct {
|
|
||||||
Rank int
|
|
||||||
api.ImageData
|
|
||||||
}
|
|
||||||
|
|
||||||
type PromptVars struct {
|
type PromptVars struct {
|
||||||
System string
|
System string
|
||||||
Prompt string
|
Prompt string
|
||||||
Response string
|
Response string
|
||||||
First bool
|
First bool
|
||||||
Images []ImageData
|
Images []llm.ImageData
|
||||||
}
|
}
|
||||||
|
|
||||||
// extractParts extracts the parts of the template before and after the {{.Response}} node.
|
// extractParts extracts the parts of the template before and after the {{.Response}} node.
|
||||||
|
@ -167,7 +162,7 @@ func (m *Model) ChatPrompts(msgs []api.Message) (*ChatHistory, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
prompts := []PromptVars{}
|
prompts := []PromptVars{}
|
||||||
var images []ImageData
|
var images []llm.ImageData
|
||||||
|
|
||||||
for _, msg := range msgs {
|
for _, msg := range msgs {
|
||||||
switch strings.ToLower(msg.Role) {
|
switch strings.ToLower(msg.Role) {
|
||||||
|
@ -188,9 +183,9 @@ func (m *Model) ChatPrompts(msgs []api.Message) (*ChatHistory, error) {
|
||||||
currentVars.Prompt = msg.Content
|
currentVars.Prompt = msg.Content
|
||||||
for i := range msg.Images {
|
for i := range msg.Images {
|
||||||
currentVars.Prompt += fmt.Sprintf(" [img-%d]", len(images)+i)
|
currentVars.Prompt += fmt.Sprintf(" [img-%d]", len(images)+i)
|
||||||
currentVars.Images = append(currentVars.Images, ImageData{
|
currentVars.Images = append(currentVars.Images, llm.ImageData{
|
||||||
Rank: len(images) + i,
|
ID: i,
|
||||||
ImageData: msg.Images[i],
|
Data: msg.Images[i],
|
||||||
})
|
})
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1191,19 +1191,11 @@ func ChatHandler(c *gin.Context) {
|
||||||
ch <- resp
|
ch <- resp
|
||||||
}
|
}
|
||||||
|
|
||||||
var imageData []llm.ImageData
|
|
||||||
for k, v := range images {
|
|
||||||
imageData = append(imageData, llm.ImageData{
|
|
||||||
ID: k,
|
|
||||||
Data: v,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start prediction
|
// Start prediction
|
||||||
predictReq := llm.PredictOpts{
|
predictReq := llm.PredictOpts{
|
||||||
Prompt: prompt,
|
Prompt: prompt,
|
||||||
Format: req.Format,
|
Format: req.Format,
|
||||||
Images: imageData,
|
Images: images,
|
||||||
Options: opts,
|
Options: opts,
|
||||||
}
|
}
|
||||||
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
|
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
|
||||||
|
@ -1250,7 +1242,7 @@ type promptInfo struct {
|
||||||
|
|
||||||
// trimmedPrompt builds a prompt to send to a running model. It ensures the prompt fits within the max context length,
|
// trimmedPrompt builds a prompt to send to a running model. It ensures the prompt fits within the max context length,
|
||||||
// while preserving the most recent system message.
|
// while preserving the most recent system message.
|
||||||
func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string, map[int]api.ImageData, error) {
|
func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string, []llm.ImageData, error) {
|
||||||
if len(chat.Prompts) == 0 {
|
if len(chat.Prompts) == 0 {
|
||||||
return "", nil, nil
|
return "", nil, nil
|
||||||
}
|
}
|
||||||
|
@ -1259,8 +1251,7 @@ func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string
|
||||||
var totalTokenLength int
|
var totalTokenLength int
|
||||||
var systemPromptIncluded bool
|
var systemPromptIncluded bool
|
||||||
|
|
||||||
images := make(map[int]api.ImageData)
|
var images []llm.ImageData
|
||||||
|
|
||||||
// reverse iterate through the prompts to build the prompt string in a way that fits the max context length
|
// reverse iterate through the prompts to build the prompt string in a way that fits the max context length
|
||||||
for i := len(chat.Prompts) - 1; i >= 0; i-- {
|
for i := len(chat.Prompts) - 1; i >= 0; i-- {
|
||||||
promptText, err := promptString(model, chat.Prompts[i], i == len(chat.Prompts)-1)
|
promptText, err := promptString(model, chat.Prompts[i], i == len(chat.Prompts)-1)
|
||||||
|
@ -1281,9 +1272,7 @@ func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string
|
||||||
systemPromptIncluded = systemPromptIncluded || chat.Prompts[i].System != ""
|
systemPromptIncluded = systemPromptIncluded || chat.Prompts[i].System != ""
|
||||||
promptsToAdd = append(promptsToAdd, promptInfo{vars: chat.Prompts[i], tokenLen: len(encodedTokens)})
|
promptsToAdd = append(promptsToAdd, promptInfo{vars: chat.Prompts[i], tokenLen: len(encodedTokens)})
|
||||||
|
|
||||||
for _, image := range chat.Prompts[i].Images {
|
images = append(images, chat.Prompts[i].Images...)
|
||||||
images[image.Rank] = image.ImageData
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ensure the system prompt is included, if not already
|
// ensure the system prompt is included, if not already
|
||||||
|
@ -1306,6 +1295,7 @@ func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string
|
||||||
}
|
}
|
||||||
result = promptText + result
|
result = promptText + result
|
||||||
}
|
}
|
||||||
|
|
||||||
return result, images, nil
|
return result, images, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue