26acdcf44e
Currently if an input has embeddings at any point then we will set cross attention to true from the beginning. This means that any tokens before the embeddings are sent will incorrectly have cross attention layers applied. This only sets cross attention when we have an embedding, either previously in this sequence or in the cache. It also makes cross attention capable of supporting parallelism at the runner level, though the mllama implementation doesn't support that yet.
156 lines
3.2 KiB
Go
156 lines
3.2 KiB
Go
package main
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"hash/maphash"
|
|
"log/slog"
|
|
"slices"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/ollama/ollama/llama"
|
|
)
|
|
|
|
const imageCacheSize = 4
|
|
|
|
type ImageContext struct {
|
|
// mu is required to be held when generating embeddings or accessing the cache
|
|
mu sync.Mutex
|
|
|
|
clip *llama.ClipContext
|
|
mllama *llama.MllamaContext
|
|
|
|
// cache of images to embeddings
|
|
images []imageCache
|
|
imageHash maphash.Hash
|
|
}
|
|
|
|
func NewImageContext(llamaContext *llama.Context, modelPath string) (*ImageContext, error) {
|
|
arch, err := llama.GetModelArch(modelPath)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("unable to determine vision architecture: %w (%s)", err, modelPath)
|
|
}
|
|
|
|
var c ImageContext
|
|
if arch == "clip" {
|
|
c.clip, err = llama.NewClipContext(llamaContext, modelPath)
|
|
} else if arch == "mllama" {
|
|
c.mllama, err = llama.NewMllamaContext(llamaContext, modelPath)
|
|
} else {
|
|
return nil, fmt.Errorf("unknown vision model architecture: %s", arch)
|
|
}
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
c.images = make([]imageCache, imageCacheSize)
|
|
|
|
return &c, nil
|
|
}
|
|
|
|
func (c *ImageContext) Free(modelPath string) {
|
|
if c == nil {
|
|
return
|
|
}
|
|
|
|
if c.clip != nil {
|
|
c.clip.Free()
|
|
}
|
|
if c.mllama != nil {
|
|
c.mllama.Free()
|
|
}
|
|
}
|
|
|
|
func (c *ImageContext) NewEmbed(llamaContext *llama.Context, data []byte, aspectRatioId int) [][]float32 {
|
|
if c == nil {
|
|
return nil
|
|
}
|
|
|
|
hash := c.hashImage(data)
|
|
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
|
|
embed, err := c.findImage(hash)
|
|
if err != nil {
|
|
if c.mllama != nil {
|
|
embed = c.mllama.NewEmbed(llamaContext, data, aspectRatioId)
|
|
} else if c.clip != nil {
|
|
embed = c.clip.NewEmbed(llamaContext, data)
|
|
} else {
|
|
return nil
|
|
}
|
|
|
|
c.addImage(hash, embed)
|
|
}
|
|
|
|
return embed
|
|
}
|
|
|
|
func (c *ImageContext) EmbedSize(llamaContext *llama.Context) int {
|
|
if c != nil && c.mllama != nil {
|
|
return c.mllama.EmbedSize(llamaContext)
|
|
} else {
|
|
return llamaContext.Model().NEmbd()
|
|
}
|
|
}
|
|
|
|
func (c *ImageContext) NeedCrossAttention(inputs ...input) bool {
|
|
if c == nil || c.mllama == nil {
|
|
return false
|
|
}
|
|
|
|
return slices.ContainsFunc(inputs, func(input input) bool {
|
|
return input.embed != nil
|
|
})
|
|
}
|
|
|
|
type imageCache struct {
|
|
key uint64
|
|
val [][]float32
|
|
lastUsed time.Time
|
|
}
|
|
|
|
func (c *ImageContext) hashImage(image []byte) uint64 {
|
|
c.imageHash.Reset()
|
|
_, _ = c.imageHash.Write(image)
|
|
return c.imageHash.Sum64()
|
|
}
|
|
|
|
var errImageNotFound = errors.New("image not found in cache")
|
|
|
|
func (c *ImageContext) findImage(hash uint64) ([][]float32, error) {
|
|
for i := range c.images {
|
|
if c.images[i].key == hash {
|
|
slog.Debug("loading image embeddings from cache", "entry", i)
|
|
c.images[i].lastUsed = time.Now()
|
|
return c.images[i].val, nil
|
|
}
|
|
}
|
|
|
|
return nil, errImageNotFound
|
|
}
|
|
|
|
func (c *ImageContext) addImage(hash uint64, embed [][]float32) {
|
|
best := time.Now()
|
|
var bestImage int
|
|
|
|
for i := range c.images {
|
|
if c.images[i].key == hash {
|
|
bestImage = i
|
|
break
|
|
}
|
|
|
|
if c.images[i].lastUsed.Compare(best) < 0 {
|
|
best = c.images[i].lastUsed
|
|
bestImage = i
|
|
}
|
|
}
|
|
|
|
slog.Debug("storing image embeddings in cache", "entry", bestImage, "used", c.images[bestImage].lastUsed)
|
|
c.images[bestImage].key = hash
|
|
c.images[bestImage].val = embed
|
|
c.images[bestImage].lastUsed = time.Now()
|
|
}
|