runner.go: Enforce NUM_PARALLEL directly in the runner

NUM_PARALEL is currently enforced by the Ollama server process - it
will only issue requests to the runner if the maximum number of
concurrent requests has not been exceeded. Although this should
be sufficient, it is good for the runner to protect its own data
structures. Currently, if too many requests get through to the
runner, they will just get stuck and never return.

This may help with reports of Ollama hanging, though it is unclear
how it would actually occur.

Bug #7573
This commit is contained in:
Jesse Gross 2024-11-12 11:23:46 -08:00 committed by Jesse Gross
parent 549c2bdfcf
commit 17b386a891

View file

@ -20,6 +20,8 @@ import (
"time" "time"
"unicode/utf8" "unicode/utf8"
"golang.org/x/sync/semaphore"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/llama" "github.com/ollama/ollama/llama"
) )
@ -203,38 +205,51 @@ func (s *Server) inputs(prompt string, images []ImageData) ([]input, error) {
} }
type Server struct { type Server struct {
model *llama.Model // is the server ready to process requests?
lc *llama.Context // protects access to model and image
ready sync.WaitGroup
// required for image embeddings // loaded model
model *llama.Model
// image model context for multi-modal models
image *ImageContext image *ImageContext
// status for external health reporting - loading, ready to serve, etc.
status ServerStatus
// current progress on loading the model
progress float32
// number of simultaneous requests to handle
parallel int
// maximum number of elements in a batch (per sequence)
// TODO (jmorganca): make this n_batch // TODO (jmorganca): make this n_batch
batchSize int batchSize int
// parallel is the number of parallel requests to handle // protects access to everything below this line
parallel int // this is context state needed for decoding
mu sync.Mutex
// seqs is the list of parallel sequences being evaluated // indicates that data is ready for processing
// TODO (jmorganca): this can probably be moved into run() cond *sync.Cond
// decoding state
lc *llama.Context
// the list of simultaneous sequences being evaluated
seqs []*Sequence seqs []*Sequence
// seqs can have a maximum of parallel entries, which
// is enfoced by seqSem
seqsSem *semaphore.Weighted
// KV cache // KV cache
cache *InputCache cache *InputCache
// next sequence for prompt processing to avoid starvation // next sequence for prompt processing to avoid starvation
nextSeq int nextSeq int
// is the server ready to process requests?
ready sync.WaitGroup
mu sync.Mutex
cond *sync.Cond
progress float32
status ServerStatus
} }
func (s *Server) allNil() bool { func (s *Server) allNil() bool {
@ -616,8 +631,13 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
return return
} }
// TODO (jmorganca): add to sequence queue instead of // Ensure that a place to put the sequence is available
// failing if a slot isn't available if err := s.seqsSem.Acquire(r.Context(), 1); err != nil {
slog.Error("Failed to acquire semaphore", "error", err)
return
}
defer s.seqsSem.Release(1)
s.mu.Lock() s.mu.Lock()
for i, sq := range s.seqs { for i, sq := range s.seqs {
if sq == nil { if sq == nil {
@ -700,7 +720,13 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
return return
} }
// TODO (jessegross): Wait for a free slot instead of failing and blocking forever // Ensure that a place to put the sequence is available
if err := s.seqsSem.Acquire(r.Context(), 1); err != nil {
slog.Error("Failed to acquire semaphore", "error", err)
return
}
defer s.seqsSem.Release(1)
s.mu.Lock() s.mu.Lock()
for i, sq := range s.seqs { for i, sq := range s.seqs {
if sq == nil { if sq == nil {
@ -855,6 +881,7 @@ func main() {
batchSize: *batchSize, batchSize: *batchSize,
parallel: *parallel, parallel: *parallel,
seqs: make([]*Sequence, *parallel), seqs: make([]*Sequence, *parallel),
seqsSem: semaphore.NewWeighted(int64(*parallel)),
status: ServerStatusLoadingModel, status: ServerStatusLoadingModel,
} }