Multimodal support (#1216)

---------

Co-authored-by: Matt Apperson <mattapperson@Matts-MacBook-Pro.local>
This commit is contained in:
Patrick Devine 2023-12-11 13:56:22 -08:00 committed by GitHub
parent 7a1b37ac64
commit 910e9401d0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 235 additions and 28 deletions

View file

@ -31,6 +31,8 @@ func (e StatusError) Error() string {
}
}
type ImageData []byte
type GenerateRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
@ -40,6 +42,7 @@ type GenerateRequest struct {
Stream *bool `json:"stream,omitempty"`
Raw bool `json:"raw,omitempty"`
Format string `json:"format"`
Images []ImageData `json:"images,omitempty"`
Options map[string]interface{} `json:"options"`
}
@ -153,6 +156,7 @@ type ShowResponse struct {
Parameters string `json:"parameters,omitempty"`
Template string `json:"template,omitempty"`
System string `json:"system,omitempty"`
Details ModelDetails `json:"details,omitempty"`
}
type CopyRequest struct {
@ -192,6 +196,7 @@ type ModelResponse struct {
ModifiedAt time.Time `json:"modified_at"`
Size int64 `json:"size"`
Digest string `json:"digest"`
Details ModelDetails `json:"details,omitempty"`
}
type TokenResponse struct {
@ -209,6 +214,14 @@ type GenerateResponse struct {
Metrics
}
type ModelDetails struct {
Format string `json:"format"`
Family string `json:"family"`
Families []string `json:"families"`
ParameterSize string `json:"parameter_size"`
QuantizationLevel string `json:"quantization_level"`
}
func (m *Metrics) Summary() {
if m.TotalDuration > 0 {
fmt.Fprintf(os.Stderr, "total duration: %v\n", m.TotalDuration)

View file

@ -17,7 +17,9 @@ import (
"os/exec"
"os/signal"
"path/filepath"
"regexp"
"runtime"
"slices"
"strings"
"syscall"
"time"
@ -36,6 +38,8 @@ import (
"github.com/jmorganca/ollama/version"
)
type ImageData []byte
func CreateHandler(cmd *cobra.Command, args []string) error {
filename, _ := cmd.Flags().GetString("file")
filename, err := filepath.Abs(filename)
@ -418,6 +422,7 @@ func RunGenerate(cmd *cobra.Command, args []string) error {
Model: args[0],
WordWrap: os.Getenv("TERM") == "xterm-256color",
Options: map[string]interface{}{},
Images: []ImageData{},
}
format, err := cmd.Flags().GetString("format")
@ -427,7 +432,6 @@ func RunGenerate(cmd *cobra.Command, args []string) error {
opts.Format = format
prompts := args[1:]
// prepend stdin to the prompt if provided
if !term.IsTerminal(int(os.Stdin.Fd())) {
in, err := io.ReadAll(os.Stdin)
@ -466,6 +470,7 @@ type generateOptions struct {
Format string
System string
Template string
Images []ImageData
Options map[string]interface{}
}
@ -551,6 +556,10 @@ func generate(cmd *cobra.Command, opts generateOptions) error {
return nil
}
images := make([]api.ImageData, 0)
for _, i := range opts.Images {
images = append(images, api.ImageData(i))
}
request := api.GenerateRequest{
Model: opts.Model,
Prompt: opts.Prompt,
@ -559,6 +568,7 @@ func generate(cmd *cobra.Command, opts generateOptions) error {
System: opts.System,
Template: opts.Template,
Options: opts.Options,
Images: images,
}
if err := client.Generate(ctx, &request, fn); err != nil {
@ -585,7 +595,9 @@ func generate(cmd *cobra.Command, opts generateOptions) error {
latest.Summary()
}
cmd.SetContext(context.WithValue(cmd.Context(), generateContextKey("context"), latest.Context))
ctx = context.WithValue(cmd.Context(), generateContextKey("context"), latest.Context)
cmd.SetContext(ctx)
return nil
}
@ -598,11 +610,31 @@ const (
MultilineTemplate
)
func modelIsMultiModal(cmd *cobra.Command, name string) bool {
// get model details
client, err := api.ClientFromEnvironment()
if err != nil {
fmt.Println("error: couldn't connect to ollama server")
return false
}
req := api.ShowRequest{Name: name}
resp, err := client.Show(cmd.Context(), &req)
if err != nil {
return false
}
return slices.Contains(resp.Details.Families, "clip")
}
func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
multiModal := modelIsMultiModal(cmd, opts.Model)
// load the model
loadOpts := generateOptions{
Model: opts.Model,
Prompt: "",
Images: []ImageData{},
}
if err := generate(cmd, loadOpts); err != nil {
return err
@ -902,6 +934,26 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
if len(prompt) > 0 && multiline == MultilineNone {
opts.Prompt = prompt
if multiModal {
newPrompt, images, err := extractFileNames(prompt)
if err != nil {
return err
}
opts.Prompt = newPrompt
// reset the context if we find another image
if len(images) > 0 {
opts.Images = images
ctx := cmd.Context()
ctx = context.WithValue(ctx, generateContextKey("context"), []int{})
cmd.SetContext(ctx)
}
if len(opts.Images) == 0 {
fmt.Println("This model requires you to add a jpeg, png, or svg image.\n")
prompt = ""
continue
}
}
if err := generate(cmd, opts); err != nil {
return err
}
@ -911,6 +963,57 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
}
}
func normalizeFilePath(fp string) string {
// Define a map of escaped characters and their replacements
replacements := map[string]string{
"\\ ": " ", // Escaped space
"\\(": "(", // Escaped left parenthesis
"\\)": ")", // Escaped right parenthesis
"\\[": "[", // Escaped left square bracket
"\\]": "]", // Escaped right square bracket
"\\{": "{", // Escaped left curly brace
"\\}": "}", // Escaped right curly brace
"\\$": "$", // Escaped dollar sign
"\\&": "&", // Escaped ampersand
"\\;": ";", // Escaped semicolon
"\\'": "'", // Escaped single quote
"\\\\": "\\", // Escaped backslash
"\\*": "*", // Escaped asterisk
"\\?": "?", // Escaped question mark
}
for escaped, actual := range replacements {
fp = strings.ReplaceAll(fp, escaped, actual)
}
return fp
}
func extractFileNames(input string) (string, []ImageData, error) {
// Regex to match file paths starting with / or ./ and include escaped spaces (\ or %20)
// and followed by more characters and a file extension
regexPattern := `(?:\./|/)[\S\\ ]+?\.(?i:jpg|jpeg|png|svg)\b`
re := regexp.MustCompile(regexPattern)
filePaths := re.FindAllString(input, -1)
var imgs []ImageData
for _, fp := range filePaths {
nfp := normalizeFilePath(fp)
data, err := getImageData(nfp)
if err != nil {
if os.IsNotExist(err) {
continue
}
fmt.Printf("Couldn't process image: %q\n", err)
return "", imgs, err
}
fmt.Printf("Added image '%s'\n", nfp)
input = strings.ReplaceAll(input, fp, "")
imgs = append(imgs, data)
}
return input, imgs, nil
}
func RunServer(cmd *cobra.Command, _ []string) error {
host, port, err := net.SplitHostPort(os.Getenv("OLLAMA_HOST"))
if err != nil {
@ -937,6 +1040,50 @@ func RunServer(cmd *cobra.Command, _ []string) error {
return server.Serve(ln, origins)
}
func getImageData(filePath string) ([]byte, error) {
file, err := os.Open(filePath)
if err != nil {
return nil, err
}
defer file.Close()
buf := make([]byte, 512)
_, err = file.Read(buf)
if err != nil {
return nil, err
}
contentType := http.DetectContentType(buf)
allowedTypes := []string{"image/jpeg", "image/jpg", "image/svg+xml", "image/png"}
if !slices.Contains(allowedTypes, contentType) {
return nil, fmt.Errorf("invalid image type: %s", contentType)
}
info, err := file.Stat()
if err != nil {
return nil, err
}
// Check if the file size exceeds 100MB
var maxSize int64 = 100 * 1024 * 1024 // 100MB in bytes
if info.Size() > maxSize {
return nil, fmt.Errorf("file size exceeds maximum limit (100MB).")
}
buf = make([]byte, info.Size())
_, err = file.Seek(0, 0)
if err != nil {
return nil, err
}
_, err = io.ReadFull(file, buf)
if err != nil {
return nil, err
}
return buf, nil
}
func initializeKeypair() error {
home, err := os.UserHomeDir()
if err != nil {

View file

@ -150,6 +150,7 @@ PARAMETER <parameter> <parametervalue>
| top_k | Reduces the probability of generating nonsense. A higher value (e.g. 100) will give more diverse answers, while a lower value (e.g. 10) will be more conservative. (Default: 40) | int | top_k 40 |
| top_p | Works together with top-k. A higher value (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text. (Default: 0.9) | float | top_p 0.9 |
### TEMPLATE
`TEMPLATE` of the full prompt template to be passed into the model. It may include (optionally) a system prompt and a user's prompt. This is used to create a full custom prompt, and syntax may be model specific. You can usually find the template for a given model in the readme for that model.

View file

@ -223,8 +223,14 @@ type Running struct {
*StatusWriter // captures error messages from the llama runner process
}
type ImageData struct {
Data []byte `json:"data"`
ID int `json:"id"`
}
type llama struct {
api.Options
ImageData []ImageData
Running
}
@ -547,6 +553,7 @@ const maxBufferSize = 512 * format.KiloByte
type PredictOpts struct {
Prompt string
Format string
Images []api.ImageData
CheckpointStart time.Time
CheckpointLoaded time.Time
}
@ -564,6 +571,14 @@ type PredictResult struct {
}
func (llm *llama) Predict(ctx context.Context, predict PredictOpts, fn func(PredictResult)) error {
imageData := llm.ImageData
if len(predict.Images) > 0 {
for cnt, i := range predict.Images {
imageData = append(imageData, ImageData{Data: i, ID: cnt})
}
}
log.Printf("loaded %d images", len(imageData))
request := map[string]any{
"prompt": predict.Prompt,
"stream": true,
@ -585,6 +600,7 @@ func (llm *llama) Predict(ctx context.Context, predict PredictOpts, fn func(Pred
"penalize_nl": llm.PenalizeNewline,
"seed": llm.Seed,
"stop": llm.Stop,
"image_data": imageData,
}
if predict.Format == "json" {

View file

@ -46,6 +46,7 @@ type Model struct {
System string
License []string
Digest string
Size int64
Options map[string]interface{}
}
@ -242,6 +243,7 @@ func GetModel(name string) (*Model, error) {
Digest: digest,
Template: "{{ .Prompt }}",
License: []string{},
Size: manifest.GetTotalSize(),
}
filename, err := GetBlobsPath(manifest.Config.Digest)
@ -545,6 +547,7 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
}
}
// xxx - can this be removed?
if config.ModelType == "65B" {
if gqa, ok := formattedParams["gqa"].(int); ok && gqa == 8 {
config.ModelType = "70B"

View file

@ -156,9 +156,9 @@ func GenerateHandler(c *gin.Context) {
defer loaded.mu.Unlock()
checkpointStart := time.Now()
var req api.GenerateRequest
err := c.ShouldBindJSON(&req)
switch {
case errors.Is(err, io.EOF):
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
@ -292,6 +292,7 @@ func GenerateHandler(c *gin.Context) {
Format: req.Format,
CheckpointStart: checkpointStart,
CheckpointLoaded: checkpointLoaded,
Images: req.Images,
}
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
ch <- gin.H{"error": err.Error()}
@ -614,10 +615,19 @@ func GetModelInfo(name string) (*api.ShowResponse, error) {
return nil, err
}
modelDetails := api.ModelDetails{
Format: model.Config.ModelFormat,
Family: model.Config.ModelFamily,
Families: model.Config.ModelFamilies,
ParameterSize: model.Config.ModelType,
QuantizationLevel: model.Config.FileType,
}
resp := &api.ShowResponse{
License: strings.Join(model.License, "\n"),
System: model.System,
Template: model.Template,
Details: modelDetails,
}
mf, err := ShowModelfile(model)
@ -667,25 +677,42 @@ func ListModelsHandler(c *gin.Context) {
return
}
modelResponse := func(modelName string) (api.ModelResponse, error) {
model, err := GetModel(modelName)
if err != nil {
return api.ModelResponse{}, err
}
modelDetails := api.ModelDetails{
Format: model.Config.ModelFormat,
Family: model.Config.ModelFamily,
Families: model.Config.ModelFamilies,
ParameterSize: model.Config.ModelType,
QuantizationLevel: model.Config.FileType,
}
return api.ModelResponse{
Name: model.ShortName,
Size: model.Size,
Digest: model.Digest,
Details: modelDetails,
}, nil
}
walkFunc := func(path string, info os.FileInfo, _ error) error {
if !info.IsDir() {
dir, file := filepath.Split(path)
dir = strings.Trim(strings.TrimPrefix(dir, fp), string(os.PathSeparator))
tag := strings.Join([]string{dir, file}, ":")
mp := ParseModelPath(tag)
manifest, digest, err := GetManifest(mp)
resp, err := modelResponse(tag)
if err != nil {
log.Printf("skipping file: %s", fp)
return nil
}
models = append(models, api.ModelResponse{
Name: mp.GetShortTagname(),
Size: manifest.GetTotalSize(),
Digest: digest,
ModifiedAt: info.ModTime(),
})
resp.ModifiedAt = info.ModTime()
models = append(models, resp)
}
return nil