diff --git a/llama/runner/runner.go b/llama/runner/runner.go index a41573ae..a38cce91 100644 --- a/llama/runner/runner.go +++ b/llama/runner/runner.go @@ -324,7 +324,11 @@ func (s *Server) run(ctx context.Context) { case <-ctx.Done(): return default: - s.processBatch(tokenBatch, embedBatch) + err := s.processBatch(tokenBatch, embedBatch) + if err != nil { + panic(err) + } + tokenBatch.Clear() embedBatch.Clear() } @@ -338,7 +342,7 @@ func (s *Server) run(ctx context.Context) { // these should instead be handled by the handlers // it should only be responsible for accepting tokens or embeddings and // processing batches as fast as possible -func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) { +func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) error { s.mu.Lock() for s.allNil() { s.cond.Wait() // Wait until an item is added @@ -357,14 +361,6 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) continue } - // If an error occurred during the processing of a previous batch then we may have emptied the inputs - // without adding a new one. In this case, end the sequence rather than infinite looping. - if len(seq.inputs) == 0 { - slog.Error("removing sequence due to no input tokens", "index", seqIdx, "cache id", seq.cache.Id) - s.removeSequence(seqIdx, "error") - continue - } - // if past the num predict limit if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict { s.removeSequence(seqIdx, "limit") @@ -419,7 +415,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) } if batch == nil || batch.NumTokens() == 0 { - return + return nil } s.lc.SetCrossAttention(crossAttention) @@ -432,8 +428,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) err = s.lc.Decode(batch) } if err != nil { - slog.Error("failed to decode batch", "error", err) - return + return fmt.Errorf("failed to decode batch: %w", err) } } @@ -531,6 +526,8 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) s.removeSequence(i, "connection") } } + + return nil } // TODO (jmorganca): use structs from the api package to avoid duplication