runner.go: Hard fail on errors rather than potentially infinite looping

We try to recover from errors by dropping the tokens that caused the
problem and re-trying. However, dropping the tokens is not correct
and continuing often leads to infinite loops. To avoid, this we
end the sequence if such a condition is detected, which is also
surprising.

At this point, it is better to just report the error. This will make
it easier to find problems and the alternatives are perhaps even more
surprising to users.

This is not a very satisfactory solution either - we should isolate
the error and return it to the user without killing the whole process.
However, this is an incremental step and consistent with most other
failures (which either manifest as abort() or panic).
This commit is contained in:
Jesse Gross 2024-11-19 10:55:29 -08:00 committed by Jesse Gross
parent 7121dfa309
commit 3fc1dc0e6f

View file

@ -324,7 +324,11 @@ func (s *Server) run(ctx context.Context) {
case <-ctx.Done(): case <-ctx.Done():
return return
default: default:
s.processBatch(tokenBatch, embedBatch) err := s.processBatch(tokenBatch, embedBatch)
if err != nil {
panic(err)
}
tokenBatch.Clear() tokenBatch.Clear()
embedBatch.Clear() embedBatch.Clear()
} }
@ -338,7 +342,7 @@ func (s *Server) run(ctx context.Context) {
// these should instead be handled by the handlers // these should instead be handled by the handlers
// it should only be responsible for accepting tokens or embeddings and // it should only be responsible for accepting tokens or embeddings and
// processing batches as fast as possible // processing batches as fast as possible
func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) { func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) error {
s.mu.Lock() s.mu.Lock()
for s.allNil() { for s.allNil() {
s.cond.Wait() // Wait until an item is added s.cond.Wait() // Wait until an item is added
@ -357,14 +361,6 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
continue continue
} }
// If an error occurred during the processing of a previous batch then we may have emptied the inputs
// without adding a new one. In this case, end the sequence rather than infinite looping.
if len(seq.inputs) == 0 {
slog.Error("removing sequence due to no input tokens", "index", seqIdx, "cache id", seq.cache.Id)
s.removeSequence(seqIdx, "error")
continue
}
// if past the num predict limit // if past the num predict limit
if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict { if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict {
s.removeSequence(seqIdx, "limit") s.removeSequence(seqIdx, "limit")
@ -419,7 +415,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
} }
if batch == nil || batch.NumTokens() == 0 { if batch == nil || batch.NumTokens() == 0 {
return return nil
} }
s.lc.SetCrossAttention(crossAttention) s.lc.SetCrossAttention(crossAttention)
@ -432,8 +428,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
err = s.lc.Decode(batch) err = s.lc.Decode(batch)
} }
if err != nil { if err != nil {
slog.Error("failed to decode batch", "error", err) return fmt.Errorf("failed to decode batch: %w", err)
return
} }
} }
@ -531,6 +526,8 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
s.removeSequence(i, "connection") s.removeSequence(i, "connection")
} }
} }
return nil
} }
// TODO (jmorganca): use structs from the api package to avoid duplication // TODO (jmorganca): use structs from the api package to avoid duplication