runner.go: Handle truncation of tokens for stop sequences

When a single token contains both text to be return and a stop
sequence, this causes an out of bounds error when we update the
cache to match our text. This is because we currently assume that
the removing the stop sequence will consume at least one token.

This also inverts the logic to deal with positive numbers, rather
than a value to be subtracted, which is easier to reason about.

Fixes #7153
This commit is contained in:
Jesse Gross 2024-10-09 16:12:23 -07:00 committed by Jesse Gross
parent 03408f3437
commit 0077e22d52
3 changed files with 60 additions and 33 deletions

View file

@ -451,14 +451,27 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
sequence := strings.Join(seq.pendingResponses, "") sequence := strings.Join(seq.pendingResponses, "")
if ok, stop := findStop(sequence, seq.stop); ok { if ok, stop := findStop(sequence, seq.stop); ok {
slog.Debug("hit stop token", "stop", seq.stop) slog.Debug("hit stop token", "pending", seq.pendingResponses, "stop", stop)
trimCacheLen := len(seq.pendingResponses) - 1 var tokenTruncated bool
seq.pendingResponses = truncateStop(seq.pendingResponses, stop) origLen := len(seq.pendingResponses)
trimCacheLen -= len(seq.pendingResponses) seq.pendingResponses, tokenTruncated = truncateStop(seq.pendingResponses, stop)
newLen := len(seq.pendingResponses)
// Update the cache based on the tokens that will be returned:
// - We have 1 token more than is currently in the cache because
// the last one generated wasn't submitted to Decode
// - Remove any stop sequences that we stripped out
// - If truncateStop removed a portion of a token, drop that
// - As defense-in-depth, if truncatedToken didn't find a stop token
// remove the extra one that we added to the cache len
tokenLen := len(seq.cache.Inputs) + 1
tokenLen -= origLen - newLen
if tokenTruncated || origLen == newLen {
tokenLen--
}
seq.cache.Inputs = seq.cache.Inputs[:tokenLen]
// remove any tokens from the cache that we don't actually return
seq.cache.Inputs = seq.cache.Inputs[:len(seq.cache.Inputs)-trimCacheLen]
s.removeSequence(i, "stop") s.removeSequence(i, "stop")
continue continue
} }

View file

@ -28,13 +28,13 @@ func containsStopSuffix(sequence string, stops []string) bool {
// truncateStop removes the provided stop string from pieces, // truncateStop removes the provided stop string from pieces,
// returning the partial pieces with stop removed, including truncating // returning the partial pieces with stop removed, including truncating
// the last piece if required // the last piece if required (and signalling if this was the case)
func truncateStop(pieces []string, stop string) []string { func truncateStop(pieces []string, stop string) ([]string, bool) {
joined := strings.Join(pieces, "") joined := strings.Join(pieces, "")
index := strings.Index(joined, stop) index := strings.Index(joined, stop)
if index == -1 { if index == -1 {
return pieces return pieces, false
} }
joined = joined[:index] joined = joined[:index]
@ -46,6 +46,7 @@ func truncateStop(pieces []string, stop string) []string {
} }
var result []string var result []string
tokenTruncated := false
start := 0 start := 0
for _, length := range lengths { for _, length := range lengths {
if start >= len(joined) { if start >= len(joined) {
@ -55,12 +56,13 @@ func truncateStop(pieces []string, stop string) []string {
end := start + length end := start + length
if end > len(joined) { if end > len(joined) {
end = len(joined) end = len(joined)
tokenTruncated = true
} }
result = append(result, joined[start:end]) result = append(result, joined[start:end])
start = end start = end
} }
return result return result, tokenTruncated
} }
func incompleteUnicode(token string) bool { func incompleteUnicode(token string) bool {

View file

@ -7,42 +7,54 @@ import (
func TestTruncateStop(t *testing.T) { func TestTruncateStop(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
pieces []string pieces []string
stop string stop string
expected []string expected []string
expectedTrunc bool
}{ }{
{ {
name: "Single word", name: "Single word",
pieces: []string{"hello", "world"}, pieces: []string{"hello", "world"},
stop: "world", stop: "world",
expected: []string{"hello"}, expected: []string{"hello"},
expectedTrunc: false,
}, },
{ {
name: "Partial", name: "Partial",
pieces: []string{"hello", "wor"}, pieces: []string{"hello", "wor"},
stop: "or", stop: "or",
expected: []string{"hello", "w"}, expected: []string{"hello", "w"},
expectedTrunc: true,
}, },
{ {
name: "Suffix", name: "Suffix",
pieces: []string{"Hello", " there", "!"}, pieces: []string{"Hello", " there", "!"},
stop: "!", stop: "!",
expected: []string{"Hello", " there"}, expected: []string{"Hello", " there"},
expectedTrunc: false,
}, },
{ {
name: "Middle", name: "Suffix partial",
pieces: []string{"hello", " wor"}, pieces: []string{"Hello", " the", "re!"},
stop: "llo w", stop: "there!",
expected: []string{"he"}, expected: []string{"Hello", " "},
expectedTrunc: true,
},
{
name: "Middle",
pieces: []string{"hello", " wor"},
stop: "llo w",
expected: []string{"he"},
expectedTrunc: true,
}, },
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
result := truncateStop(tt.pieces, tt.stop) result, resultTrunc := truncateStop(tt.pieces, tt.stop)
if !reflect.DeepEqual(result, tt.expected) { if !reflect.DeepEqual(result, tt.expected) || resultTrunc != tt.expectedTrunc {
t.Errorf("truncateStop(%v, %s): have %v; want %v", tt.pieces, tt.stop, result, tt.expected) t.Errorf("truncateStop(%v, %s): have %v (%v); want %v (%v)", tt.pieces, tt.stop, result, resultTrunc, tt.expected, tt.expectedTrunc)
} }
}) })
} }