diff --git a/llama/runner/runner.go b/llama/runner/runner.go index 0ae50608..3ffb57bb 100644 --- a/llama/runner/runner.go +++ b/llama/runner/runner.go @@ -14,6 +14,7 @@ import ( "path/filepath" "regexp" "runtime" + "runtime/debug" "strconv" "strings" "sync" @@ -339,6 +340,15 @@ func (s *Server) run(ctx context.Context) { // it should only be responsible for accepting tokens or embeddings and // processing batches as fast as possible func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) { + // Try to keep going even if we hit a panic so that corner cases don't take the whole + // runner down. In most cases, this will result in dropping the tokens that we are currently + // processing and then continuing with what is remaining. + defer func() { + if err := recover(); err != nil { + slog.Error("error while processing batch", "error", err, "stack", debug.Stack()) + } + }() + s.mu.Lock() for s.allNil() { s.cond.Wait() // Wait until an item is added @@ -357,6 +367,14 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) continue } + // If an error occurred during the processing of a previous batch then we may have emptied the inputs + // without adding a new one. In this case, end the sequence rather than infinite looping. + if len(seq.inputs) == 0 { + slog.Error("removing sequence due to no input tokens", "index", seqIdx, "cache id", seq.cache.Id) + s.removeSequence(seqIdx, "error") + continue + } + // if past the num predict limit if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict { s.removeSequence(seqIdx, "limit")