From c3ff9164317940ec09534fd2370ec604a0de32ad Mon Sep 17 00:00:00 2001 From: Jesse Gross Date: Tue, 19 Nov 2024 10:51:47 -0800 Subject: [PATCH] runner.go: Don't add inputs to cache view until actually processed We need to track which tokens are in the cache ourselves. We currently add tokens to the cache tracker when we add them to batch but they are not actually in the cache until we call Decode. This can cause confusion when we are shifting the cache. Avoids "could not find a KV slot for the batch" issues. Bug #7545 --- llama/runner/cache.go | 16 ++++++++++++---- llama/runner/runner.go | 33 +++++++++++++++++++-------------- 2 files changed, 31 insertions(+), 18 deletions(-) diff --git a/llama/runner/cache.go b/llama/runner/cache.go index 190ccdff..b487fe25 100644 --- a/llama/runner/cache.go +++ b/llama/runner/cache.go @@ -203,7 +203,11 @@ func countCommonPrefix(a []input, b []input) int { // the newest half into that space (saving numKeep inputs at the beginning). // // Assumes that at least 1 entry can be freed up by shifting (i.e. numKeep < numCtx) -func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int) { +func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int) error { + if numKeep >= c.numCtx { + return fmt.Errorf("unable to shift context - keep exceeds context (keep: %v context: %v)", numKeep, c.numCtx) + } + targetFree := (c.numCtx - numKeep) / 2 targetFree = max(targetFree, 1) @@ -211,18 +215,22 @@ func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int) { discard := targetFree - currentFree if discard <= 0 { - return + return nil } - slog.Debug("context limit hit - shifting", "limit", c.numCtx, "input", len(slot.Inputs), + slog.Debug("context limit hit - shifting", "id", slot.Id, "limit", c.numCtx, "input", len(slot.Inputs), "keep", numKeep, "discard", discard) // TODO (jessegross): KV cache removal can fail for certain types of models - c.lc.KvCacheSeqRm(slot.Id, numKeep, numKeep+discard) + if !c.lc.KvCacheSeqRm(slot.Id, numKeep, numKeep+discard) { + return fmt.Errorf("unable to remove old kv cache entries (id: %v, keep: %v discard: %v)", slot.Id, numKeep, discard) + } c.lc.KvCacheSeqAdd(slot.Id, numKeep+discard, len(slot.Inputs), -discard) for i := numKeep + discard; i < len(slot.Inputs); i++ { slot.Inputs[i-discard] = slot.Inputs[i] } slot.Inputs = slot.Inputs[:len(slot.Inputs)-discard] + + return nil } diff --git a/llama/runner/runner.go b/llama/runner/runner.go index a38cce91..1ed25c27 100644 --- a/llama/runner/runner.go +++ b/llama/runner/runner.go @@ -45,6 +45,9 @@ type Sequence struct { // prompt inputs left to evaluate inputs []input + // inputs that have been added to a batch but not yet submitted to Decode + pendingInputs []input + // tokens that have been generated but not returned yet (e.g. for stop sequences) pendingResponses []string @@ -367,14 +370,13 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) continue } - var numInputsProcessed int - shifted := false - for i, input := range seq.inputs { - if len(seq.cache.Inputs)+1 > s.cache.numCtx { - if !shifted { - s.cache.ShiftCacheSlot(seq.cache, seq.numKeep) - shifted = true + if len(seq.cache.Inputs)+len(seq.pendingInputs)+1 > s.cache.numCtx { + if len(seq.pendingInputs) == 0 { + err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep) + if err != nil { + return err + } } else { break } @@ -403,15 +405,12 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) } crossAttention = seq.crossAttention - batch.Add(input.token, input.embed, len(seq.cache.Inputs), i+1 == len(seq.inputs), seq.cache.Id) - seq.cache.Inputs = append(seq.cache.Inputs, input) - numInputsProcessed++ - } - - if numInputsProcessed > 0 { - seq.inputs = seq.inputs[numInputsProcessed:] + batch.Add(input.token, input.embed, len(seq.cache.Inputs)+len(seq.pendingInputs), i+1 == len(seq.inputs), seq.cache.Id) + seq.pendingInputs = append(seq.pendingInputs, input) seq.iBatch = batch.NumTokens() - 1 } + + seq.inputs = seq.inputs[len(seq.pendingInputs):] } if batch == nil || batch.NumTokens() == 0 { @@ -444,6 +443,12 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) continue } + // After calling Decode, pending inputs are now in the cache + if len(seq.pendingInputs) > 0 { + seq.cache.Inputs = append(seq.cache.Inputs, seq.pendingInputs...) + seq.pendingInputs = []input{} + } + // don't sample prompt processing if len(seq.inputs) != 0 { continue