diff --git a/envconfig/config.go b/envconfig/config.go index 01abea42..b82b773d 100644 --- a/envconfig/config.go +++ b/envconfig/config.go @@ -1,7 +1,6 @@ package envconfig import ( - "errors" "fmt" "log/slog" "math" @@ -15,15 +14,12 @@ import ( "time" ) -var ErrInvalidHostPort = errors.New("invalid port specified in OLLAMA_HOST") - // Host returns the scheme and host. Host can be configured via the OLLAMA_HOST environment variable. // Default is scheme "http" and host "127.0.0.1:11434" func Host() *url.URL { defaultPort := "11434" - s := os.Getenv("OLLAMA_HOST") - s = strings.TrimSpace(strings.Trim(strings.TrimSpace(s), "\"'")) + s := strings.TrimSpace(Var("OLLAMA_HOST")) scheme, hostport, ok := strings.Cut(s, "://") switch { case !ok: @@ -48,6 +44,7 @@ func Host() *url.URL { } if n, err := strconv.ParseInt(port, 10, 32); err != nil || n > 65535 || n < 0 { + slog.Warn("invalid port, using default", "port", port, "default", defaultPort) return &url.URL{ Scheme: scheme, Host: net.JoinHostPort(host, defaultPort), @@ -62,7 +59,7 @@ func Host() *url.URL { // Origins returns a list of allowed origins. Origins can be configured via the OLLAMA_ORIGINS environment variable. func Origins() (origins []string) { - if s := getenv("OLLAMA_ORIGINS"); s != "" { + if s := Var("OLLAMA_ORIGINS"); s != "" { origins = strings.Split(s, ",") } @@ -87,7 +84,7 @@ func Origins() (origins []string) { // Models returns the path to the models directory. Models directory can be configured via the OLLAMA_MODELS environment variable. // Default is $HOME/.ollama/models func Models() string { - if s, ok := os.LookupEnv("OLLAMA_MODELS"); ok { + if s := Var("OLLAMA_MODELS"); s != "" { return s } @@ -104,7 +101,7 @@ func Models() string { // Default is 5 minutes. func KeepAlive() (keepAlive time.Duration) { keepAlive = 5 * time.Minute - if s := os.Getenv("OLLAMA_KEEP_ALIVE"); s != "" { + if s := Var("OLLAMA_KEEP_ALIVE"); s != "" { if d, err := time.ParseDuration(s); err == nil { keepAlive = d } else if n, err := strconv.ParseInt(s, 10, 64); err == nil { @@ -121,7 +118,7 @@ func KeepAlive() (keepAlive time.Duration) { func Bool(k string) func() bool { return func() bool { - if s := getenv(k); s != "" { + if s := Var(k); s != "" { b, err := strconv.ParseBool(s) if err != nil { return true @@ -151,7 +148,7 @@ var ( func String(s string) func() string { return func() string { - return getenv(s) + return Var(s) } } @@ -167,7 +164,7 @@ var ( ) func RunnersDir() (p string) { - if p := getenv("OLLAMA_RUNNERS_DIR"); p != "" { + if p := Var("OLLAMA_RUNNERS_DIR"); p != "" { return p } @@ -213,22 +210,29 @@ func RunnersDir() (p string) { return p } -func Int(k string, n int) func() int { - return func() int { - if s := getenv(k); s != "" { - if n, err := strconv.ParseInt(s, 10, 64); err == nil && n >= 0 { - return int(n) +func Uint(key string, defaultValue uint) func() uint { + return func() uint { + if s := Var(key); s != "" { + if n, err := strconv.ParseUint(s, 10, 64); err != nil { + slog.Warn("invalid environment variable, using default", "key", key, "value", s, "default", defaultValue) + } else { + return uint(n) } } - return n + return defaultValue } } var ( - NumParallel = Int("OLLAMA_NUM_PARALLEL", 0) - MaxRunners = Int("OLLAMA_MAX_LOADED_MODELS", 0) - MaxQueue = Int("OLLAMA_MAX_QUEUE", 512) + // NumParallel sets the number of parallel model requests. NumParallel can be configured via the OLLAMA_NUM_PARALLEL environment variable. + NumParallel = Uint("OLLAMA_NUM_PARALLEL", 0) + // MaxRunners sets the maximum number of loaded models. MaxRunners can be configured via the OLLAMA_MAX_LOADED_MODELS environment variable. + MaxRunners = Uint("OLLAMA_MAX_LOADED_MODELS", 0) + // MaxQueue sets the maximum number of queued requests. MaxQueue can be configured via the OLLAMA_MAX_QUEUE environment variable. + MaxQueue = Uint("OLLAMA_MAX_QUEUE", 512) + // MaxVRAM sets a maximum VRAM override in bytes. MaxVRAM can be configured via the OLLAMA_MAX_VRAM environment variable. + MaxVRAM = Uint("OLLAMA_MAX_VRAM", 0) ) type EnvVar struct { @@ -274,7 +278,7 @@ func Values() map[string]string { return vals } -// getenv returns an environment variable stripped of leading and trailing quotes or spaces -func getenv(key string) string { - return strings.Trim(os.Getenv(key), "\"' ") +// Var returns an environment variable stripped of leading and trailing quotes or spaces +func Var(key string) string { + return strings.Trim(strings.TrimSpace(os.Getenv(key)), "\"'") } diff --git a/envconfig/config_test.go b/envconfig/config_test.go index 977298aa..92a500f1 100644 --- a/envconfig/config_test.go +++ b/envconfig/config_test.go @@ -30,6 +30,10 @@ func TestHost(t *testing.T) { "extra quotes": {"\"1.2.3.4\"", "1.2.3.4:11434"}, "extra space+quotes": {" \" 1.2.3.4 \" ", "1.2.3.4:11434"}, "extra single quotes": {"'1.2.3.4'", "1.2.3.4:11434"}, + "http": {"http://1.2.3.4", "1.2.3.4:80"}, + "http port": {"http://1.2.3.4:4321", "1.2.3.4:4321"}, + "https": {"https://1.2.3.4", "1.2.3.4:443"}, + "https port": {"https://1.2.3.4:4321", "1.2.3.4:4321"}, } for name, tt := range cases { @@ -133,24 +137,45 @@ func TestOrigins(t *testing.T) { } func TestBool(t *testing.T) { - cases := map[string]struct { - value string - expect bool - }{ - "empty": {"", false}, - "true": {"true", true}, - "false": {"false", false}, - "1": {"1", true}, - "0": {"0", false}, - "random": {"random", true}, - "something": {"something", true}, + cases := map[string]bool{ + "": false, + "true": true, + "false": false, + "1": true, + "0": false, + // invalid values + "random": true, + "something": true, } - for name, tt := range cases { - t.Run(name, func(t *testing.T) { - t.Setenv("OLLAMA_BOOL", tt.value) - if b := Bool("OLLAMA_BOOL"); b() != tt.expect { - t.Errorf("%s: expected %t, got %t", name, tt.expect, b()) + for k, v := range cases { + t.Run(k, func(t *testing.T) { + t.Setenv("OLLAMA_BOOL", k) + if b := Bool("OLLAMA_BOOL")(); b != v { + t.Errorf("%s: expected %t, got %t", k, v, b) + } + }) + } +} + +func TestUint(t *testing.T) { + cases := map[string]uint{ + "0": 0, + "1": 1, + "1337": 1337, + // default values + "": 11434, + "-1": 11434, + "0o10": 11434, + "0x10": 11434, + "string": 11434, + } + + for k, v := range cases { + t.Run(k, func(t *testing.T) { + t.Setenv("OLLAMA_UINT", k) + if i := Uint("OLLAMA_UINT", 11434)(); i != v { + t.Errorf("%s: expected %d, got %d", k, v, i) } }) } @@ -188,3 +213,23 @@ func TestKeepAlive(t *testing.T) { }) } } + +func TestVar(t *testing.T) { + cases := map[string]string{ + "value": "value", + " value ": "value", + " 'value' ": "value", + ` "value" `: "value", + " ' value ' ": " value ", + ` " value " `: " value ", + } + + for k, v := range cases { + t.Run(k, func(t *testing.T) { + t.Setenv("OLLAMA_VAR", k) + if s := Var("OLLAMA_VAR"); s != v { + t.Errorf("%s: expected %q, got %q", k, v, s) + } + }) + } +} diff --git a/server/sched.go b/server/sched.go index 610a2c50..ce2945d8 100644 --- a/server/sched.go +++ b/server/sched.go @@ -129,7 +129,7 @@ func (s *Scheduler) processPending(ctx context.Context) { slog.Debug("pending request cancelled or timed out, skipping scheduling") continue } - numParallel := envconfig.NumParallel() + numParallel := int(envconfig.NumParallel()) // TODO (jmorganca): multimodal models don't support parallel yet // see https://github.com/ollama/ollama/issues/4165 if len(pending.model.ProjectorPaths) > 0 && numParallel != 1 { @@ -151,7 +151,7 @@ func (s *Scheduler) processPending(ctx context.Context) { pending.useLoadedRunner(runner, s.finishedReqCh) break } - } else if envconfig.MaxRunners() > 0 && loadedCount >= envconfig.MaxRunners() { + } else if envconfig.MaxRunners() > 0 && loadedCount >= int(envconfig.MaxRunners()) { slog.Debug("max runners achieved, unloading one to make room", "runner_count", loadedCount) runnerToExpire = s.findRunnerToUnload() } else {