diff --git a/llama/runner/cache.go b/llama/runner/cache.go index 75c1d874..190ccdff 100644 --- a/llama/runner/cache.go +++ b/llama/runner/cache.go @@ -2,6 +2,7 @@ package main import ( "errors" + "fmt" "log/slog" "reflect" "time" @@ -22,7 +23,11 @@ type InputCache struct { lc *llama.Context } -func NewInputCache(lc *llama.Context, kvSize int, numSlots int, multiUserCache bool) *InputCache { +func NewInputCache(lc *llama.Context, kvSize int, numSlots int, multiUserCache bool) (*InputCache, error) { + if kvSize/numSlots < 1 { + return nil, fmt.Errorf("must have at least one kv cache entry per parallel sequence (kv: %v parallel: %v)", kvSize, numSlots) + } + slots := make([]InputCacheSlot, numSlots) for i := range slots { @@ -37,7 +42,7 @@ func NewInputCache(lc *llama.Context, kvSize int, numSlots int, multiUserCache b slots: slots, multiUserCache: multiUserCache, lc: lc, - } + }, nil } // Locking: Operations on InputCacheSlot (including finding one @@ -58,7 +63,7 @@ type InputCacheSlot struct { lastUsed time.Time } -func (c *InputCache) LoadCacheSlot(prompt []input, cachePrompt bool) (*InputCacheSlot, []input, int, error) { +func (c *InputCache) LoadCacheSlot(prompt []input, cachePrompt bool) (*InputCacheSlot, []input, error) { var slot *InputCacheSlot var numPast int var err error @@ -75,7 +80,7 @@ func (c *InputCache) LoadCacheSlot(prompt []input, cachePrompt bool) (*InputCach slot, numPast, err = c.findBestCacheSlot(prompt) } if err != nil { - return nil, nil, 0, err + return nil, nil, err } if !cachePrompt { @@ -102,7 +107,7 @@ func (c *InputCache) LoadCacheSlot(prompt []input, cachePrompt bool) (*InputCach prompt = prompt[numPast:] slot.Inputs = slot.Inputs[:numPast] - return slot, prompt, numPast, nil + return slot, prompt, nil } func (c *InputCache) findLongestCacheSlot(prompt []input) (*InputCacheSlot, int, error) { @@ -194,14 +199,30 @@ func countCommonPrefix(a []input, b []input) int { return count } -func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int, numDiscard int, numPast int) { - // TODO (jessegross): KV cache removal can fail for certain types of models - // server.cpp doesn't handle this, though we can be more graceful - c.lc.KvCacheSeqRm(slot.Id, numKeep, numKeep+numDiscard) - c.lc.KvCacheSeqAdd(slot.Id, numKeep+numDiscard, numPast, -numDiscard) +// Frees up space in the KV cache by deleting the oldest half of history and shifting +// the newest half into that space (saving numKeep inputs at the beginning). +// +// Assumes that at least 1 entry can be freed up by shifting (i.e. numKeep < numCtx) +func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int) { + targetFree := (c.numCtx - numKeep) / 2 + targetFree = max(targetFree, 1) - for i := numKeep + numDiscard; i < len(slot.Inputs); i++ { - slot.Inputs[i-numDiscard] = slot.Inputs[i] + currentFree := c.numCtx - len(slot.Inputs) + discard := targetFree - currentFree + + if discard <= 0 { + return } - slot.Inputs = slot.Inputs[:len(slot.Inputs)-numDiscard] + + slog.Debug("context limit hit - shifting", "limit", c.numCtx, "input", len(slot.Inputs), + "keep", numKeep, "discard", discard) + + // TODO (jessegross): KV cache removal can fail for certain types of models + c.lc.KvCacheSeqRm(slot.Id, numKeep, numKeep+discard) + c.lc.KvCacheSeqAdd(slot.Id, numKeep+discard, len(slot.Inputs), -discard) + + for i := numKeep + discard; i < len(slot.Inputs); i++ { + slot.Inputs[i-discard] = slot.Inputs[i] + } + slot.Inputs = slot.Inputs[:len(slot.Inputs)-discard] } diff --git a/llama/runner/runner.go b/llama/runner/runner.go index 0a37dee0..b680f060 100644 --- a/llama/runner/runner.go +++ b/llama/runner/runner.go @@ -34,9 +34,6 @@ type input struct { } type Sequence struct { - // number of inputs evaluated - numPast int - // batch index iBatch int @@ -112,21 +109,15 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen params.numKeep = len(inputs) } - if !params.embedding { - // Subtracting 4 ensures that at least 1 input can be discarded during shift - params.numKeep = min(params.numKeep, s.cache.numCtx-4) - params.numKeep += s.bosToken - } else { - // Embeddings are 1 shot - just truncate to the context window, without ever shifting - params.numKeep = min(params.numKeep, s.cache.numCtx) + if s.model.AddBOSToken() { + params.numKeep += 1 } - // truncate to fit in context window + // Ensure that at least 1 input can be discarded during shift + params.numKeep = min(params.numKeep, s.cache.numCtx-1) + if len(inputs) > s.cache.numCtx { - slog.Warn("truncating input prompt", "limit", s.cache.numCtx, "prompt", len(inputs), "numKeep", params.numKeep) - newInputs := inputs[:params.numKeep] - newInputs = append(newInputs, inputs[len(inputs)-s.cache.numCtx+params.numKeep:]...) - inputs = newInputs + slog.Warn("input exceeds context length", "prompt", len(inputs), "limit", s.cache.numCtx) } var sc *llama.SamplingContext @@ -231,9 +222,6 @@ type Server struct { // KV cache cache *InputCache - // does this model require a beginning of sequence token? - bosToken int - // next sequence for prompt processing to avoid starvation nextSeq int @@ -258,18 +246,6 @@ func (s *Server) allNil() bool { return true } -func (s *Server) shiftContext(seq *Sequence) { - numLeft := seq.numPast - seq.numKeep - numDiscard := numLeft / 2 - - slog.Debug("context limit hit - shifting", "limit", s.cache.numCtx, "numPast", seq.numPast, - "numKeep", seq.numKeep, "numLeft", numLeft, "numDiscard", numDiscard) - - s.cache.ShiftCacheSlot(seq.cache, seq.numKeep, numDiscard, seq.numPast) - - seq.numPast -= numDiscard -} - func flushPending(seq *Sequence) bool { joined := strings.Join(seq.pendingResponses, "") seq.pendingResponses = []string{} @@ -374,12 +350,19 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) continue } - if seq.numPast+len(seq.inputs) > s.cache.numCtx { - s.shiftContext(seq) - } - var numInputsProcessed int + shifted := false + for i, input := range seq.inputs { + if len(seq.cache.Inputs)+1 > s.cache.numCtx { + if !shifted { + s.cache.ShiftCacheSlot(seq.cache, seq.numKeep) + shifted = true + } else { + break + } + } + embedding := input.embed != nil // If we don't currently have a batch, use one of the correct type and @@ -403,13 +386,12 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) } crossAttention = seq.crossAttention - batch.Add(input.token, input.embed, seq.numPast, numInputsProcessed+1 == len(seq.inputs), seq.cache.Id) - seq.numPast++ + batch.Add(input.token, input.embed, len(seq.cache.Inputs), i+1 == len(seq.inputs), seq.cache.Id) + seq.cache.Inputs = append(seq.cache.Inputs, input) numInputsProcessed++ } if numInputsProcessed > 0 { - seq.cache.Inputs = append(seq.cache.Inputs, seq.inputs[:numInputsProcessed]...) seq.inputs = seq.inputs[numInputsProcessed:] seq.iBatch = batch.NumTokens() - 1 } @@ -632,7 +614,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { s.mu.Lock() for i, sq := range s.seqs { if sq == nil { - seq.cache, seq.inputs, seq.numPast, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt) + seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt) if err != nil { s.mu.Unlock() http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError) @@ -715,7 +697,7 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) { s.mu.Lock() for i, sq := range s.seqs { if sq == nil { - seq.cache, seq.inputs, seq.numPast, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt) + seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt) if err != nil { s.mu.Unlock() http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError) @@ -802,10 +784,6 @@ func (s *Server) loadModel( } } - if s.model.AddBOSToken() { - s.bosToken = 1 - } - if ppath != "" { var err error s.image, err = NewImageContext(s.lc, ppath) @@ -814,7 +792,10 @@ func (s *Server) loadModel( } } - s.cache = NewInputCache(s.lc, kvSize, s.parallel, multiUserCache) + s.cache, err = NewInputCache(s.lc, kvSize, s.parallel, multiUserCache) + if err != nil { + panic(err) + } s.status = ServerStatusReady s.ready.Done()