Check image filetype in api handlers (#2467)

This commit is contained in:
Jeffrey Morgan 2024-02-12 11:16:20 -08:00 committed by GitHub
parent 26b13fc33c
commit 1f9078d6ae
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 24 additions and 1 deletions

View file

@ -625,7 +625,7 @@ func getImageData(filePath string) ([]byte, error) {
} }
contentType := http.DetectContentType(buf) contentType := http.DetectContentType(buf)
allowedTypes := []string{"image/jpeg", "image/jpg", "image/svg+xml", "image/png"} allowedTypes := []string{"image/jpeg", "image/jpg", "image/png"}
if !slices.Contains(allowedTypes, contentType) { if !slices.Contains(allowedTypes, contentType) {
return nil, fmt.Errorf("invalid image type: %s", contentType) return nil, fmt.Errorf("invalid image type: %s", contentType)
} }

View file

@ -22,6 +22,7 @@ import (
"github.com/gin-contrib/cors" "github.com/gin-contrib/cors"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"golang.org/x/exp/slices"
"github.com/jmorganca/ollama/api" "github.com/jmorganca/ollama/api"
"github.com/jmorganca/ollama/gpu" "github.com/jmorganca/ollama/gpu"
@ -136,6 +137,12 @@ func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options
return opts, nil return opts, nil
} }
func isSupportedImageType(image []byte) bool {
contentType := http.DetectContentType(image)
allowedTypes := []string{"image/jpeg", "image/jpg", "image/png"}
return slices.Contains(allowedTypes, contentType)
}
func GenerateHandler(c *gin.Context) { func GenerateHandler(c *gin.Context) {
loaded.mu.Lock() loaded.mu.Lock()
defer loaded.mu.Unlock() defer loaded.mu.Unlock()
@ -166,6 +173,13 @@ func GenerateHandler(c *gin.Context) {
return return
} }
for _, img := range req.Images {
if !isSupportedImageType(img) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "unsupported image format"})
return
}
}
model, err := GetModel(req.Model) model, err := GetModel(req.Model)
if err != nil { if err != nil {
var pErr *fs.PathError var pErr *fs.PathError
@ -1103,6 +1117,15 @@ func ChatHandler(c *gin.Context) {
return return
} }
for _, msg := range req.Messages {
for _, img := range msg.Images {
if !isSupportedImageType(img) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "unsupported image format"})
return
}
}
}
model, err := GetModel(req.Model) model, err := GetModel(req.Model)
if err != nil { if err != nil {
var pErr *fs.PathError var pErr *fs.PathError