Multimodal support (#1216)
--------- Co-authored-by: Matt Apperson <mattapperson@Matts-MacBook-Pro.local>
This commit is contained in:
parent
7a1b37ac64
commit
910e9401d0
6 changed files with 235 additions and 28 deletions
47
api/types.go
47
api/types.go
|
@ -31,15 +31,18 @@ func (e StatusError) Error() string {
|
|||
}
|
||||
}
|
||||
|
||||
type ImageData []byte
|
||||
|
||||
type GenerateRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
System string `json:"system"`
|
||||
Template string `json:"template"`
|
||||
Context []int `json:"context,omitempty"`
|
||||
Stream *bool `json:"stream,omitempty"`
|
||||
Raw bool `json:"raw,omitempty"`
|
||||
Format string `json:"format"`
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
System string `json:"system"`
|
||||
Template string `json:"template"`
|
||||
Context []int `json:"context,omitempty"`
|
||||
Stream *bool `json:"stream,omitempty"`
|
||||
Raw bool `json:"raw,omitempty"`
|
||||
Format string `json:"format"`
|
||||
Images []ImageData `json:"images,omitempty"`
|
||||
|
||||
Options map[string]interface{} `json:"options"`
|
||||
}
|
||||
|
@ -148,11 +151,12 @@ type ShowRequest struct {
|
|||
}
|
||||
|
||||
type ShowResponse struct {
|
||||
License string `json:"license,omitempty"`
|
||||
Modelfile string `json:"modelfile,omitempty"`
|
||||
Parameters string `json:"parameters,omitempty"`
|
||||
Template string `json:"template,omitempty"`
|
||||
System string `json:"system,omitempty"`
|
||||
License string `json:"license,omitempty"`
|
||||
Modelfile string `json:"modelfile,omitempty"`
|
||||
Parameters string `json:"parameters,omitempty"`
|
||||
Template string `json:"template,omitempty"`
|
||||
System string `json:"system,omitempty"`
|
||||
Details ModelDetails `json:"details,omitempty"`
|
||||
}
|
||||
|
||||
type CopyRequest struct {
|
||||
|
@ -188,10 +192,11 @@ type ListResponse struct {
|
|||
}
|
||||
|
||||
type ModelResponse struct {
|
||||
Name string `json:"name"`
|
||||
ModifiedAt time.Time `json:"modified_at"`
|
||||
Size int64 `json:"size"`
|
||||
Digest string `json:"digest"`
|
||||
Name string `json:"name"`
|
||||
ModifiedAt time.Time `json:"modified_at"`
|
||||
Size int64 `json:"size"`
|
||||
Digest string `json:"digest"`
|
||||
Details ModelDetails `json:"details,omitempty"`
|
||||
}
|
||||
|
||||
type TokenResponse struct {
|
||||
|
@ -209,6 +214,14 @@ type GenerateResponse struct {
|
|||
Metrics
|
||||
}
|
||||
|
||||
type ModelDetails struct {
|
||||
Format string `json:"format"`
|
||||
Family string `json:"family"`
|
||||
Families []string `json:"families"`
|
||||
ParameterSize string `json:"parameter_size"`
|
||||
QuantizationLevel string `json:"quantization_level"`
|
||||
}
|
||||
|
||||
func (m *Metrics) Summary() {
|
||||
if m.TotalDuration > 0 {
|
||||
fmt.Fprintf(os.Stderr, "total duration: %v\n", m.TotalDuration)
|
||||
|
|
151
cmd/cmd.go
151
cmd/cmd.go
|
@ -17,7 +17,9 @@ import (
|
|||
"os/exec"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
@ -36,6 +38,8 @@ import (
|
|||
"github.com/jmorganca/ollama/version"
|
||||
)
|
||||
|
||||
type ImageData []byte
|
||||
|
||||
func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||
filename, _ := cmd.Flags().GetString("file")
|
||||
filename, err := filepath.Abs(filename)
|
||||
|
@ -418,6 +422,7 @@ func RunGenerate(cmd *cobra.Command, args []string) error {
|
|||
Model: args[0],
|
||||
WordWrap: os.Getenv("TERM") == "xterm-256color",
|
||||
Options: map[string]interface{}{},
|
||||
Images: []ImageData{},
|
||||
}
|
||||
|
||||
format, err := cmd.Flags().GetString("format")
|
||||
|
@ -427,7 +432,6 @@ func RunGenerate(cmd *cobra.Command, args []string) error {
|
|||
opts.Format = format
|
||||
|
||||
prompts := args[1:]
|
||||
|
||||
// prepend stdin to the prompt if provided
|
||||
if !term.IsTerminal(int(os.Stdin.Fd())) {
|
||||
in, err := io.ReadAll(os.Stdin)
|
||||
|
@ -466,6 +470,7 @@ type generateOptions struct {
|
|||
Format string
|
||||
System string
|
||||
Template string
|
||||
Images []ImageData
|
||||
Options map[string]interface{}
|
||||
}
|
||||
|
||||
|
@ -551,6 +556,10 @@ func generate(cmd *cobra.Command, opts generateOptions) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
images := make([]api.ImageData, 0)
|
||||
for _, i := range opts.Images {
|
||||
images = append(images, api.ImageData(i))
|
||||
}
|
||||
request := api.GenerateRequest{
|
||||
Model: opts.Model,
|
||||
Prompt: opts.Prompt,
|
||||
|
@ -559,6 +568,7 @@ func generate(cmd *cobra.Command, opts generateOptions) error {
|
|||
System: opts.System,
|
||||
Template: opts.Template,
|
||||
Options: opts.Options,
|
||||
Images: images,
|
||||
}
|
||||
|
||||
if err := client.Generate(ctx, &request, fn); err != nil {
|
||||
|
@ -585,7 +595,9 @@ func generate(cmd *cobra.Command, opts generateOptions) error {
|
|||
latest.Summary()
|
||||
}
|
||||
|
||||
cmd.SetContext(context.WithValue(cmd.Context(), generateContextKey("context"), latest.Context))
|
||||
ctx = context.WithValue(cmd.Context(), generateContextKey("context"), latest.Context)
|
||||
cmd.SetContext(ctx)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -598,11 +610,31 @@ const (
|
|||
MultilineTemplate
|
||||
)
|
||||
|
||||
func modelIsMultiModal(cmd *cobra.Command, name string) bool {
|
||||
// get model details
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
fmt.Println("error: couldn't connect to ollama server")
|
||||
return false
|
||||
}
|
||||
|
||||
req := api.ShowRequest{Name: name}
|
||||
resp, err := client.Show(cmd.Context(), &req)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return slices.Contains(resp.Details.Families, "clip")
|
||||
}
|
||||
|
||||
func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
|
||||
multiModal := modelIsMultiModal(cmd, opts.Model)
|
||||
|
||||
// load the model
|
||||
loadOpts := generateOptions{
|
||||
Model: opts.Model,
|
||||
Prompt: "",
|
||||
Images: []ImageData{},
|
||||
}
|
||||
if err := generate(cmd, loadOpts); err != nil {
|
||||
return err
|
||||
|
@ -902,6 +934,26 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
|
|||
|
||||
if len(prompt) > 0 && multiline == MultilineNone {
|
||||
opts.Prompt = prompt
|
||||
if multiModal {
|
||||
newPrompt, images, err := extractFileNames(prompt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
opts.Prompt = newPrompt
|
||||
|
||||
// reset the context if we find another image
|
||||
if len(images) > 0 {
|
||||
opts.Images = images
|
||||
ctx := cmd.Context()
|
||||
ctx = context.WithValue(ctx, generateContextKey("context"), []int{})
|
||||
cmd.SetContext(ctx)
|
||||
}
|
||||
if len(opts.Images) == 0 {
|
||||
fmt.Println("This model requires you to add a jpeg, png, or svg image.\n")
|
||||
prompt = ""
|
||||
continue
|
||||
}
|
||||
}
|
||||
if err := generate(cmd, opts); err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -911,6 +963,57 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
|
|||
}
|
||||
}
|
||||
|
||||
func normalizeFilePath(fp string) string {
|
||||
// Define a map of escaped characters and their replacements
|
||||
replacements := map[string]string{
|
||||
"\\ ": " ", // Escaped space
|
||||
"\\(": "(", // Escaped left parenthesis
|
||||
"\\)": ")", // Escaped right parenthesis
|
||||
"\\[": "[", // Escaped left square bracket
|
||||
"\\]": "]", // Escaped right square bracket
|
||||
"\\{": "{", // Escaped left curly brace
|
||||
"\\}": "}", // Escaped right curly brace
|
||||
"\\$": "$", // Escaped dollar sign
|
||||
"\\&": "&", // Escaped ampersand
|
||||
"\\;": ";", // Escaped semicolon
|
||||
"\\'": "'", // Escaped single quote
|
||||
"\\\\": "\\", // Escaped backslash
|
||||
"\\*": "*", // Escaped asterisk
|
||||
"\\?": "?", // Escaped question mark
|
||||
}
|
||||
|
||||
for escaped, actual := range replacements {
|
||||
fp = strings.ReplaceAll(fp, escaped, actual)
|
||||
}
|
||||
return fp
|
||||
}
|
||||
|
||||
func extractFileNames(input string) (string, []ImageData, error) {
|
||||
// Regex to match file paths starting with / or ./ and include escaped spaces (\ or %20)
|
||||
// and followed by more characters and a file extension
|
||||
regexPattern := `(?:\./|/)[\S\\ ]+?\.(?i:jpg|jpeg|png|svg)\b`
|
||||
re := regexp.MustCompile(regexPattern)
|
||||
|
||||
filePaths := re.FindAllString(input, -1)
|
||||
var imgs []ImageData
|
||||
|
||||
for _, fp := range filePaths {
|
||||
nfp := normalizeFilePath(fp)
|
||||
data, err := getImageData(nfp)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
continue
|
||||
}
|
||||
fmt.Printf("Couldn't process image: %q\n", err)
|
||||
return "", imgs, err
|
||||
}
|
||||
fmt.Printf("Added image '%s'\n", nfp)
|
||||
input = strings.ReplaceAll(input, fp, "")
|
||||
imgs = append(imgs, data)
|
||||
}
|
||||
return input, imgs, nil
|
||||
}
|
||||
|
||||
func RunServer(cmd *cobra.Command, _ []string) error {
|
||||
host, port, err := net.SplitHostPort(os.Getenv("OLLAMA_HOST"))
|
||||
if err != nil {
|
||||
|
@ -937,6 +1040,50 @@ func RunServer(cmd *cobra.Command, _ []string) error {
|
|||
return server.Serve(ln, origins)
|
||||
}
|
||||
|
||||
func getImageData(filePath string) ([]byte, error) {
|
||||
file, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
buf := make([]byte, 512)
|
||||
_, err = file.Read(buf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
contentType := http.DetectContentType(buf)
|
||||
allowedTypes := []string{"image/jpeg", "image/jpg", "image/svg+xml", "image/png"}
|
||||
if !slices.Contains(allowedTypes, contentType) {
|
||||
return nil, fmt.Errorf("invalid image type: %s", contentType)
|
||||
}
|
||||
|
||||
info, err := file.Stat()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Check if the file size exceeds 100MB
|
||||
var maxSize int64 = 100 * 1024 * 1024 // 100MB in bytes
|
||||
if info.Size() > maxSize {
|
||||
return nil, fmt.Errorf("file size exceeds maximum limit (100MB).")
|
||||
}
|
||||
|
||||
buf = make([]byte, info.Size())
|
||||
_, err = file.Seek(0, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, err = io.ReadFull(file, buf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return buf, nil
|
||||
}
|
||||
|
||||
func initializeKeypair() error {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
|
|
|
@ -150,6 +150,7 @@ PARAMETER <parameter> <parametervalue>
|
|||
| top_k | Reduces the probability of generating nonsense. A higher value (e.g. 100) will give more diverse answers, while a lower value (e.g. 10) will be more conservative. (Default: 40) | int | top_k 40 |
|
||||
| top_p | Works together with top-k. A higher value (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text. (Default: 0.9) | float | top_p 0.9 |
|
||||
|
||||
|
||||
### TEMPLATE
|
||||
|
||||
`TEMPLATE` of the full prompt template to be passed into the model. It may include (optionally) a system prompt and a user's prompt. This is used to create a full custom prompt, and syntax may be model specific. You can usually find the template for a given model in the readme for that model.
|
||||
|
|
16
llm/llama.go
16
llm/llama.go
|
@ -223,8 +223,14 @@ type Running struct {
|
|||
*StatusWriter // captures error messages from the llama runner process
|
||||
}
|
||||
|
||||
type ImageData struct {
|
||||
Data []byte `json:"data"`
|
||||
ID int `json:"id"`
|
||||
}
|
||||
|
||||
type llama struct {
|
||||
api.Options
|
||||
ImageData []ImageData
|
||||
Running
|
||||
}
|
||||
|
||||
|
@ -547,6 +553,7 @@ const maxBufferSize = 512 * format.KiloByte
|
|||
type PredictOpts struct {
|
||||
Prompt string
|
||||
Format string
|
||||
Images []api.ImageData
|
||||
CheckpointStart time.Time
|
||||
CheckpointLoaded time.Time
|
||||
}
|
||||
|
@ -564,6 +571,14 @@ type PredictResult struct {
|
|||
}
|
||||
|
||||
func (llm *llama) Predict(ctx context.Context, predict PredictOpts, fn func(PredictResult)) error {
|
||||
imageData := llm.ImageData
|
||||
if len(predict.Images) > 0 {
|
||||
for cnt, i := range predict.Images {
|
||||
imageData = append(imageData, ImageData{Data: i, ID: cnt})
|
||||
}
|
||||
}
|
||||
log.Printf("loaded %d images", len(imageData))
|
||||
|
||||
request := map[string]any{
|
||||
"prompt": predict.Prompt,
|
||||
"stream": true,
|
||||
|
@ -585,6 +600,7 @@ func (llm *llama) Predict(ctx context.Context, predict PredictOpts, fn func(Pred
|
|||
"penalize_nl": llm.PenalizeNewline,
|
||||
"seed": llm.Seed,
|
||||
"stop": llm.Stop,
|
||||
"image_data": imageData,
|
||||
}
|
||||
|
||||
if predict.Format == "json" {
|
||||
|
|
|
@ -46,6 +46,7 @@ type Model struct {
|
|||
System string
|
||||
License []string
|
||||
Digest string
|
||||
Size int64
|
||||
Options map[string]interface{}
|
||||
}
|
||||
|
||||
|
@ -242,6 +243,7 @@ func GetModel(name string) (*Model, error) {
|
|||
Digest: digest,
|
||||
Template: "{{ .Prompt }}",
|
||||
License: []string{},
|
||||
Size: manifest.GetTotalSize(),
|
||||
}
|
||||
|
||||
filename, err := GetBlobsPath(manifest.Config.Digest)
|
||||
|
@ -545,6 +547,7 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
|
|||
}
|
||||
}
|
||||
|
||||
// xxx - can this be removed?
|
||||
if config.ModelType == "65B" {
|
||||
if gqa, ok := formattedParams["gqa"].(int); ok && gqa == 8 {
|
||||
config.ModelType = "70B"
|
||||
|
|
|
@ -156,9 +156,9 @@ func GenerateHandler(c *gin.Context) {
|
|||
defer loaded.mu.Unlock()
|
||||
|
||||
checkpointStart := time.Now()
|
||||
|
||||
var req api.GenerateRequest
|
||||
err := c.ShouldBindJSON(&req)
|
||||
|
||||
switch {
|
||||
case errors.Is(err, io.EOF):
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
||||
|
@ -292,6 +292,7 @@ func GenerateHandler(c *gin.Context) {
|
|||
Format: req.Format,
|
||||
CheckpointStart: checkpointStart,
|
||||
CheckpointLoaded: checkpointLoaded,
|
||||
Images: req.Images,
|
||||
}
|
||||
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
|
@ -614,10 +615,19 @@ func GetModelInfo(name string) (*api.ShowResponse, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
modelDetails := api.ModelDetails{
|
||||
Format: model.Config.ModelFormat,
|
||||
Family: model.Config.ModelFamily,
|
||||
Families: model.Config.ModelFamilies,
|
||||
ParameterSize: model.Config.ModelType,
|
||||
QuantizationLevel: model.Config.FileType,
|
||||
}
|
||||
|
||||
resp := &api.ShowResponse{
|
||||
License: strings.Join(model.License, "\n"),
|
||||
System: model.System,
|
||||
Template: model.Template,
|
||||
Details: modelDetails,
|
||||
}
|
||||
|
||||
mf, err := ShowModelfile(model)
|
||||
|
@ -667,25 +677,42 @@ func ListModelsHandler(c *gin.Context) {
|
|||
return
|
||||
}
|
||||
|
||||
modelResponse := func(modelName string) (api.ModelResponse, error) {
|
||||
model, err := GetModel(modelName)
|
||||
if err != nil {
|
||||
return api.ModelResponse{}, err
|
||||
}
|
||||
|
||||
modelDetails := api.ModelDetails{
|
||||
Format: model.Config.ModelFormat,
|
||||
Family: model.Config.ModelFamily,
|
||||
Families: model.Config.ModelFamilies,
|
||||
ParameterSize: model.Config.ModelType,
|
||||
QuantizationLevel: model.Config.FileType,
|
||||
}
|
||||
|
||||
return api.ModelResponse{
|
||||
Name: model.ShortName,
|
||||
Size: model.Size,
|
||||
Digest: model.Digest,
|
||||
Details: modelDetails,
|
||||
}, nil
|
||||
}
|
||||
|
||||
walkFunc := func(path string, info os.FileInfo, _ error) error {
|
||||
if !info.IsDir() {
|
||||
dir, file := filepath.Split(path)
|
||||
dir = strings.Trim(strings.TrimPrefix(dir, fp), string(os.PathSeparator))
|
||||
tag := strings.Join([]string{dir, file}, ":")
|
||||
|
||||
mp := ParseModelPath(tag)
|
||||
manifest, digest, err := GetManifest(mp)
|
||||
resp, err := modelResponse(tag)
|
||||
if err != nil {
|
||||
log.Printf("skipping file: %s", fp)
|
||||
return nil
|
||||
}
|
||||
|
||||
models = append(models, api.ModelResponse{
|
||||
Name: mp.GetShortTagname(),
|
||||
Size: manifest.GetTotalSize(),
|
||||
Digest: digest,
|
||||
ModifiedAt: info.ModTime(),
|
||||
})
|
||||
resp.ModifiedAt = info.ModTime()
|
||||
models = append(models, resp)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
|
Loading…
Add table
Reference in a new issue