runner.go: Add unit tests for context shifting
This also makes it easier to truncate long inputs the same as shifting but does not actually implement it. This type of truncation has a trade off between quality and time to first token.
This commit is contained in:
parent
52bbad12f9
commit
2cd11ae365
3 changed files with 82 additions and 7 deletions
|
@ -199,6 +199,20 @@ func countCommonPrefix(a []input, b []input) int {
|
||||||
return count
|
return count
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *InputCache) ShiftDiscard(inputLen int, numKeep int) int {
|
||||||
|
targetFree := (c.numCtx - numKeep) / 2
|
||||||
|
targetFree = max(targetFree, 1)
|
||||||
|
|
||||||
|
currentFree := c.numCtx - inputLen
|
||||||
|
discard := targetFree - currentFree
|
||||||
|
|
||||||
|
if discard < 0 {
|
||||||
|
discard = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
return discard
|
||||||
|
}
|
||||||
|
|
||||||
// Frees up space in the KV cache by deleting the oldest half of history and shifting
|
// Frees up space in the KV cache by deleting the oldest half of history and shifting
|
||||||
// the newest half into that space (saving numKeep inputs at the beginning).
|
// the newest half into that space (saving numKeep inputs at the beginning).
|
||||||
//
|
//
|
||||||
|
@ -208,11 +222,7 @@ func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int) error {
|
||||||
return fmt.Errorf("unable to shift context - keep exceeds context (keep: %v context: %v)", numKeep, c.numCtx)
|
return fmt.Errorf("unable to shift context - keep exceeds context (keep: %v context: %v)", numKeep, c.numCtx)
|
||||||
}
|
}
|
||||||
|
|
||||||
targetFree := (c.numCtx - numKeep) / 2
|
discard := c.ShiftDiscard(len(slot.Inputs), numKeep)
|
||||||
targetFree = max(targetFree, 1)
|
|
||||||
|
|
||||||
currentFree := c.numCtx - len(slot.Inputs)
|
|
||||||
discard := targetFree - currentFree
|
|
||||||
|
|
||||||
if discard <= 0 {
|
if discard <= 0 {
|
||||||
return nil
|
return nil
|
||||||
|
|
|
@ -227,3 +227,66 @@ func TestFindCacheSlot(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestShiftDiscard(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
numCtx int
|
||||||
|
numKeep int
|
||||||
|
inputLen int
|
||||||
|
expected int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Shift",
|
||||||
|
numCtx: 2048,
|
||||||
|
numKeep: 5,
|
||||||
|
inputLen: 2048,
|
||||||
|
expected: 1021,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Max Keep",
|
||||||
|
numCtx: 2048,
|
||||||
|
numKeep: 2047,
|
||||||
|
inputLen: 2048,
|
||||||
|
expected: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "No Keep",
|
||||||
|
numCtx: 2048,
|
||||||
|
numKeep: 0,
|
||||||
|
inputLen: 2048,
|
||||||
|
expected: 1024,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Truncate",
|
||||||
|
numCtx: 2048,
|
||||||
|
numKeep: 5,
|
||||||
|
inputLen: 5000,
|
||||||
|
expected: 3973,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Truncate Keep",
|
||||||
|
numCtx: 2048,
|
||||||
|
numKeep: 2047,
|
||||||
|
inputLen: 5000,
|
||||||
|
expected: 2953,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "No Op",
|
||||||
|
numCtx: 2048,
|
||||||
|
numKeep: 5,
|
||||||
|
inputLen: 512,
|
||||||
|
expected: 0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
c := InputCache{numCtx: tt.numCtx}
|
||||||
|
result := c.ShiftDiscard(tt.inputLen, tt.numKeep)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("shiftDiscard(ctx: %v, keep: %v input: %v): have %v; want %v", tt.numCtx, tt.numKeep, tt.inputLen, result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -122,9 +122,11 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
|
||||||
params.numKeep = min(params.numKeep, s.cache.numCtx-1)
|
params.numKeep = min(params.numKeep, s.cache.numCtx-1)
|
||||||
|
|
||||||
if len(inputs) > s.cache.numCtx {
|
if len(inputs) > s.cache.numCtx {
|
||||||
slog.Warn("truncating input prompt", "limit", s.cache.numCtx, "prompt", len(inputs), "numKeep", params.numKeep)
|
discard := len(inputs) - s.cache.numCtx
|
||||||
newInputs := inputs[:params.numKeep]
|
newInputs := inputs[:params.numKeep]
|
||||||
newInputs = append(newInputs, inputs[len(inputs)-s.cache.numCtx+params.numKeep:]...)
|
newInputs = append(newInputs, inputs[params.numKeep+discard:]...)
|
||||||
|
|
||||||
|
slog.Warn("truncating input prompt", "limit", s.cache.numCtx, "prompt", len(inputs), "keep", params.numKeep, "new", len(newInputs))
|
||||||
inputs = newInputs
|
inputs = newInputs
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue