runner.go: Hard fail on errors rather than potentially infinite looping
We try to recover from errors by dropping the tokens that caused the problem and re-trying. However, dropping the tokens is not correct and continuing often leads to infinite loops. To avoid, this we end the sequence if such a condition is detected, which is also surprising. At this point, it is better to just report the error. This will make it easier to find problems and the alternatives are perhaps even more surprising to users. This is not a very satisfactory solution either - we should isolate the error and return it to the user without killing the whole process. However, this is an incremental step and consistent with most other failures (which either manifest as abort() or panic).
This commit is contained in:
parent
7121dfa309
commit
3fc1dc0e6f
1 changed files with 10 additions and 13 deletions
|
@ -324,7 +324,11 @@ func (s *Server) run(ctx context.Context) {
|
|||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
s.processBatch(tokenBatch, embedBatch)
|
||||
err := s.processBatch(tokenBatch, embedBatch)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
tokenBatch.Clear()
|
||||
embedBatch.Clear()
|
||||
}
|
||||
|
@ -338,7 +342,7 @@ func (s *Server) run(ctx context.Context) {
|
|||
// these should instead be handled by the handlers
|
||||
// it should only be responsible for accepting tokens or embeddings and
|
||||
// processing batches as fast as possible
|
||||
func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) {
|
||||
func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) error {
|
||||
s.mu.Lock()
|
||||
for s.allNil() {
|
||||
s.cond.Wait() // Wait until an item is added
|
||||
|
@ -357,14 +361,6 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
|||
continue
|
||||
}
|
||||
|
||||
// If an error occurred during the processing of a previous batch then we may have emptied the inputs
|
||||
// without adding a new one. In this case, end the sequence rather than infinite looping.
|
||||
if len(seq.inputs) == 0 {
|
||||
slog.Error("removing sequence due to no input tokens", "index", seqIdx, "cache id", seq.cache.Id)
|
||||
s.removeSequence(seqIdx, "error")
|
||||
continue
|
||||
}
|
||||
|
||||
// if past the num predict limit
|
||||
if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict {
|
||||
s.removeSequence(seqIdx, "limit")
|
||||
|
@ -419,7 +415,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
|||
}
|
||||
|
||||
if batch == nil || batch.NumTokens() == 0 {
|
||||
return
|
||||
return nil
|
||||
}
|
||||
|
||||
s.lc.SetCrossAttention(crossAttention)
|
||||
|
@ -432,8 +428,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
|||
err = s.lc.Decode(batch)
|
||||
}
|
||||
if err != nil {
|
||||
slog.Error("failed to decode batch", "error", err)
|
||||
return
|
||||
return fmt.Errorf("failed to decode batch: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -531,6 +526,8 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
|||
s.removeSequence(i, "connection")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO (jmorganca): use structs from the api package to avoid duplication
|
||||
|
|
Loading…
Reference in a new issue