diff --git a/llama/runner/runner.go b/llama/runner/runner.go index e65bd637..c034bc46 100644 --- a/llama/runner/runner.go +++ b/llama/runner/runner.go @@ -20,6 +20,8 @@ import ( "time" "unicode/utf8" + "golang.org/x/sync/semaphore" + "github.com/ollama/ollama/api" "github.com/ollama/ollama/llama" ) @@ -203,38 +205,51 @@ func (s *Server) inputs(prompt string, images []ImageData) ([]input, error) { } type Server struct { - model *llama.Model - lc *llama.Context + // is the server ready to process requests? + // protects access to model and image + ready sync.WaitGroup - // required for image embeddings + // loaded model + model *llama.Model + + // image model context for multi-modal models image *ImageContext + // status for external health reporting - loading, ready to serve, etc. + status ServerStatus + + // current progress on loading the model + progress float32 + + // number of simultaneous requests to handle + parallel int + + // maximum number of elements in a batch (per sequence) // TODO (jmorganca): make this n_batch batchSize int - // parallel is the number of parallel requests to handle - parallel int + // protects access to everything below this line + // this is context state needed for decoding + mu sync.Mutex - // seqs is the list of parallel sequences being evaluated - // TODO (jmorganca): this can probably be moved into run() + // indicates that data is ready for processing + cond *sync.Cond + + // decoding state + lc *llama.Context + + // the list of simultaneous sequences being evaluated seqs []*Sequence + // seqs can have a maximum of parallel entries, which + // is enfoced by seqSem + seqsSem *semaphore.Weighted + // KV cache cache *InputCache // next sequence for prompt processing to avoid starvation nextSeq int - - // is the server ready to process requests? - ready sync.WaitGroup - - mu sync.Mutex - - cond *sync.Cond - - progress float32 - - status ServerStatus } func (s *Server) allNil() bool { @@ -616,8 +631,13 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { return } - // TODO (jmorganca): add to sequence queue instead of - // failing if a slot isn't available + // Ensure that a place to put the sequence is available + if err := s.seqsSem.Acquire(r.Context(), 1); err != nil { + slog.Error("Failed to acquire semaphore", "error", err) + return + } + defer s.seqsSem.Release(1) + s.mu.Lock() for i, sq := range s.seqs { if sq == nil { @@ -700,7 +720,13 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) { return } - // TODO (jessegross): Wait for a free slot instead of failing and blocking forever + // Ensure that a place to put the sequence is available + if err := s.seqsSem.Acquire(r.Context(), 1); err != nil { + slog.Error("Failed to acquire semaphore", "error", err) + return + } + defer s.seqsSem.Release(1) + s.mu.Lock() for i, sq := range s.seqs { if sq == nil { @@ -855,6 +881,7 @@ func main() { batchSize: *batchSize, parallel: *parallel, seqs: make([]*Sequence, *parallel), + seqsSem: semaphore.NewWeighted(int64(*parallel)), status: ServerStatusLoadingModel, }