From 2cd11ae365a9423578069457312dce6b9e1e5a37 Mon Sep 17 00:00:00 2001 From: Jesse Gross Date: Mon, 25 Nov 2024 14:49:38 -0800 Subject: [PATCH] runner.go: Add unit tests for context shifting This also makes it easier to truncate long inputs the same as shifting but does not actually implement it. This type of truncation has a trade off between quality and time to first token. --- llama/runner/cache.go | 20 +++++++++--- llama/runner/cache_test.go | 63 ++++++++++++++++++++++++++++++++++++++ llama/runner/runner.go | 6 ++-- 3 files changed, 82 insertions(+), 7 deletions(-) diff --git a/llama/runner/cache.go b/llama/runner/cache.go index b487fe25..0f5f0a09 100644 --- a/llama/runner/cache.go +++ b/llama/runner/cache.go @@ -199,6 +199,20 @@ func countCommonPrefix(a []input, b []input) int { return count } +func (c *InputCache) ShiftDiscard(inputLen int, numKeep int) int { + targetFree := (c.numCtx - numKeep) / 2 + targetFree = max(targetFree, 1) + + currentFree := c.numCtx - inputLen + discard := targetFree - currentFree + + if discard < 0 { + discard = 0 + } + + return discard +} + // Frees up space in the KV cache by deleting the oldest half of history and shifting // the newest half into that space (saving numKeep inputs at the beginning). // @@ -208,11 +222,7 @@ func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int) error { return fmt.Errorf("unable to shift context - keep exceeds context (keep: %v context: %v)", numKeep, c.numCtx) } - targetFree := (c.numCtx - numKeep) / 2 - targetFree = max(targetFree, 1) - - currentFree := c.numCtx - len(slot.Inputs) - discard := targetFree - currentFree + discard := c.ShiftDiscard(len(slot.Inputs), numKeep) if discard <= 0 { return nil diff --git a/llama/runner/cache_test.go b/llama/runner/cache_test.go index 0e38c67d..79cd93cb 100644 --- a/llama/runner/cache_test.go +++ b/llama/runner/cache_test.go @@ -227,3 +227,66 @@ func TestFindCacheSlot(t *testing.T) { }) } } + +func TestShiftDiscard(t *testing.T) { + tests := []struct { + name string + numCtx int + numKeep int + inputLen int + expected int + }{ + { + name: "Shift", + numCtx: 2048, + numKeep: 5, + inputLen: 2048, + expected: 1021, + }, + { + name: "Max Keep", + numCtx: 2048, + numKeep: 2047, + inputLen: 2048, + expected: 1, + }, + { + name: "No Keep", + numCtx: 2048, + numKeep: 0, + inputLen: 2048, + expected: 1024, + }, + { + name: "Truncate", + numCtx: 2048, + numKeep: 5, + inputLen: 5000, + expected: 3973, + }, + { + name: "Truncate Keep", + numCtx: 2048, + numKeep: 2047, + inputLen: 5000, + expected: 2953, + }, + { + name: "No Op", + numCtx: 2048, + numKeep: 5, + inputLen: 512, + expected: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := InputCache{numCtx: tt.numCtx} + result := c.ShiftDiscard(tt.inputLen, tt.numKeep) + if result != tt.expected { + t.Errorf("shiftDiscard(ctx: %v, keep: %v input: %v): have %v; want %v", tt.numCtx, tt.numKeep, tt.inputLen, result, tt.expected) + } + }) + } +} diff --git a/llama/runner/runner.go b/llama/runner/runner.go index db8092f3..8762b3da 100644 --- a/llama/runner/runner.go +++ b/llama/runner/runner.go @@ -122,9 +122,11 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen params.numKeep = min(params.numKeep, s.cache.numCtx-1) if len(inputs) > s.cache.numCtx { - slog.Warn("truncating input prompt", "limit", s.cache.numCtx, "prompt", len(inputs), "numKeep", params.numKeep) + discard := len(inputs) - s.cache.numCtx newInputs := inputs[:params.numKeep] - newInputs = append(newInputs, inputs[len(inputs)-s.cache.numCtx+params.numKeep:]...) + newInputs = append(newInputs, inputs[params.numKeep+discard:]...) + + slog.Warn("truncating input prompt", "limit", s.cache.numCtx, "prompt", len(inputs), "keep", params.numKeep, "new", len(newInputs)) inputs = newInputs }