runner.go: Don't add inputs to cache view until actually processed
We need to track which tokens are in the cache ourselves. We currently add tokens to the cache tracker when we add them to batch but they are not actually in the cache until we call Decode. This can cause confusion when we are shifting the cache. Avoids "could not find a KV slot for the batch" issues. Bug #7545
This commit is contained in:
parent
3fc1dc0e6f
commit
c3ff916431
2 changed files with 31 additions and 18 deletions
|
@ -203,7 +203,11 @@ func countCommonPrefix(a []input, b []input) int {
|
|||
// the newest half into that space (saving numKeep inputs at the beginning).
|
||||
//
|
||||
// Assumes that at least 1 entry can be freed up by shifting (i.e. numKeep < numCtx)
|
||||
func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int) {
|
||||
func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int) error {
|
||||
if numKeep >= c.numCtx {
|
||||
return fmt.Errorf("unable to shift context - keep exceeds context (keep: %v context: %v)", numKeep, c.numCtx)
|
||||
}
|
||||
|
||||
targetFree := (c.numCtx - numKeep) / 2
|
||||
targetFree = max(targetFree, 1)
|
||||
|
||||
|
@ -211,18 +215,22 @@ func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int) {
|
|||
discard := targetFree - currentFree
|
||||
|
||||
if discard <= 0 {
|
||||
return
|
||||
return nil
|
||||
}
|
||||
|
||||
slog.Debug("context limit hit - shifting", "limit", c.numCtx, "input", len(slot.Inputs),
|
||||
slog.Debug("context limit hit - shifting", "id", slot.Id, "limit", c.numCtx, "input", len(slot.Inputs),
|
||||
"keep", numKeep, "discard", discard)
|
||||
|
||||
// TODO (jessegross): KV cache removal can fail for certain types of models
|
||||
c.lc.KvCacheSeqRm(slot.Id, numKeep, numKeep+discard)
|
||||
if !c.lc.KvCacheSeqRm(slot.Id, numKeep, numKeep+discard) {
|
||||
return fmt.Errorf("unable to remove old kv cache entries (id: %v, keep: %v discard: %v)", slot.Id, numKeep, discard)
|
||||
}
|
||||
c.lc.KvCacheSeqAdd(slot.Id, numKeep+discard, len(slot.Inputs), -discard)
|
||||
|
||||
for i := numKeep + discard; i < len(slot.Inputs); i++ {
|
||||
slot.Inputs[i-discard] = slot.Inputs[i]
|
||||
}
|
||||
slot.Inputs = slot.Inputs[:len(slot.Inputs)-discard]
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -45,6 +45,9 @@ type Sequence struct {
|
|||
// prompt inputs left to evaluate
|
||||
inputs []input
|
||||
|
||||
// inputs that have been added to a batch but not yet submitted to Decode
|
||||
pendingInputs []input
|
||||
|
||||
// tokens that have been generated but not returned yet (e.g. for stop sequences)
|
||||
pendingResponses []string
|
||||
|
||||
|
@ -367,14 +370,13 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
|||
continue
|
||||
}
|
||||
|
||||
var numInputsProcessed int
|
||||
shifted := false
|
||||
|
||||
for i, input := range seq.inputs {
|
||||
if len(seq.cache.Inputs)+1 > s.cache.numCtx {
|
||||
if !shifted {
|
||||
s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
|
||||
shifted = true
|
||||
if len(seq.cache.Inputs)+len(seq.pendingInputs)+1 > s.cache.numCtx {
|
||||
if len(seq.pendingInputs) == 0 {
|
||||
err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
break
|
||||
}
|
||||
|
@ -403,15 +405,12 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
|||
}
|
||||
|
||||
crossAttention = seq.crossAttention
|
||||
batch.Add(input.token, input.embed, len(seq.cache.Inputs), i+1 == len(seq.inputs), seq.cache.Id)
|
||||
seq.cache.Inputs = append(seq.cache.Inputs, input)
|
||||
numInputsProcessed++
|
||||
}
|
||||
|
||||
if numInputsProcessed > 0 {
|
||||
seq.inputs = seq.inputs[numInputsProcessed:]
|
||||
batch.Add(input.token, input.embed, len(seq.cache.Inputs)+len(seq.pendingInputs), i+1 == len(seq.inputs), seq.cache.Id)
|
||||
seq.pendingInputs = append(seq.pendingInputs, input)
|
||||
seq.iBatch = batch.NumTokens() - 1
|
||||
}
|
||||
|
||||
seq.inputs = seq.inputs[len(seq.pendingInputs):]
|
||||
}
|
||||
|
||||
if batch == nil || batch.NumTokens() == 0 {
|
||||
|
@ -444,6 +443,12 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
|||
continue
|
||||
}
|
||||
|
||||
// After calling Decode, pending inputs are now in the cache
|
||||
if len(seq.pendingInputs) > 0 {
|
||||
seq.cache.Inputs = append(seq.cache.Inputs, seq.pendingInputs...)
|
||||
seq.pendingInputs = []input{}
|
||||
}
|
||||
|
||||
// don't sample prompt processing
|
||||
if len(seq.inputs) != 0 {
|
||||
continue
|
||||
|
|
Loading…
Reference in a new issue