diff --git a/llama/runner/cache.go b/llama/runner/cache.go index 190ccdff..b487fe25 100644 --- a/llama/runner/cache.go +++ b/llama/runner/cache.go @@ -203,7 +203,11 @@ func countCommonPrefix(a []input, b []input) int { // the newest half into that space (saving numKeep inputs at the beginning). // // Assumes that at least 1 entry can be freed up by shifting (i.e. numKeep < numCtx) -func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int) { +func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int) error { + if numKeep >= c.numCtx { + return fmt.Errorf("unable to shift context - keep exceeds context (keep: %v context: %v)", numKeep, c.numCtx) + } + targetFree := (c.numCtx - numKeep) / 2 targetFree = max(targetFree, 1) @@ -211,18 +215,22 @@ func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int) { discard := targetFree - currentFree if discard <= 0 { - return + return nil } - slog.Debug("context limit hit - shifting", "limit", c.numCtx, "input", len(slot.Inputs), + slog.Debug("context limit hit - shifting", "id", slot.Id, "limit", c.numCtx, "input", len(slot.Inputs), "keep", numKeep, "discard", discard) // TODO (jessegross): KV cache removal can fail for certain types of models - c.lc.KvCacheSeqRm(slot.Id, numKeep, numKeep+discard) + if !c.lc.KvCacheSeqRm(slot.Id, numKeep, numKeep+discard) { + return fmt.Errorf("unable to remove old kv cache entries (id: %v, keep: %v discard: %v)", slot.Id, numKeep, discard) + } c.lc.KvCacheSeqAdd(slot.Id, numKeep+discard, len(slot.Inputs), -discard) for i := numKeep + discard; i < len(slot.Inputs); i++ { slot.Inputs[i-discard] = slot.Inputs[i] } slot.Inputs = slot.Inputs[:len(slot.Inputs)-discard] + + return nil } diff --git a/llama/runner/runner.go b/llama/runner/runner.go index a38cce91..1ed25c27 100644 --- a/llama/runner/runner.go +++ b/llama/runner/runner.go @@ -45,6 +45,9 @@ type Sequence struct { // prompt inputs left to evaluate inputs []input + // inputs that have been added to a batch but not yet submitted to Decode + pendingInputs []input + // tokens that have been generated but not returned yet (e.g. for stop sequences) pendingResponses []string @@ -367,14 +370,13 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) continue } - var numInputsProcessed int - shifted := false - for i, input := range seq.inputs { - if len(seq.cache.Inputs)+1 > s.cache.numCtx { - if !shifted { - s.cache.ShiftCacheSlot(seq.cache, seq.numKeep) - shifted = true + if len(seq.cache.Inputs)+len(seq.pendingInputs)+1 > s.cache.numCtx { + if len(seq.pendingInputs) == 0 { + err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep) + if err != nil { + return err + } } else { break } @@ -403,15 +405,12 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) } crossAttention = seq.crossAttention - batch.Add(input.token, input.embed, len(seq.cache.Inputs), i+1 == len(seq.inputs), seq.cache.Id) - seq.cache.Inputs = append(seq.cache.Inputs, input) - numInputsProcessed++ - } - - if numInputsProcessed > 0 { - seq.inputs = seq.inputs[numInputsProcessed:] + batch.Add(input.token, input.embed, len(seq.cache.Inputs)+len(seq.pendingInputs), i+1 == len(seq.inputs), seq.cache.Id) + seq.pendingInputs = append(seq.pendingInputs, input) seq.iBatch = batch.NumTokens() - 1 } + + seq.inputs = seq.inputs[len(seq.pendingInputs):] } if batch == nil || batch.NumTokens() == 0 { @@ -444,6 +443,12 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) continue } + // After calling Decode, pending inputs are now in the cache + if len(seq.pendingInputs) > 0 { + seq.cache.Inputs = append(seq.cache.Inputs, seq.pendingInputs...) + seq.pendingInputs = []input{} + } + // don't sample prompt processing if len(seq.inputs) != 0 { continue