diff --git a/llama/runner/runner.go b/llama/runner/runner.go index c3d0353f..db8092f3 100644 --- a/llama/runner/runner.go +++ b/llama/runner/runner.go @@ -300,6 +300,7 @@ func (s *Server) removeSequence(seqIndex int, reason string) { close(seq.embedding) seq.cache.InUse = false s.seqs[seqIndex] = nil + s.seqsSem.Release(1) } func (s *Server) run(ctx context.Context) { @@ -649,7 +650,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { return } - // Ensure that a place to put the sequence is available + // Ensure there is a place to put the sequence, released when removed from s.seqs if err := s.seqsSem.Acquire(r.Context(), 1); err != nil { if errors.Is(err, context.Canceled) { slog.Info("aborting completion request due to client closing the connection") @@ -658,9 +659,9 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { } return } - defer s.seqsSem.Release(1) s.mu.Lock() + found := false for i, sq := range s.seqs { if sq == nil { seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt) @@ -674,11 +675,17 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { s.seqs[i] = seq s.cond.Signal() + found = true break } } s.mu.Unlock() + if !found { + http.Error(w, "could not find an available sequence", http.StatusInternalServerError) + return + } + for { select { case <-r.Context().Done(): @@ -742,7 +749,7 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) { return } - // Ensure that a place to put the sequence is available + // Ensure there is a place to put the sequence, released when removed from s.seqs if err := s.seqsSem.Acquire(r.Context(), 1); err != nil { if errors.Is(err, context.Canceled) { slog.Info("aborting embeddings request due to client closing the connection") @@ -751,9 +758,9 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) { } return } - defer s.seqsSem.Release(1) s.mu.Lock() + found := false for i, sq := range s.seqs { if sq == nil { seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt) @@ -764,11 +771,17 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) { } s.seqs[i] = seq s.cond.Signal() + found = true break } } s.mu.Unlock() + if !found { + http.Error(w, "could not find an available sequence", http.StatusInternalServerError) + return + } + embedding := <-seq.embedding if err := json.NewEncoder(w).Encode(&EmbeddingResponse{