use llm.ImageData

This commit is contained in:
Jeffrey Morgan 2024-01-31 18:56:12 -08:00 committed by Michael Yang
parent 8450bf66e6
commit f11bf0740b
3 changed files with 18 additions and 10 deletions

View file

@ -161,13 +161,10 @@ func newDynExtServer(library, model string, adapters, projectors []string, opts
func (llm *dynExtServer) Predict(ctx context.Context, predict PredictOpts, fn func(PredictResult)) error { func (llm *dynExtServer) Predict(ctx context.Context, predict PredictOpts, fn func(PredictResult)) error {
resp := newExtServerResp(128) resp := newExtServerResp(128)
defer freeExtServerResp(resp) defer freeExtServerResp(resp)
var imageData []ImageData
if len(predict.Images) > 0 { if len(predict.Images) > 0 {
for cnt, i := range predict.Images { slog.Info(fmt.Sprintf("loaded %d images", len(predict.Images)))
imageData = append(imageData, ImageData{Data: i, ID: cnt})
} }
}
slog.Info(fmt.Sprintf("loaded %d images", len(imageData)))
request := map[string]any{ request := map[string]any{
"prompt": predict.Prompt, "prompt": predict.Prompt,
@ -189,7 +186,7 @@ func (llm *dynExtServer) Predict(ctx context.Context, predict PredictOpts, fn fu
"penalize_nl": predict.Options.PenalizeNewline, "penalize_nl": predict.Options.PenalizeNewline,
"seed": predict.Options.Seed, "seed": predict.Options.Seed,
"stop": predict.Options.Stop, "stop": predict.Options.Stop,
"image_data": imageData, "image_data": predict.Images,
"cache_prompt": true, "cache_prompt": true,
} }

View file

@ -62,7 +62,7 @@ const maxRetries = 3
type PredictOpts struct { type PredictOpts struct {
Prompt string Prompt string
Format string Format string
Images map[int]api.ImageData Images []ImageData
Options api.Options Options api.Options
} }

View file

@ -312,9 +312,12 @@ func GenerateHandler(c *gin.Context) {
ch <- resp ch <- resp
} }
images := make(map[int]api.ImageData) var images []llm.ImageData
for i := range req.Images { for i := range req.Images {
images[i] = req.Images[i] images = append(images, llm.ImageData{
ID: i,
Data: req.Images[i],
})
} }
// Start prediction // Start prediction
@ -1188,11 +1191,19 @@ func ChatHandler(c *gin.Context) {
ch <- resp ch <- resp
} }
var imageData []llm.ImageData
for k, v := range images {
imageData = append(imageData, llm.ImageData{
ID: k,
Data: v,
})
}
// Start prediction // Start prediction
predictReq := llm.PredictOpts{ predictReq := llm.PredictOpts{
Prompt: prompt, Prompt: prompt,
Format: req.Format, Format: req.Format,
Images: images, Images: imageData,
Options: opts, Options: opts,
} }
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil { if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {