From 65973ceb6417c2e2796fa59bd3225bc7bd79b403 Mon Sep 17 00:00:00 2001 From: Jesse Gross Date: Fri, 8 Nov 2024 11:10:56 -0800 Subject: [PATCH] runner.go: Make KV entry accounting more robust The structure of the accounting for KV cache shifting was carried over from the old runner but it now doesn't feel natural with the new runner. There are a number of invariants that should hold true but are difficult to reason about. There is at least one bug report that would imply that the invariants are not holding. This reduces the number of implicit assumptions and is more forgiving of unexpected situations. It also improves behavior around which input tokens are kept when truncation occurs. Bug #7545 --- llama/runner/cache.go | 47 ++++++++++++++++++++-------- llama/runner/runner.go | 69 +++++++++++++++--------------------------- 2 files changed, 59 insertions(+), 57 deletions(-) diff --git a/llama/runner/cache.go b/llama/runner/cache.go index 75c1d874..190ccdff 100644 --- a/llama/runner/cache.go +++ b/llama/runner/cache.go @@ -2,6 +2,7 @@ package main import ( "errors" + "fmt" "log/slog" "reflect" "time" @@ -22,7 +23,11 @@ type InputCache struct { lc *llama.Context } -func NewInputCache(lc *llama.Context, kvSize int, numSlots int, multiUserCache bool) *InputCache { +func NewInputCache(lc *llama.Context, kvSize int, numSlots int, multiUserCache bool) (*InputCache, error) { + if kvSize/numSlots < 1 { + return nil, fmt.Errorf("must have at least one kv cache entry per parallel sequence (kv: %v parallel: %v)", kvSize, numSlots) + } + slots := make([]InputCacheSlot, numSlots) for i := range slots { @@ -37,7 +42,7 @@ func NewInputCache(lc *llama.Context, kvSize int, numSlots int, multiUserCache b slots: slots, multiUserCache: multiUserCache, lc: lc, - } + }, nil } // Locking: Operations on InputCacheSlot (including finding one @@ -58,7 +63,7 @@ type InputCacheSlot struct { lastUsed time.Time } -func (c *InputCache) LoadCacheSlot(prompt []input, cachePrompt bool) (*InputCacheSlot, []input, int, error) { +func (c *InputCache) LoadCacheSlot(prompt []input, cachePrompt bool) (*InputCacheSlot, []input, error) { var slot *InputCacheSlot var numPast int var err error @@ -75,7 +80,7 @@ func (c *InputCache) LoadCacheSlot(prompt []input, cachePrompt bool) (*InputCach slot, numPast, err = c.findBestCacheSlot(prompt) } if err != nil { - return nil, nil, 0, err + return nil, nil, err } if !cachePrompt { @@ -102,7 +107,7 @@ func (c *InputCache) LoadCacheSlot(prompt []input, cachePrompt bool) (*InputCach prompt = prompt[numPast:] slot.Inputs = slot.Inputs[:numPast] - return slot, prompt, numPast, nil + return slot, prompt, nil } func (c *InputCache) findLongestCacheSlot(prompt []input) (*InputCacheSlot, int, error) { @@ -194,14 +199,30 @@ func countCommonPrefix(a []input, b []input) int { return count } -func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int, numDiscard int, numPast int) { - // TODO (jessegross): KV cache removal can fail for certain types of models - // server.cpp doesn't handle this, though we can be more graceful - c.lc.KvCacheSeqRm(slot.Id, numKeep, numKeep+numDiscard) - c.lc.KvCacheSeqAdd(slot.Id, numKeep+numDiscard, numPast, -numDiscard) +// Frees up space in the KV cache by deleting the oldest half of history and shifting +// the newest half into that space (saving numKeep inputs at the beginning). +// +// Assumes that at least 1 entry can be freed up by shifting (i.e. numKeep < numCtx) +func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int) { + targetFree := (c.numCtx - numKeep) / 2 + targetFree = max(targetFree, 1) - for i := numKeep + numDiscard; i < len(slot.Inputs); i++ { - slot.Inputs[i-numDiscard] = slot.Inputs[i] + currentFree := c.numCtx - len(slot.Inputs) + discard := targetFree - currentFree + + if discard <= 0 { + return } - slot.Inputs = slot.Inputs[:len(slot.Inputs)-numDiscard] + + slog.Debug("context limit hit - shifting", "limit", c.numCtx, "input", len(slot.Inputs), + "keep", numKeep, "discard", discard) + + // TODO (jessegross): KV cache removal can fail for certain types of models + c.lc.KvCacheSeqRm(slot.Id, numKeep, numKeep+discard) + c.lc.KvCacheSeqAdd(slot.Id, numKeep+discard, len(slot.Inputs), -discard) + + for i := numKeep + discard; i < len(slot.Inputs); i++ { + slot.Inputs[i-discard] = slot.Inputs[i] + } + slot.Inputs = slot.Inputs[:len(slot.Inputs)-discard] } diff --git a/llama/runner/runner.go b/llama/runner/runner.go index 0a37dee0..b680f060 100644 --- a/llama/runner/runner.go +++ b/llama/runner/runner.go @@ -34,9 +34,6 @@ type input struct { } type Sequence struct { - // number of inputs evaluated - numPast int - // batch index iBatch int @@ -112,21 +109,15 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen params.numKeep = len(inputs) } - if !params.embedding { - // Subtracting 4 ensures that at least 1 input can be discarded during shift - params.numKeep = min(params.numKeep, s.cache.numCtx-4) - params.numKeep += s.bosToken - } else { - // Embeddings are 1 shot - just truncate to the context window, without ever shifting - params.numKeep = min(params.numKeep, s.cache.numCtx) + if s.model.AddBOSToken() { + params.numKeep += 1 } - // truncate to fit in context window + // Ensure that at least 1 input can be discarded during shift + params.numKeep = min(params.numKeep, s.cache.numCtx-1) + if len(inputs) > s.cache.numCtx { - slog.Warn("truncating input prompt", "limit", s.cache.numCtx, "prompt", len(inputs), "numKeep", params.numKeep) - newInputs := inputs[:params.numKeep] - newInputs = append(newInputs, inputs[len(inputs)-s.cache.numCtx+params.numKeep:]...) - inputs = newInputs + slog.Warn("input exceeds context length", "prompt", len(inputs), "limit", s.cache.numCtx) } var sc *llama.SamplingContext @@ -231,9 +222,6 @@ type Server struct { // KV cache cache *InputCache - // does this model require a beginning of sequence token? - bosToken int - // next sequence for prompt processing to avoid starvation nextSeq int @@ -258,18 +246,6 @@ func (s *Server) allNil() bool { return true } -func (s *Server) shiftContext(seq *Sequence) { - numLeft := seq.numPast - seq.numKeep - numDiscard := numLeft / 2 - - slog.Debug("context limit hit - shifting", "limit", s.cache.numCtx, "numPast", seq.numPast, - "numKeep", seq.numKeep, "numLeft", numLeft, "numDiscard", numDiscard) - - s.cache.ShiftCacheSlot(seq.cache, seq.numKeep, numDiscard, seq.numPast) - - seq.numPast -= numDiscard -} - func flushPending(seq *Sequence) bool { joined := strings.Join(seq.pendingResponses, "") seq.pendingResponses = []string{} @@ -374,12 +350,19 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) continue } - if seq.numPast+len(seq.inputs) > s.cache.numCtx { - s.shiftContext(seq) - } - var numInputsProcessed int + shifted := false + for i, input := range seq.inputs { + if len(seq.cache.Inputs)+1 > s.cache.numCtx { + if !shifted { + s.cache.ShiftCacheSlot(seq.cache, seq.numKeep) + shifted = true + } else { + break + } + } + embedding := input.embed != nil // If we don't currently have a batch, use one of the correct type and @@ -403,13 +386,12 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) } crossAttention = seq.crossAttention - batch.Add(input.token, input.embed, seq.numPast, numInputsProcessed+1 == len(seq.inputs), seq.cache.Id) - seq.numPast++ + batch.Add(input.token, input.embed, len(seq.cache.Inputs), i+1 == len(seq.inputs), seq.cache.Id) + seq.cache.Inputs = append(seq.cache.Inputs, input) numInputsProcessed++ } if numInputsProcessed > 0 { - seq.cache.Inputs = append(seq.cache.Inputs, seq.inputs[:numInputsProcessed]...) seq.inputs = seq.inputs[numInputsProcessed:] seq.iBatch = batch.NumTokens() - 1 } @@ -632,7 +614,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { s.mu.Lock() for i, sq := range s.seqs { if sq == nil { - seq.cache, seq.inputs, seq.numPast, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt) + seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt) if err != nil { s.mu.Unlock() http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError) @@ -715,7 +697,7 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) { s.mu.Lock() for i, sq := range s.seqs { if sq == nil { - seq.cache, seq.inputs, seq.numPast, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt) + seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt) if err != nil { s.mu.Unlock() http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError) @@ -802,10 +784,6 @@ func (s *Server) loadModel( } } - if s.model.AddBOSToken() { - s.bosToken = 1 - } - if ppath != "" { var err error s.image, err = NewImageContext(s.lc, ppath) @@ -814,7 +792,10 @@ func (s *Server) loadModel( } } - s.cache = NewInputCache(s.lc, kvSize, s.parallel, multiUserCache) + s.cache, err = NewInputCache(s.lc, kvSize, s.parallel, multiUserCache) + if err != nil { + panic(err) + } s.status = ServerStatusReady s.ready.Done()