diff --git a/envconfig/config.go b/envconfig/config.go index ea78585b..62bfad64 100644 --- a/envconfig/config.go +++ b/envconfig/config.go @@ -99,6 +99,26 @@ func Models() string { return filepath.Join(home, ".ollama", "models") } +// KeepAlive returns the duration that models stay loaded in memory. KeepAlive can be configured via the OLLAMA_KEEP_ALIVE environment variable. +// Negative values are treated as infinite. Zero is treated as no keep alive. +// Default is 5 minutes. +func KeepAlive() (keepAlive time.Duration) { + keepAlive = 5 * time.Minute + if s := os.Getenv("OLLAMA_KEEP_ALIVE"); s != "" { + if d, err := time.ParseDuration(s); err == nil { + keepAlive = d + } else if n, err := strconv.ParseInt(s, 10, 64); err == nil { + keepAlive = time.Duration(n) * time.Second + } + } + + if keepAlive < 0 { + return time.Duration(math.MaxInt64) + } + + return keepAlive +} + func Bool(k string) func() bool { return func() bool { if s := getenv(k); s != "" { @@ -130,8 +150,6 @@ var ( ) var ( - // Set via OLLAMA_KEEP_ALIVE in the environment - KeepAlive time.Duration // Set via OLLAMA_LLM_LIBRARY in the environment LLMLibrary string // Set via OLLAMA_MAX_LOADED_MODELS in the environment @@ -168,7 +186,7 @@ func AsMap() map[string]EnvVar { "OLLAMA_DEBUG": {"OLLAMA_DEBUG", Debug(), "Show additional debug information (e.g. OLLAMA_DEBUG=1)"}, "OLLAMA_FLASH_ATTENTION": {"OLLAMA_FLASH_ATTENTION", FlashAttention(), "Enabled flash attention"}, "OLLAMA_HOST": {"OLLAMA_HOST", Host(), "IP Address for the ollama server (default 127.0.0.1:11434)"}, - "OLLAMA_KEEP_ALIVE": {"OLLAMA_KEEP_ALIVE", KeepAlive, "The duration that models stay loaded in memory (default \"5m\")"}, + "OLLAMA_KEEP_ALIVE": {"OLLAMA_KEEP_ALIVE", KeepAlive(), "The duration that models stay loaded in memory (default \"5m\")"}, "OLLAMA_LLM_LIBRARY": {"OLLAMA_LLM_LIBRARY", LLMLibrary, "Set LLM library to bypass autodetection"}, "OLLAMA_MAX_LOADED_MODELS": {"OLLAMA_MAX_LOADED_MODELS", MaxRunners, "Maximum number of loaded models per GPU"}, "OLLAMA_MAX_QUEUE": {"OLLAMA_MAX_QUEUE", MaxQueuedRequests, "Maximum number of queued requests"}, @@ -210,7 +228,6 @@ func init() { NumParallel = 0 // Autoselect MaxRunners = 0 // Autoselect MaxQueuedRequests = 512 - KeepAlive = 5 * time.Minute LoadConfig() } @@ -284,35 +301,9 @@ func LoadConfig() { } } - ka := getenv("OLLAMA_KEEP_ALIVE") - if ka != "" { - loadKeepAlive(ka) - } - CudaVisibleDevices = getenv("CUDA_VISIBLE_DEVICES") HipVisibleDevices = getenv("HIP_VISIBLE_DEVICES") RocrVisibleDevices = getenv("ROCR_VISIBLE_DEVICES") GpuDeviceOrdinal = getenv("GPU_DEVICE_ORDINAL") HsaOverrideGfxVersion = getenv("HSA_OVERRIDE_GFX_VERSION") } - -func loadKeepAlive(ka string) { - v, err := strconv.Atoi(ka) - if err != nil { - d, err := time.ParseDuration(ka) - if err == nil { - if d < 0 { - KeepAlive = time.Duration(math.MaxInt64) - } else { - KeepAlive = d - } - } - } else { - d := time.Duration(v) * time.Second - if d < 0 { - KeepAlive = time.Duration(math.MaxInt64) - } else { - KeepAlive = d - } - } -} diff --git a/envconfig/config_test.go b/envconfig/config_test.go index b364b009..87c808ca 100644 --- a/envconfig/config_test.go +++ b/envconfig/config_test.go @@ -21,22 +21,6 @@ func TestSmoke(t *testing.T) { t.Setenv("OLLAMA_FLASH_ATTENTION", "1") require.True(t, FlashAttention()) - - t.Setenv("OLLAMA_KEEP_ALIVE", "") - LoadConfig() - require.Equal(t, 5*time.Minute, KeepAlive) - t.Setenv("OLLAMA_KEEP_ALIVE", "3") - LoadConfig() - require.Equal(t, 3*time.Second, KeepAlive) - t.Setenv("OLLAMA_KEEP_ALIVE", "1h") - LoadConfig() - require.Equal(t, 1*time.Hour, KeepAlive) - t.Setenv("OLLAMA_KEEP_ALIVE", "-1s") - LoadConfig() - require.Equal(t, time.Duration(math.MaxInt64), KeepAlive) - t.Setenv("OLLAMA_KEEP_ALIVE", "-1") - LoadConfig() - require.Equal(t, time.Duration(math.MaxInt64), KeepAlive) } func TestHost(t *testing.T) { @@ -186,3 +170,36 @@ func TestBool(t *testing.T) { }) } } + +func TestKeepAlive(t *testing.T) { + cases := map[string]time.Duration{ + "": 5 * time.Minute, + "1s": time.Second, + "1m": time.Minute, + "1h": time.Hour, + "5m0s": 5 * time.Minute, + "1h2m3s": 1*time.Hour + 2*time.Minute + 3*time.Second, + "0": time.Duration(0), + "60": 60 * time.Second, + "120": 2 * time.Minute, + "3600": time.Hour, + "-0": time.Duration(0), + "-1": time.Duration(math.MaxInt64), + "-1m": time.Duration(math.MaxInt64), + // invalid values + " ": 5 * time.Minute, + "???": 5 * time.Minute, + "1d": 5 * time.Minute, + "1y": 5 * time.Minute, + "1w": 5 * time.Minute, + } + + for tt, expect := range cases { + t.Run(tt, func(t *testing.T) { + t.Setenv("OLLAMA_KEEP_ALIVE", tt) + if actual := KeepAlive(); actual != expect { + t.Errorf("%s: expected %s, got %s", tt, expect, actual) + } + }) + } +} diff --git a/server/sched.go b/server/sched.go index e1e986a5..ad40c4ef 100644 --- a/server/sched.go +++ b/server/sched.go @@ -401,7 +401,7 @@ func (s *Scheduler) load(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList, if numParallel < 1 { numParallel = 1 } - sessionDuration := envconfig.KeepAlive + sessionDuration := envconfig.KeepAlive() if req.sessionDuration != nil { sessionDuration = req.sessionDuration.Duration }