diff --git a/envconfig/config.go b/envconfig/config.go index c02c4878..105b9af6 100644 --- a/envconfig/config.go +++ b/envconfig/config.go @@ -4,12 +4,14 @@ import ( "errors" "fmt" "log/slog" + "math" "net" "os" "path/filepath" "runtime" "strconv" "strings" + "time" ) type OllamaHost struct { @@ -34,7 +36,7 @@ var ( // Set via OLLAMA_HOST in the environment Host *OllamaHost // Set via OLLAMA_KEEP_ALIVE in the environment - KeepAlive string + KeepAlive time.Duration // Set via OLLAMA_LLM_LIBRARY in the environment LLMLibrary string // Set via OLLAMA_MAX_LOADED_MODELS in the environment @@ -132,6 +134,7 @@ func init() { NumParallel = 0 // Autoselect MaxRunners = 0 // Autoselect MaxQueuedRequests = 512 + KeepAlive = 5 * time.Minute LoadConfig() } @@ -266,7 +269,10 @@ func LoadConfig() { } } - KeepAlive = clean("OLLAMA_KEEP_ALIVE") + ka := clean("OLLAMA_KEEP_ALIVE") + if ka != "" { + loadKeepAlive(ka) + } var err error ModelsDir, err = getModelsDir() @@ -344,3 +350,24 @@ func getOllamaHost() (*OllamaHost, error) { Port: port, }, nil } + +func loadKeepAlive(ka string) { + v, err := strconv.Atoi(ka) + if err != nil { + d, err := time.ParseDuration(ka) + if err == nil { + if d < 0 { + KeepAlive = time.Duration(math.MaxInt64) + } else { + KeepAlive = d + } + } + } else { + d := time.Duration(v) * time.Second + if d < 0 { + KeepAlive = time.Duration(math.MaxInt64) + } else { + KeepAlive = d + } + } +} diff --git a/envconfig/config_test.go b/envconfig/config_test.go index 7d923d62..a5d73fd7 100644 --- a/envconfig/config_test.go +++ b/envconfig/config_test.go @@ -2,8 +2,10 @@ package envconfig import ( "fmt" + "math" "net" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -23,6 +25,21 @@ func TestConfig(t *testing.T) { t.Setenv("OLLAMA_FLASH_ATTENTION", "1") LoadConfig() require.True(t, FlashAttention) + t.Setenv("OLLAMA_KEEP_ALIVE", "") + LoadConfig() + require.Equal(t, 5*time.Minute, KeepAlive) + t.Setenv("OLLAMA_KEEP_ALIVE", "3") + LoadConfig() + require.Equal(t, 3*time.Second, KeepAlive) + t.Setenv("OLLAMA_KEEP_ALIVE", "1h") + LoadConfig() + require.Equal(t, 1*time.Hour, KeepAlive) + t.Setenv("OLLAMA_KEEP_ALIVE", "-1s") + LoadConfig() + require.Equal(t, time.Duration(math.MaxInt64), KeepAlive) + t.Setenv("OLLAMA_KEEP_ALIVE", "-1") + LoadConfig() + require.Equal(t, time.Duration(math.MaxInt64), KeepAlive) } func TestClientFromEnvironment(t *testing.T) { diff --git a/server/routes.go b/server/routes.go index b14a146c..ac6b713a 100644 --- a/server/routes.go +++ b/server/routes.go @@ -9,7 +9,6 @@ import ( "io" "io/fs" "log/slog" - "math" "net" "net/http" "net/netip" @@ -17,7 +16,6 @@ import ( "os/signal" "path/filepath" "slices" - "strconv" "strings" "syscall" "time" @@ -56,8 +54,6 @@ func init() { gin.SetMode(mode) } -var defaultSessionDuration = 5 * time.Minute - func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) { opts := api.DefaultOptions() if err := opts.FromMap(model.Options); err != nil { @@ -133,14 +129,7 @@ func (s *Server) GenerateHandler(c *gin.Context) { return } - var sessionDuration time.Duration - if req.KeepAlive == nil { - sessionDuration = getDefaultSessionDuration() - } else { - sessionDuration = req.KeepAlive.Duration - } - - rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration) + rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, req.KeepAlive) var runner *runnerRef select { case runner = <-rCh: @@ -320,32 +309,6 @@ func (s *Server) GenerateHandler(c *gin.Context) { streamResponse(c, ch) } -func getDefaultSessionDuration() time.Duration { - if envconfig.KeepAlive != "" { - v, err := strconv.Atoi(envconfig.KeepAlive) - if err != nil { - d, err := time.ParseDuration(envconfig.KeepAlive) - if err != nil { - return defaultSessionDuration - } - - if d < 0 { - return time.Duration(math.MaxInt64) - } - - return d - } - - d := time.Duration(v) * time.Second - if d < 0 { - return time.Duration(math.MaxInt64) - } - return d - } - - return defaultSessionDuration -} - func (s *Server) EmbeddingsHandler(c *gin.Context) { var req api.EmbeddingRequest err := c.ShouldBindJSON(&req) @@ -380,14 +343,7 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) { return } - var sessionDuration time.Duration - if req.KeepAlive == nil { - sessionDuration = getDefaultSessionDuration() - } else { - sessionDuration = req.KeepAlive.Duration - } - - rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration) + rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, req.KeepAlive) var runner *runnerRef select { case runner = <-rCh: @@ -1318,14 +1274,7 @@ func (s *Server) ChatHandler(c *gin.Context) { return } - var sessionDuration time.Duration - if req.KeepAlive == nil { - sessionDuration = getDefaultSessionDuration() - } else { - sessionDuration = req.KeepAlive.Duration - } - - rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration) + rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, req.KeepAlive) var runner *runnerRef select { case runner = <-rCh: diff --git a/server/sched.go b/server/sched.go index 71b535ae..dc492cfb 100644 --- a/server/sched.go +++ b/server/sched.go @@ -24,7 +24,7 @@ type LlmRequest struct { model *Model opts api.Options origNumCtx int // Track the initial ctx request - sessionDuration time.Duration + sessionDuration *api.Duration successCh chan *runnerRef errCh chan error schedAttempts uint @@ -75,7 +75,7 @@ func InitScheduler(ctx context.Context) *Scheduler { } // context must be canceled to decrement ref count and release the runner -func (s *Scheduler) GetRunner(c context.Context, model *Model, opts api.Options, sessionDuration time.Duration) (chan *runnerRef, chan error) { +func (s *Scheduler) GetRunner(c context.Context, model *Model, opts api.Options, sessionDuration *api.Duration) (chan *runnerRef, chan error) { if opts.NumCtx < 4 { opts.NumCtx = 4 } @@ -389,7 +389,9 @@ func (pending *LlmRequest) useLoadedRunner(runner *runnerRef, finished chan *Llm runner.expireTimer.Stop() runner.expireTimer = nil } - runner.sessionDuration = pending.sessionDuration + if pending.sessionDuration != nil { + runner.sessionDuration = pending.sessionDuration.Duration + } pending.successCh <- runner go func() { <-pending.ctx.Done() @@ -402,6 +404,10 @@ func (s *Scheduler) load(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList, if numParallel < 1 { numParallel = 1 } + sessionDuration := envconfig.KeepAlive + if req.sessionDuration != nil { + sessionDuration = req.sessionDuration.Duration + } llama, err := s.newServerFn(gpus, req.model.ModelPath, ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts, numParallel) if err != nil { // some older models are not compatible with newer versions of llama.cpp @@ -419,7 +425,7 @@ func (s *Scheduler) load(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList, modelPath: req.model.ModelPath, llama: llama, Options: &req.opts, - sessionDuration: req.sessionDuration, + sessionDuration: sessionDuration, gpus: gpus, estimatedVRAM: llama.EstimatedVRAM(), estimatedTotal: llama.EstimatedTotal(), diff --git a/server/sched_test.go b/server/sched_test.go index be0830a3..d957927e 100644 --- a/server/sched_test.go +++ b/server/sched_test.go @@ -44,7 +44,7 @@ func TestLoad(t *testing.T) { opts: api.DefaultOptions(), successCh: make(chan *runnerRef, 1), errCh: make(chan error, 1), - sessionDuration: 2, + sessionDuration: &api.Duration{Duration: 2 * time.Second}, } // Fail to load model first s.newServerFn = func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) { @@ -142,7 +142,7 @@ func newScenario(t *testing.T, ctx context.Context, modelName string, estimatedV ctx: scenario.ctx, model: model, opts: api.DefaultOptions(), - sessionDuration: 5 * time.Millisecond, + sessionDuration: &api.Duration{Duration: 5 * time.Millisecond}, successCh: make(chan *runnerRef, 1), errCh: make(chan error, 1), } @@ -156,18 +156,18 @@ func TestRequests(t *testing.T) { // Same model, same request scenario1a := newScenario(t, ctx, "ollama-model-1", 10) - scenario1a.req.sessionDuration = 5 * time.Millisecond + scenario1a.req.sessionDuration = &api.Duration{Duration: 5 * time.Millisecond} scenario1b := newScenario(t, ctx, "ollama-model-1", 11) scenario1b.req.model = scenario1a.req.model scenario1b.ggml = scenario1a.ggml - scenario1b.req.sessionDuration = 0 + scenario1b.req.sessionDuration = &api.Duration{Duration: 0} // simple reload of same model scenario2a := newScenario(t, ctx, "ollama-model-1", 20) tmpModel := *scenario1a.req.model scenario2a.req.model = &tmpModel scenario2a.ggml = scenario1a.ggml - scenario2a.req.sessionDuration = 5 * time.Millisecond + scenario2a.req.sessionDuration = &api.Duration{Duration: 5 * time.Millisecond} // Multiple loaded models scenario3a := newScenario(t, ctx, "ollama-model-3a", 1*format.GigaByte) @@ -318,11 +318,11 @@ func TestGetRunner(t *testing.T) { defer done() scenario1a := newScenario(t, ctx, "ollama-model-1a", 10) - scenario1a.req.sessionDuration = 0 + scenario1a.req.sessionDuration = &api.Duration{Duration: 0} scenario1b := newScenario(t, ctx, "ollama-model-1b", 10) - scenario1b.req.sessionDuration = 0 + scenario1b.req.sessionDuration = &api.Duration{Duration: 0} scenario1c := newScenario(t, ctx, "ollama-model-1c", 10) - scenario1c.req.sessionDuration = 0 + scenario1c.req.sessionDuration = &api.Duration{Duration: 0} envconfig.MaxQueuedRequests = 1 s := InitScheduler(ctx) s.getGpuFn = func() gpu.GpuInfoList { @@ -402,7 +402,7 @@ func TestPrematureExpired(t *testing.T) { case <-ctx.Done(): t.Fatal("timeout") } - time.Sleep(scenario1a.req.sessionDuration) + time.Sleep(scenario1a.req.sessionDuration.Duration) scenario1a.ctxDone() time.Sleep(20 * time.Millisecond) require.LessOrEqual(t, len(s.finishedReqCh), 1) @@ -423,7 +423,7 @@ func TestUseLoadedRunner(t *testing.T) { ctx: ctx, opts: api.DefaultOptions(), successCh: make(chan *runnerRef, 1), - sessionDuration: 2, + sessionDuration: &api.Duration{Duration: 2}, } finished := make(chan *LlmRequest) llm1 := &mockLlm{estimatedVRAMByGPU: map[string]uint64{}} @@ -614,7 +614,7 @@ func TestAlreadyCanceled(t *testing.T) { dctx, done2 := context.WithCancel(ctx) done2() scenario1a := newScenario(t, dctx, "ollama-model-1", 10) - scenario1a.req.sessionDuration = 0 + scenario1a.req.sessionDuration = &api.Duration{Duration: 0} s := InitScheduler(ctx) slog.Info("scenario1a") s.pendingReqCh <- scenario1a.req