312d9de1d1
Check for NULL return values from llama.cpp in more places and convert them into Go errors, which should make debugging easier in the future rather than having hidden surprises in our data structures.
179 lines
3.8 KiB
Go
179 lines
3.8 KiB
Go
package main
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"hash/maphash"
|
|
"log/slog"
|
|
"slices"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/ollama/ollama/llama"
|
|
)
|
|
|
|
const imageCacheSize = 4
|
|
|
|
type ImageContext struct {
|
|
// mu is required to be held when generating embeddings or accessing the cache
|
|
mu sync.Mutex
|
|
|
|
clip *llama.ClipContext
|
|
mllama *llama.MllamaContext
|
|
|
|
// cache of images to embeddings
|
|
images []imageCache
|
|
imageHash maphash.Hash
|
|
}
|
|
|
|
func NewImageContext(llamaContext *llama.Context, modelPath string) (*ImageContext, error) {
|
|
arch, err := llama.GetModelArch(modelPath)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("unable to determine vision architecture: %w (%s)", err, modelPath)
|
|
}
|
|
|
|
var c ImageContext
|
|
if arch == "clip" {
|
|
c.clip, err = llama.NewClipContext(llamaContext, modelPath)
|
|
} else if arch == "mllama" {
|
|
c.mllama, err = llama.NewMllamaContext(llamaContext, modelPath)
|
|
} else {
|
|
return nil, fmt.Errorf("unknown vision model architecture: %s", arch)
|
|
}
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
c.images = make([]imageCache, imageCacheSize)
|
|
|
|
return &c, nil
|
|
}
|
|
|
|
func (c *ImageContext) Free(modelPath string) {
|
|
if c == nil {
|
|
return
|
|
}
|
|
|
|
if c.clip != nil {
|
|
c.clip.Free()
|
|
}
|
|
if c.mllama != nil {
|
|
c.mllama.Free()
|
|
}
|
|
}
|
|
|
|
func (c *ImageContext) NewEmbed(llamaContext *llama.Context, data []byte, aspectRatioId int) ([][]float32, error) {
|
|
if c == nil {
|
|
return nil, nil
|
|
}
|
|
|
|
hash := c.hashImage(data)
|
|
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
|
|
embed, err := c.findImage(hash)
|
|
if err != nil {
|
|
if c.mllama != nil {
|
|
embed, err = c.mllama.NewEmbed(llamaContext, data, aspectRatioId)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
} else if c.clip != nil {
|
|
embed, err = c.clip.NewEmbed(llamaContext, data)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
} else {
|
|
return nil, errors.New("received image but vision model not loaded")
|
|
}
|
|
|
|
c.addImage(hash, embed)
|
|
}
|
|
|
|
return embed, nil
|
|
}
|
|
|
|
func (c *ImageContext) BatchSize(configuredBatchSize int) int {
|
|
// If images are not supported, we don't need to allocate embedding batches
|
|
if c == nil {
|
|
return 0
|
|
}
|
|
|
|
// Mllama maps an image to 1 embedding token (llava creates many tokens)
|
|
// and doesn't support more than a single image per request.
|
|
// The embeddings are large (100 MB), so allocating a big batch can fail
|
|
// on some systems
|
|
if c.mllama != nil {
|
|
return 1
|
|
}
|
|
|
|
return configuredBatchSize
|
|
}
|
|
|
|
func (c *ImageContext) EmbedSize(llamaContext *llama.Context) int {
|
|
if c != nil && c.mllama != nil {
|
|
return c.mllama.EmbedSize(llamaContext)
|
|
} else {
|
|
return llamaContext.Model().NEmbd()
|
|
}
|
|
}
|
|
|
|
func (c *ImageContext) NeedCrossAttention(inputs ...input) bool {
|
|
if c == nil || c.mllama == nil {
|
|
return false
|
|
}
|
|
|
|
return slices.ContainsFunc(inputs, func(input input) bool {
|
|
return input.embed != nil
|
|
})
|
|
}
|
|
|
|
type imageCache struct {
|
|
key uint64
|
|
val [][]float32
|
|
lastUsed time.Time
|
|
}
|
|
|
|
func (c *ImageContext) hashImage(image []byte) uint64 {
|
|
c.imageHash.Reset()
|
|
_, _ = c.imageHash.Write(image)
|
|
return c.imageHash.Sum64()
|
|
}
|
|
|
|
var errImageNotFound = errors.New("image not found in cache")
|
|
|
|
func (c *ImageContext) findImage(hash uint64) ([][]float32, error) {
|
|
for i := range c.images {
|
|
if c.images[i].key == hash {
|
|
slog.Debug("loading image embeddings from cache", "entry", i)
|
|
c.images[i].lastUsed = time.Now()
|
|
return c.images[i].val, nil
|
|
}
|
|
}
|
|
|
|
return nil, errImageNotFound
|
|
}
|
|
|
|
func (c *ImageContext) addImage(hash uint64, embed [][]float32) {
|
|
best := time.Now()
|
|
var bestImage int
|
|
|
|
for i := range c.images {
|
|
if c.images[i].key == hash {
|
|
bestImage = i
|
|
break
|
|
}
|
|
|
|
if c.images[i].lastUsed.Compare(best) < 0 {
|
|
best = c.images[i].lastUsed
|
|
bestImage = i
|
|
}
|
|
}
|
|
|
|
slog.Debug("storing image embeddings in cache", "entry", bestImage, "used", c.images[bestImage].lastUsed)
|
|
c.images[bestImage].key = hash
|
|
c.images[bestImage].val = embed
|
|
c.images[bestImage].lastUsed = time.Now()
|
|
}
|