keepalive

This commit is contained in:
Michael Yang 2024-07-03 18:39:35 -07:00
parent 55cd3ddcca
commit 8570c1c0ef
3 changed files with 55 additions and 47 deletions

View file

@ -99,6 +99,26 @@ func Models() string {
return filepath.Join(home, ".ollama", "models") return filepath.Join(home, ".ollama", "models")
} }
// KeepAlive returns the duration that models stay loaded in memory. KeepAlive can be configured via the OLLAMA_KEEP_ALIVE environment variable.
// Negative values are treated as infinite. Zero is treated as no keep alive.
// Default is 5 minutes.
func KeepAlive() (keepAlive time.Duration) {
keepAlive = 5 * time.Minute
if s := os.Getenv("OLLAMA_KEEP_ALIVE"); s != "" {
if d, err := time.ParseDuration(s); err == nil {
keepAlive = d
} else if n, err := strconv.ParseInt(s, 10, 64); err == nil {
keepAlive = time.Duration(n) * time.Second
}
}
if keepAlive < 0 {
return time.Duration(math.MaxInt64)
}
return keepAlive
}
func Bool(k string) func() bool { func Bool(k string) func() bool {
return func() bool { return func() bool {
if s := getenv(k); s != "" { if s := getenv(k); s != "" {
@ -130,8 +150,6 @@ var (
) )
var ( var (
// Set via OLLAMA_KEEP_ALIVE in the environment
KeepAlive time.Duration
// Set via OLLAMA_LLM_LIBRARY in the environment // Set via OLLAMA_LLM_LIBRARY in the environment
LLMLibrary string LLMLibrary string
// Set via OLLAMA_MAX_LOADED_MODELS in the environment // Set via OLLAMA_MAX_LOADED_MODELS in the environment
@ -168,7 +186,7 @@ func AsMap() map[string]EnvVar {
"OLLAMA_DEBUG": {"OLLAMA_DEBUG", Debug(), "Show additional debug information (e.g. OLLAMA_DEBUG=1)"}, "OLLAMA_DEBUG": {"OLLAMA_DEBUG", Debug(), "Show additional debug information (e.g. OLLAMA_DEBUG=1)"},
"OLLAMA_FLASH_ATTENTION": {"OLLAMA_FLASH_ATTENTION", FlashAttention(), "Enabled flash attention"}, "OLLAMA_FLASH_ATTENTION": {"OLLAMA_FLASH_ATTENTION", FlashAttention(), "Enabled flash attention"},
"OLLAMA_HOST": {"OLLAMA_HOST", Host(), "IP Address for the ollama server (default 127.0.0.1:11434)"}, "OLLAMA_HOST": {"OLLAMA_HOST", Host(), "IP Address for the ollama server (default 127.0.0.1:11434)"},
"OLLAMA_KEEP_ALIVE": {"OLLAMA_KEEP_ALIVE", KeepAlive, "The duration that models stay loaded in memory (default \"5m\")"}, "OLLAMA_KEEP_ALIVE": {"OLLAMA_KEEP_ALIVE", KeepAlive(), "The duration that models stay loaded in memory (default \"5m\")"},
"OLLAMA_LLM_LIBRARY": {"OLLAMA_LLM_LIBRARY", LLMLibrary, "Set LLM library to bypass autodetection"}, "OLLAMA_LLM_LIBRARY": {"OLLAMA_LLM_LIBRARY", LLMLibrary, "Set LLM library to bypass autodetection"},
"OLLAMA_MAX_LOADED_MODELS": {"OLLAMA_MAX_LOADED_MODELS", MaxRunners, "Maximum number of loaded models per GPU"}, "OLLAMA_MAX_LOADED_MODELS": {"OLLAMA_MAX_LOADED_MODELS", MaxRunners, "Maximum number of loaded models per GPU"},
"OLLAMA_MAX_QUEUE": {"OLLAMA_MAX_QUEUE", MaxQueuedRequests, "Maximum number of queued requests"}, "OLLAMA_MAX_QUEUE": {"OLLAMA_MAX_QUEUE", MaxQueuedRequests, "Maximum number of queued requests"},
@ -210,7 +228,6 @@ func init() {
NumParallel = 0 // Autoselect NumParallel = 0 // Autoselect
MaxRunners = 0 // Autoselect MaxRunners = 0 // Autoselect
MaxQueuedRequests = 512 MaxQueuedRequests = 512
KeepAlive = 5 * time.Minute
LoadConfig() LoadConfig()
} }
@ -284,35 +301,9 @@ func LoadConfig() {
} }
} }
ka := getenv("OLLAMA_KEEP_ALIVE")
if ka != "" {
loadKeepAlive(ka)
}
CudaVisibleDevices = getenv("CUDA_VISIBLE_DEVICES") CudaVisibleDevices = getenv("CUDA_VISIBLE_DEVICES")
HipVisibleDevices = getenv("HIP_VISIBLE_DEVICES") HipVisibleDevices = getenv("HIP_VISIBLE_DEVICES")
RocrVisibleDevices = getenv("ROCR_VISIBLE_DEVICES") RocrVisibleDevices = getenv("ROCR_VISIBLE_DEVICES")
GpuDeviceOrdinal = getenv("GPU_DEVICE_ORDINAL") GpuDeviceOrdinal = getenv("GPU_DEVICE_ORDINAL")
HsaOverrideGfxVersion = getenv("HSA_OVERRIDE_GFX_VERSION") HsaOverrideGfxVersion = getenv("HSA_OVERRIDE_GFX_VERSION")
} }
func loadKeepAlive(ka string) {
v, err := strconv.Atoi(ka)
if err != nil {
d, err := time.ParseDuration(ka)
if err == nil {
if d < 0 {
KeepAlive = time.Duration(math.MaxInt64)
} else {
KeepAlive = d
}
}
} else {
d := time.Duration(v) * time.Second
if d < 0 {
KeepAlive = time.Duration(math.MaxInt64)
} else {
KeepAlive = d
}
}
}

View file

@ -21,22 +21,6 @@ func TestSmoke(t *testing.T) {
t.Setenv("OLLAMA_FLASH_ATTENTION", "1") t.Setenv("OLLAMA_FLASH_ATTENTION", "1")
require.True(t, FlashAttention()) require.True(t, FlashAttention())
t.Setenv("OLLAMA_KEEP_ALIVE", "")
LoadConfig()
require.Equal(t, 5*time.Minute, KeepAlive)
t.Setenv("OLLAMA_KEEP_ALIVE", "3")
LoadConfig()
require.Equal(t, 3*time.Second, KeepAlive)
t.Setenv("OLLAMA_KEEP_ALIVE", "1h")
LoadConfig()
require.Equal(t, 1*time.Hour, KeepAlive)
t.Setenv("OLLAMA_KEEP_ALIVE", "-1s")
LoadConfig()
require.Equal(t, time.Duration(math.MaxInt64), KeepAlive)
t.Setenv("OLLAMA_KEEP_ALIVE", "-1")
LoadConfig()
require.Equal(t, time.Duration(math.MaxInt64), KeepAlive)
} }
func TestHost(t *testing.T) { func TestHost(t *testing.T) {
@ -186,3 +170,36 @@ func TestBool(t *testing.T) {
}) })
} }
} }
func TestKeepAlive(t *testing.T) {
cases := map[string]time.Duration{
"": 5 * time.Minute,
"1s": time.Second,
"1m": time.Minute,
"1h": time.Hour,
"5m0s": 5 * time.Minute,
"1h2m3s": 1*time.Hour + 2*time.Minute + 3*time.Second,
"0": time.Duration(0),
"60": 60 * time.Second,
"120": 2 * time.Minute,
"3600": time.Hour,
"-0": time.Duration(0),
"-1": time.Duration(math.MaxInt64),
"-1m": time.Duration(math.MaxInt64),
// invalid values
" ": 5 * time.Minute,
"???": 5 * time.Minute,
"1d": 5 * time.Minute,
"1y": 5 * time.Minute,
"1w": 5 * time.Minute,
}
for tt, expect := range cases {
t.Run(tt, func(t *testing.T) {
t.Setenv("OLLAMA_KEEP_ALIVE", tt)
if actual := KeepAlive(); actual != expect {
t.Errorf("%s: expected %s, got %s", tt, expect, actual)
}
})
}
}

View file

@ -401,7 +401,7 @@ func (s *Scheduler) load(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList,
if numParallel < 1 { if numParallel < 1 {
numParallel = 1 numParallel = 1
} }
sessionDuration := envconfig.KeepAlive sessionDuration := envconfig.KeepAlive()
if req.sessionDuration != nil { if req.sessionDuration != nil {
sessionDuration = req.sessionDuration.Duration sessionDuration = req.sessionDuration.Duration
} }