From 35b89b2eaba4ac6fc4ae1ba4bf2ec6c8bafe9529 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Wed, 3 Jul 2024 16:00:54 -0700 Subject: [PATCH 01/11] rfc: dynamic environ lookup --- app/lifecycle/logging.go | 2 +- envconfig/config.go | 28 ++++++++++++++++------------ envconfig/config_test.go | 13 ++++++------- gpu/gpu.go | 2 +- llm/memory_test.go | 4 ++-- llm/server.go | 4 ++-- server/routes.go | 2 +- 7 files changed, 29 insertions(+), 26 deletions(-) diff --git a/app/lifecycle/logging.go b/app/lifecycle/logging.go index a8f1f7cd..3672aad5 100644 --- a/app/lifecycle/logging.go +++ b/app/lifecycle/logging.go @@ -14,7 +14,7 @@ import ( func InitLogging() { level := slog.LevelInfo - if envconfig.Debug { + if envconfig.Debug() { level = slog.LevelDebug } diff --git a/envconfig/config.go b/envconfig/config.go index 0abc6968..426507be 100644 --- a/envconfig/config.go +++ b/envconfig/config.go @@ -26,11 +26,24 @@ func (o OllamaHost) String() string { var ErrInvalidHostPort = errors.New("invalid port specified in OLLAMA_HOST") +// Debug returns true if the OLLAMA_DEBUG environment variable is set to a truthy value. +func Debug() bool { + if s := clean("OLLAMA_DEBUG"); s != "" { + b, err := strconv.ParseBool(s) + if err != nil { + // non-empty value is truthy + return true + } + + return b + } + + return false +} + var ( // Set via OLLAMA_ORIGINS in the environment AllowOrigins []string - // Set via OLLAMA_DEBUG in the environment - Debug bool // Experimental flash attention FlashAttention bool // Set via OLLAMA_HOST in the environment @@ -80,7 +93,7 @@ type EnvVar struct { func AsMap() map[string]EnvVar { ret := map[string]EnvVar{ - "OLLAMA_DEBUG": {"OLLAMA_DEBUG", Debug, "Show additional debug information (e.g. OLLAMA_DEBUG=1)"}, + "OLLAMA_DEBUG": {"OLLAMA_DEBUG", Debug(), "Show additional debug information (e.g. OLLAMA_DEBUG=1)"}, "OLLAMA_FLASH_ATTENTION": {"OLLAMA_FLASH_ATTENTION", FlashAttention, "Enabled flash attention"}, "OLLAMA_HOST": {"OLLAMA_HOST", Host, "IP Address for the ollama server (default 127.0.0.1:11434)"}, "OLLAMA_KEEP_ALIVE": {"OLLAMA_KEEP_ALIVE", KeepAlive, "The duration that models stay loaded in memory (default \"5m\")"}, @@ -137,15 +150,6 @@ func init() { } func LoadConfig() { - if debug := clean("OLLAMA_DEBUG"); debug != "" { - d, err := strconv.ParseBool(debug) - if err == nil { - Debug = d - } else { - Debug = true - } - } - if fa := clean("OLLAMA_FLASH_ATTENTION"); fa != "" { d, err := strconv.ParseBool(fa) if err == nil { diff --git a/envconfig/config_test.go b/envconfig/config_test.go index a5d73fd7..f083bb03 100644 --- a/envconfig/config_test.go +++ b/envconfig/config_test.go @@ -12,16 +12,15 @@ import ( ) func TestConfig(t *testing.T) { - Debug = false // Reset whatever was loaded in init() t.Setenv("OLLAMA_DEBUG", "") - LoadConfig() - require.False(t, Debug) + require.False(t, Debug()) + t.Setenv("OLLAMA_DEBUG", "false") - LoadConfig() - require.False(t, Debug) + require.False(t, Debug()) + t.Setenv("OLLAMA_DEBUG", "1") - LoadConfig() - require.True(t, Debug) + require.True(t, Debug()) + t.Setenv("OLLAMA_FLASH_ATTENTION", "1") LoadConfig() require.True(t, FlashAttention) diff --git a/gpu/gpu.go b/gpu/gpu.go index 6e25cb46..1815668f 100644 --- a/gpu/gpu.go +++ b/gpu/gpu.go @@ -611,7 +611,7 @@ func LoadOneapiMgmt(oneapiLibPaths []string) (int, *C.oneapi_handle_t, string) { } func getVerboseState() C.uint16_t { - if envconfig.Debug { + if envconfig.Debug() { return C.uint16_t(1) } return C.uint16_t(0) diff --git a/llm/memory_test.go b/llm/memory_test.go index f972f927..06ae7438 100644 --- a/llm/memory_test.go +++ b/llm/memory_test.go @@ -8,14 +8,14 @@ import ( "testing" "github.com/ollama/ollama/api" - "github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/gpu" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestEstimateGPULayers(t *testing.T) { - envconfig.Debug = true + t.Setenv("OLLAMA_DEBUG", "1") + modelName := "dummy" f, err := os.CreateTemp(t.TempDir(), modelName) require.NoError(t, err) diff --git a/llm/server.go b/llm/server.go index 08463ef0..eb966650 100644 --- a/llm/server.go +++ b/llm/server.go @@ -195,7 +195,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr params = append(params, "--n-gpu-layers", fmt.Sprintf("%d", opts.NumGPU)) } - if envconfig.Debug { + if envconfig.Debug() { params = append(params, "--verbose") } @@ -381,7 +381,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr } slog.Info("starting llama server", "cmd", s.cmd.String()) - if envconfig.Debug { + if envconfig.Debug() { filteredEnv := []string{} for _, ev := range s.cmd.Env { if strings.HasPrefix(ev, "CUDA_") || diff --git a/server/routes.go b/server/routes.go index 0d7ca003..c049421b 100644 --- a/server/routes.go +++ b/server/routes.go @@ -1093,7 +1093,7 @@ func (s *Server) GenerateRoutes() http.Handler { func Serve(ln net.Listener) error { level := slog.LevelInfo - if envconfig.Debug { + if envconfig.Debug() { level = slog.LevelDebug } From 4f1afd575d1dfd803b0d9abb995862d61e8d0734 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Wed, 3 Jul 2024 16:44:57 -0700 Subject: [PATCH 02/11] host --- api/client.go | 8 +-- cmd/cmd.go | 2 +- envconfig/config.go | 107 ++++++++++++++++----------------------- envconfig/config_test.go | 62 +++++++++-------------- 4 files changed, 71 insertions(+), 108 deletions(-) diff --git a/api/client.go b/api/client.go index c59fbc42..e02b21bf 100644 --- a/api/client.go +++ b/api/client.go @@ -20,7 +20,6 @@ import ( "encoding/json" "fmt" "io" - "net" "net/http" "net/url" "runtime" @@ -63,13 +62,8 @@ func checkError(resp *http.Response, body []byte) error { // If the variable is not specified, a default ollama host and port will be // used. func ClientFromEnvironment() (*Client, error) { - ollamaHost := envconfig.Host - return &Client{ - base: &url.URL{ - Scheme: ollamaHost.Scheme, - Host: net.JoinHostPort(ollamaHost.Host, ollamaHost.Port), - }, + base: envconfig.Host(), http: http.DefaultClient, }, nil } diff --git a/cmd/cmd.go b/cmd/cmd.go index b761d018..5f3735f4 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -1076,7 +1076,7 @@ func RunServer(cmd *cobra.Command, _ []string) error { return err } - ln, err := net.Listen("tcp", net.JoinHostPort(envconfig.Host.Host, envconfig.Host.Port)) + ln, err := net.Listen("tcp", envconfig.Host().Host) if err != nil { return err } diff --git a/envconfig/config.go b/envconfig/config.go index 426507be..23f93270 100644 --- a/envconfig/config.go +++ b/envconfig/config.go @@ -6,6 +6,7 @@ import ( "log/slog" "math" "net" + "net/url" "os" "path/filepath" "runtime" @@ -14,16 +15,6 @@ import ( "time" ) -type OllamaHost struct { - Scheme string - Host string - Port string -} - -func (o OllamaHost) String() string { - return fmt.Sprintf("%s://%s:%s", o.Scheme, o.Host, o.Port) -} - var ErrInvalidHostPort = errors.New("invalid port specified in OLLAMA_HOST") // Debug returns true if the OLLAMA_DEBUG environment variable is set to a truthy value. @@ -41,13 +32,54 @@ func Debug() bool { return false } +// Host returns the scheme and host. Host can be configured via the OLLAMA_HOST environment variable. +// Default is scheme "http" and host "127.0.0.1:11434" +func Host() *url.URL { + defaultPort := "11434" + + s := os.Getenv("OLLAMA_HOST") + s = strings.TrimSpace(strings.Trim(strings.TrimSpace(s), "\"'")) + scheme, hostport, ok := strings.Cut(s, "://") + switch { + case !ok: + scheme, hostport = "http", s + case scheme == "http": + defaultPort = "80" + case scheme == "https": + defaultPort = "443" + } + + // trim trailing slashes + hostport = strings.TrimRight(hostport, "/") + + host, port, err := net.SplitHostPort(hostport) + if err != nil { + host, port = "127.0.0.1", defaultPort + if ip := net.ParseIP(strings.Trim(hostport, "[]")); ip != nil { + host = ip.String() + } else if hostport != "" { + host = hostport + } + } + + if n, err := strconv.ParseInt(port, 10, 32); err != nil || n > 65535 || n < 0 { + return &url.URL{ + Scheme: scheme, + Host: net.JoinHostPort(host, defaultPort), + } + } + + return &url.URL{ + Scheme: scheme, + Host: net.JoinHostPort(host, port), + } +} + var ( // Set via OLLAMA_ORIGINS in the environment AllowOrigins []string // Experimental flash attention FlashAttention bool - // Set via OLLAMA_HOST in the environment - Host *OllamaHost // Set via OLLAMA_KEEP_ALIVE in the environment KeepAlive time.Duration // Set via OLLAMA_LLM_LIBRARY in the environment @@ -95,7 +127,7 @@ func AsMap() map[string]EnvVar { ret := map[string]EnvVar{ "OLLAMA_DEBUG": {"OLLAMA_DEBUG", Debug(), "Show additional debug information (e.g. OLLAMA_DEBUG=1)"}, "OLLAMA_FLASH_ATTENTION": {"OLLAMA_FLASH_ATTENTION", FlashAttention, "Enabled flash attention"}, - "OLLAMA_HOST": {"OLLAMA_HOST", Host, "IP Address for the ollama server (default 127.0.0.1:11434)"}, + "OLLAMA_HOST": {"OLLAMA_HOST", Host(), "IP Address for the ollama server (default 127.0.0.1:11434)"}, "OLLAMA_KEEP_ALIVE": {"OLLAMA_KEEP_ALIVE", KeepAlive, "The duration that models stay loaded in memory (default \"5m\")"}, "OLLAMA_LLM_LIBRARY": {"OLLAMA_LLM_LIBRARY", LLMLibrary, "Set LLM library to bypass autodetection"}, "OLLAMA_MAX_LOADED_MODELS": {"OLLAMA_MAX_LOADED_MODELS", MaxRunners, "Maximum number of loaded models per GPU"}, @@ -271,11 +303,6 @@ func LoadConfig() { slog.Error("invalid setting", "OLLAMA_MODELS", ModelsDir, "error", err) } - Host, err = getOllamaHost() - if err != nil { - slog.Error("invalid setting", "OLLAMA_HOST", Host, "error", err, "using default port", Host.Port) - } - if set, err := strconv.ParseBool(clean("OLLAMA_INTEL_GPU")); err == nil { IntelGpu = set } @@ -298,50 +325,6 @@ func getModelsDir() (string, error) { return filepath.Join(home, ".ollama", "models"), nil } -func getOllamaHost() (*OllamaHost, error) { - defaultPort := "11434" - - hostVar := os.Getenv("OLLAMA_HOST") - hostVar = strings.TrimSpace(strings.Trim(strings.TrimSpace(hostVar), "\"'")) - - scheme, hostport, ok := strings.Cut(hostVar, "://") - switch { - case !ok: - scheme, hostport = "http", hostVar - case scheme == "http": - defaultPort = "80" - case scheme == "https": - defaultPort = "443" - } - - // trim trailing slashes - hostport = strings.TrimRight(hostport, "/") - - host, port, err := net.SplitHostPort(hostport) - if err != nil { - host, port = "127.0.0.1", defaultPort - if ip := net.ParseIP(strings.Trim(hostport, "[]")); ip != nil { - host = ip.String() - } else if hostport != "" { - host = hostport - } - } - - if portNum, err := strconv.ParseInt(port, 10, 32); err != nil || portNum > 65535 || portNum < 0 { - return &OllamaHost{ - Scheme: scheme, - Host: host, - Port: defaultPort, - }, ErrInvalidHostPort - } - - return &OllamaHost{ - Scheme: scheme, - Host: host, - Port: port, - }, nil -} - func loadKeepAlive(ka string) { v, err := strconv.Atoi(ka) if err != nil { diff --git a/envconfig/config_test.go b/envconfig/config_test.go index f083bb03..af89e7b7 100644 --- a/envconfig/config_test.go +++ b/envconfig/config_test.go @@ -1,13 +1,10 @@ package envconfig import ( - "fmt" "math" - "net" "testing" "time" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -42,45 +39,34 @@ func TestConfig(t *testing.T) { } func TestClientFromEnvironment(t *testing.T) { - type testCase struct { + cases := map[string]struct { value string expect string - err error + }{ + "empty": {"", "127.0.0.1:11434"}, + "only address": {"1.2.3.4", "1.2.3.4:11434"}, + "only port": {":1234", ":1234"}, + "address and port": {"1.2.3.4:1234", "1.2.3.4:1234"}, + "hostname": {"example.com", "example.com:11434"}, + "hostname and port": {"example.com:1234", "example.com:1234"}, + "zero port": {":0", ":0"}, + "too large port": {":66000", ":11434"}, + "too small port": {":-1", ":11434"}, + "ipv6 localhost": {"[::1]", "[::1]:11434"}, + "ipv6 world open": {"[::]", "[::]:11434"}, + "ipv6 no brackets": {"::1", "[::1]:11434"}, + "ipv6 + port": {"[::1]:1337", "[::1]:1337"}, + "extra space": {" 1.2.3.4 ", "1.2.3.4:11434"}, + "extra quotes": {"\"1.2.3.4\"", "1.2.3.4:11434"}, + "extra space+quotes": {" \" 1.2.3.4 \" ", "1.2.3.4:11434"}, + "extra single quotes": {"'1.2.3.4'", "1.2.3.4:11434"}, } - hostTestCases := map[string]*testCase{ - "empty": {value: "", expect: "127.0.0.1:11434"}, - "only address": {value: "1.2.3.4", expect: "1.2.3.4:11434"}, - "only port": {value: ":1234", expect: ":1234"}, - "address and port": {value: "1.2.3.4:1234", expect: "1.2.3.4:1234"}, - "hostname": {value: "example.com", expect: "example.com:11434"}, - "hostname and port": {value: "example.com:1234", expect: "example.com:1234"}, - "zero port": {value: ":0", expect: ":0"}, - "too large port": {value: ":66000", err: ErrInvalidHostPort}, - "too small port": {value: ":-1", err: ErrInvalidHostPort}, - "ipv6 localhost": {value: "[::1]", expect: "[::1]:11434"}, - "ipv6 world open": {value: "[::]", expect: "[::]:11434"}, - "ipv6 no brackets": {value: "::1", expect: "[::1]:11434"}, - "ipv6 + port": {value: "[::1]:1337", expect: "[::1]:1337"}, - "extra space": {value: " 1.2.3.4 ", expect: "1.2.3.4:11434"}, - "extra quotes": {value: "\"1.2.3.4\"", expect: "1.2.3.4:11434"}, - "extra space+quotes": {value: " \" 1.2.3.4 \" ", expect: "1.2.3.4:11434"}, - "extra single quotes": {value: "'1.2.3.4'", expect: "1.2.3.4:11434"}, - } - - for k, v := range hostTestCases { - t.Run(k, func(t *testing.T) { - t.Setenv("OLLAMA_HOST", v.value) - LoadConfig() - - oh, err := getOllamaHost() - if err != v.err { - t.Fatalf("expected %s, got %s", v.err, err) - } - - if err == nil { - host := net.JoinHostPort(oh.Host, oh.Port) - assert.Equal(t, v.expect, host, fmt.Sprintf("%s: expected %s, got %s", k, v.expect, host)) + for name, tt := range cases { + t.Run(name, func(t *testing.T) { + t.Setenv("OLLAMA_HOST", tt.value) + if host := Host(); host.Host != tt.expect { + t.Errorf("%s: expected %s, got %s", name, tt.expect, host.Host) } }) } From d1a5227cadf6ae736f8dd5cb9fb7452dd015f820 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Wed, 3 Jul 2024 17:02:07 -0700 Subject: [PATCH 03/11] origins --- envconfig/config.go | 52 +++++++++++----------- envconfig/config_test.go | 95 +++++++++++++++++++++++++++++++++++++++- server/routes.go | 2 +- 3 files changed, 119 insertions(+), 30 deletions(-) diff --git a/envconfig/config.go b/envconfig/config.go index 23f93270..7ae521ab 100644 --- a/envconfig/config.go +++ b/envconfig/config.go @@ -75,9 +75,31 @@ func Host() *url.URL { } } +// Origins returns a list of allowed origins. Origins can be configured via the OLLAMA_ORIGINS environment variable. +func Origins() (origins []string) { + if s := clean("OLLAMA_ORIGINS"); s != "" { + origins = strings.Split(s, ",") + } + + for _, origin := range []string{"localhost", "127.0.0.1", "0.0.0.0"} { + origins = append(origins, + fmt.Sprintf("http://%s", origin), + fmt.Sprintf("https://%s", origin), + fmt.Sprintf("http://%s", net.JoinHostPort(origin, "*")), + fmt.Sprintf("https://%s", net.JoinHostPort(origin, "*")), + ) + } + + origins = append(origins, + "app://*", + "file://*", + "tauri://*", + ) + + return origins +} + var ( - // Set via OLLAMA_ORIGINS in the environment - AllowOrigins []string // Experimental flash attention FlashAttention bool // Set via OLLAMA_KEEP_ALIVE in the environment @@ -136,7 +158,7 @@ func AsMap() map[string]EnvVar { "OLLAMA_NOHISTORY": {"OLLAMA_NOHISTORY", NoHistory, "Do not preserve readline history"}, "OLLAMA_NOPRUNE": {"OLLAMA_NOPRUNE", NoPrune, "Do not prune model blobs on startup"}, "OLLAMA_NUM_PARALLEL": {"OLLAMA_NUM_PARALLEL", NumParallel, "Maximum number of parallel requests"}, - "OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", AllowOrigins, "A comma separated list of allowed origins"}, + "OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", Origins(), "A comma separated list of allowed origins"}, "OLLAMA_RUNNERS_DIR": {"OLLAMA_RUNNERS_DIR", RunnersDir, "Location for runners"}, "OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread, "Always schedule model across all GPUs"}, "OLLAMA_TMPDIR": {"OLLAMA_TMPDIR", TmpDir, "Location for temporary files"}, @@ -160,12 +182,6 @@ func Values() map[string]string { return vals } -var defaultAllowOrigins = []string{ - "localhost", - "127.0.0.1", - "0.0.0.0", -} - // Clean quotes and spaces from the value func clean(key string) string { return strings.Trim(os.Getenv(key), "\"' ") @@ -255,24 +271,6 @@ func LoadConfig() { NoPrune = true } - if origins := clean("OLLAMA_ORIGINS"); origins != "" { - AllowOrigins = strings.Split(origins, ",") - } - for _, allowOrigin := range defaultAllowOrigins { - AllowOrigins = append(AllowOrigins, - fmt.Sprintf("http://%s", allowOrigin), - fmt.Sprintf("https://%s", allowOrigin), - fmt.Sprintf("http://%s", net.JoinHostPort(allowOrigin, "*")), - fmt.Sprintf("https://%s", net.JoinHostPort(allowOrigin, "*")), - ) - } - - AllowOrigins = append(AllowOrigins, - "app://*", - "file://*", - "tauri://*", - ) - maxRunners := clean("OLLAMA_MAX_LOADED_MODELS") if maxRunners != "" { m, err := strconv.Atoi(maxRunners) diff --git a/envconfig/config_test.go b/envconfig/config_test.go index af89e7b7..dc65ef70 100644 --- a/envconfig/config_test.go +++ b/envconfig/config_test.go @@ -5,10 +5,11 @@ import ( "testing" "time" + "github.com/google/go-cmp/cmp" "github.com/stretchr/testify/require" ) -func TestConfig(t *testing.T) { +func TestSmoke(t *testing.T) { t.Setenv("OLLAMA_DEBUG", "") require.False(t, Debug()) @@ -38,7 +39,7 @@ func TestConfig(t *testing.T) { require.Equal(t, time.Duration(math.MaxInt64), KeepAlive) } -func TestClientFromEnvironment(t *testing.T) { +func TestHost(t *testing.T) { cases := map[string]struct { value string expect string @@ -71,3 +72,93 @@ func TestClientFromEnvironment(t *testing.T) { }) } } + +func TestOrigins(t *testing.T) { + cases := []struct { + value string + expect []string + }{ + {"", []string{ + "http://localhost", + "https://localhost", + "http://localhost:*", + "https://localhost:*", + "http://127.0.0.1", + "https://127.0.0.1", + "http://127.0.0.1:*", + "https://127.0.0.1:*", + "http://0.0.0.0", + "https://0.0.0.0", + "http://0.0.0.0:*", + "https://0.0.0.0:*", + "app://*", + "file://*", + "tauri://*", + }}, + {"http://10.0.0.1", []string{ + "http://10.0.0.1", + "http://localhost", + "https://localhost", + "http://localhost:*", + "https://localhost:*", + "http://127.0.0.1", + "https://127.0.0.1", + "http://127.0.0.1:*", + "https://127.0.0.1:*", + "http://0.0.0.0", + "https://0.0.0.0", + "http://0.0.0.0:*", + "https://0.0.0.0:*", + "app://*", + "file://*", + "tauri://*", + }}, + {"http://172.16.0.1,https://192.168.0.1", []string{ + "http://172.16.0.1", + "https://192.168.0.1", + "http://localhost", + "https://localhost", + "http://localhost:*", + "https://localhost:*", + "http://127.0.0.1", + "https://127.0.0.1", + "http://127.0.0.1:*", + "https://127.0.0.1:*", + "http://0.0.0.0", + "https://0.0.0.0", + "http://0.0.0.0:*", + "https://0.0.0.0:*", + "app://*", + "file://*", + "tauri://*", + }}, + {"http://totally.safe,http://definitely.legit", []string{ + "http://totally.safe", + "http://definitely.legit", + "http://localhost", + "https://localhost", + "http://localhost:*", + "https://localhost:*", + "http://127.0.0.1", + "https://127.0.0.1", + "http://127.0.0.1:*", + "https://127.0.0.1:*", + "http://0.0.0.0", + "https://0.0.0.0", + "http://0.0.0.0:*", + "https://0.0.0.0:*", + "app://*", + "file://*", + "tauri://*", + }}, + } + for _, tt := range cases { + t.Run(tt.value, func(t *testing.T) { + t.Setenv("OLLAMA_ORIGINS", tt.value) + + if diff := cmp.Diff(Origins(), tt.expect); diff != "" { + t.Errorf("%s: mismatch (-want +got):\n%s", tt.value, diff) + } + }) + } +} diff --git a/server/routes.go b/server/routes.go index c049421b..07898d9b 100644 --- a/server/routes.go +++ b/server/routes.go @@ -1048,7 +1048,7 @@ func (s *Server) GenerateRoutes() http.Handler { for _, prop := range openAIProperties { config.AllowHeaders = append(config.AllowHeaders, "x-stainless-"+prop) } - config.AllowOrigins = envconfig.AllowOrigins + config.AllowOrigins = envconfig.Origins() r := gin.Default() r.Use( From 66fe77f0841622054e29f5fd3d3643f514991004 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Wed, 3 Jul 2024 17:07:42 -0700 Subject: [PATCH 04/11] models --- envconfig/config.go | 34 ++++++++++++++++------------------ server/modelpath.go | 12 +++--------- 2 files changed, 19 insertions(+), 27 deletions(-) diff --git a/envconfig/config.go b/envconfig/config.go index 7ae521ab..286f51d4 100644 --- a/envconfig/config.go +++ b/envconfig/config.go @@ -99,6 +99,21 @@ func Origins() (origins []string) { return origins } +// Models returns the path to the models directory. Models directory can be configured via the OLLAMA_MODELS environment variable. +// Default is $HOME/.ollama/models +func Models() string { + if s, ok := os.LookupEnv("OLLAMA_MODELS"); ok { + return s + } + + home, err := os.UserHomeDir() + if err != nil { + panic(err) + } + + return filepath.Join(home, ".ollama", "models") +} + var ( // Experimental flash attention FlashAttention bool @@ -154,7 +169,7 @@ func AsMap() map[string]EnvVar { "OLLAMA_LLM_LIBRARY": {"OLLAMA_LLM_LIBRARY", LLMLibrary, "Set LLM library to bypass autodetection"}, "OLLAMA_MAX_LOADED_MODELS": {"OLLAMA_MAX_LOADED_MODELS", MaxRunners, "Maximum number of loaded models per GPU"}, "OLLAMA_MAX_QUEUE": {"OLLAMA_MAX_QUEUE", MaxQueuedRequests, "Maximum number of queued requests"}, - "OLLAMA_MODELS": {"OLLAMA_MODELS", ModelsDir, "The path to the models directory"}, + "OLLAMA_MODELS": {"OLLAMA_MODELS", Models(), "The path to the models directory"}, "OLLAMA_NOHISTORY": {"OLLAMA_NOHISTORY", NoHistory, "Do not preserve readline history"}, "OLLAMA_NOPRUNE": {"OLLAMA_NOPRUNE", NoPrune, "Do not prune model blobs on startup"}, "OLLAMA_NUM_PARALLEL": {"OLLAMA_NUM_PARALLEL", NumParallel, "Maximum number of parallel requests"}, @@ -295,12 +310,6 @@ func LoadConfig() { loadKeepAlive(ka) } - var err error - ModelsDir, err = getModelsDir() - if err != nil { - slog.Error("invalid setting", "OLLAMA_MODELS", ModelsDir, "error", err) - } - if set, err := strconv.ParseBool(clean("OLLAMA_INTEL_GPU")); err == nil { IntelGpu = set } @@ -312,17 +321,6 @@ func LoadConfig() { HsaOverrideGfxVersion = clean("HSA_OVERRIDE_GFX_VERSION") } -func getModelsDir() (string, error) { - if models, exists := os.LookupEnv("OLLAMA_MODELS"); exists { - return models, nil - } - home, err := os.UserHomeDir() - if err != nil { - return "", err - } - return filepath.Join(home, ".ollama", "models"), nil -} - func loadKeepAlive(ka string) { v, err := strconv.Atoi(ka) if err != nil { diff --git a/server/modelpath.go b/server/modelpath.go index 3fdb4238..354eeed7 100644 --- a/server/modelpath.go +++ b/server/modelpath.go @@ -105,9 +105,7 @@ func (mp ModelPath) GetShortTagname() string { // GetManifestPath returns the path to the manifest file for the given model path, it is up to the caller to create the directory if it does not exist. func (mp ModelPath) GetManifestPath() (string, error) { - dir := envconfig.ModelsDir - - return filepath.Join(dir, "manifests", mp.Registry, mp.Namespace, mp.Repository, mp.Tag), nil + return filepath.Join(envconfig.Models(), "manifests", mp.Registry, mp.Namespace, mp.Repository, mp.Tag), nil } func (mp ModelPath) BaseURL() *url.URL { @@ -118,9 +116,7 @@ func (mp ModelPath) BaseURL() *url.URL { } func GetManifestPath() (string, error) { - dir := envconfig.ModelsDir - - path := filepath.Join(dir, "manifests") + path := filepath.Join(envconfig.Models(), "manifests") if err := os.MkdirAll(path, 0o755); err != nil { return "", err } @@ -129,8 +125,6 @@ func GetManifestPath() (string, error) { } func GetBlobsPath(digest string) (string, error) { - dir := envconfig.ModelsDir - // only accept actual sha256 digests pattern := "^sha256[:-][0-9a-fA-F]{64}$" re := regexp.MustCompile(pattern) @@ -140,7 +134,7 @@ func GetBlobsPath(digest string) (string, error) { } digest = strings.ReplaceAll(digest, ":", "-") - path := filepath.Join(dir, "blobs", digest) + path := filepath.Join(envconfig.Models(), "blobs", digest) dirPath := filepath.Dir(path) if digest == "" { dirPath = path From 55cd3ddccac14d48f5f129ec35b3a109be215d01 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Wed, 3 Jul 2024 17:22:13 -0700 Subject: [PATCH 05/11] bool --- cmd/interactive.go | 2 +- envconfig/config.go | 123 ++++++++++++++++----------------------- envconfig/config_test.go | 28 ++++++++- gpu/gpu.go | 2 +- llm/server.go | 2 +- server/images.go | 4 +- server/routes.go | 2 +- server/sched.go | 2 +- 8 files changed, 82 insertions(+), 83 deletions(-) diff --git a/cmd/interactive.go b/cmd/interactive.go index adbc3e9f..9fb66851 100644 --- a/cmd/interactive.go +++ b/cmd/interactive.go @@ -157,7 +157,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error { return err } - if envconfig.NoHistory { + if envconfig.NoHistory() { scanner.HistoryDisable() } diff --git a/envconfig/config.go b/envconfig/config.go index 286f51d4..ea78585b 100644 --- a/envconfig/config.go +++ b/envconfig/config.go @@ -17,21 +17,6 @@ import ( var ErrInvalidHostPort = errors.New("invalid port specified in OLLAMA_HOST") -// Debug returns true if the OLLAMA_DEBUG environment variable is set to a truthy value. -func Debug() bool { - if s := clean("OLLAMA_DEBUG"); s != "" { - b, err := strconv.ParseBool(s) - if err != nil { - // non-empty value is truthy - return true - } - - return b - } - - return false -} - // Host returns the scheme and host. Host can be configured via the OLLAMA_HOST environment variable. // Default is scheme "http" and host "127.0.0.1:11434" func Host() *url.URL { @@ -77,7 +62,7 @@ func Host() *url.URL { // Origins returns a list of allowed origins. Origins can be configured via the OLLAMA_ORIGINS environment variable. func Origins() (origins []string) { - if s := clean("OLLAMA_ORIGINS"); s != "" { + if s := getenv("OLLAMA_ORIGINS"); s != "" { origins = strings.Split(s, ",") } @@ -114,9 +99,37 @@ func Models() string { return filepath.Join(home, ".ollama", "models") } +func Bool(k string) func() bool { + return func() bool { + if s := getenv(k); s != "" { + b, err := strconv.ParseBool(s) + if err != nil { + return true + } + + return b + } + + return false + } +} + +var ( + // Debug enabled additional debug information. + Debug = Bool("OLLAMA_DEBUG") + // FlashAttention enables the experimental flash attention feature. + FlashAttention = Bool("OLLAMA_FLASH_ATTENTION") + // NoHistory disables readline history. + NoHistory = Bool("OLLAMA_NOHISTORY") + // NoPrune disables pruning of model blobs on startup. + NoPrune = Bool("OLLAMA_NOPRUNE") + // SchedSpread allows scheduling models across all GPUs. + SchedSpread = Bool("OLLAMA_SCHED_SPREAD") + // IntelGPU enables experimental Intel GPU detection. + IntelGPU = Bool("OLLAMA_INTEL_GPU") +) + var ( - // Experimental flash attention - FlashAttention bool // Set via OLLAMA_KEEP_ALIVE in the environment KeepAlive time.Duration // Set via OLLAMA_LLM_LIBRARY in the environment @@ -125,22 +138,12 @@ var ( MaxRunners int // Set via OLLAMA_MAX_QUEUE in the environment MaxQueuedRequests int - // Set via OLLAMA_MODELS in the environment - ModelsDir string - // Set via OLLAMA_NOHISTORY in the environment - NoHistory bool - // Set via OLLAMA_NOPRUNE in the environment - NoPrune bool // Set via OLLAMA_NUM_PARALLEL in the environment NumParallel int // Set via OLLAMA_RUNNERS_DIR in the environment RunnersDir string - // Set via OLLAMA_SCHED_SPREAD in the environment - SchedSpread bool // Set via OLLAMA_TMPDIR in the environment TmpDir string - // Set via OLLAMA_INTEL_GPU in the environment - IntelGpu bool // Set via CUDA_VISIBLE_DEVICES in the environment CudaVisibleDevices string @@ -163,19 +166,19 @@ type EnvVar struct { func AsMap() map[string]EnvVar { ret := map[string]EnvVar{ "OLLAMA_DEBUG": {"OLLAMA_DEBUG", Debug(), "Show additional debug information (e.g. OLLAMA_DEBUG=1)"}, - "OLLAMA_FLASH_ATTENTION": {"OLLAMA_FLASH_ATTENTION", FlashAttention, "Enabled flash attention"}, + "OLLAMA_FLASH_ATTENTION": {"OLLAMA_FLASH_ATTENTION", FlashAttention(), "Enabled flash attention"}, "OLLAMA_HOST": {"OLLAMA_HOST", Host(), "IP Address for the ollama server (default 127.0.0.1:11434)"}, "OLLAMA_KEEP_ALIVE": {"OLLAMA_KEEP_ALIVE", KeepAlive, "The duration that models stay loaded in memory (default \"5m\")"}, "OLLAMA_LLM_LIBRARY": {"OLLAMA_LLM_LIBRARY", LLMLibrary, "Set LLM library to bypass autodetection"}, "OLLAMA_MAX_LOADED_MODELS": {"OLLAMA_MAX_LOADED_MODELS", MaxRunners, "Maximum number of loaded models per GPU"}, "OLLAMA_MAX_QUEUE": {"OLLAMA_MAX_QUEUE", MaxQueuedRequests, "Maximum number of queued requests"}, "OLLAMA_MODELS": {"OLLAMA_MODELS", Models(), "The path to the models directory"}, - "OLLAMA_NOHISTORY": {"OLLAMA_NOHISTORY", NoHistory, "Do not preserve readline history"}, - "OLLAMA_NOPRUNE": {"OLLAMA_NOPRUNE", NoPrune, "Do not prune model blobs on startup"}, + "OLLAMA_NOHISTORY": {"OLLAMA_NOHISTORY", NoHistory(), "Do not preserve readline history"}, + "OLLAMA_NOPRUNE": {"OLLAMA_NOPRUNE", NoPrune(), "Do not prune model blobs on startup"}, "OLLAMA_NUM_PARALLEL": {"OLLAMA_NUM_PARALLEL", NumParallel, "Maximum number of parallel requests"}, "OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", Origins(), "A comma separated list of allowed origins"}, "OLLAMA_RUNNERS_DIR": {"OLLAMA_RUNNERS_DIR", RunnersDir, "Location for runners"}, - "OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread, "Always schedule model across all GPUs"}, + "OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread(), "Always schedule model across all GPUs"}, "OLLAMA_TMPDIR": {"OLLAMA_TMPDIR", TmpDir, "Location for temporary files"}, } if runtime.GOOS != "darwin" { @@ -184,7 +187,7 @@ func AsMap() map[string]EnvVar { ret["ROCR_VISIBLE_DEVICES"] = EnvVar{"ROCR_VISIBLE_DEVICES", RocrVisibleDevices, "Set which AMD devices are visible"} ret["GPU_DEVICE_ORDINAL"] = EnvVar{"GPU_DEVICE_ORDINAL", GpuDeviceOrdinal, "Set which AMD devices are visible"} ret["HSA_OVERRIDE_GFX_VERSION"] = EnvVar{"HSA_OVERRIDE_GFX_VERSION", HsaOverrideGfxVersion, "Override the gfx used for all detected AMD GPUs"} - ret["OLLAMA_INTEL_GPU"] = EnvVar{"OLLAMA_INTEL_GPU", IntelGpu, "Enable experimental Intel GPU detection"} + ret["OLLAMA_INTEL_GPU"] = EnvVar{"OLLAMA_INTEL_GPU", IntelGPU(), "Enable experimental Intel GPU detection"} } return ret } @@ -197,8 +200,8 @@ func Values() map[string]string { return vals } -// Clean quotes and spaces from the value -func clean(key string) string { +// getenv returns an environment variable stripped of leading and trailing quotes or spaces +func getenv(key string) string { return strings.Trim(os.Getenv(key), "\"' ") } @@ -213,14 +216,7 @@ func init() { } func LoadConfig() { - if fa := clean("OLLAMA_FLASH_ATTENTION"); fa != "" { - d, err := strconv.ParseBool(fa) - if err == nil { - FlashAttention = d - } - } - - RunnersDir = clean("OLLAMA_RUNNERS_DIR") + RunnersDir = getenv("OLLAMA_RUNNERS_DIR") if runtime.GOOS == "windows" && RunnersDir == "" { // On Windows we do not carry the payloads inside the main executable appExe, err := os.Executable() @@ -256,11 +252,11 @@ func LoadConfig() { } } - TmpDir = clean("OLLAMA_TMPDIR") + TmpDir = getenv("OLLAMA_TMPDIR") - LLMLibrary = clean("OLLAMA_LLM_LIBRARY") + LLMLibrary = getenv("OLLAMA_LLM_LIBRARY") - if onp := clean("OLLAMA_NUM_PARALLEL"); onp != "" { + if onp := getenv("OLLAMA_NUM_PARALLEL"); onp != "" { val, err := strconv.Atoi(onp) if err != nil { slog.Error("invalid setting, ignoring", "OLLAMA_NUM_PARALLEL", onp, "error", err) @@ -269,24 +265,7 @@ func LoadConfig() { } } - if nohistory := clean("OLLAMA_NOHISTORY"); nohistory != "" { - NoHistory = true - } - - if spread := clean("OLLAMA_SCHED_SPREAD"); spread != "" { - s, err := strconv.ParseBool(spread) - if err == nil { - SchedSpread = s - } else { - SchedSpread = true - } - } - - if noprune := clean("OLLAMA_NOPRUNE"); noprune != "" { - NoPrune = true - } - - maxRunners := clean("OLLAMA_MAX_LOADED_MODELS") + maxRunners := getenv("OLLAMA_MAX_LOADED_MODELS") if maxRunners != "" { m, err := strconv.Atoi(maxRunners) if err != nil { @@ -305,20 +284,16 @@ func LoadConfig() { } } - ka := clean("OLLAMA_KEEP_ALIVE") + ka := getenv("OLLAMA_KEEP_ALIVE") if ka != "" { loadKeepAlive(ka) } - if set, err := strconv.ParseBool(clean("OLLAMA_INTEL_GPU")); err == nil { - IntelGpu = set - } - - CudaVisibleDevices = clean("CUDA_VISIBLE_DEVICES") - HipVisibleDevices = clean("HIP_VISIBLE_DEVICES") - RocrVisibleDevices = clean("ROCR_VISIBLE_DEVICES") - GpuDeviceOrdinal = clean("GPU_DEVICE_ORDINAL") - HsaOverrideGfxVersion = clean("HSA_OVERRIDE_GFX_VERSION") + CudaVisibleDevices = getenv("CUDA_VISIBLE_DEVICES") + HipVisibleDevices = getenv("HIP_VISIBLE_DEVICES") + RocrVisibleDevices = getenv("ROCR_VISIBLE_DEVICES") + GpuDeviceOrdinal = getenv("GPU_DEVICE_ORDINAL") + HsaOverrideGfxVersion = getenv("HSA_OVERRIDE_GFX_VERSION") } func loadKeepAlive(ka string) { diff --git a/envconfig/config_test.go b/envconfig/config_test.go index dc65ef70..b364b009 100644 --- a/envconfig/config_test.go +++ b/envconfig/config_test.go @@ -20,8 +20,8 @@ func TestSmoke(t *testing.T) { require.True(t, Debug()) t.Setenv("OLLAMA_FLASH_ATTENTION", "1") - LoadConfig() - require.True(t, FlashAttention) + require.True(t, FlashAttention()) + t.Setenv("OLLAMA_KEEP_ALIVE", "") LoadConfig() require.Equal(t, 5*time.Minute, KeepAlive) @@ -162,3 +162,27 @@ func TestOrigins(t *testing.T) { }) } } + +func TestBool(t *testing.T) { + cases := map[string]struct { + value string + expect bool + }{ + "empty": {"", false}, + "true": {"true", true}, + "false": {"false", false}, + "1": {"1", true}, + "0": {"0", false}, + "random": {"random", true}, + "something": {"something", true}, + } + + for name, tt := range cases { + t.Run(name, func(t *testing.T) { + t.Setenv("OLLAMA_BOOL", tt.value) + if b := Bool("OLLAMA_BOOL"); b() != tt.expect { + t.Errorf("%s: expected %t, got %t", name, tt.expect, b()) + } + }) + } +} diff --git a/gpu/gpu.go b/gpu/gpu.go index 1815668f..c3059542 100644 --- a/gpu/gpu.go +++ b/gpu/gpu.go @@ -302,7 +302,7 @@ func GetGPUInfo() GpuInfoList { } // Intel - if envconfig.IntelGpu { + if envconfig.IntelGPU() { oHandles = initOneAPIHandles() // On windows we bundle the oneapi library one level above the runner dir depPath = "" diff --git a/llm/server.go b/llm/server.go index eb966650..84d9e93a 100644 --- a/llm/server.go +++ b/llm/server.go @@ -221,7 +221,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr params = append(params, "--memory-f32") } - flashAttnEnabled := envconfig.FlashAttention + flashAttnEnabled := envconfig.FlashAttention() for _, g := range gpus { // only cuda (compute capability 7+) and metal support flash attention diff --git a/server/images.go b/server/images.go index 574dec19..3eb3b3fa 100644 --- a/server/images.go +++ b/server/images.go @@ -644,7 +644,7 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio return err } - if !envconfig.NoPrune && old != nil { + if !envconfig.NoPrune() && old != nil { if err := old.RemoveLayers(); err != nil { return err } @@ -883,7 +883,7 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu // build deleteMap to prune unused layers deleteMap := make(map[string]struct{}) - if !envconfig.NoPrune { + if !envconfig.NoPrune() { manifest, _, err = GetManifest(mp) if err != nil && !errors.Is(err, os.ErrNotExist) { return err diff --git a/server/routes.go b/server/routes.go index 07898d9b..41a73cb4 100644 --- a/server/routes.go +++ b/server/routes.go @@ -1121,7 +1121,7 @@ func Serve(ln net.Listener) error { return err } - if !envconfig.NoPrune { + if !envconfig.NoPrune() { // clean up unused layers and manifests if err := PruneLayers(); err != nil { return err diff --git a/server/sched.go b/server/sched.go index 2daed3ab..e1e986a5 100644 --- a/server/sched.go +++ b/server/sched.go @@ -695,7 +695,7 @@ func pickBestFitGPUs(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList, numP // First attempt to fit the model into a single GPU for _, p := range numParallelToTry { req.opts.NumCtx = req.origNumCtx * p - if !envconfig.SchedSpread { + if !envconfig.SchedSpread() { for _, g := range sgl { if ok, estimatedVRAM = llm.PredictServerFit([]gpu.GpuInfo{g}, ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts); ok { slog.Info("new model will fit in available VRAM in single GPU, loading", "model", req.model.ModelPath, "gpu", g.ID, "parallel", p, "available", g.FreeMemory, "required", format.HumanBytes2(estimatedVRAM)) From 8570c1c0ef73e89448f6724645f56b9b10efef44 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Wed, 3 Jul 2024 18:39:35 -0700 Subject: [PATCH 06/11] keepalive --- envconfig/config.go | 51 +++++++++++++++++----------------------- envconfig/config_test.go | 49 +++++++++++++++++++++++++------------- server/sched.go | 2 +- 3 files changed, 55 insertions(+), 47 deletions(-) diff --git a/envconfig/config.go b/envconfig/config.go index ea78585b..62bfad64 100644 --- a/envconfig/config.go +++ b/envconfig/config.go @@ -99,6 +99,26 @@ func Models() string { return filepath.Join(home, ".ollama", "models") } +// KeepAlive returns the duration that models stay loaded in memory. KeepAlive can be configured via the OLLAMA_KEEP_ALIVE environment variable. +// Negative values are treated as infinite. Zero is treated as no keep alive. +// Default is 5 minutes. +func KeepAlive() (keepAlive time.Duration) { + keepAlive = 5 * time.Minute + if s := os.Getenv("OLLAMA_KEEP_ALIVE"); s != "" { + if d, err := time.ParseDuration(s); err == nil { + keepAlive = d + } else if n, err := strconv.ParseInt(s, 10, 64); err == nil { + keepAlive = time.Duration(n) * time.Second + } + } + + if keepAlive < 0 { + return time.Duration(math.MaxInt64) + } + + return keepAlive +} + func Bool(k string) func() bool { return func() bool { if s := getenv(k); s != "" { @@ -130,8 +150,6 @@ var ( ) var ( - // Set via OLLAMA_KEEP_ALIVE in the environment - KeepAlive time.Duration // Set via OLLAMA_LLM_LIBRARY in the environment LLMLibrary string // Set via OLLAMA_MAX_LOADED_MODELS in the environment @@ -168,7 +186,7 @@ func AsMap() map[string]EnvVar { "OLLAMA_DEBUG": {"OLLAMA_DEBUG", Debug(), "Show additional debug information (e.g. OLLAMA_DEBUG=1)"}, "OLLAMA_FLASH_ATTENTION": {"OLLAMA_FLASH_ATTENTION", FlashAttention(), "Enabled flash attention"}, "OLLAMA_HOST": {"OLLAMA_HOST", Host(), "IP Address for the ollama server (default 127.0.0.1:11434)"}, - "OLLAMA_KEEP_ALIVE": {"OLLAMA_KEEP_ALIVE", KeepAlive, "The duration that models stay loaded in memory (default \"5m\")"}, + "OLLAMA_KEEP_ALIVE": {"OLLAMA_KEEP_ALIVE", KeepAlive(), "The duration that models stay loaded in memory (default \"5m\")"}, "OLLAMA_LLM_LIBRARY": {"OLLAMA_LLM_LIBRARY", LLMLibrary, "Set LLM library to bypass autodetection"}, "OLLAMA_MAX_LOADED_MODELS": {"OLLAMA_MAX_LOADED_MODELS", MaxRunners, "Maximum number of loaded models per GPU"}, "OLLAMA_MAX_QUEUE": {"OLLAMA_MAX_QUEUE", MaxQueuedRequests, "Maximum number of queued requests"}, @@ -210,7 +228,6 @@ func init() { NumParallel = 0 // Autoselect MaxRunners = 0 // Autoselect MaxQueuedRequests = 512 - KeepAlive = 5 * time.Minute LoadConfig() } @@ -284,35 +301,9 @@ func LoadConfig() { } } - ka := getenv("OLLAMA_KEEP_ALIVE") - if ka != "" { - loadKeepAlive(ka) - } - CudaVisibleDevices = getenv("CUDA_VISIBLE_DEVICES") HipVisibleDevices = getenv("HIP_VISIBLE_DEVICES") RocrVisibleDevices = getenv("ROCR_VISIBLE_DEVICES") GpuDeviceOrdinal = getenv("GPU_DEVICE_ORDINAL") HsaOverrideGfxVersion = getenv("HSA_OVERRIDE_GFX_VERSION") } - -func loadKeepAlive(ka string) { - v, err := strconv.Atoi(ka) - if err != nil { - d, err := time.ParseDuration(ka) - if err == nil { - if d < 0 { - KeepAlive = time.Duration(math.MaxInt64) - } else { - KeepAlive = d - } - } - } else { - d := time.Duration(v) * time.Second - if d < 0 { - KeepAlive = time.Duration(math.MaxInt64) - } else { - KeepAlive = d - } - } -} diff --git a/envconfig/config_test.go b/envconfig/config_test.go index b364b009..87c808ca 100644 --- a/envconfig/config_test.go +++ b/envconfig/config_test.go @@ -21,22 +21,6 @@ func TestSmoke(t *testing.T) { t.Setenv("OLLAMA_FLASH_ATTENTION", "1") require.True(t, FlashAttention()) - - t.Setenv("OLLAMA_KEEP_ALIVE", "") - LoadConfig() - require.Equal(t, 5*time.Minute, KeepAlive) - t.Setenv("OLLAMA_KEEP_ALIVE", "3") - LoadConfig() - require.Equal(t, 3*time.Second, KeepAlive) - t.Setenv("OLLAMA_KEEP_ALIVE", "1h") - LoadConfig() - require.Equal(t, 1*time.Hour, KeepAlive) - t.Setenv("OLLAMA_KEEP_ALIVE", "-1s") - LoadConfig() - require.Equal(t, time.Duration(math.MaxInt64), KeepAlive) - t.Setenv("OLLAMA_KEEP_ALIVE", "-1") - LoadConfig() - require.Equal(t, time.Duration(math.MaxInt64), KeepAlive) } func TestHost(t *testing.T) { @@ -186,3 +170,36 @@ func TestBool(t *testing.T) { }) } } + +func TestKeepAlive(t *testing.T) { + cases := map[string]time.Duration{ + "": 5 * time.Minute, + "1s": time.Second, + "1m": time.Minute, + "1h": time.Hour, + "5m0s": 5 * time.Minute, + "1h2m3s": 1*time.Hour + 2*time.Minute + 3*time.Second, + "0": time.Duration(0), + "60": 60 * time.Second, + "120": 2 * time.Minute, + "3600": time.Hour, + "-0": time.Duration(0), + "-1": time.Duration(math.MaxInt64), + "-1m": time.Duration(math.MaxInt64), + // invalid values + " ": 5 * time.Minute, + "???": 5 * time.Minute, + "1d": 5 * time.Minute, + "1y": 5 * time.Minute, + "1w": 5 * time.Minute, + } + + for tt, expect := range cases { + t.Run(tt, func(t *testing.T) { + t.Setenv("OLLAMA_KEEP_ALIVE", tt) + if actual := KeepAlive(); actual != expect { + t.Errorf("%s: expected %s, got %s", tt, expect, actual) + } + }) + } +} diff --git a/server/sched.go b/server/sched.go index e1e986a5..ad40c4ef 100644 --- a/server/sched.go +++ b/server/sched.go @@ -401,7 +401,7 @@ func (s *Scheduler) load(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList, if numParallel < 1 { numParallel = 1 } - sessionDuration := envconfig.KeepAlive + sessionDuration := envconfig.KeepAlive() if req.sessionDuration != nil { sessionDuration = req.sessionDuration.Duration } From e2c3f6b3e2de014656ab9ddffccf7b89d1bcc09e Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Wed, 3 Jul 2024 19:30:19 -0700 Subject: [PATCH 07/11] string --- envconfig/config.go | 143 ++++++++++++++++++++++---------------------- gpu/amd_linux.go | 8 +-- gpu/amd_windows.go | 2 +- gpu/assets.go | 6 +- gpu/gpu.go | 8 +-- llm/server.go | 2 +- 6 files changed, 85 insertions(+), 84 deletions(-) diff --git a/envconfig/config.go b/envconfig/config.go index 62bfad64..34cc4dac 100644 --- a/envconfig/config.go +++ b/envconfig/config.go @@ -149,30 +149,77 @@ var ( IntelGPU = Bool("OLLAMA_INTEL_GPU") ) +func String(s string) func() string { + return func() string { + return getenv(s) + } +} + +var ( + LLMLibrary = String("OLLAMA_LLM_LIBRARY") + TmpDir = String("OLLAMA_TMPDIR") + + CudaVisibleDevices = String("CUDA_VISIBLE_DEVICES") + HipVisibleDevices = String("HIP_VISIBLE_DEVICES") + RocrVisibleDevices = String("ROCR_VISIBLE_DEVICES") + GpuDeviceOrdinal = String("GPU_DEVICE_ORDINAL") + HsaOverrideGfxVersion = String("HSA_OVERRIDE_GFX_VERSION") +) + +func RunnersDir() (p string) { + if p := getenv("OLLAMA_RUNNERS_DIR"); p != "" { + return p + } + + if runtime.GOOS != "windows" { + return + } + + defer func() { + if p == "" { + slog.Error("unable to locate llm runner directory. Set OLLAMA_RUNNERS_DIR to the location of 'ollama_runners'") + } + }() + + // On Windows we do not carry the payloads inside the main executable + exe, err := os.Executable() + if err != nil { + return + } + + cwd, err := os.Getwd() + if err != nil { + return + } + + var paths []string + for _, root := range []string{filepath.Dir(exe), cwd} { + paths = append(paths, + root, + filepath.Join(root, "windows-"+runtime.GOARCH), + filepath.Join(root, "dist", "windows-"+runtime.GOARCH), + ) + } + + // Try a few variations to improve developer experience when building from source in the local tree + for _, path := range paths { + candidate := filepath.Join(path, "ollama_runners") + if _, err := os.Stat(candidate); err == nil { + p = candidate + break + } + } + + return p +} + var ( - // Set via OLLAMA_LLM_LIBRARY in the environment - LLMLibrary string // Set via OLLAMA_MAX_LOADED_MODELS in the environment MaxRunners int // Set via OLLAMA_MAX_QUEUE in the environment MaxQueuedRequests int // Set via OLLAMA_NUM_PARALLEL in the environment NumParallel int - // Set via OLLAMA_RUNNERS_DIR in the environment - RunnersDir string - // Set via OLLAMA_TMPDIR in the environment - TmpDir string - - // Set via CUDA_VISIBLE_DEVICES in the environment - CudaVisibleDevices string - // Set via HIP_VISIBLE_DEVICES in the environment - HipVisibleDevices string - // Set via ROCR_VISIBLE_DEVICES in the environment - RocrVisibleDevices string - // Set via GPU_DEVICE_ORDINAL in the environment - GpuDeviceOrdinal string - // Set via HSA_OVERRIDE_GFX_VERSION in the environment - HsaOverrideGfxVersion string ) type EnvVar struct { @@ -187,7 +234,7 @@ func AsMap() map[string]EnvVar { "OLLAMA_FLASH_ATTENTION": {"OLLAMA_FLASH_ATTENTION", FlashAttention(), "Enabled flash attention"}, "OLLAMA_HOST": {"OLLAMA_HOST", Host(), "IP Address for the ollama server (default 127.0.0.1:11434)"}, "OLLAMA_KEEP_ALIVE": {"OLLAMA_KEEP_ALIVE", KeepAlive(), "The duration that models stay loaded in memory (default \"5m\")"}, - "OLLAMA_LLM_LIBRARY": {"OLLAMA_LLM_LIBRARY", LLMLibrary, "Set LLM library to bypass autodetection"}, + "OLLAMA_LLM_LIBRARY": {"OLLAMA_LLM_LIBRARY", LLMLibrary(), "Set LLM library to bypass autodetection"}, "OLLAMA_MAX_LOADED_MODELS": {"OLLAMA_MAX_LOADED_MODELS", MaxRunners, "Maximum number of loaded models per GPU"}, "OLLAMA_MAX_QUEUE": {"OLLAMA_MAX_QUEUE", MaxQueuedRequests, "Maximum number of queued requests"}, "OLLAMA_MODELS": {"OLLAMA_MODELS", Models(), "The path to the models directory"}, @@ -195,16 +242,16 @@ func AsMap() map[string]EnvVar { "OLLAMA_NOPRUNE": {"OLLAMA_NOPRUNE", NoPrune(), "Do not prune model blobs on startup"}, "OLLAMA_NUM_PARALLEL": {"OLLAMA_NUM_PARALLEL", NumParallel, "Maximum number of parallel requests"}, "OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", Origins(), "A comma separated list of allowed origins"}, - "OLLAMA_RUNNERS_DIR": {"OLLAMA_RUNNERS_DIR", RunnersDir, "Location for runners"}, + "OLLAMA_RUNNERS_DIR": {"OLLAMA_RUNNERS_DIR", RunnersDir(), "Location for runners"}, "OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread(), "Always schedule model across all GPUs"}, - "OLLAMA_TMPDIR": {"OLLAMA_TMPDIR", TmpDir, "Location for temporary files"}, + "OLLAMA_TMPDIR": {"OLLAMA_TMPDIR", TmpDir(), "Location for temporary files"}, } if runtime.GOOS != "darwin" { - ret["CUDA_VISIBLE_DEVICES"] = EnvVar{"CUDA_VISIBLE_DEVICES", CudaVisibleDevices, "Set which NVIDIA devices are visible"} - ret["HIP_VISIBLE_DEVICES"] = EnvVar{"HIP_VISIBLE_DEVICES", HipVisibleDevices, "Set which AMD devices are visible"} - ret["ROCR_VISIBLE_DEVICES"] = EnvVar{"ROCR_VISIBLE_DEVICES", RocrVisibleDevices, "Set which AMD devices are visible"} - ret["GPU_DEVICE_ORDINAL"] = EnvVar{"GPU_DEVICE_ORDINAL", GpuDeviceOrdinal, "Set which AMD devices are visible"} - ret["HSA_OVERRIDE_GFX_VERSION"] = EnvVar{"HSA_OVERRIDE_GFX_VERSION", HsaOverrideGfxVersion, "Override the gfx used for all detected AMD GPUs"} + ret["CUDA_VISIBLE_DEVICES"] = EnvVar{"CUDA_VISIBLE_DEVICES", CudaVisibleDevices(), "Set which NVIDIA devices are visible"} + ret["HIP_VISIBLE_DEVICES"] = EnvVar{"HIP_VISIBLE_DEVICES", HipVisibleDevices(), "Set which AMD devices are visible"} + ret["ROCR_VISIBLE_DEVICES"] = EnvVar{"ROCR_VISIBLE_DEVICES", RocrVisibleDevices(), "Set which AMD devices are visible"} + ret["GPU_DEVICE_ORDINAL"] = EnvVar{"GPU_DEVICE_ORDINAL", GpuDeviceOrdinal(), "Set which AMD devices are visible"} + ret["HSA_OVERRIDE_GFX_VERSION"] = EnvVar{"HSA_OVERRIDE_GFX_VERSION", HsaOverrideGfxVersion(), "Override the gfx used for all detected AMD GPUs"} ret["OLLAMA_INTEL_GPU"] = EnvVar{"OLLAMA_INTEL_GPU", IntelGPU(), "Enable experimental Intel GPU detection"} } return ret @@ -233,46 +280,6 @@ func init() { } func LoadConfig() { - RunnersDir = getenv("OLLAMA_RUNNERS_DIR") - if runtime.GOOS == "windows" && RunnersDir == "" { - // On Windows we do not carry the payloads inside the main executable - appExe, err := os.Executable() - if err != nil { - slog.Error("failed to lookup executable path", "error", err) - } - - cwd, err := os.Getwd() - if err != nil { - slog.Error("failed to lookup working directory", "error", err) - } - - var paths []string - for _, root := range []string{filepath.Dir(appExe), cwd} { - paths = append(paths, - root, - filepath.Join(root, "windows-"+runtime.GOARCH), - filepath.Join(root, "dist", "windows-"+runtime.GOARCH), - ) - } - - // Try a few variations to improve developer experience when building from source in the local tree - for _, p := range paths { - candidate := filepath.Join(p, "ollama_runners") - _, err := os.Stat(candidate) - if err == nil { - RunnersDir = candidate - break - } - } - if RunnersDir == "" { - slog.Error("unable to locate llm runner directory. Set OLLAMA_RUNNERS_DIR to the location of 'ollama_runners'") - } - } - - TmpDir = getenv("OLLAMA_TMPDIR") - - LLMLibrary = getenv("OLLAMA_LLM_LIBRARY") - if onp := getenv("OLLAMA_NUM_PARALLEL"); onp != "" { val, err := strconv.Atoi(onp) if err != nil { @@ -300,10 +307,4 @@ func LoadConfig() { MaxQueuedRequests = p } } - - CudaVisibleDevices = getenv("CUDA_VISIBLE_DEVICES") - HipVisibleDevices = getenv("HIP_VISIBLE_DEVICES") - RocrVisibleDevices = getenv("ROCR_VISIBLE_DEVICES") - GpuDeviceOrdinal = getenv("GPU_DEVICE_ORDINAL") - HsaOverrideGfxVersion = getenv("HSA_OVERRIDE_GFX_VERSION") } diff --git a/gpu/amd_linux.go b/gpu/amd_linux.go index 15b6fc61..33dd03ab 100644 --- a/gpu/amd_linux.go +++ b/gpu/amd_linux.go @@ -60,9 +60,9 @@ func AMDGetGPUInfo() []RocmGPUInfo { // Determine if the user has already pre-selected which GPUs to look at, then ignore the others var visibleDevices []string - hipVD := envconfig.HipVisibleDevices // zero based index only - rocrVD := envconfig.RocrVisibleDevices // zero based index or UUID, but consumer cards seem to not support UUID - gpuDO := envconfig.GpuDeviceOrdinal // zero based index + hipVD := envconfig.HipVisibleDevices() // zero based index only + rocrVD := envconfig.RocrVisibleDevices() // zero based index or UUID, but consumer cards seem to not support UUID + gpuDO := envconfig.GpuDeviceOrdinal() // zero based index switch { // TODO is this priorty order right? case hipVD != "": @@ -75,7 +75,7 @@ func AMDGetGPUInfo() []RocmGPUInfo { visibleDevices = strings.Split(gpuDO, ",") } - gfxOverride := envconfig.HsaOverrideGfxVersion + gfxOverride := envconfig.HsaOverrideGfxVersion() var supported []string libDir := "" diff --git a/gpu/amd_windows.go b/gpu/amd_windows.go index 20aed447..a170dfdc 100644 --- a/gpu/amd_windows.go +++ b/gpu/amd_windows.go @@ -53,7 +53,7 @@ func AMDGetGPUInfo() []RocmGPUInfo { } var supported []string - gfxOverride := envconfig.HsaOverrideGfxVersion + gfxOverride := envconfig.HsaOverrideGfxVersion() if gfxOverride == "" { supported, err = GetSupportedGFX(libDir) if err != nil { diff --git a/gpu/assets.go b/gpu/assets.go index 073d2e81..39ff7c21 100644 --- a/gpu/assets.go +++ b/gpu/assets.go @@ -26,7 +26,7 @@ func PayloadsDir() (string, error) { defer lock.Unlock() var err error if payloadsDir == "" { - runnersDir := envconfig.RunnersDir + runnersDir := envconfig.RunnersDir() if runnersDir != "" { payloadsDir = runnersDir @@ -35,7 +35,7 @@ func PayloadsDir() (string, error) { // The remainder only applies on non-windows where we still carry payloads in the main executable cleanupTmpDirs() - tmpDir := envconfig.TmpDir + tmpDir := envconfig.TmpDir() if tmpDir == "" { tmpDir, err = os.MkdirTemp("", "ollama") if err != nil { @@ -105,7 +105,7 @@ func cleanupTmpDirs() { func Cleanup() { lock.Lock() defer lock.Unlock() - runnersDir := envconfig.RunnersDir + runnersDir := envconfig.RunnersDir() if payloadsDir != "" && runnersDir == "" && runtime.GOOS != "windows" { // We want to fully clean up the tmpdir parent of the payloads dir tmpDir := filepath.Clean(filepath.Join(payloadsDir, "..")) diff --git a/gpu/gpu.go b/gpu/gpu.go index c3059542..acab1c8d 100644 --- a/gpu/gpu.go +++ b/gpu/gpu.go @@ -230,8 +230,8 @@ func GetGPUInfo() GpuInfoList { // On windows we bundle the nvidia library one level above the runner dir depPath := "" - if runtime.GOOS == "windows" && envconfig.RunnersDir != "" { - depPath = filepath.Join(filepath.Dir(envconfig.RunnersDir), "cuda") + if runtime.GOOS == "windows" && envconfig.RunnersDir() != "" { + depPath = filepath.Join(filepath.Dir(envconfig.RunnersDir()), "cuda") } // Load ALL libraries @@ -306,8 +306,8 @@ func GetGPUInfo() GpuInfoList { oHandles = initOneAPIHandles() // On windows we bundle the oneapi library one level above the runner dir depPath = "" - if runtime.GOOS == "windows" && envconfig.RunnersDir != "" { - depPath = filepath.Join(filepath.Dir(envconfig.RunnersDir), "oneapi") + if runtime.GOOS == "windows" && envconfig.RunnersDir() != "" { + depPath = filepath.Join(filepath.Dir(envconfig.RunnersDir()), "oneapi") } for d := range oHandles.oneapi.num_drivers { diff --git a/llm/server.go b/llm/server.go index 84d9e93a..0741d386 100644 --- a/llm/server.go +++ b/llm/server.go @@ -163,7 +163,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr } else { servers = serversForGpu(gpus[0]) // All GPUs in the list are matching Library and Variant } - demandLib := envconfig.LLMLibrary + demandLib := envconfig.LLMLibrary() if demandLib != "" { serverPath := availableServers[demandLib] if serverPath == "" { From 0f1910129f0a73c469ce2c012d39c8d98b79ef80 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Wed, 3 Jul 2024 19:41:17 -0700 Subject: [PATCH 08/11] int --- envconfig/config.go | 66 ++++++++++------------------------- integration/basic_test.go | 9 +---- integration/max_queue_test.go | 14 ++++---- server/sched.go | 23 +++++++----- server/sched_test.go | 7 ++-- 5 files changed, 42 insertions(+), 77 deletions(-) diff --git a/envconfig/config.go b/envconfig/config.go index 34cc4dac..01abea42 100644 --- a/envconfig/config.go +++ b/envconfig/config.go @@ -213,13 +213,22 @@ func RunnersDir() (p string) { return p } +func Int(k string, n int) func() int { + return func() int { + if s := getenv(k); s != "" { + if n, err := strconv.ParseInt(s, 10, 64); err == nil && n >= 0 { + return int(n) + } + } + + return n + } +} + var ( - // Set via OLLAMA_MAX_LOADED_MODELS in the environment - MaxRunners int - // Set via OLLAMA_MAX_QUEUE in the environment - MaxQueuedRequests int - // Set via OLLAMA_NUM_PARALLEL in the environment - NumParallel int + NumParallel = Int("OLLAMA_NUM_PARALLEL", 0) + MaxRunners = Int("OLLAMA_MAX_LOADED_MODELS", 0) + MaxQueue = Int("OLLAMA_MAX_QUEUE", 512) ) type EnvVar struct { @@ -235,12 +244,12 @@ func AsMap() map[string]EnvVar { "OLLAMA_HOST": {"OLLAMA_HOST", Host(), "IP Address for the ollama server (default 127.0.0.1:11434)"}, "OLLAMA_KEEP_ALIVE": {"OLLAMA_KEEP_ALIVE", KeepAlive(), "The duration that models stay loaded in memory (default \"5m\")"}, "OLLAMA_LLM_LIBRARY": {"OLLAMA_LLM_LIBRARY", LLMLibrary(), "Set LLM library to bypass autodetection"}, - "OLLAMA_MAX_LOADED_MODELS": {"OLLAMA_MAX_LOADED_MODELS", MaxRunners, "Maximum number of loaded models per GPU"}, - "OLLAMA_MAX_QUEUE": {"OLLAMA_MAX_QUEUE", MaxQueuedRequests, "Maximum number of queued requests"}, + "OLLAMA_MAX_LOADED_MODELS": {"OLLAMA_MAX_LOADED_MODELS", MaxRunners(), "Maximum number of loaded models per GPU"}, + "OLLAMA_MAX_QUEUE": {"OLLAMA_MAX_QUEUE", MaxQueue(), "Maximum number of queued requests"}, "OLLAMA_MODELS": {"OLLAMA_MODELS", Models(), "The path to the models directory"}, "OLLAMA_NOHISTORY": {"OLLAMA_NOHISTORY", NoHistory(), "Do not preserve readline history"}, "OLLAMA_NOPRUNE": {"OLLAMA_NOPRUNE", NoPrune(), "Do not prune model blobs on startup"}, - "OLLAMA_NUM_PARALLEL": {"OLLAMA_NUM_PARALLEL", NumParallel, "Maximum number of parallel requests"}, + "OLLAMA_NUM_PARALLEL": {"OLLAMA_NUM_PARALLEL", NumParallel(), "Maximum number of parallel requests"}, "OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", Origins(), "A comma separated list of allowed origins"}, "OLLAMA_RUNNERS_DIR": {"OLLAMA_RUNNERS_DIR", RunnersDir(), "Location for runners"}, "OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread(), "Always schedule model across all GPUs"}, @@ -269,42 +278,3 @@ func Values() map[string]string { func getenv(key string) string { return strings.Trim(os.Getenv(key), "\"' ") } - -func init() { - // default values - NumParallel = 0 // Autoselect - MaxRunners = 0 // Autoselect - MaxQueuedRequests = 512 - - LoadConfig() -} - -func LoadConfig() { - if onp := getenv("OLLAMA_NUM_PARALLEL"); onp != "" { - val, err := strconv.Atoi(onp) - if err != nil { - slog.Error("invalid setting, ignoring", "OLLAMA_NUM_PARALLEL", onp, "error", err) - } else { - NumParallel = val - } - } - - maxRunners := getenv("OLLAMA_MAX_LOADED_MODELS") - if maxRunners != "" { - m, err := strconv.Atoi(maxRunners) - if err != nil { - slog.Error("invalid setting, ignoring", "OLLAMA_MAX_LOADED_MODELS", maxRunners, "error", err) - } else { - MaxRunners = m - } - } - - if onp := os.Getenv("OLLAMA_MAX_QUEUE"); onp != "" { - p, err := strconv.Atoi(onp) - if err != nil || p <= 0 { - slog.Error("invalid setting, ignoring", "OLLAMA_MAX_QUEUE", onp, "error", err) - } else { - MaxQueuedRequests = p - } - } -} diff --git a/integration/basic_test.go b/integration/basic_test.go index 6e632a1c..8e35b5c5 100644 --- a/integration/basic_test.go +++ b/integration/basic_test.go @@ -45,14 +45,7 @@ func TestUnicodeModelDir(t *testing.T) { defer os.RemoveAll(modelDir) slog.Info("unicode", "OLLAMA_MODELS", modelDir) - oldModelsDir := os.Getenv("OLLAMA_MODELS") - if oldModelsDir == "" { - defer os.Unsetenv("OLLAMA_MODELS") - } else { - defer os.Setenv("OLLAMA_MODELS", oldModelsDir) - } - err = os.Setenv("OLLAMA_MODELS", modelDir) - require.NoError(t, err) + t.Setenv("OLLAMA_MODELS", modelDir) ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) defer cancel() diff --git a/integration/max_queue_test.go b/integration/max_queue_test.go index dfa5eae0..b06197e1 100644 --- a/integration/max_queue_test.go +++ b/integration/max_queue_test.go @@ -5,7 +5,6 @@ package integration import ( "context" "errors" - "fmt" "log/slog" "os" "strconv" @@ -14,8 +13,10 @@ import ( "testing" "time" - "github.com/ollama/ollama/api" "github.com/stretchr/testify/require" + + "github.com/ollama/ollama/api" + "github.com/ollama/ollama/envconfig" ) func TestMaxQueue(t *testing.T) { @@ -27,13 +28,10 @@ func TestMaxQueue(t *testing.T) { // Note: This test can be quite slow when running in CPU mode, so keep the threadCount low unless your on GPU // Also note that by default Darwin can't sustain > ~128 connections without adjusting limits threadCount := 32 - mq := os.Getenv("OLLAMA_MAX_QUEUE") - if mq != "" { - var err error - threadCount, err = strconv.Atoi(mq) - require.NoError(t, err) + if maxQueue := envconfig.MaxQueue(); maxQueue != 0 { + threadCount = maxQueue } else { - os.Setenv("OLLAMA_MAX_QUEUE", fmt.Sprintf("%d", threadCount)) + t.Setenv("OLLAMA_MAX_QUEUE", strconv.Itoa(threadCount)) } req := api.GenerateRequest{ diff --git a/server/sched.go b/server/sched.go index ad40c4ef..610a2c50 100644 --- a/server/sched.go +++ b/server/sched.go @@ -5,9 +5,11 @@ import ( "errors" "fmt" "log/slog" + "os" "reflect" "runtime" "sort" + "strconv" "strings" "sync" "time" @@ -59,11 +61,12 @@ var defaultParallel = 4 var ErrMaxQueue = fmt.Errorf("server busy, please try again. maximum pending requests exceeded") func InitScheduler(ctx context.Context) *Scheduler { + maxQueue := envconfig.MaxQueue() sched := &Scheduler{ - pendingReqCh: make(chan *LlmRequest, envconfig.MaxQueuedRequests), - finishedReqCh: make(chan *LlmRequest, envconfig.MaxQueuedRequests), - expiredCh: make(chan *runnerRef, envconfig.MaxQueuedRequests), - unloadedCh: make(chan interface{}, envconfig.MaxQueuedRequests), + pendingReqCh: make(chan *LlmRequest, maxQueue), + finishedReqCh: make(chan *LlmRequest, maxQueue), + expiredCh: make(chan *runnerRef, maxQueue), + unloadedCh: make(chan interface{}, maxQueue), loaded: make(map[string]*runnerRef), newServerFn: llm.NewLlamaServer, getGpuFn: gpu.GetGPUInfo, @@ -126,7 +129,7 @@ func (s *Scheduler) processPending(ctx context.Context) { slog.Debug("pending request cancelled or timed out, skipping scheduling") continue } - numParallel := envconfig.NumParallel + numParallel := envconfig.NumParallel() // TODO (jmorganca): multimodal models don't support parallel yet // see https://github.com/ollama/ollama/issues/4165 if len(pending.model.ProjectorPaths) > 0 && numParallel != 1 { @@ -148,7 +151,7 @@ func (s *Scheduler) processPending(ctx context.Context) { pending.useLoadedRunner(runner, s.finishedReqCh) break } - } else if envconfig.MaxRunners > 0 && loadedCount >= envconfig.MaxRunners { + } else if envconfig.MaxRunners() > 0 && loadedCount >= envconfig.MaxRunners() { slog.Debug("max runners achieved, unloading one to make room", "runner_count", loadedCount) runnerToExpire = s.findRunnerToUnload() } else { @@ -161,7 +164,7 @@ func (s *Scheduler) processPending(ctx context.Context) { gpus = s.getGpuFn() } - if envconfig.MaxRunners <= 0 { + if envconfig.MaxRunners() <= 0 { // No user specified MaxRunners, so figure out what automatic setting to use // If all GPUs have reliable free memory reporting, defaultModelsPerGPU * the number of GPUs // if any GPU has unreliable free memory reporting, 1x the number of GPUs @@ -173,11 +176,13 @@ func (s *Scheduler) processPending(ctx context.Context) { } } if allReliable { - envconfig.MaxRunners = defaultModelsPerGPU * len(gpus) + // HACK + os.Setenv("OLLAMA_MAX_LOADED_MODELS", strconv.Itoa(defaultModelsPerGPU*len(gpus))) slog.Debug("updating default concurrency", "OLLAMA_MAX_LOADED_MODELS", envconfig.MaxRunners, "gpu_count", len(gpus)) } else { + // HACK + os.Setenv("OLLAMA_MAX_LOADED_MODELS", strconv.Itoa(len(gpus))) slog.Info("one or more GPUs detected that are unable to accurately report free memory - disabling default concurrency") - envconfig.MaxRunners = len(gpus) } } diff --git a/server/sched_test.go b/server/sched_test.go index 9ddd1fab..3166ff66 100644 --- a/server/sched_test.go +++ b/server/sched_test.go @@ -12,7 +12,6 @@ import ( "github.com/ollama/ollama/api" "github.com/ollama/ollama/app/lifecycle" - "github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/format" "github.com/ollama/ollama/gpu" "github.com/ollama/ollama/llm" @@ -272,7 +271,7 @@ func TestRequestsMultipleLoadedModels(t *testing.T) { c.req.opts.NumGPU = 0 // CPU load, will be allowed d := newScenarioRequest(t, ctx, "ollama-model-3c", 30, nil) // Needs prior unloaded - envconfig.MaxRunners = 1 + t.Setenv("OLLAMA_MAX_LOADED_MODELS", "1") s.newServerFn = a.newServer slog.Info("a") s.pendingReqCh <- a.req @@ -291,7 +290,7 @@ func TestRequestsMultipleLoadedModels(t *testing.T) { require.Len(t, s.loaded, 1) s.loadedMu.Unlock() - envconfig.MaxRunners = 0 + t.Setenv("OLLAMA_MAX_LOADED_MODELS", "0") s.newServerFn = b.newServer slog.Info("b") s.pendingReqCh <- b.req @@ -362,7 +361,7 @@ func TestGetRunner(t *testing.T) { a := newScenarioRequest(t, ctx, "ollama-model-1a", 10, &api.Duration{Duration: 2 * time.Millisecond}) b := newScenarioRequest(t, ctx, "ollama-model-1b", 10, &api.Duration{Duration: 2 * time.Millisecond}) c := newScenarioRequest(t, ctx, "ollama-model-1c", 10, &api.Duration{Duration: 2 * time.Millisecond}) - envconfig.MaxQueuedRequests = 1 + t.Setenv("OLLAMA_MAX_QUEUE", "1") s := InitScheduler(ctx) s.getGpuFn = getGpuFn s.getCpuFn = getCpuFn From 1954ec5917bf81ac743ba19bf0e7a6da47766778 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Wed, 3 Jul 2024 19:43:17 -0700 Subject: [PATCH 09/11] uint64 --- api/client_test.go | 3 -- integration/concurrency_test.go | 70 +++++++++++++++++---------------- server/manifest_test.go | 2 - server/modelpath_test.go | 3 -- server/routes_create_test.go | 10 ----- server/routes_delete_test.go | 2 - server/routes_list_test.go | 2 - server/routes_test.go | 4 -- 8 files changed, 37 insertions(+), 59 deletions(-) diff --git a/api/client_test.go b/api/client_test.go index fe9fd74f..23fe9334 100644 --- a/api/client_test.go +++ b/api/client_test.go @@ -2,8 +2,6 @@ package api import ( "testing" - - "github.com/ollama/ollama/envconfig" ) func TestClientFromEnvironment(t *testing.T) { @@ -33,7 +31,6 @@ func TestClientFromEnvironment(t *testing.T) { for k, v := range testCases { t.Run(k, func(t *testing.T) { t.Setenv("OLLAMA_HOST", v.value) - envconfig.LoadConfig() client, err := ClientFromEnvironment() if err != v.err { diff --git a/integration/concurrency_test.go b/integration/concurrency_test.go index 8593285b..81d0b587 100644 --- a/integration/concurrency_test.go +++ b/integration/concurrency_test.go @@ -5,14 +5,16 @@ package integration import ( "context" "log/slog" - "os" "strconv" "sync" "testing" "time" - "github.com/ollama/ollama/api" "github.com/stretchr/testify/require" + + "github.com/ollama/ollama/api" + "github.com/ollama/ollama/envconfig" + "github.com/ollama/ollama/format" ) func TestMultiModelConcurrency(t *testing.T) { @@ -106,13 +108,16 @@ func TestIntegrationConcurrentPredictOrcaMini(t *testing.T) { // Stress the system if we know how much VRAM it has, and attempt to load more models than will fit func TestMultiModelStress(t *testing.T) { - vram := os.Getenv("OLLAMA_MAX_VRAM") // TODO - discover actual VRAM - if vram == "" { + s := os.Getenv("OLLAMA_MAX_VRAM") // TODO - discover actual VRAM + if s == "" { t.Skip("OLLAMA_MAX_VRAM not specified, can't pick the right models for the stress test") } - max, err := strconv.ParseUint(vram, 10, 64) - require.NoError(t, err) - const MB = uint64(1024 * 1024) + + maxVram, err := strconv.ParseUint(s, 10, 64) + if err != nil { + t.Fatal(err) + } + type model struct { name string size uint64 // Approximate amount of VRAM they typically use when fully loaded in VRAM @@ -121,83 +126,82 @@ func TestMultiModelStress(t *testing.T) { smallModels := []model{ { name: "orca-mini", - size: 2992 * MB, + size: 2992 * format.MebiByte, }, { name: "phi", - size: 2616 * MB, + size: 2616 * format.MebiByte, }, { name: "gemma:2b", - size: 2364 * MB, + size: 2364 * format.MebiByte, }, { name: "stable-code:3b", - size: 2608 * MB, + size: 2608 * format.MebiByte, }, { name: "starcoder2:3b", - size: 2166 * MB, + size: 2166 * format.MebiByte, }, } mediumModels := []model{ { name: "llama2", - size: 5118 * MB, + size: 5118 * format.MebiByte, }, { name: "mistral", - size: 4620 * MB, + size: 4620 * format.MebiByte, }, { name: "orca-mini:7b", - size: 5118 * MB, + size: 5118 * format.MebiByte, }, { name: "dolphin-mistral", - size: 4620 * MB, + size: 4620 * format.MebiByte, }, { name: "gemma:7b", - size: 5000 * MB, + size: 5000 * format.MebiByte, + }, + { + name: "codellama:7b", + size: 5118 * format.MebiByte, }, - // TODO - uncomment this once #3565 is merged and this is rebased on it - // { - // name: "codellama:7b", - // size: 5118 * MB, - // }, } // These seem to be too slow to be useful... // largeModels := []model{ // { // name: "llama2:13b", - // size: 7400 * MB, + // size: 7400 * format.MebiByte, // }, // { // name: "codellama:13b", - // size: 7400 * MB, + // size: 7400 * format.MebiByte, // }, // { // name: "orca-mini:13b", - // size: 7400 * MB, + // size: 7400 * format.MebiByte, // }, // { // name: "gemma:7b", - // size: 5000 * MB, + // size: 5000 * format.MebiByte, // }, // { // name: "starcoder2:15b", - // size: 9100 * MB, + // size: 9100 * format.MebiByte, // }, // } var chosenModels []model switch { - case max < 10000*MB: + case maxVram < 10000*format.MebiByte: slog.Info("selecting small models") chosenModels = smallModels - // case max < 30000*MB: + // case maxVram < 30000*format.MebiByte: default: slog.Info("selecting medium models") chosenModels = mediumModels @@ -226,15 +230,15 @@ func TestMultiModelStress(t *testing.T) { } var wg sync.WaitGroup - consumed := uint64(256 * MB) // Assume some baseline usage + consumed := uint64(256 * format.MebiByte) // Assume some baseline usage for i := 0; i < len(req); i++ { // Always get at least 2 models, but dont' overshoot VRAM too much or we'll take too long - if i > 1 && consumed > max { - slog.Info("achieved target vram exhaustion", "count", i, "vramMB", max/1024/1024, "modelsMB", consumed/1024/1024) + if i > 1 && consumed > vram { + slog.Info("achieved target vram exhaustion", "count", i, "vram", format.HumanBytes2(vram), "models", format.HumanBytes2(consumed)) break } consumed += chosenModels[i].size - slog.Info("target vram", "count", i, "vramMB", max/1024/1024, "modelsMB", consumed/1024/1024) + slog.Info("target vram", "count", i, "vram", format.HumanBytes2(vram), "models", format.HumanBytes2(consumed)) wg.Add(1) go func(i int) { diff --git a/server/manifest_test.go b/server/manifest_test.go index ca6c3d2e..a4af5d5e 100644 --- a/server/manifest_test.go +++ b/server/manifest_test.go @@ -7,7 +7,6 @@ import ( "slices" "testing" - "github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/types/model" ) @@ -108,7 +107,6 @@ func TestManifests(t *testing.T) { t.Run(n, func(t *testing.T) { d := t.TempDir() t.Setenv("OLLAMA_MODELS", d) - envconfig.LoadConfig() for _, p := range wants.ps { createManifest(t, d, p) diff --git a/server/modelpath_test.go b/server/modelpath_test.go index 6c4dfbee..849e0fa7 100644 --- a/server/modelpath_test.go +++ b/server/modelpath_test.go @@ -7,8 +7,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - - "github.com/ollama/ollama/envconfig" ) func TestGetBlobsPath(t *testing.T) { @@ -63,7 +61,6 @@ func TestGetBlobsPath(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { t.Setenv("OLLAMA_MODELS", dir) - envconfig.LoadConfig() got, err := GetBlobsPath(tc.digest) diff --git a/server/routes_create_test.go b/server/routes_create_test.go index 3234ea5e..c853a9e9 100644 --- a/server/routes_create_test.go +++ b/server/routes_create_test.go @@ -15,7 +15,6 @@ import ( "github.com/gin-gonic/gin" "github.com/ollama/ollama/api" - "github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/llm" ) @@ -89,7 +88,6 @@ func TestCreateFromBin(t *testing.T) { p := t.TempDir() t.Setenv("OLLAMA_MODELS", p) - envconfig.LoadConfig() var s Server w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ @@ -117,7 +115,6 @@ func TestCreateFromModel(t *testing.T) { p := t.TempDir() t.Setenv("OLLAMA_MODELS", p) - envconfig.LoadConfig() var s Server w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ @@ -160,7 +157,6 @@ func TestCreateRemovesLayers(t *testing.T) { p := t.TempDir() t.Setenv("OLLAMA_MODELS", p) - envconfig.LoadConfig() var s Server w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ @@ -209,7 +205,6 @@ func TestCreateUnsetsSystem(t *testing.T) { p := t.TempDir() t.Setenv("OLLAMA_MODELS", p) - envconfig.LoadConfig() var s Server w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ @@ -267,7 +262,6 @@ func TestCreateMergeParameters(t *testing.T) { p := t.TempDir() t.Setenv("OLLAMA_MODELS", p) - envconfig.LoadConfig() var s Server w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ @@ -372,7 +366,6 @@ func TestCreateReplacesMessages(t *testing.T) { p := t.TempDir() t.Setenv("OLLAMA_MODELS", p) - envconfig.LoadConfig() var s Server w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ @@ -450,7 +443,6 @@ func TestCreateTemplateSystem(t *testing.T) { p := t.TempDir() t.Setenv("OLLAMA_MODELS", p) - envconfig.LoadConfig() var s Server w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ @@ -534,7 +526,6 @@ func TestCreateLicenses(t *testing.T) { p := t.TempDir() t.Setenv("OLLAMA_MODELS", p) - envconfig.LoadConfig() var s Server w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ @@ -582,7 +573,6 @@ func TestCreateDetectTemplate(t *testing.T) { p := t.TempDir() t.Setenv("OLLAMA_MODELS", p) - envconfig.LoadConfig() var s Server t.Run("matched", func(t *testing.T) { diff --git a/server/routes_delete_test.go b/server/routes_delete_test.go index 33a97a73..2354d730 100644 --- a/server/routes_delete_test.go +++ b/server/routes_delete_test.go @@ -10,7 +10,6 @@ import ( "github.com/gin-gonic/gin" "github.com/ollama/ollama/api" - "github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/types/model" ) @@ -19,7 +18,6 @@ func TestDelete(t *testing.T) { p := t.TempDir() t.Setenv("OLLAMA_MODELS", p) - envconfig.LoadConfig() var s Server diff --git a/server/routes_list_test.go b/server/routes_list_test.go index c2d9c113..29e3214c 100644 --- a/server/routes_list_test.go +++ b/server/routes_list_test.go @@ -9,14 +9,12 @@ import ( "github.com/gin-gonic/gin" "github.com/ollama/ollama/api" - "github.com/ollama/ollama/envconfig" ) func TestList(t *testing.T) { gin.SetMode(gin.TestMode) t.Setenv("OLLAMA_MODELS", t.TempDir()) - envconfig.LoadConfig() expectNames := []string{ "mistral:7b-instruct-q4_0", diff --git a/server/routes_test.go b/server/routes_test.go index 97786ba2..17da2305 100644 --- a/server/routes_test.go +++ b/server/routes_test.go @@ -19,7 +19,6 @@ import ( "github.com/stretchr/testify/require" "github.com/ollama/ollama/api" - "github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/llm" "github.com/ollama/ollama/openai" "github.com/ollama/ollama/parser" @@ -347,7 +346,6 @@ func Test_Routes(t *testing.T) { } t.Setenv("OLLAMA_MODELS", t.TempDir()) - envconfig.LoadConfig() s := &Server{} router := s.GenerateRoutes() @@ -378,7 +376,6 @@ func Test_Routes(t *testing.T) { func TestCase(t *testing.T) { t.Setenv("OLLAMA_MODELS", t.TempDir()) - envconfig.LoadConfig() cases := []string{ "mistral", @@ -458,7 +455,6 @@ func TestCase(t *testing.T) { func TestShow(t *testing.T) { t.Setenv("OLLAMA_MODELS", t.TempDir()) - envconfig.LoadConfig() var s Server From 78140a712ce8feac6fad2ae2c0043056f1a47fdc Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Fri, 5 Jul 2024 16:52:01 -0700 Subject: [PATCH 10/11] cleanup tests --- envconfig/config_test.go | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/envconfig/config_test.go b/envconfig/config_test.go index 87c808ca..977298aa 100644 --- a/envconfig/config_test.go +++ b/envconfig/config_test.go @@ -6,23 +6,8 @@ import ( "time" "github.com/google/go-cmp/cmp" - "github.com/stretchr/testify/require" ) -func TestSmoke(t *testing.T) { - t.Setenv("OLLAMA_DEBUG", "") - require.False(t, Debug()) - - t.Setenv("OLLAMA_DEBUG", "false") - require.False(t, Debug()) - - t.Setenv("OLLAMA_DEBUG", "1") - require.True(t, Debug()) - - t.Setenv("OLLAMA_FLASH_ATTENTION", "1") - require.True(t, FlashAttention()) -} - func TestHost(t *testing.T) { cases := map[string]struct { value string From 85d9d73a7253fce232208a2355113c8ae6d69353 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Mon, 8 Jul 2024 10:34:12 -0700 Subject: [PATCH 11/11] comments --- envconfig/config.go | 50 ++++++++++++++------------ envconfig/config_test.go | 77 +++++++++++++++++++++++++++++++--------- server/sched.go | 4 +-- 3 files changed, 90 insertions(+), 41 deletions(-) diff --git a/envconfig/config.go b/envconfig/config.go index 01abea42..b82b773d 100644 --- a/envconfig/config.go +++ b/envconfig/config.go @@ -1,7 +1,6 @@ package envconfig import ( - "errors" "fmt" "log/slog" "math" @@ -15,15 +14,12 @@ import ( "time" ) -var ErrInvalidHostPort = errors.New("invalid port specified in OLLAMA_HOST") - // Host returns the scheme and host. Host can be configured via the OLLAMA_HOST environment variable. // Default is scheme "http" and host "127.0.0.1:11434" func Host() *url.URL { defaultPort := "11434" - s := os.Getenv("OLLAMA_HOST") - s = strings.TrimSpace(strings.Trim(strings.TrimSpace(s), "\"'")) + s := strings.TrimSpace(Var("OLLAMA_HOST")) scheme, hostport, ok := strings.Cut(s, "://") switch { case !ok: @@ -48,6 +44,7 @@ func Host() *url.URL { } if n, err := strconv.ParseInt(port, 10, 32); err != nil || n > 65535 || n < 0 { + slog.Warn("invalid port, using default", "port", port, "default", defaultPort) return &url.URL{ Scheme: scheme, Host: net.JoinHostPort(host, defaultPort), @@ -62,7 +59,7 @@ func Host() *url.URL { // Origins returns a list of allowed origins. Origins can be configured via the OLLAMA_ORIGINS environment variable. func Origins() (origins []string) { - if s := getenv("OLLAMA_ORIGINS"); s != "" { + if s := Var("OLLAMA_ORIGINS"); s != "" { origins = strings.Split(s, ",") } @@ -87,7 +84,7 @@ func Origins() (origins []string) { // Models returns the path to the models directory. Models directory can be configured via the OLLAMA_MODELS environment variable. // Default is $HOME/.ollama/models func Models() string { - if s, ok := os.LookupEnv("OLLAMA_MODELS"); ok { + if s := Var("OLLAMA_MODELS"); s != "" { return s } @@ -104,7 +101,7 @@ func Models() string { // Default is 5 minutes. func KeepAlive() (keepAlive time.Duration) { keepAlive = 5 * time.Minute - if s := os.Getenv("OLLAMA_KEEP_ALIVE"); s != "" { + if s := Var("OLLAMA_KEEP_ALIVE"); s != "" { if d, err := time.ParseDuration(s); err == nil { keepAlive = d } else if n, err := strconv.ParseInt(s, 10, 64); err == nil { @@ -121,7 +118,7 @@ func KeepAlive() (keepAlive time.Duration) { func Bool(k string) func() bool { return func() bool { - if s := getenv(k); s != "" { + if s := Var(k); s != "" { b, err := strconv.ParseBool(s) if err != nil { return true @@ -151,7 +148,7 @@ var ( func String(s string) func() string { return func() string { - return getenv(s) + return Var(s) } } @@ -167,7 +164,7 @@ var ( ) func RunnersDir() (p string) { - if p := getenv("OLLAMA_RUNNERS_DIR"); p != "" { + if p := Var("OLLAMA_RUNNERS_DIR"); p != "" { return p } @@ -213,22 +210,29 @@ func RunnersDir() (p string) { return p } -func Int(k string, n int) func() int { - return func() int { - if s := getenv(k); s != "" { - if n, err := strconv.ParseInt(s, 10, 64); err == nil && n >= 0 { - return int(n) +func Uint(key string, defaultValue uint) func() uint { + return func() uint { + if s := Var(key); s != "" { + if n, err := strconv.ParseUint(s, 10, 64); err != nil { + slog.Warn("invalid environment variable, using default", "key", key, "value", s, "default", defaultValue) + } else { + return uint(n) } } - return n + return defaultValue } } var ( - NumParallel = Int("OLLAMA_NUM_PARALLEL", 0) - MaxRunners = Int("OLLAMA_MAX_LOADED_MODELS", 0) - MaxQueue = Int("OLLAMA_MAX_QUEUE", 512) + // NumParallel sets the number of parallel model requests. NumParallel can be configured via the OLLAMA_NUM_PARALLEL environment variable. + NumParallel = Uint("OLLAMA_NUM_PARALLEL", 0) + // MaxRunners sets the maximum number of loaded models. MaxRunners can be configured via the OLLAMA_MAX_LOADED_MODELS environment variable. + MaxRunners = Uint("OLLAMA_MAX_LOADED_MODELS", 0) + // MaxQueue sets the maximum number of queued requests. MaxQueue can be configured via the OLLAMA_MAX_QUEUE environment variable. + MaxQueue = Uint("OLLAMA_MAX_QUEUE", 512) + // MaxVRAM sets a maximum VRAM override in bytes. MaxVRAM can be configured via the OLLAMA_MAX_VRAM environment variable. + MaxVRAM = Uint("OLLAMA_MAX_VRAM", 0) ) type EnvVar struct { @@ -274,7 +278,7 @@ func Values() map[string]string { return vals } -// getenv returns an environment variable stripped of leading and trailing quotes or spaces -func getenv(key string) string { - return strings.Trim(os.Getenv(key), "\"' ") +// Var returns an environment variable stripped of leading and trailing quotes or spaces +func Var(key string) string { + return strings.Trim(strings.TrimSpace(os.Getenv(key)), "\"'") } diff --git a/envconfig/config_test.go b/envconfig/config_test.go index 977298aa..92a500f1 100644 --- a/envconfig/config_test.go +++ b/envconfig/config_test.go @@ -30,6 +30,10 @@ func TestHost(t *testing.T) { "extra quotes": {"\"1.2.3.4\"", "1.2.3.4:11434"}, "extra space+quotes": {" \" 1.2.3.4 \" ", "1.2.3.4:11434"}, "extra single quotes": {"'1.2.3.4'", "1.2.3.4:11434"}, + "http": {"http://1.2.3.4", "1.2.3.4:80"}, + "http port": {"http://1.2.3.4:4321", "1.2.3.4:4321"}, + "https": {"https://1.2.3.4", "1.2.3.4:443"}, + "https port": {"https://1.2.3.4:4321", "1.2.3.4:4321"}, } for name, tt := range cases { @@ -133,24 +137,45 @@ func TestOrigins(t *testing.T) { } func TestBool(t *testing.T) { - cases := map[string]struct { - value string - expect bool - }{ - "empty": {"", false}, - "true": {"true", true}, - "false": {"false", false}, - "1": {"1", true}, - "0": {"0", false}, - "random": {"random", true}, - "something": {"something", true}, + cases := map[string]bool{ + "": false, + "true": true, + "false": false, + "1": true, + "0": false, + // invalid values + "random": true, + "something": true, } - for name, tt := range cases { - t.Run(name, func(t *testing.T) { - t.Setenv("OLLAMA_BOOL", tt.value) - if b := Bool("OLLAMA_BOOL"); b() != tt.expect { - t.Errorf("%s: expected %t, got %t", name, tt.expect, b()) + for k, v := range cases { + t.Run(k, func(t *testing.T) { + t.Setenv("OLLAMA_BOOL", k) + if b := Bool("OLLAMA_BOOL")(); b != v { + t.Errorf("%s: expected %t, got %t", k, v, b) + } + }) + } +} + +func TestUint(t *testing.T) { + cases := map[string]uint{ + "0": 0, + "1": 1, + "1337": 1337, + // default values + "": 11434, + "-1": 11434, + "0o10": 11434, + "0x10": 11434, + "string": 11434, + } + + for k, v := range cases { + t.Run(k, func(t *testing.T) { + t.Setenv("OLLAMA_UINT", k) + if i := Uint("OLLAMA_UINT", 11434)(); i != v { + t.Errorf("%s: expected %d, got %d", k, v, i) } }) } @@ -188,3 +213,23 @@ func TestKeepAlive(t *testing.T) { }) } } + +func TestVar(t *testing.T) { + cases := map[string]string{ + "value": "value", + " value ": "value", + " 'value' ": "value", + ` "value" `: "value", + " ' value ' ": " value ", + ` " value " `: " value ", + } + + for k, v := range cases { + t.Run(k, func(t *testing.T) { + t.Setenv("OLLAMA_VAR", k) + if s := Var("OLLAMA_VAR"); s != v { + t.Errorf("%s: expected %q, got %q", k, v, s) + } + }) + } +} diff --git a/server/sched.go b/server/sched.go index 610a2c50..ce2945d8 100644 --- a/server/sched.go +++ b/server/sched.go @@ -129,7 +129,7 @@ func (s *Scheduler) processPending(ctx context.Context) { slog.Debug("pending request cancelled or timed out, skipping scheduling") continue } - numParallel := envconfig.NumParallel() + numParallel := int(envconfig.NumParallel()) // TODO (jmorganca): multimodal models don't support parallel yet // see https://github.com/ollama/ollama/issues/4165 if len(pending.model.ProjectorPaths) > 0 && numParallel != 1 { @@ -151,7 +151,7 @@ func (s *Scheduler) processPending(ctx context.Context) { pending.useLoadedRunner(runner, s.finishedReqCh) break } - } else if envconfig.MaxRunners() > 0 && loadedCount >= envconfig.MaxRunners() { + } else if envconfig.MaxRunners() > 0 && loadedCount >= int(envconfig.MaxRunners()) { slog.Debug("max runners achieved, unloading one to make room", "runner_count", loadedCount) runnerToExpire = s.findRunnerToUnload() } else {