266 lines
6.4 KiB
Go
266 lines
6.4 KiB
Go
|
package main
|
||
|
|
||
|
import (
|
||
|
"errors"
|
||
|
"hash/maphash"
|
||
|
"log/slog"
|
||
|
"reflect"
|
||
|
"time"
|
||
|
|
||
|
"github.com/ollama/ollama/llama"
|
||
|
)
|
||
|
|
||
|
type InputCache struct {
|
||
|
// context window size (per slot)
|
||
|
numCtx int
|
||
|
|
||
|
// individual KV caches
|
||
|
slots []InputCacheSlot
|
||
|
|
||
|
// optimize cache eviction for multiple users
|
||
|
multiUserCache bool
|
||
|
|
||
|
// cache of images to embeddings
|
||
|
images []imageCache
|
||
|
imageHash maphash.Hash
|
||
|
|
||
|
lc *llama.Context
|
||
|
}
|
||
|
|
||
|
func NewInputCache(lc *llama.Context, kvSize int, numSlots int, multiUserCache bool) *InputCache {
|
||
|
slots := make([]InputCacheSlot, numSlots)
|
||
|
|
||
|
for i := range slots {
|
||
|
slots[i] = InputCacheSlot{
|
||
|
Id: i,
|
||
|
Inputs: make([]input, 0),
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return &InputCache{
|
||
|
numCtx: kvSize / numSlots,
|
||
|
slots: slots,
|
||
|
multiUserCache: multiUserCache,
|
||
|
images: make([]imageCache, numSlots),
|
||
|
lc: lc,
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Locking: Operations on InputCacheSlot (including finding one
|
||
|
// through LoadCacheSlot) require a lock to be be held that serializes
|
||
|
// these operations with each other and llama.Decode
|
||
|
|
||
|
type InputCacheSlot struct {
|
||
|
// Index in the KV cache
|
||
|
Id int
|
||
|
|
||
|
// Inputs that are stored in the KV cache
|
||
|
Inputs []input
|
||
|
|
||
|
// is this cache actively being processed as part of a sequence?
|
||
|
InUse bool
|
||
|
|
||
|
// last time this cache was used (as of start of processing)
|
||
|
lastUsed time.Time
|
||
|
}
|
||
|
|
||
|
func (c *InputCache) LoadCacheSlot(prompt []input, cachePrompt bool) (*InputCacheSlot, []input, int, error) {
|
||
|
var slot *InputCacheSlot
|
||
|
var numPast int
|
||
|
var err error
|
||
|
|
||
|
// In single-user scenarios, the longest cache slot works fine for getting good input
|
||
|
// cache hit rates and it reuses the same VRAM over and over again, which is good for
|
||
|
// GPU performance in situations where we miss the input cache.
|
||
|
// For multiple users, the "best" cache slot produces better input cache hit rates
|
||
|
// at the cost of worse performance when we miss the input cache (because it causes
|
||
|
// GPU L2 cache misses due to spreading out accesses across VRAM).
|
||
|
if !c.multiUserCache {
|
||
|
slot, numPast, err = c.findLongestCacheSlot(prompt)
|
||
|
} else {
|
||
|
slot, numPast, err = c.findBestCacheSlot(prompt)
|
||
|
}
|
||
|
if err != nil {
|
||
|
return nil, nil, 0, err
|
||
|
}
|
||
|
|
||
|
if !cachePrompt {
|
||
|
numPast = 0
|
||
|
}
|
||
|
|
||
|
slot.InUse = true
|
||
|
slot.lastUsed = time.Now()
|
||
|
|
||
|
if numPast == len(prompt) {
|
||
|
// Leave one input to sample so we can get a response
|
||
|
numPast--
|
||
|
}
|
||
|
|
||
|
if !c.lc.KvCacheSeqRm(slot.Id, numPast, -1) {
|
||
|
// Some models don't support partial erasure
|
||
|
c.lc.KvCacheSeqRm(slot.Id, 0, -1)
|
||
|
numPast = 0
|
||
|
}
|
||
|
|
||
|
slog.Debug("loading cache slot", "id", slot.Id, "cache", len(slot.Inputs), "prompt", len(prompt),
|
||
|
"used", numPast, "remaining", len(prompt)-numPast)
|
||
|
|
||
|
prompt = prompt[numPast:]
|
||
|
slot.Inputs = slot.Inputs[:numPast]
|
||
|
|
||
|
return slot, prompt, numPast, nil
|
||
|
}
|
||
|
|
||
|
func (c *InputCache) findLongestCacheSlot(prompt []input) (*InputCacheSlot, int, error) {
|
||
|
longest := -1
|
||
|
var longestSlot *InputCacheSlot
|
||
|
|
||
|
for i, s := range c.slots {
|
||
|
if s.InUse {
|
||
|
continue
|
||
|
}
|
||
|
|
||
|
count := countCommonPrefix(s.Inputs, prompt)
|
||
|
if count > longest {
|
||
|
longest = count
|
||
|
longestSlot = &c.slots[i]
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if longestSlot == nil {
|
||
|
return nil, 0, errors.New("no available cache slots")
|
||
|
}
|
||
|
|
||
|
return longestSlot, longest, nil
|
||
|
}
|
||
|
|
||
|
func (c *InputCache) findBestCacheSlot(prompt []input) (*InputCacheSlot, int, error) {
|
||
|
oldest := time.Now()
|
||
|
var oldestSlot *InputCacheSlot
|
||
|
|
||
|
longest := -1
|
||
|
var longestSlot *InputCacheSlot
|
||
|
|
||
|
for i, s := range c.slots {
|
||
|
count := countCommonPrefix(s.Inputs, prompt)
|
||
|
if count > longest {
|
||
|
longest = count
|
||
|
longestSlot = &c.slots[i]
|
||
|
}
|
||
|
|
||
|
if s.lastUsed.Compare(oldest) < 0 && !s.InUse {
|
||
|
oldest = s.lastUsed
|
||
|
oldestSlot = &c.slots[i]
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if longest == len(longestSlot.Inputs) && !longestSlot.InUse {
|
||
|
return longestSlot, longest, nil
|
||
|
}
|
||
|
|
||
|
if oldestSlot.InUse {
|
||
|
return nil, 0, errors.New("no available cache slots")
|
||
|
}
|
||
|
|
||
|
if len(oldestSlot.Inputs) != 0 {
|
||
|
slog.Debug("evicting cache slot", "id", oldestSlot.Id, "inputs", len(oldestSlot.Inputs),
|
||
|
"used", oldestSlot.lastUsed)
|
||
|
}
|
||
|
|
||
|
if longest > 0 && longestSlot != oldestSlot {
|
||
|
slog.Debug("forking cache slot", "src", longestSlot.Id, "dst", oldestSlot.Id, "inputs", longest, "total",
|
||
|
len(longestSlot.Inputs))
|
||
|
oldestSlot.Inputs = make([]input, longest)
|
||
|
copy(oldestSlot.Inputs, longestSlot.Inputs[:longest])
|
||
|
// This is only nil for unit tests
|
||
|
if c.lc != nil {
|
||
|
c.lc.KvCacheSeqRm(oldestSlot.Id, 0, -1)
|
||
|
c.lc.KvCacheSeqCp(longestSlot.Id, oldestSlot.Id, 0, longest)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return oldestSlot, longest, nil
|
||
|
}
|
||
|
|
||
|
func countCommonPrefix(a []input, b []input) int {
|
||
|
var count int
|
||
|
|
||
|
for i := range a {
|
||
|
if i >= len(b) {
|
||
|
break
|
||
|
}
|
||
|
|
||
|
if !reflect.DeepEqual(a[i], b[i]) {
|
||
|
break
|
||
|
}
|
||
|
|
||
|
count++
|
||
|
}
|
||
|
|
||
|
return count
|
||
|
}
|
||
|
|
||
|
func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int, numDiscard int, numPast int) {
|
||
|
// TODO (jessegross): KV cache removal can fail for certain types of models
|
||
|
// server.cpp doesn't handle this, though we can be more graceful
|
||
|
c.lc.KvCacheSeqRm(slot.Id, numKeep, numKeep+numDiscard)
|
||
|
c.lc.KvCacheSeqAdd(slot.Id, numKeep+numDiscard, numPast, -numDiscard)
|
||
|
|
||
|
for i := numKeep + numDiscard; i < len(slot.Inputs); i++ {
|
||
|
slot.Inputs[i-numDiscard] = slot.Inputs[i]
|
||
|
}
|
||
|
slot.Inputs = slot.Inputs[:len(slot.Inputs)-numDiscard]
|
||
|
}
|
||
|
|
||
|
// Locking: Lookup and store operations on imageCache require a lock
|
||
|
// to be held that serializes these with each other. Hash does not
|
||
|
// require a lock nor they need to be serialized with InputCacheSlot.
|
||
|
|
||
|
type imageCache struct {
|
||
|
key uint64
|
||
|
val [][]float32
|
||
|
lastUsed time.Time
|
||
|
}
|
||
|
|
||
|
func (c *InputCache) HashImage(image []byte) uint64 {
|
||
|
c.imageHash.Reset()
|
||
|
_, _ = c.imageHash.Write(image)
|
||
|
return c.imageHash.Sum64()
|
||
|
}
|
||
|
|
||
|
var ErrImageNotFound = errors.New("image not found in cache")
|
||
|
|
||
|
func (c *InputCache) FindImage(hash uint64) ([][]float32, error) {
|
||
|
for i := range c.images {
|
||
|
if c.images[i].key == hash {
|
||
|
slog.Debug("loading image embeddings from cache", "entry", i)
|
||
|
c.images[i].lastUsed = time.Now()
|
||
|
return c.images[i].val, nil
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return nil, ErrImageNotFound
|
||
|
}
|
||
|
|
||
|
func (c *InputCache) AddImage(hash uint64, embed [][]float32) {
|
||
|
best := time.Now()
|
||
|
var bestImage int
|
||
|
|
||
|
for i := range c.images {
|
||
|
if c.images[i].key == hash {
|
||
|
bestImage = i
|
||
|
break
|
||
|
}
|
||
|
|
||
|
if c.images[i].lastUsed.Compare(best) < 0 {
|
||
|
best = c.images[i].lastUsed
|
||
|
bestImage = i
|
||
|
}
|
||
|
}
|
||
|
|
||
|
slog.Debug("storing image embeddings in cache", "entry", bestImage, "used", c.images[bestImage].lastUsed)
|
||
|
c.images[bestImage].key = hash
|
||
|
c.images[bestImage].val = embed
|
||
|
c.images[bestImage].lastUsed = time.Now()
|
||
|
}
|