From 17b386a891af182650f93d528ff78f2fded9efc6 Mon Sep 17 00:00:00 2001 From: Jesse Gross Date: Tue, 12 Nov 2024 11:23:46 -0800 Subject: [PATCH] runner.go: Enforce NUM_PARALLEL directly in the runner NUM_PARALEL is currently enforced by the Ollama server process - it will only issue requests to the runner if the maximum number of concurrent requests has not been exceeded. Although this should be sufficient, it is good for the runner to protect its own data structures. Currently, if too many requests get through to the runner, they will just get stuck and never return. This may help with reports of Ollama hanging, though it is unclear how it would actually occur. Bug #7573 --- llama/runner/runner.go | 69 +++++++++++++++++++++++++++++------------- 1 file changed, 48 insertions(+), 21 deletions(-) diff --git a/llama/runner/runner.go b/llama/runner/runner.go index e65bd637..c034bc46 100644 --- a/llama/runner/runner.go +++ b/llama/runner/runner.go @@ -20,6 +20,8 @@ import ( "time" "unicode/utf8" + "golang.org/x/sync/semaphore" + "github.com/ollama/ollama/api" "github.com/ollama/ollama/llama" ) @@ -203,38 +205,51 @@ func (s *Server) inputs(prompt string, images []ImageData) ([]input, error) { } type Server struct { - model *llama.Model - lc *llama.Context + // is the server ready to process requests? + // protects access to model and image + ready sync.WaitGroup - // required for image embeddings + // loaded model + model *llama.Model + + // image model context for multi-modal models image *ImageContext + // status for external health reporting - loading, ready to serve, etc. + status ServerStatus + + // current progress on loading the model + progress float32 + + // number of simultaneous requests to handle + parallel int + + // maximum number of elements in a batch (per sequence) // TODO (jmorganca): make this n_batch batchSize int - // parallel is the number of parallel requests to handle - parallel int + // protects access to everything below this line + // this is context state needed for decoding + mu sync.Mutex - // seqs is the list of parallel sequences being evaluated - // TODO (jmorganca): this can probably be moved into run() + // indicates that data is ready for processing + cond *sync.Cond + + // decoding state + lc *llama.Context + + // the list of simultaneous sequences being evaluated seqs []*Sequence + // seqs can have a maximum of parallel entries, which + // is enfoced by seqSem + seqsSem *semaphore.Weighted + // KV cache cache *InputCache // next sequence for prompt processing to avoid starvation nextSeq int - - // is the server ready to process requests? - ready sync.WaitGroup - - mu sync.Mutex - - cond *sync.Cond - - progress float32 - - status ServerStatus } func (s *Server) allNil() bool { @@ -616,8 +631,13 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { return } - // TODO (jmorganca): add to sequence queue instead of - // failing if a slot isn't available + // Ensure that a place to put the sequence is available + if err := s.seqsSem.Acquire(r.Context(), 1); err != nil { + slog.Error("Failed to acquire semaphore", "error", err) + return + } + defer s.seqsSem.Release(1) + s.mu.Lock() for i, sq := range s.seqs { if sq == nil { @@ -700,7 +720,13 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) { return } - // TODO (jessegross): Wait for a free slot instead of failing and blocking forever + // Ensure that a place to put the sequence is available + if err := s.seqsSem.Acquire(r.Context(), 1); err != nil { + slog.Error("Failed to acquire semaphore", "error", err) + return + } + defer s.seqsSem.Release(1) + s.mu.Lock() for i, sq := range s.seqs { if sq == nil { @@ -855,6 +881,7 @@ func main() { batchSize: *batchSize, parallel: *parallel, seqs: make([]*Sequence, *parallel), + seqsSem: semaphore.NewWeighted(int64(*parallel)), status: ServerStatusLoadingModel, }