Merge pull request #5410 from dhiltgen/ctx_cleanup

Fix case for NumCtx
This commit is contained in:
Daniel Hiltgen 2024-07-01 09:54:20 -07:00 committed by GitHub
commit e70610ef06
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 9 additions and 9 deletions

View file

@ -85,13 +85,13 @@ func AsMap() map[string]EnvVar {
"OLLAMA_HOST": {"OLLAMA_HOST", Host, "IP Address for the ollama server (default 127.0.0.1:11434)"}, "OLLAMA_HOST": {"OLLAMA_HOST", Host, "IP Address for the ollama server (default 127.0.0.1:11434)"},
"OLLAMA_KEEP_ALIVE": {"OLLAMA_KEEP_ALIVE", KeepAlive, "The duration that models stay loaded in memory (default \"5m\")"}, "OLLAMA_KEEP_ALIVE": {"OLLAMA_KEEP_ALIVE", KeepAlive, "The duration that models stay loaded in memory (default \"5m\")"},
"OLLAMA_LLM_LIBRARY": {"OLLAMA_LLM_LIBRARY", LLMLibrary, "Set LLM library to bypass autodetection"}, "OLLAMA_LLM_LIBRARY": {"OLLAMA_LLM_LIBRARY", LLMLibrary, "Set LLM library to bypass autodetection"},
"OLLAMA_MAX_LOADED_MODELS": {"OLLAMA_MAX_LOADED_MODELS", MaxRunners, "Maximum number of loaded models per GPU (default auto)"}, "OLLAMA_MAX_LOADED_MODELS": {"OLLAMA_MAX_LOADED_MODELS", MaxRunners, "Maximum number of loaded models per GPU"},
"OLLAMA_MAX_QUEUE": {"OLLAMA_MAX_QUEUE", MaxQueuedRequests, "Maximum number of queued requests"}, "OLLAMA_MAX_QUEUE": {"OLLAMA_MAX_QUEUE", MaxQueuedRequests, "Maximum number of queued requests"},
"OLLAMA_MAX_VRAM": {"OLLAMA_MAX_VRAM", MaxVRAM, "Maximum VRAM"}, "OLLAMA_MAX_VRAM": {"OLLAMA_MAX_VRAM", MaxVRAM, "Maximum VRAM"},
"OLLAMA_MODELS": {"OLLAMA_MODELS", ModelsDir, "The path to the models directory"}, "OLLAMA_MODELS": {"OLLAMA_MODELS", ModelsDir, "The path to the models directory"},
"OLLAMA_NOHISTORY": {"OLLAMA_NOHISTORY", NoHistory, "Do not preserve readline history"}, "OLLAMA_NOHISTORY": {"OLLAMA_NOHISTORY", NoHistory, "Do not preserve readline history"},
"OLLAMA_NOPRUNE": {"OLLAMA_NOPRUNE", NoPrune, "Do not prune model blobs on startup"}, "OLLAMA_NOPRUNE": {"OLLAMA_NOPRUNE", NoPrune, "Do not prune model blobs on startup"},
"OLLAMA_NUM_PARALLEL": {"OLLAMA_NUM_PARALLEL", NumParallel, "Maximum number of parallel requests (default auto)"}, "OLLAMA_NUM_PARALLEL": {"OLLAMA_NUM_PARALLEL", NumParallel, "Maximum number of parallel requests"},
"OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", AllowOrigins, "A comma separated list of allowed origins"}, "OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", AllowOrigins, "A comma separated list of allowed origins"},
"OLLAMA_RUNNERS_DIR": {"OLLAMA_RUNNERS_DIR", RunnersDir, "Location for runners"}, "OLLAMA_RUNNERS_DIR": {"OLLAMA_RUNNERS_DIR", RunnersDir, "Location for runners"},
"OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread, "Always schedule model across all GPUs"}, "OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread, "Always schedule model across all GPUs"},

View file

@ -23,7 +23,7 @@ type LlmRequest struct {
ctx context.Context //nolint:containedctx ctx context.Context //nolint:containedctx
model *Model model *Model
opts api.Options opts api.Options
origNumCTX int // Track the initial ctx request origNumCtx int // Track the initial ctx request
sessionDuration time.Duration sessionDuration time.Duration
successCh chan *runnerRef successCh chan *runnerRef
errCh chan error errCh chan error
@ -118,8 +118,8 @@ func (s *Scheduler) processPending(ctx context.Context) {
case pending := <-s.pendingReqCh: case pending := <-s.pendingReqCh:
// Block other requests until we get this pending request running // Block other requests until we get this pending request running
pending.schedAttempts++ pending.schedAttempts++
if pending.origNumCTX == 0 { if pending.origNumCtx == 0 {
pending.origNumCTX = pending.opts.NumCtx pending.origNumCtx = pending.opts.NumCtx
} }
if pending.ctx.Err() != nil { if pending.ctx.Err() != nil {
@ -135,7 +135,7 @@ func (s *Scheduler) processPending(ctx context.Context) {
} }
// Keep NumCtx and numParallel in sync // Keep NumCtx and numParallel in sync
if numParallel > 1 { if numParallel > 1 {
pending.opts.NumCtx = pending.origNumCTX * numParallel pending.opts.NumCtx = pending.origNumCtx * numParallel
} }
for { for {
@ -197,7 +197,7 @@ func (s *Scheduler) processPending(ctx context.Context) {
// simplifying assumption of defaultParallel when in CPU mode // simplifying assumption of defaultParallel when in CPU mode
if numParallel <= 0 { if numParallel <= 0 {
numParallel = defaultParallel numParallel = defaultParallel
pending.opts.NumCtx = pending.origNumCTX * numParallel pending.opts.NumCtx = pending.origNumCtx * numParallel
} }
if loadedCount == 0 { if loadedCount == 0 {
@ -691,7 +691,7 @@ func pickBestFitGPUs(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList, numP
// First attempt to fit the model into a single GPU // First attempt to fit the model into a single GPU
for _, p := range numParallelToTry { for _, p := range numParallelToTry {
req.opts.NumCtx = req.origNumCTX * p req.opts.NumCtx = req.origNumCtx * p
if !envconfig.SchedSpread { if !envconfig.SchedSpread {
for _, g := range sgl { for _, g := range sgl {
if ok, estimatedVRAM = llm.PredictServerFit([]gpu.GpuInfo{g}, ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts); ok { if ok, estimatedVRAM = llm.PredictServerFit([]gpu.GpuInfo{g}, ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts); ok {
@ -709,7 +709,7 @@ func pickBestFitGPUs(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList, numP
// Now try all the GPUs // Now try all the GPUs
for _, p := range numParallelToTry { for _, p := range numParallelToTry {
req.opts.NumCtx = req.origNumCTX * p req.opts.NumCtx = req.origNumCtx * p
if ok, estimatedVRAM = llm.PredictServerFit(sgl, ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts); ok { if ok, estimatedVRAM = llm.PredictServerFit(sgl, ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts); ok {
slog.Info("new model will fit in available VRAM, loading", "model", req.model.ModelPath, "library", sgl[0].Library, "parallel", p, "required", format.HumanBytes2(estimatedVRAM)) slog.Info("new model will fit in available VRAM, loading", "model", req.model.ModelPath, "library", sgl[0].Library, "parallel", p, "required", format.HumanBytes2(estimatedVRAM))
*numParallel = p *numParallel = p