This commit is contained in:
Michael Yang 2024-07-08 10:34:12 -07:00
parent 78140a712c
commit 85d9d73a72
3 changed files with 90 additions and 41 deletions

View file

@ -1,7 +1,6 @@
package envconfig
import (
"errors"
"fmt"
"log/slog"
"math"
@ -15,15 +14,12 @@ import (
"time"
)
var ErrInvalidHostPort = errors.New("invalid port specified in OLLAMA_HOST")
// Host returns the scheme and host. Host can be configured via the OLLAMA_HOST environment variable.
// Default is scheme "http" and host "127.0.0.1:11434"
func Host() *url.URL {
defaultPort := "11434"
s := os.Getenv("OLLAMA_HOST")
s = strings.TrimSpace(strings.Trim(strings.TrimSpace(s), "\"'"))
s := strings.TrimSpace(Var("OLLAMA_HOST"))
scheme, hostport, ok := strings.Cut(s, "://")
switch {
case !ok:
@ -48,6 +44,7 @@ func Host() *url.URL {
}
if n, err := strconv.ParseInt(port, 10, 32); err != nil || n > 65535 || n < 0 {
slog.Warn("invalid port, using default", "port", port, "default", defaultPort)
return &url.URL{
Scheme: scheme,
Host: net.JoinHostPort(host, defaultPort),
@ -62,7 +59,7 @@ func Host() *url.URL {
// Origins returns a list of allowed origins. Origins can be configured via the OLLAMA_ORIGINS environment variable.
func Origins() (origins []string) {
if s := getenv("OLLAMA_ORIGINS"); s != "" {
if s := Var("OLLAMA_ORIGINS"); s != "" {
origins = strings.Split(s, ",")
}
@ -87,7 +84,7 @@ func Origins() (origins []string) {
// Models returns the path to the models directory. Models directory can be configured via the OLLAMA_MODELS environment variable.
// Default is $HOME/.ollama/models
func Models() string {
if s, ok := os.LookupEnv("OLLAMA_MODELS"); ok {
if s := Var("OLLAMA_MODELS"); s != "" {
return s
}
@ -104,7 +101,7 @@ func Models() string {
// Default is 5 minutes.
func KeepAlive() (keepAlive time.Duration) {
keepAlive = 5 * time.Minute
if s := os.Getenv("OLLAMA_KEEP_ALIVE"); s != "" {
if s := Var("OLLAMA_KEEP_ALIVE"); s != "" {
if d, err := time.ParseDuration(s); err == nil {
keepAlive = d
} else if n, err := strconv.ParseInt(s, 10, 64); err == nil {
@ -121,7 +118,7 @@ func KeepAlive() (keepAlive time.Duration) {
func Bool(k string) func() bool {
return func() bool {
if s := getenv(k); s != "" {
if s := Var(k); s != "" {
b, err := strconv.ParseBool(s)
if err != nil {
return true
@ -151,7 +148,7 @@ var (
func String(s string) func() string {
return func() string {
return getenv(s)
return Var(s)
}
}
@ -167,7 +164,7 @@ var (
)
func RunnersDir() (p string) {
if p := getenv("OLLAMA_RUNNERS_DIR"); p != "" {
if p := Var("OLLAMA_RUNNERS_DIR"); p != "" {
return p
}
@ -213,22 +210,29 @@ func RunnersDir() (p string) {
return p
}
func Int(k string, n int) func() int {
return func() int {
if s := getenv(k); s != "" {
if n, err := strconv.ParseInt(s, 10, 64); err == nil && n >= 0 {
return int(n)
func Uint(key string, defaultValue uint) func() uint {
return func() uint {
if s := Var(key); s != "" {
if n, err := strconv.ParseUint(s, 10, 64); err != nil {
slog.Warn("invalid environment variable, using default", "key", key, "value", s, "default", defaultValue)
} else {
return uint(n)
}
}
return n
return defaultValue
}
}
var (
NumParallel = Int("OLLAMA_NUM_PARALLEL", 0)
MaxRunners = Int("OLLAMA_MAX_LOADED_MODELS", 0)
MaxQueue = Int("OLLAMA_MAX_QUEUE", 512)
// NumParallel sets the number of parallel model requests. NumParallel can be configured via the OLLAMA_NUM_PARALLEL environment variable.
NumParallel = Uint("OLLAMA_NUM_PARALLEL", 0)
// MaxRunners sets the maximum number of loaded models. MaxRunners can be configured via the OLLAMA_MAX_LOADED_MODELS environment variable.
MaxRunners = Uint("OLLAMA_MAX_LOADED_MODELS", 0)
// MaxQueue sets the maximum number of queued requests. MaxQueue can be configured via the OLLAMA_MAX_QUEUE environment variable.
MaxQueue = Uint("OLLAMA_MAX_QUEUE", 512)
// MaxVRAM sets a maximum VRAM override in bytes. MaxVRAM can be configured via the OLLAMA_MAX_VRAM environment variable.
MaxVRAM = Uint("OLLAMA_MAX_VRAM", 0)
)
type EnvVar struct {
@ -274,7 +278,7 @@ func Values() map[string]string {
return vals
}
// getenv returns an environment variable stripped of leading and trailing quotes or spaces
func getenv(key string) string {
return strings.Trim(os.Getenv(key), "\"' ")
// Var returns an environment variable stripped of leading and trailing quotes or spaces
func Var(key string) string {
return strings.Trim(strings.TrimSpace(os.Getenv(key)), "\"'")
}

View file

@ -30,6 +30,10 @@ func TestHost(t *testing.T) {
"extra quotes": {"\"1.2.3.4\"", "1.2.3.4:11434"},
"extra space+quotes": {" \" 1.2.3.4 \" ", "1.2.3.4:11434"},
"extra single quotes": {"'1.2.3.4'", "1.2.3.4:11434"},
"http": {"http://1.2.3.4", "1.2.3.4:80"},
"http port": {"http://1.2.3.4:4321", "1.2.3.4:4321"},
"https": {"https://1.2.3.4", "1.2.3.4:443"},
"https port": {"https://1.2.3.4:4321", "1.2.3.4:4321"},
}
for name, tt := range cases {
@ -133,24 +137,45 @@ func TestOrigins(t *testing.T) {
}
func TestBool(t *testing.T) {
cases := map[string]struct {
value string
expect bool
}{
"empty": {"", false},
"true": {"true", true},
"false": {"false", false},
"1": {"1", true},
"0": {"0", false},
"random": {"random", true},
"something": {"something", true},
cases := map[string]bool{
"": false,
"true": true,
"false": false,
"1": true,
"0": false,
// invalid values
"random": true,
"something": true,
}
for name, tt := range cases {
t.Run(name, func(t *testing.T) {
t.Setenv("OLLAMA_BOOL", tt.value)
if b := Bool("OLLAMA_BOOL"); b() != tt.expect {
t.Errorf("%s: expected %t, got %t", name, tt.expect, b())
for k, v := range cases {
t.Run(k, func(t *testing.T) {
t.Setenv("OLLAMA_BOOL", k)
if b := Bool("OLLAMA_BOOL")(); b != v {
t.Errorf("%s: expected %t, got %t", k, v, b)
}
})
}
}
func TestUint(t *testing.T) {
cases := map[string]uint{
"0": 0,
"1": 1,
"1337": 1337,
// default values
"": 11434,
"-1": 11434,
"0o10": 11434,
"0x10": 11434,
"string": 11434,
}
for k, v := range cases {
t.Run(k, func(t *testing.T) {
t.Setenv("OLLAMA_UINT", k)
if i := Uint("OLLAMA_UINT", 11434)(); i != v {
t.Errorf("%s: expected %d, got %d", k, v, i)
}
})
}
@ -188,3 +213,23 @@ func TestKeepAlive(t *testing.T) {
})
}
}
func TestVar(t *testing.T) {
cases := map[string]string{
"value": "value",
" value ": "value",
" 'value' ": "value",
` "value" `: "value",
" ' value ' ": " value ",
` " value " `: " value ",
}
for k, v := range cases {
t.Run(k, func(t *testing.T) {
t.Setenv("OLLAMA_VAR", k)
if s := Var("OLLAMA_VAR"); s != v {
t.Errorf("%s: expected %q, got %q", k, v, s)
}
})
}
}

View file

@ -129,7 +129,7 @@ func (s *Scheduler) processPending(ctx context.Context) {
slog.Debug("pending request cancelled or timed out, skipping scheduling")
continue
}
numParallel := envconfig.NumParallel()
numParallel := int(envconfig.NumParallel())
// TODO (jmorganca): multimodal models don't support parallel yet
// see https://github.com/ollama/ollama/issues/4165
if len(pending.model.ProjectorPaths) > 0 && numParallel != 1 {
@ -151,7 +151,7 @@ func (s *Scheduler) processPending(ctx context.Context) {
pending.useLoadedRunner(runner, s.finishedReqCh)
break
}
} else if envconfig.MaxRunners() > 0 && loadedCount >= envconfig.MaxRunners() {
} else if envconfig.MaxRunners() > 0 && loadedCount >= int(envconfig.MaxRunners()) {
slog.Debug("max runners achieved, unloading one to make room", "runner_count", loadedCount)
runnerToExpire = s.findRunnerToUnload()
} else {