This commit is contained in:
Michael Yang 2024-07-08 10:34:12 -07:00
parent 78140a712c
commit 85d9d73a72
3 changed files with 90 additions and 41 deletions

View file

@ -1,7 +1,6 @@
package envconfig package envconfig
import ( import (
"errors"
"fmt" "fmt"
"log/slog" "log/slog"
"math" "math"
@ -15,15 +14,12 @@ import (
"time" "time"
) )
var ErrInvalidHostPort = errors.New("invalid port specified in OLLAMA_HOST")
// Host returns the scheme and host. Host can be configured via the OLLAMA_HOST environment variable. // Host returns the scheme and host. Host can be configured via the OLLAMA_HOST environment variable.
// Default is scheme "http" and host "127.0.0.1:11434" // Default is scheme "http" and host "127.0.0.1:11434"
func Host() *url.URL { func Host() *url.URL {
defaultPort := "11434" defaultPort := "11434"
s := os.Getenv("OLLAMA_HOST") s := strings.TrimSpace(Var("OLLAMA_HOST"))
s = strings.TrimSpace(strings.Trim(strings.TrimSpace(s), "\"'"))
scheme, hostport, ok := strings.Cut(s, "://") scheme, hostport, ok := strings.Cut(s, "://")
switch { switch {
case !ok: case !ok:
@ -48,6 +44,7 @@ func Host() *url.URL {
} }
if n, err := strconv.ParseInt(port, 10, 32); err != nil || n > 65535 || n < 0 { if n, err := strconv.ParseInt(port, 10, 32); err != nil || n > 65535 || n < 0 {
slog.Warn("invalid port, using default", "port", port, "default", defaultPort)
return &url.URL{ return &url.URL{
Scheme: scheme, Scheme: scheme,
Host: net.JoinHostPort(host, defaultPort), Host: net.JoinHostPort(host, defaultPort),
@ -62,7 +59,7 @@ func Host() *url.URL {
// Origins returns a list of allowed origins. Origins can be configured via the OLLAMA_ORIGINS environment variable. // Origins returns a list of allowed origins. Origins can be configured via the OLLAMA_ORIGINS environment variable.
func Origins() (origins []string) { func Origins() (origins []string) {
if s := getenv("OLLAMA_ORIGINS"); s != "" { if s := Var("OLLAMA_ORIGINS"); s != "" {
origins = strings.Split(s, ",") origins = strings.Split(s, ",")
} }
@ -87,7 +84,7 @@ func Origins() (origins []string) {
// Models returns the path to the models directory. Models directory can be configured via the OLLAMA_MODELS environment variable. // Models returns the path to the models directory. Models directory can be configured via the OLLAMA_MODELS environment variable.
// Default is $HOME/.ollama/models // Default is $HOME/.ollama/models
func Models() string { func Models() string {
if s, ok := os.LookupEnv("OLLAMA_MODELS"); ok { if s := Var("OLLAMA_MODELS"); s != "" {
return s return s
} }
@ -104,7 +101,7 @@ func Models() string {
// Default is 5 minutes. // Default is 5 minutes.
func KeepAlive() (keepAlive time.Duration) { func KeepAlive() (keepAlive time.Duration) {
keepAlive = 5 * time.Minute keepAlive = 5 * time.Minute
if s := os.Getenv("OLLAMA_KEEP_ALIVE"); s != "" { if s := Var("OLLAMA_KEEP_ALIVE"); s != "" {
if d, err := time.ParseDuration(s); err == nil { if d, err := time.ParseDuration(s); err == nil {
keepAlive = d keepAlive = d
} else if n, err := strconv.ParseInt(s, 10, 64); err == nil { } else if n, err := strconv.ParseInt(s, 10, 64); err == nil {
@ -121,7 +118,7 @@ func KeepAlive() (keepAlive time.Duration) {
func Bool(k string) func() bool { func Bool(k string) func() bool {
return func() bool { return func() bool {
if s := getenv(k); s != "" { if s := Var(k); s != "" {
b, err := strconv.ParseBool(s) b, err := strconv.ParseBool(s)
if err != nil { if err != nil {
return true return true
@ -151,7 +148,7 @@ var (
func String(s string) func() string { func String(s string) func() string {
return func() string { return func() string {
return getenv(s) return Var(s)
} }
} }
@ -167,7 +164,7 @@ var (
) )
func RunnersDir() (p string) { func RunnersDir() (p string) {
if p := getenv("OLLAMA_RUNNERS_DIR"); p != "" { if p := Var("OLLAMA_RUNNERS_DIR"); p != "" {
return p return p
} }
@ -213,22 +210,29 @@ func RunnersDir() (p string) {
return p return p
} }
func Int(k string, n int) func() int { func Uint(key string, defaultValue uint) func() uint {
return func() int { return func() uint {
if s := getenv(k); s != "" { if s := Var(key); s != "" {
if n, err := strconv.ParseInt(s, 10, 64); err == nil && n >= 0 { if n, err := strconv.ParseUint(s, 10, 64); err != nil {
return int(n) slog.Warn("invalid environment variable, using default", "key", key, "value", s, "default", defaultValue)
} else {
return uint(n)
} }
} }
return n return defaultValue
} }
} }
var ( var (
NumParallel = Int("OLLAMA_NUM_PARALLEL", 0) // NumParallel sets the number of parallel model requests. NumParallel can be configured via the OLLAMA_NUM_PARALLEL environment variable.
MaxRunners = Int("OLLAMA_MAX_LOADED_MODELS", 0) NumParallel = Uint("OLLAMA_NUM_PARALLEL", 0)
MaxQueue = Int("OLLAMA_MAX_QUEUE", 512) // MaxRunners sets the maximum number of loaded models. MaxRunners can be configured via the OLLAMA_MAX_LOADED_MODELS environment variable.
MaxRunners = Uint("OLLAMA_MAX_LOADED_MODELS", 0)
// MaxQueue sets the maximum number of queued requests. MaxQueue can be configured via the OLLAMA_MAX_QUEUE environment variable.
MaxQueue = Uint("OLLAMA_MAX_QUEUE", 512)
// MaxVRAM sets a maximum VRAM override in bytes. MaxVRAM can be configured via the OLLAMA_MAX_VRAM environment variable.
MaxVRAM = Uint("OLLAMA_MAX_VRAM", 0)
) )
type EnvVar struct { type EnvVar struct {
@ -274,7 +278,7 @@ func Values() map[string]string {
return vals return vals
} }
// getenv returns an environment variable stripped of leading and trailing quotes or spaces // Var returns an environment variable stripped of leading and trailing quotes or spaces
func getenv(key string) string { func Var(key string) string {
return strings.Trim(os.Getenv(key), "\"' ") return strings.Trim(strings.TrimSpace(os.Getenv(key)), "\"'")
} }

View file

@ -30,6 +30,10 @@ func TestHost(t *testing.T) {
"extra quotes": {"\"1.2.3.4\"", "1.2.3.4:11434"}, "extra quotes": {"\"1.2.3.4\"", "1.2.3.4:11434"},
"extra space+quotes": {" \" 1.2.3.4 \" ", "1.2.3.4:11434"}, "extra space+quotes": {" \" 1.2.3.4 \" ", "1.2.3.4:11434"},
"extra single quotes": {"'1.2.3.4'", "1.2.3.4:11434"}, "extra single quotes": {"'1.2.3.4'", "1.2.3.4:11434"},
"http": {"http://1.2.3.4", "1.2.3.4:80"},
"http port": {"http://1.2.3.4:4321", "1.2.3.4:4321"},
"https": {"https://1.2.3.4", "1.2.3.4:443"},
"https port": {"https://1.2.3.4:4321", "1.2.3.4:4321"},
} }
for name, tt := range cases { for name, tt := range cases {
@ -133,24 +137,45 @@ func TestOrigins(t *testing.T) {
} }
func TestBool(t *testing.T) { func TestBool(t *testing.T) {
cases := map[string]struct { cases := map[string]bool{
value string "": false,
expect bool "true": true,
}{ "false": false,
"empty": {"", false}, "1": true,
"true": {"true", true}, "0": false,
"false": {"false", false}, // invalid values
"1": {"1", true}, "random": true,
"0": {"0", false}, "something": true,
"random": {"random", true},
"something": {"something", true},
} }
for name, tt := range cases { for k, v := range cases {
t.Run(name, func(t *testing.T) { t.Run(k, func(t *testing.T) {
t.Setenv("OLLAMA_BOOL", tt.value) t.Setenv("OLLAMA_BOOL", k)
if b := Bool("OLLAMA_BOOL"); b() != tt.expect { if b := Bool("OLLAMA_BOOL")(); b != v {
t.Errorf("%s: expected %t, got %t", name, tt.expect, b()) t.Errorf("%s: expected %t, got %t", k, v, b)
}
})
}
}
func TestUint(t *testing.T) {
cases := map[string]uint{
"0": 0,
"1": 1,
"1337": 1337,
// default values
"": 11434,
"-1": 11434,
"0o10": 11434,
"0x10": 11434,
"string": 11434,
}
for k, v := range cases {
t.Run(k, func(t *testing.T) {
t.Setenv("OLLAMA_UINT", k)
if i := Uint("OLLAMA_UINT", 11434)(); i != v {
t.Errorf("%s: expected %d, got %d", k, v, i)
} }
}) })
} }
@ -188,3 +213,23 @@ func TestKeepAlive(t *testing.T) {
}) })
} }
} }
func TestVar(t *testing.T) {
cases := map[string]string{
"value": "value",
" value ": "value",
" 'value' ": "value",
` "value" `: "value",
" ' value ' ": " value ",
` " value " `: " value ",
}
for k, v := range cases {
t.Run(k, func(t *testing.T) {
t.Setenv("OLLAMA_VAR", k)
if s := Var("OLLAMA_VAR"); s != v {
t.Errorf("%s: expected %q, got %q", k, v, s)
}
})
}
}

View file

@ -129,7 +129,7 @@ func (s *Scheduler) processPending(ctx context.Context) {
slog.Debug("pending request cancelled or timed out, skipping scheduling") slog.Debug("pending request cancelled or timed out, skipping scheduling")
continue continue
} }
numParallel := envconfig.NumParallel() numParallel := int(envconfig.NumParallel())
// TODO (jmorganca): multimodal models don't support parallel yet // TODO (jmorganca): multimodal models don't support parallel yet
// see https://github.com/ollama/ollama/issues/4165 // see https://github.com/ollama/ollama/issues/4165
if len(pending.model.ProjectorPaths) > 0 && numParallel != 1 { if len(pending.model.ProjectorPaths) > 0 && numParallel != 1 {
@ -151,7 +151,7 @@ func (s *Scheduler) processPending(ctx context.Context) {
pending.useLoadedRunner(runner, s.finishedReqCh) pending.useLoadedRunner(runner, s.finishedReqCh)
break break
} }
} else if envconfig.MaxRunners() > 0 && loadedCount >= envconfig.MaxRunners() { } else if envconfig.MaxRunners() > 0 && loadedCount >= int(envconfig.MaxRunners()) {
slog.Debug("max runners achieved, unloading one to make room", "runner_count", loadedCount) slog.Debug("max runners achieved, unloading one to make room", "runner_count", loadedCount)
runnerToExpire = s.findRunnerToUnload() runnerToExpire = s.findRunnerToUnload()
} else { } else {