From 3fc1dc0e6f32a22063db22a4dc72a75f8411a663 Mon Sep 17 00:00:00 2001 From: Jesse Gross Date: Tue, 19 Nov 2024 10:55:29 -0800 Subject: [PATCH] runner.go: Hard fail on errors rather than potentially infinite looping We try to recover from errors by dropping the tokens that caused the problem and re-trying. However, dropping the tokens is not correct and continuing often leads to infinite loops. To avoid, this we end the sequence if such a condition is detected, which is also surprising. At this point, it is better to just report the error. This will make it easier to find problems and the alternatives are perhaps even more surprising to users. This is not a very satisfactory solution either - we should isolate the error and return it to the user without killing the whole process. However, this is an incremental step and consistent with most other failures (which either manifest as abort() or panic). --- llama/runner/runner.go | 23 ++++++++++------------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/llama/runner/runner.go b/llama/runner/runner.go index a41573ae..a38cce91 100644 --- a/llama/runner/runner.go +++ b/llama/runner/runner.go @@ -324,7 +324,11 @@ func (s *Server) run(ctx context.Context) { case <-ctx.Done(): return default: - s.processBatch(tokenBatch, embedBatch) + err := s.processBatch(tokenBatch, embedBatch) + if err != nil { + panic(err) + } + tokenBatch.Clear() embedBatch.Clear() } @@ -338,7 +342,7 @@ func (s *Server) run(ctx context.Context) { // these should instead be handled by the handlers // it should only be responsible for accepting tokens or embeddings and // processing batches as fast as possible -func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) { +func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) error { s.mu.Lock() for s.allNil() { s.cond.Wait() // Wait until an item is added @@ -357,14 +361,6 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) continue } - // If an error occurred during the processing of a previous batch then we may have emptied the inputs - // without adding a new one. In this case, end the sequence rather than infinite looping. - if len(seq.inputs) == 0 { - slog.Error("removing sequence due to no input tokens", "index", seqIdx, "cache id", seq.cache.Id) - s.removeSequence(seqIdx, "error") - continue - } - // if past the num predict limit if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict { s.removeSequence(seqIdx, "limit") @@ -419,7 +415,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) } if batch == nil || batch.NumTokens() == 0 { - return + return nil } s.lc.SetCrossAttention(crossAttention) @@ -432,8 +428,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) err = s.lc.Decode(batch) } if err != nil { - slog.Error("failed to decode batch", "error", err) - return + return fmt.Errorf("failed to decode batch: %w", err) } } @@ -531,6 +526,8 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) s.removeSequence(i, "connection") } } + + return nil } // TODO (jmorganca): use structs from the api package to avoid duplication