diff --git a/api/client.go b/api/client.go index c59fbc42..e02b21bf 100644 --- a/api/client.go +++ b/api/client.go @@ -20,7 +20,6 @@ import ( "encoding/json" "fmt" "io" - "net" "net/http" "net/url" "runtime" @@ -63,13 +62,8 @@ func checkError(resp *http.Response, body []byte) error { // If the variable is not specified, a default ollama host and port will be // used. func ClientFromEnvironment() (*Client, error) { - ollamaHost := envconfig.Host - return &Client{ - base: &url.URL{ - Scheme: ollamaHost.Scheme, - Host: net.JoinHostPort(ollamaHost.Host, ollamaHost.Port), - }, + base: envconfig.Host(), http: http.DefaultClient, }, nil } diff --git a/api/client_test.go b/api/client_test.go index fe9fd74f..23fe9334 100644 --- a/api/client_test.go +++ b/api/client_test.go @@ -2,8 +2,6 @@ package api import ( "testing" - - "github.com/ollama/ollama/envconfig" ) func TestClientFromEnvironment(t *testing.T) { @@ -33,7 +31,6 @@ func TestClientFromEnvironment(t *testing.T) { for k, v := range testCases { t.Run(k, func(t *testing.T) { t.Setenv("OLLAMA_HOST", v.value) - envconfig.LoadConfig() client, err := ClientFromEnvironment() if err != v.err { diff --git a/app/lifecycle/logging.go b/app/lifecycle/logging.go index a8f1f7cd..3672aad5 100644 --- a/app/lifecycle/logging.go +++ b/app/lifecycle/logging.go @@ -14,7 +14,7 @@ import ( func InitLogging() { level := slog.LevelInfo - if envconfig.Debug { + if envconfig.Debug() { level = slog.LevelDebug } diff --git a/cmd/cmd.go b/cmd/cmd.go index 610fddcb..86910bf0 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -1076,7 +1076,7 @@ func RunServer(cmd *cobra.Command, _ []string) error { return err } - ln, err := net.Listen("tcp", net.JoinHostPort(envconfig.Host.Host, envconfig.Host.Port)) + ln, err := net.Listen("tcp", envconfig.Host().Host) if err != nil { return err } diff --git a/cmd/interactive.go b/cmd/interactive.go index 70afc6ea..e1b81753 100644 --- a/cmd/interactive.go +++ b/cmd/interactive.go @@ -160,7 +160,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error { return err } - if envconfig.NoHistory { + if envconfig.NoHistory() { scanner.HistoryDisable() } diff --git a/envconfig/config.go b/envconfig/config.go index 0abc6968..b82b773d 100644 --- a/envconfig/config.go +++ b/envconfig/config.go @@ -1,11 +1,11 @@ package envconfig import ( - "errors" "fmt" "log/slog" "math" "net" + "net/url" "os" "path/filepath" "runtime" @@ -14,296 +14,16 @@ import ( "time" ) -type OllamaHost struct { - Scheme string - Host string - Port string -} - -func (o OllamaHost) String() string { - return fmt.Sprintf("%s://%s:%s", o.Scheme, o.Host, o.Port) -} - -var ErrInvalidHostPort = errors.New("invalid port specified in OLLAMA_HOST") - -var ( - // Set via OLLAMA_ORIGINS in the environment - AllowOrigins []string - // Set via OLLAMA_DEBUG in the environment - Debug bool - // Experimental flash attention - FlashAttention bool - // Set via OLLAMA_HOST in the environment - Host *OllamaHost - // Set via OLLAMA_KEEP_ALIVE in the environment - KeepAlive time.Duration - // Set via OLLAMA_LLM_LIBRARY in the environment - LLMLibrary string - // Set via OLLAMA_MAX_LOADED_MODELS in the environment - MaxRunners int - // Set via OLLAMA_MAX_QUEUE in the environment - MaxQueuedRequests int - // Set via OLLAMA_MODELS in the environment - ModelsDir string - // Set via OLLAMA_NOHISTORY in the environment - NoHistory bool - // Set via OLLAMA_NOPRUNE in the environment - NoPrune bool - // Set via OLLAMA_NUM_PARALLEL in the environment - NumParallel int - // Set via OLLAMA_RUNNERS_DIR in the environment - RunnersDir string - // Set via OLLAMA_SCHED_SPREAD in the environment - SchedSpread bool - // Set via OLLAMA_TMPDIR in the environment - TmpDir string - // Set via OLLAMA_INTEL_GPU in the environment - IntelGpu bool - - // Set via CUDA_VISIBLE_DEVICES in the environment - CudaVisibleDevices string - // Set via HIP_VISIBLE_DEVICES in the environment - HipVisibleDevices string - // Set via ROCR_VISIBLE_DEVICES in the environment - RocrVisibleDevices string - // Set via GPU_DEVICE_ORDINAL in the environment - GpuDeviceOrdinal string - // Set via HSA_OVERRIDE_GFX_VERSION in the environment - HsaOverrideGfxVersion string -) - -type EnvVar struct { - Name string - Value any - Description string -} - -func AsMap() map[string]EnvVar { - ret := map[string]EnvVar{ - "OLLAMA_DEBUG": {"OLLAMA_DEBUG", Debug, "Show additional debug information (e.g. OLLAMA_DEBUG=1)"}, - "OLLAMA_FLASH_ATTENTION": {"OLLAMA_FLASH_ATTENTION", FlashAttention, "Enabled flash attention"}, - "OLLAMA_HOST": {"OLLAMA_HOST", Host, "IP Address for the ollama server (default 127.0.0.1:11434)"}, - "OLLAMA_KEEP_ALIVE": {"OLLAMA_KEEP_ALIVE", KeepAlive, "The duration that models stay loaded in memory (default \"5m\")"}, - "OLLAMA_LLM_LIBRARY": {"OLLAMA_LLM_LIBRARY", LLMLibrary, "Set LLM library to bypass autodetection"}, - "OLLAMA_MAX_LOADED_MODELS": {"OLLAMA_MAX_LOADED_MODELS", MaxRunners, "Maximum number of loaded models per GPU"}, - "OLLAMA_MAX_QUEUE": {"OLLAMA_MAX_QUEUE", MaxQueuedRequests, "Maximum number of queued requests"}, - "OLLAMA_MODELS": {"OLLAMA_MODELS", ModelsDir, "The path to the models directory"}, - "OLLAMA_NOHISTORY": {"OLLAMA_NOHISTORY", NoHistory, "Do not preserve readline history"}, - "OLLAMA_NOPRUNE": {"OLLAMA_NOPRUNE", NoPrune, "Do not prune model blobs on startup"}, - "OLLAMA_NUM_PARALLEL": {"OLLAMA_NUM_PARALLEL", NumParallel, "Maximum number of parallel requests"}, - "OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", AllowOrigins, "A comma separated list of allowed origins"}, - "OLLAMA_RUNNERS_DIR": {"OLLAMA_RUNNERS_DIR", RunnersDir, "Location for runners"}, - "OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread, "Always schedule model across all GPUs"}, - "OLLAMA_TMPDIR": {"OLLAMA_TMPDIR", TmpDir, "Location for temporary files"}, - } - if runtime.GOOS != "darwin" { - ret["CUDA_VISIBLE_DEVICES"] = EnvVar{"CUDA_VISIBLE_DEVICES", CudaVisibleDevices, "Set which NVIDIA devices are visible"} - ret["HIP_VISIBLE_DEVICES"] = EnvVar{"HIP_VISIBLE_DEVICES", HipVisibleDevices, "Set which AMD devices are visible"} - ret["ROCR_VISIBLE_DEVICES"] = EnvVar{"ROCR_VISIBLE_DEVICES", RocrVisibleDevices, "Set which AMD devices are visible"} - ret["GPU_DEVICE_ORDINAL"] = EnvVar{"GPU_DEVICE_ORDINAL", GpuDeviceOrdinal, "Set which AMD devices are visible"} - ret["HSA_OVERRIDE_GFX_VERSION"] = EnvVar{"HSA_OVERRIDE_GFX_VERSION", HsaOverrideGfxVersion, "Override the gfx used for all detected AMD GPUs"} - ret["OLLAMA_INTEL_GPU"] = EnvVar{"OLLAMA_INTEL_GPU", IntelGpu, "Enable experimental Intel GPU detection"} - } - return ret -} - -func Values() map[string]string { - vals := make(map[string]string) - for k, v := range AsMap() { - vals[k] = fmt.Sprintf("%v", v.Value) - } - return vals -} - -var defaultAllowOrigins = []string{ - "localhost", - "127.0.0.1", - "0.0.0.0", -} - -// Clean quotes and spaces from the value -func clean(key string) string { - return strings.Trim(os.Getenv(key), "\"' ") -} - -func init() { - // default values - NumParallel = 0 // Autoselect - MaxRunners = 0 // Autoselect - MaxQueuedRequests = 512 - KeepAlive = 5 * time.Minute - - LoadConfig() -} - -func LoadConfig() { - if debug := clean("OLLAMA_DEBUG"); debug != "" { - d, err := strconv.ParseBool(debug) - if err == nil { - Debug = d - } else { - Debug = true - } - } - - if fa := clean("OLLAMA_FLASH_ATTENTION"); fa != "" { - d, err := strconv.ParseBool(fa) - if err == nil { - FlashAttention = d - } - } - - RunnersDir = clean("OLLAMA_RUNNERS_DIR") - if runtime.GOOS == "windows" && RunnersDir == "" { - // On Windows we do not carry the payloads inside the main executable - appExe, err := os.Executable() - if err != nil { - slog.Error("failed to lookup executable path", "error", err) - } - - cwd, err := os.Getwd() - if err != nil { - slog.Error("failed to lookup working directory", "error", err) - } - - var paths []string - for _, root := range []string{filepath.Dir(appExe), cwd} { - paths = append(paths, - root, - filepath.Join(root, "windows-"+runtime.GOARCH), - filepath.Join(root, "dist", "windows-"+runtime.GOARCH), - ) - } - - // Try a few variations to improve developer experience when building from source in the local tree - for _, p := range paths { - candidate := filepath.Join(p, "ollama_runners") - _, err := os.Stat(candidate) - if err == nil { - RunnersDir = candidate - break - } - } - if RunnersDir == "" { - slog.Error("unable to locate llm runner directory. Set OLLAMA_RUNNERS_DIR to the location of 'ollama_runners'") - } - } - - TmpDir = clean("OLLAMA_TMPDIR") - - LLMLibrary = clean("OLLAMA_LLM_LIBRARY") - - if onp := clean("OLLAMA_NUM_PARALLEL"); onp != "" { - val, err := strconv.Atoi(onp) - if err != nil { - slog.Error("invalid setting, ignoring", "OLLAMA_NUM_PARALLEL", onp, "error", err) - } else { - NumParallel = val - } - } - - if nohistory := clean("OLLAMA_NOHISTORY"); nohistory != "" { - NoHistory = true - } - - if spread := clean("OLLAMA_SCHED_SPREAD"); spread != "" { - s, err := strconv.ParseBool(spread) - if err == nil { - SchedSpread = s - } else { - SchedSpread = true - } - } - - if noprune := clean("OLLAMA_NOPRUNE"); noprune != "" { - NoPrune = true - } - - if origins := clean("OLLAMA_ORIGINS"); origins != "" { - AllowOrigins = strings.Split(origins, ",") - } - for _, allowOrigin := range defaultAllowOrigins { - AllowOrigins = append(AllowOrigins, - fmt.Sprintf("http://%s", allowOrigin), - fmt.Sprintf("https://%s", allowOrigin), - fmt.Sprintf("http://%s", net.JoinHostPort(allowOrigin, "*")), - fmt.Sprintf("https://%s", net.JoinHostPort(allowOrigin, "*")), - ) - } - - AllowOrigins = append(AllowOrigins, - "app://*", - "file://*", - "tauri://*", - ) - - maxRunners := clean("OLLAMA_MAX_LOADED_MODELS") - if maxRunners != "" { - m, err := strconv.Atoi(maxRunners) - if err != nil { - slog.Error("invalid setting, ignoring", "OLLAMA_MAX_LOADED_MODELS", maxRunners, "error", err) - } else { - MaxRunners = m - } - } - - if onp := os.Getenv("OLLAMA_MAX_QUEUE"); onp != "" { - p, err := strconv.Atoi(onp) - if err != nil || p <= 0 { - slog.Error("invalid setting, ignoring", "OLLAMA_MAX_QUEUE", onp, "error", err) - } else { - MaxQueuedRequests = p - } - } - - ka := clean("OLLAMA_KEEP_ALIVE") - if ka != "" { - loadKeepAlive(ka) - } - - var err error - ModelsDir, err = getModelsDir() - if err != nil { - slog.Error("invalid setting", "OLLAMA_MODELS", ModelsDir, "error", err) - } - - Host, err = getOllamaHost() - if err != nil { - slog.Error("invalid setting", "OLLAMA_HOST", Host, "error", err, "using default port", Host.Port) - } - - if set, err := strconv.ParseBool(clean("OLLAMA_INTEL_GPU")); err == nil { - IntelGpu = set - } - - CudaVisibleDevices = clean("CUDA_VISIBLE_DEVICES") - HipVisibleDevices = clean("HIP_VISIBLE_DEVICES") - RocrVisibleDevices = clean("ROCR_VISIBLE_DEVICES") - GpuDeviceOrdinal = clean("GPU_DEVICE_ORDINAL") - HsaOverrideGfxVersion = clean("HSA_OVERRIDE_GFX_VERSION") -} - -func getModelsDir() (string, error) { - if models, exists := os.LookupEnv("OLLAMA_MODELS"); exists { - return models, nil - } - home, err := os.UserHomeDir() - if err != nil { - return "", err - } - return filepath.Join(home, ".ollama", "models"), nil -} - -func getOllamaHost() (*OllamaHost, error) { +// Host returns the scheme and host. Host can be configured via the OLLAMA_HOST environment variable. +// Default is scheme "http" and host "127.0.0.1:11434" +func Host() *url.URL { defaultPort := "11434" - hostVar := os.Getenv("OLLAMA_HOST") - hostVar = strings.TrimSpace(strings.Trim(strings.TrimSpace(hostVar), "\"'")) - - scheme, hostport, ok := strings.Cut(hostVar, "://") + s := strings.TrimSpace(Var("OLLAMA_HOST")) + scheme, hostport, ok := strings.Cut(s, "://") switch { case !ok: - scheme, hostport = "http", hostVar + scheme, hostport = "http", s case scheme == "http": defaultPort = "80" case scheme == "https": @@ -323,38 +43,242 @@ func getOllamaHost() (*OllamaHost, error) { } } - if portNum, err := strconv.ParseInt(port, 10, 32); err != nil || portNum > 65535 || portNum < 0 { - return &OllamaHost{ + if n, err := strconv.ParseInt(port, 10, 32); err != nil || n > 65535 || n < 0 { + slog.Warn("invalid port, using default", "port", port, "default", defaultPort) + return &url.URL{ Scheme: scheme, - Host: host, - Port: defaultPort, - }, ErrInvalidHostPort + Host: net.JoinHostPort(host, defaultPort), + } } - return &OllamaHost{ + return &url.URL{ Scheme: scheme, - Host: host, - Port: port, - }, nil + Host: net.JoinHostPort(host, port), + } } -func loadKeepAlive(ka string) { - v, err := strconv.Atoi(ka) +// Origins returns a list of allowed origins. Origins can be configured via the OLLAMA_ORIGINS environment variable. +func Origins() (origins []string) { + if s := Var("OLLAMA_ORIGINS"); s != "" { + origins = strings.Split(s, ",") + } + + for _, origin := range []string{"localhost", "127.0.0.1", "0.0.0.0"} { + origins = append(origins, + fmt.Sprintf("http://%s", origin), + fmt.Sprintf("https://%s", origin), + fmt.Sprintf("http://%s", net.JoinHostPort(origin, "*")), + fmt.Sprintf("https://%s", net.JoinHostPort(origin, "*")), + ) + } + + origins = append(origins, + "app://*", + "file://*", + "tauri://*", + ) + + return origins +} + +// Models returns the path to the models directory. Models directory can be configured via the OLLAMA_MODELS environment variable. +// Default is $HOME/.ollama/models +func Models() string { + if s := Var("OLLAMA_MODELS"); s != "" { + return s + } + + home, err := os.UserHomeDir() if err != nil { - d, err := time.ParseDuration(ka) - if err == nil { - if d < 0 { - KeepAlive = time.Duration(math.MaxInt64) + panic(err) + } + + return filepath.Join(home, ".ollama", "models") +} + +// KeepAlive returns the duration that models stay loaded in memory. KeepAlive can be configured via the OLLAMA_KEEP_ALIVE environment variable. +// Negative values are treated as infinite. Zero is treated as no keep alive. +// Default is 5 minutes. +func KeepAlive() (keepAlive time.Duration) { + keepAlive = 5 * time.Minute + if s := Var("OLLAMA_KEEP_ALIVE"); s != "" { + if d, err := time.ParseDuration(s); err == nil { + keepAlive = d + } else if n, err := strconv.ParseInt(s, 10, 64); err == nil { + keepAlive = time.Duration(n) * time.Second + } + } + + if keepAlive < 0 { + return time.Duration(math.MaxInt64) + } + + return keepAlive +} + +func Bool(k string) func() bool { + return func() bool { + if s := Var(k); s != "" { + b, err := strconv.ParseBool(s) + if err != nil { + return true + } + + return b + } + + return false + } +} + +var ( + // Debug enabled additional debug information. + Debug = Bool("OLLAMA_DEBUG") + // FlashAttention enables the experimental flash attention feature. + FlashAttention = Bool("OLLAMA_FLASH_ATTENTION") + // NoHistory disables readline history. + NoHistory = Bool("OLLAMA_NOHISTORY") + // NoPrune disables pruning of model blobs on startup. + NoPrune = Bool("OLLAMA_NOPRUNE") + // SchedSpread allows scheduling models across all GPUs. + SchedSpread = Bool("OLLAMA_SCHED_SPREAD") + // IntelGPU enables experimental Intel GPU detection. + IntelGPU = Bool("OLLAMA_INTEL_GPU") +) + +func String(s string) func() string { + return func() string { + return Var(s) + } +} + +var ( + LLMLibrary = String("OLLAMA_LLM_LIBRARY") + TmpDir = String("OLLAMA_TMPDIR") + + CudaVisibleDevices = String("CUDA_VISIBLE_DEVICES") + HipVisibleDevices = String("HIP_VISIBLE_DEVICES") + RocrVisibleDevices = String("ROCR_VISIBLE_DEVICES") + GpuDeviceOrdinal = String("GPU_DEVICE_ORDINAL") + HsaOverrideGfxVersion = String("HSA_OVERRIDE_GFX_VERSION") +) + +func RunnersDir() (p string) { + if p := Var("OLLAMA_RUNNERS_DIR"); p != "" { + return p + } + + if runtime.GOOS != "windows" { + return + } + + defer func() { + if p == "" { + slog.Error("unable to locate llm runner directory. Set OLLAMA_RUNNERS_DIR to the location of 'ollama_runners'") + } + }() + + // On Windows we do not carry the payloads inside the main executable + exe, err := os.Executable() + if err != nil { + return + } + + cwd, err := os.Getwd() + if err != nil { + return + } + + var paths []string + for _, root := range []string{filepath.Dir(exe), cwd} { + paths = append(paths, + root, + filepath.Join(root, "windows-"+runtime.GOARCH), + filepath.Join(root, "dist", "windows-"+runtime.GOARCH), + ) + } + + // Try a few variations to improve developer experience when building from source in the local tree + for _, path := range paths { + candidate := filepath.Join(path, "ollama_runners") + if _, err := os.Stat(candidate); err == nil { + p = candidate + break + } + } + + return p +} + +func Uint(key string, defaultValue uint) func() uint { + return func() uint { + if s := Var(key); s != "" { + if n, err := strconv.ParseUint(s, 10, 64); err != nil { + slog.Warn("invalid environment variable, using default", "key", key, "value", s, "default", defaultValue) } else { - KeepAlive = d + return uint(n) } } - } else { - d := time.Duration(v) * time.Second - if d < 0 { - KeepAlive = time.Duration(math.MaxInt64) - } else { - KeepAlive = d - } + + return defaultValue } } + +var ( + // NumParallel sets the number of parallel model requests. NumParallel can be configured via the OLLAMA_NUM_PARALLEL environment variable. + NumParallel = Uint("OLLAMA_NUM_PARALLEL", 0) + // MaxRunners sets the maximum number of loaded models. MaxRunners can be configured via the OLLAMA_MAX_LOADED_MODELS environment variable. + MaxRunners = Uint("OLLAMA_MAX_LOADED_MODELS", 0) + // MaxQueue sets the maximum number of queued requests. MaxQueue can be configured via the OLLAMA_MAX_QUEUE environment variable. + MaxQueue = Uint("OLLAMA_MAX_QUEUE", 512) + // MaxVRAM sets a maximum VRAM override in bytes. MaxVRAM can be configured via the OLLAMA_MAX_VRAM environment variable. + MaxVRAM = Uint("OLLAMA_MAX_VRAM", 0) +) + +type EnvVar struct { + Name string + Value any + Description string +} + +func AsMap() map[string]EnvVar { + ret := map[string]EnvVar{ + "OLLAMA_DEBUG": {"OLLAMA_DEBUG", Debug(), "Show additional debug information (e.g. OLLAMA_DEBUG=1)"}, + "OLLAMA_FLASH_ATTENTION": {"OLLAMA_FLASH_ATTENTION", FlashAttention(), "Enabled flash attention"}, + "OLLAMA_HOST": {"OLLAMA_HOST", Host(), "IP Address for the ollama server (default 127.0.0.1:11434)"}, + "OLLAMA_KEEP_ALIVE": {"OLLAMA_KEEP_ALIVE", KeepAlive(), "The duration that models stay loaded in memory (default \"5m\")"}, + "OLLAMA_LLM_LIBRARY": {"OLLAMA_LLM_LIBRARY", LLMLibrary(), "Set LLM library to bypass autodetection"}, + "OLLAMA_MAX_LOADED_MODELS": {"OLLAMA_MAX_LOADED_MODELS", MaxRunners(), "Maximum number of loaded models per GPU"}, + "OLLAMA_MAX_QUEUE": {"OLLAMA_MAX_QUEUE", MaxQueue(), "Maximum number of queued requests"}, + "OLLAMA_MODELS": {"OLLAMA_MODELS", Models(), "The path to the models directory"}, + "OLLAMA_NOHISTORY": {"OLLAMA_NOHISTORY", NoHistory(), "Do not preserve readline history"}, + "OLLAMA_NOPRUNE": {"OLLAMA_NOPRUNE", NoPrune(), "Do not prune model blobs on startup"}, + "OLLAMA_NUM_PARALLEL": {"OLLAMA_NUM_PARALLEL", NumParallel(), "Maximum number of parallel requests"}, + "OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", Origins(), "A comma separated list of allowed origins"}, + "OLLAMA_RUNNERS_DIR": {"OLLAMA_RUNNERS_DIR", RunnersDir(), "Location for runners"}, + "OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread(), "Always schedule model across all GPUs"}, + "OLLAMA_TMPDIR": {"OLLAMA_TMPDIR", TmpDir(), "Location for temporary files"}, + } + if runtime.GOOS != "darwin" { + ret["CUDA_VISIBLE_DEVICES"] = EnvVar{"CUDA_VISIBLE_DEVICES", CudaVisibleDevices(), "Set which NVIDIA devices are visible"} + ret["HIP_VISIBLE_DEVICES"] = EnvVar{"HIP_VISIBLE_DEVICES", HipVisibleDevices(), "Set which AMD devices are visible"} + ret["ROCR_VISIBLE_DEVICES"] = EnvVar{"ROCR_VISIBLE_DEVICES", RocrVisibleDevices(), "Set which AMD devices are visible"} + ret["GPU_DEVICE_ORDINAL"] = EnvVar{"GPU_DEVICE_ORDINAL", GpuDeviceOrdinal(), "Set which AMD devices are visible"} + ret["HSA_OVERRIDE_GFX_VERSION"] = EnvVar{"HSA_OVERRIDE_GFX_VERSION", HsaOverrideGfxVersion(), "Override the gfx used for all detected AMD GPUs"} + ret["OLLAMA_INTEL_GPU"] = EnvVar{"OLLAMA_INTEL_GPU", IntelGPU(), "Enable experimental Intel GPU detection"} + } + return ret +} + +func Values() map[string]string { + vals := make(map[string]string) + for k, v := range AsMap() { + vals[k] = fmt.Sprintf("%v", v.Value) + } + return vals +} + +// Var returns an environment variable stripped of leading and trailing quotes or spaces +func Var(key string) string { + return strings.Trim(strings.TrimSpace(os.Getenv(key)), "\"'") +} diff --git a/envconfig/config_test.go b/envconfig/config_test.go index a5d73fd7..92a500f1 100644 --- a/envconfig/config_test.go +++ b/envconfig/config_test.go @@ -1,87 +1,234 @@ package envconfig import ( - "fmt" "math" - "net" "testing" "time" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/google/go-cmp/cmp" ) -func TestConfig(t *testing.T) { - Debug = false // Reset whatever was loaded in init() - t.Setenv("OLLAMA_DEBUG", "") - LoadConfig() - require.False(t, Debug) - t.Setenv("OLLAMA_DEBUG", "false") - LoadConfig() - require.False(t, Debug) - t.Setenv("OLLAMA_DEBUG", "1") - LoadConfig() - require.True(t, Debug) - t.Setenv("OLLAMA_FLASH_ATTENTION", "1") - LoadConfig() - require.True(t, FlashAttention) - t.Setenv("OLLAMA_KEEP_ALIVE", "") - LoadConfig() - require.Equal(t, 5*time.Minute, KeepAlive) - t.Setenv("OLLAMA_KEEP_ALIVE", "3") - LoadConfig() - require.Equal(t, 3*time.Second, KeepAlive) - t.Setenv("OLLAMA_KEEP_ALIVE", "1h") - LoadConfig() - require.Equal(t, 1*time.Hour, KeepAlive) - t.Setenv("OLLAMA_KEEP_ALIVE", "-1s") - LoadConfig() - require.Equal(t, time.Duration(math.MaxInt64), KeepAlive) - t.Setenv("OLLAMA_KEEP_ALIVE", "-1") - LoadConfig() - require.Equal(t, time.Duration(math.MaxInt64), KeepAlive) -} - -func TestClientFromEnvironment(t *testing.T) { - type testCase struct { +func TestHost(t *testing.T) { + cases := map[string]struct { value string expect string - err error + }{ + "empty": {"", "127.0.0.1:11434"}, + "only address": {"1.2.3.4", "1.2.3.4:11434"}, + "only port": {":1234", ":1234"}, + "address and port": {"1.2.3.4:1234", "1.2.3.4:1234"}, + "hostname": {"example.com", "example.com:11434"}, + "hostname and port": {"example.com:1234", "example.com:1234"}, + "zero port": {":0", ":0"}, + "too large port": {":66000", ":11434"}, + "too small port": {":-1", ":11434"}, + "ipv6 localhost": {"[::1]", "[::1]:11434"}, + "ipv6 world open": {"[::]", "[::]:11434"}, + "ipv6 no brackets": {"::1", "[::1]:11434"}, + "ipv6 + port": {"[::1]:1337", "[::1]:1337"}, + "extra space": {" 1.2.3.4 ", "1.2.3.4:11434"}, + "extra quotes": {"\"1.2.3.4\"", "1.2.3.4:11434"}, + "extra space+quotes": {" \" 1.2.3.4 \" ", "1.2.3.4:11434"}, + "extra single quotes": {"'1.2.3.4'", "1.2.3.4:11434"}, + "http": {"http://1.2.3.4", "1.2.3.4:80"}, + "http port": {"http://1.2.3.4:4321", "1.2.3.4:4321"}, + "https": {"https://1.2.3.4", "1.2.3.4:443"}, + "https port": {"https://1.2.3.4:4321", "1.2.3.4:4321"}, } - hostTestCases := map[string]*testCase{ - "empty": {value: "", expect: "127.0.0.1:11434"}, - "only address": {value: "1.2.3.4", expect: "1.2.3.4:11434"}, - "only port": {value: ":1234", expect: ":1234"}, - "address and port": {value: "1.2.3.4:1234", expect: "1.2.3.4:1234"}, - "hostname": {value: "example.com", expect: "example.com:11434"}, - "hostname and port": {value: "example.com:1234", expect: "example.com:1234"}, - "zero port": {value: ":0", expect: ":0"}, - "too large port": {value: ":66000", err: ErrInvalidHostPort}, - "too small port": {value: ":-1", err: ErrInvalidHostPort}, - "ipv6 localhost": {value: "[::1]", expect: "[::1]:11434"}, - "ipv6 world open": {value: "[::]", expect: "[::]:11434"}, - "ipv6 no brackets": {value: "::1", expect: "[::1]:11434"}, - "ipv6 + port": {value: "[::1]:1337", expect: "[::1]:1337"}, - "extra space": {value: " 1.2.3.4 ", expect: "1.2.3.4:11434"}, - "extra quotes": {value: "\"1.2.3.4\"", expect: "1.2.3.4:11434"}, - "extra space+quotes": {value: " \" 1.2.3.4 \" ", expect: "1.2.3.4:11434"}, - "extra single quotes": {value: "'1.2.3.4'", expect: "1.2.3.4:11434"}, - } - - for k, v := range hostTestCases { - t.Run(k, func(t *testing.T) { - t.Setenv("OLLAMA_HOST", v.value) - LoadConfig() - - oh, err := getOllamaHost() - if err != v.err { - t.Fatalf("expected %s, got %s", v.err, err) - } - - if err == nil { - host := net.JoinHostPort(oh.Host, oh.Port) - assert.Equal(t, v.expect, host, fmt.Sprintf("%s: expected %s, got %s", k, v.expect, host)) + for name, tt := range cases { + t.Run(name, func(t *testing.T) { + t.Setenv("OLLAMA_HOST", tt.value) + if host := Host(); host.Host != tt.expect { + t.Errorf("%s: expected %s, got %s", name, tt.expect, host.Host) + } + }) + } +} + +func TestOrigins(t *testing.T) { + cases := []struct { + value string + expect []string + }{ + {"", []string{ + "http://localhost", + "https://localhost", + "http://localhost:*", + "https://localhost:*", + "http://127.0.0.1", + "https://127.0.0.1", + "http://127.0.0.1:*", + "https://127.0.0.1:*", + "http://0.0.0.0", + "https://0.0.0.0", + "http://0.0.0.0:*", + "https://0.0.0.0:*", + "app://*", + "file://*", + "tauri://*", + }}, + {"http://10.0.0.1", []string{ + "http://10.0.0.1", + "http://localhost", + "https://localhost", + "http://localhost:*", + "https://localhost:*", + "http://127.0.0.1", + "https://127.0.0.1", + "http://127.0.0.1:*", + "https://127.0.0.1:*", + "http://0.0.0.0", + "https://0.0.0.0", + "http://0.0.0.0:*", + "https://0.0.0.0:*", + "app://*", + "file://*", + "tauri://*", + }}, + {"http://172.16.0.1,https://192.168.0.1", []string{ + "http://172.16.0.1", + "https://192.168.0.1", + "http://localhost", + "https://localhost", + "http://localhost:*", + "https://localhost:*", + "http://127.0.0.1", + "https://127.0.0.1", + "http://127.0.0.1:*", + "https://127.0.0.1:*", + "http://0.0.0.0", + "https://0.0.0.0", + "http://0.0.0.0:*", + "https://0.0.0.0:*", + "app://*", + "file://*", + "tauri://*", + }}, + {"http://totally.safe,http://definitely.legit", []string{ + "http://totally.safe", + "http://definitely.legit", + "http://localhost", + "https://localhost", + "http://localhost:*", + "https://localhost:*", + "http://127.0.0.1", + "https://127.0.0.1", + "http://127.0.0.1:*", + "https://127.0.0.1:*", + "http://0.0.0.0", + "https://0.0.0.0", + "http://0.0.0.0:*", + "https://0.0.0.0:*", + "app://*", + "file://*", + "tauri://*", + }}, + } + for _, tt := range cases { + t.Run(tt.value, func(t *testing.T) { + t.Setenv("OLLAMA_ORIGINS", tt.value) + + if diff := cmp.Diff(Origins(), tt.expect); diff != "" { + t.Errorf("%s: mismatch (-want +got):\n%s", tt.value, diff) + } + }) + } +} + +func TestBool(t *testing.T) { + cases := map[string]bool{ + "": false, + "true": true, + "false": false, + "1": true, + "0": false, + // invalid values + "random": true, + "something": true, + } + + for k, v := range cases { + t.Run(k, func(t *testing.T) { + t.Setenv("OLLAMA_BOOL", k) + if b := Bool("OLLAMA_BOOL")(); b != v { + t.Errorf("%s: expected %t, got %t", k, v, b) + } + }) + } +} + +func TestUint(t *testing.T) { + cases := map[string]uint{ + "0": 0, + "1": 1, + "1337": 1337, + // default values + "": 11434, + "-1": 11434, + "0o10": 11434, + "0x10": 11434, + "string": 11434, + } + + for k, v := range cases { + t.Run(k, func(t *testing.T) { + t.Setenv("OLLAMA_UINT", k) + if i := Uint("OLLAMA_UINT", 11434)(); i != v { + t.Errorf("%s: expected %d, got %d", k, v, i) + } + }) + } +} + +func TestKeepAlive(t *testing.T) { + cases := map[string]time.Duration{ + "": 5 * time.Minute, + "1s": time.Second, + "1m": time.Minute, + "1h": time.Hour, + "5m0s": 5 * time.Minute, + "1h2m3s": 1*time.Hour + 2*time.Minute + 3*time.Second, + "0": time.Duration(0), + "60": 60 * time.Second, + "120": 2 * time.Minute, + "3600": time.Hour, + "-0": time.Duration(0), + "-1": time.Duration(math.MaxInt64), + "-1m": time.Duration(math.MaxInt64), + // invalid values + " ": 5 * time.Minute, + "???": 5 * time.Minute, + "1d": 5 * time.Minute, + "1y": 5 * time.Minute, + "1w": 5 * time.Minute, + } + + for tt, expect := range cases { + t.Run(tt, func(t *testing.T) { + t.Setenv("OLLAMA_KEEP_ALIVE", tt) + if actual := KeepAlive(); actual != expect { + t.Errorf("%s: expected %s, got %s", tt, expect, actual) + } + }) + } +} + +func TestVar(t *testing.T) { + cases := map[string]string{ + "value": "value", + " value ": "value", + " 'value' ": "value", + ` "value" `: "value", + " ' value ' ": " value ", + ` " value " `: " value ", + } + + for k, v := range cases { + t.Run(k, func(t *testing.T) { + t.Setenv("OLLAMA_VAR", k) + if s := Var("OLLAMA_VAR"); s != v { + t.Errorf("%s: expected %q, got %q", k, v, s) } }) } diff --git a/gpu/amd_linux.go b/gpu/amd_linux.go index 6493af9e..1ad4b906 100644 --- a/gpu/amd_linux.go +++ b/gpu/amd_linux.go @@ -61,9 +61,9 @@ func AMDGetGPUInfo() []RocmGPUInfo { // Determine if the user has already pre-selected which GPUs to look at, then ignore the others var visibleDevices []string - hipVD := envconfig.HipVisibleDevices // zero based index only - rocrVD := envconfig.RocrVisibleDevices // zero based index or UUID, but consumer cards seem to not support UUID - gpuDO := envconfig.GpuDeviceOrdinal // zero based index + hipVD := envconfig.HipVisibleDevices() // zero based index only + rocrVD := envconfig.RocrVisibleDevices() // zero based index or UUID, but consumer cards seem to not support UUID + gpuDO := envconfig.GpuDeviceOrdinal() // zero based index switch { // TODO is this priorty order right? case hipVD != "": @@ -76,7 +76,7 @@ func AMDGetGPUInfo() []RocmGPUInfo { visibleDevices = strings.Split(gpuDO, ",") } - gfxOverride := envconfig.HsaOverrideGfxVersion + gfxOverride := envconfig.HsaOverrideGfxVersion() var supported []string libDir := "" diff --git a/gpu/amd_windows.go b/gpu/amd_windows.go index 20aed447..a170dfdc 100644 --- a/gpu/amd_windows.go +++ b/gpu/amd_windows.go @@ -53,7 +53,7 @@ func AMDGetGPUInfo() []RocmGPUInfo { } var supported []string - gfxOverride := envconfig.HsaOverrideGfxVersion + gfxOverride := envconfig.HsaOverrideGfxVersion() if gfxOverride == "" { supported, err = GetSupportedGFX(libDir) if err != nil { diff --git a/gpu/assets.go b/gpu/assets.go index 073d2e81..39ff7c21 100644 --- a/gpu/assets.go +++ b/gpu/assets.go @@ -26,7 +26,7 @@ func PayloadsDir() (string, error) { defer lock.Unlock() var err error if payloadsDir == "" { - runnersDir := envconfig.RunnersDir + runnersDir := envconfig.RunnersDir() if runnersDir != "" { payloadsDir = runnersDir @@ -35,7 +35,7 @@ func PayloadsDir() (string, error) { // The remainder only applies on non-windows where we still carry payloads in the main executable cleanupTmpDirs() - tmpDir := envconfig.TmpDir + tmpDir := envconfig.TmpDir() if tmpDir == "" { tmpDir, err = os.MkdirTemp("", "ollama") if err != nil { @@ -105,7 +105,7 @@ func cleanupTmpDirs() { func Cleanup() { lock.Lock() defer lock.Unlock() - runnersDir := envconfig.RunnersDir + runnersDir := envconfig.RunnersDir() if payloadsDir != "" && runnersDir == "" && runtime.GOOS != "windows" { // We want to fully clean up the tmpdir parent of the payloads dir tmpDir := filepath.Clean(filepath.Join(payloadsDir, "..")) diff --git a/gpu/gpu.go b/gpu/gpu.go index 6e25cb46..acab1c8d 100644 --- a/gpu/gpu.go +++ b/gpu/gpu.go @@ -230,8 +230,8 @@ func GetGPUInfo() GpuInfoList { // On windows we bundle the nvidia library one level above the runner dir depPath := "" - if runtime.GOOS == "windows" && envconfig.RunnersDir != "" { - depPath = filepath.Join(filepath.Dir(envconfig.RunnersDir), "cuda") + if runtime.GOOS == "windows" && envconfig.RunnersDir() != "" { + depPath = filepath.Join(filepath.Dir(envconfig.RunnersDir()), "cuda") } // Load ALL libraries @@ -302,12 +302,12 @@ func GetGPUInfo() GpuInfoList { } // Intel - if envconfig.IntelGpu { + if envconfig.IntelGPU() { oHandles = initOneAPIHandles() // On windows we bundle the oneapi library one level above the runner dir depPath = "" - if runtime.GOOS == "windows" && envconfig.RunnersDir != "" { - depPath = filepath.Join(filepath.Dir(envconfig.RunnersDir), "oneapi") + if runtime.GOOS == "windows" && envconfig.RunnersDir() != "" { + depPath = filepath.Join(filepath.Dir(envconfig.RunnersDir()), "oneapi") } for d := range oHandles.oneapi.num_drivers { @@ -611,7 +611,7 @@ func LoadOneapiMgmt(oneapiLibPaths []string) (int, *C.oneapi_handle_t, string) { } func getVerboseState() C.uint16_t { - if envconfig.Debug { + if envconfig.Debug() { return C.uint16_t(1) } return C.uint16_t(0) diff --git a/integration/basic_test.go b/integration/basic_test.go index 6e632a1c..8e35b5c5 100644 --- a/integration/basic_test.go +++ b/integration/basic_test.go @@ -45,14 +45,7 @@ func TestUnicodeModelDir(t *testing.T) { defer os.RemoveAll(modelDir) slog.Info("unicode", "OLLAMA_MODELS", modelDir) - oldModelsDir := os.Getenv("OLLAMA_MODELS") - if oldModelsDir == "" { - defer os.Unsetenv("OLLAMA_MODELS") - } else { - defer os.Setenv("OLLAMA_MODELS", oldModelsDir) - } - err = os.Setenv("OLLAMA_MODELS", modelDir) - require.NoError(t, err) + t.Setenv("OLLAMA_MODELS", modelDir) ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) defer cancel() diff --git a/integration/concurrency_test.go b/integration/concurrency_test.go index 8593285b..81d0b587 100644 --- a/integration/concurrency_test.go +++ b/integration/concurrency_test.go @@ -5,14 +5,16 @@ package integration import ( "context" "log/slog" - "os" "strconv" "sync" "testing" "time" - "github.com/ollama/ollama/api" "github.com/stretchr/testify/require" + + "github.com/ollama/ollama/api" + "github.com/ollama/ollama/envconfig" + "github.com/ollama/ollama/format" ) func TestMultiModelConcurrency(t *testing.T) { @@ -106,13 +108,16 @@ func TestIntegrationConcurrentPredictOrcaMini(t *testing.T) { // Stress the system if we know how much VRAM it has, and attempt to load more models than will fit func TestMultiModelStress(t *testing.T) { - vram := os.Getenv("OLLAMA_MAX_VRAM") // TODO - discover actual VRAM - if vram == "" { + s := os.Getenv("OLLAMA_MAX_VRAM") // TODO - discover actual VRAM + if s == "" { t.Skip("OLLAMA_MAX_VRAM not specified, can't pick the right models for the stress test") } - max, err := strconv.ParseUint(vram, 10, 64) - require.NoError(t, err) - const MB = uint64(1024 * 1024) + + maxVram, err := strconv.ParseUint(s, 10, 64) + if err != nil { + t.Fatal(err) + } + type model struct { name string size uint64 // Approximate amount of VRAM they typically use when fully loaded in VRAM @@ -121,83 +126,82 @@ func TestMultiModelStress(t *testing.T) { smallModels := []model{ { name: "orca-mini", - size: 2992 * MB, + size: 2992 * format.MebiByte, }, { name: "phi", - size: 2616 * MB, + size: 2616 * format.MebiByte, }, { name: "gemma:2b", - size: 2364 * MB, + size: 2364 * format.MebiByte, }, { name: "stable-code:3b", - size: 2608 * MB, + size: 2608 * format.MebiByte, }, { name: "starcoder2:3b", - size: 2166 * MB, + size: 2166 * format.MebiByte, }, } mediumModels := []model{ { name: "llama2", - size: 5118 * MB, + size: 5118 * format.MebiByte, }, { name: "mistral", - size: 4620 * MB, + size: 4620 * format.MebiByte, }, { name: "orca-mini:7b", - size: 5118 * MB, + size: 5118 * format.MebiByte, }, { name: "dolphin-mistral", - size: 4620 * MB, + size: 4620 * format.MebiByte, }, { name: "gemma:7b", - size: 5000 * MB, + size: 5000 * format.MebiByte, + }, + { + name: "codellama:7b", + size: 5118 * format.MebiByte, }, - // TODO - uncomment this once #3565 is merged and this is rebased on it - // { - // name: "codellama:7b", - // size: 5118 * MB, - // }, } // These seem to be too slow to be useful... // largeModels := []model{ // { // name: "llama2:13b", - // size: 7400 * MB, + // size: 7400 * format.MebiByte, // }, // { // name: "codellama:13b", - // size: 7400 * MB, + // size: 7400 * format.MebiByte, // }, // { // name: "orca-mini:13b", - // size: 7400 * MB, + // size: 7400 * format.MebiByte, // }, // { // name: "gemma:7b", - // size: 5000 * MB, + // size: 5000 * format.MebiByte, // }, // { // name: "starcoder2:15b", - // size: 9100 * MB, + // size: 9100 * format.MebiByte, // }, // } var chosenModels []model switch { - case max < 10000*MB: + case maxVram < 10000*format.MebiByte: slog.Info("selecting small models") chosenModels = smallModels - // case max < 30000*MB: + // case maxVram < 30000*format.MebiByte: default: slog.Info("selecting medium models") chosenModels = mediumModels @@ -226,15 +230,15 @@ func TestMultiModelStress(t *testing.T) { } var wg sync.WaitGroup - consumed := uint64(256 * MB) // Assume some baseline usage + consumed := uint64(256 * format.MebiByte) // Assume some baseline usage for i := 0; i < len(req); i++ { // Always get at least 2 models, but dont' overshoot VRAM too much or we'll take too long - if i > 1 && consumed > max { - slog.Info("achieved target vram exhaustion", "count", i, "vramMB", max/1024/1024, "modelsMB", consumed/1024/1024) + if i > 1 && consumed > vram { + slog.Info("achieved target vram exhaustion", "count", i, "vram", format.HumanBytes2(vram), "models", format.HumanBytes2(consumed)) break } consumed += chosenModels[i].size - slog.Info("target vram", "count", i, "vramMB", max/1024/1024, "modelsMB", consumed/1024/1024) + slog.Info("target vram", "count", i, "vram", format.HumanBytes2(vram), "models", format.HumanBytes2(consumed)) wg.Add(1) go func(i int) { diff --git a/integration/max_queue_test.go b/integration/max_queue_test.go index dfa5eae0..b06197e1 100644 --- a/integration/max_queue_test.go +++ b/integration/max_queue_test.go @@ -5,7 +5,6 @@ package integration import ( "context" "errors" - "fmt" "log/slog" "os" "strconv" @@ -14,8 +13,10 @@ import ( "testing" "time" - "github.com/ollama/ollama/api" "github.com/stretchr/testify/require" + + "github.com/ollama/ollama/api" + "github.com/ollama/ollama/envconfig" ) func TestMaxQueue(t *testing.T) { @@ -27,13 +28,10 @@ func TestMaxQueue(t *testing.T) { // Note: This test can be quite slow when running in CPU mode, so keep the threadCount low unless your on GPU // Also note that by default Darwin can't sustain > ~128 connections without adjusting limits threadCount := 32 - mq := os.Getenv("OLLAMA_MAX_QUEUE") - if mq != "" { - var err error - threadCount, err = strconv.Atoi(mq) - require.NoError(t, err) + if maxQueue := envconfig.MaxQueue(); maxQueue != 0 { + threadCount = maxQueue } else { - os.Setenv("OLLAMA_MAX_QUEUE", fmt.Sprintf("%d", threadCount)) + t.Setenv("OLLAMA_MAX_QUEUE", strconv.Itoa(threadCount)) } req := api.GenerateRequest{ diff --git a/llm/memory_test.go b/llm/memory_test.go index f972f927..06ae7438 100644 --- a/llm/memory_test.go +++ b/llm/memory_test.go @@ -8,14 +8,14 @@ import ( "testing" "github.com/ollama/ollama/api" - "github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/gpu" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestEstimateGPULayers(t *testing.T) { - envconfig.Debug = true + t.Setenv("OLLAMA_DEBUG", "1") + modelName := "dummy" f, err := os.CreateTemp(t.TempDir(), modelName) require.NoError(t, err) diff --git a/llm/server.go b/llm/server.go index afde077e..7fadb0c9 100644 --- a/llm/server.go +++ b/llm/server.go @@ -163,7 +163,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr } else { servers = serversForGpu(gpus[0]) // All GPUs in the list are matching Library and Variant } - demandLib := envconfig.LLMLibrary + demandLib := envconfig.LLMLibrary() if demandLib != "" { serverPath := availableServers[demandLib] if serverPath == "" { @@ -195,7 +195,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr params = append(params, "--n-gpu-layers", fmt.Sprintf("%d", opts.NumGPU)) } - if envconfig.Debug { + if envconfig.Debug() { params = append(params, "--verbose") } @@ -221,7 +221,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr params = append(params, "--memory-f32") } - flashAttnEnabled := envconfig.FlashAttention + flashAttnEnabled := envconfig.FlashAttention() for _, g := range gpus { // only cuda (compute capability 7+) and metal support flash attention @@ -382,7 +382,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr } slog.Info("starting llama server", "cmd", s.cmd.String()) - if envconfig.Debug { + if envconfig.Debug() { filteredEnv := []string{} for _, ev := range s.cmd.Env { if strings.HasPrefix(ev, "CUDA_") || diff --git a/server/images.go b/server/images.go index 836dbcc2..1ffe057a 100644 --- a/server/images.go +++ b/server/images.go @@ -646,7 +646,7 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio return err } - if !envconfig.NoPrune && old != nil { + if !envconfig.NoPrune() && old != nil { if err := old.RemoveLayers(); err != nil { return err } @@ -885,7 +885,7 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu // build deleteMap to prune unused layers deleteMap := make(map[string]struct{}) - if !envconfig.NoPrune { + if !envconfig.NoPrune() { manifest, _, err = GetManifest(mp) if err != nil && !errors.Is(err, os.ErrNotExist) { return err diff --git a/server/manifest_test.go b/server/manifest_test.go index ca6c3d2e..a4af5d5e 100644 --- a/server/manifest_test.go +++ b/server/manifest_test.go @@ -7,7 +7,6 @@ import ( "slices" "testing" - "github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/types/model" ) @@ -108,7 +107,6 @@ func TestManifests(t *testing.T) { t.Run(n, func(t *testing.T) { d := t.TempDir() t.Setenv("OLLAMA_MODELS", d) - envconfig.LoadConfig() for _, p := range wants.ps { createManifest(t, d, p) diff --git a/server/modelpath.go b/server/modelpath.go index 3fdb4238..354eeed7 100644 --- a/server/modelpath.go +++ b/server/modelpath.go @@ -105,9 +105,7 @@ func (mp ModelPath) GetShortTagname() string { // GetManifestPath returns the path to the manifest file for the given model path, it is up to the caller to create the directory if it does not exist. func (mp ModelPath) GetManifestPath() (string, error) { - dir := envconfig.ModelsDir - - return filepath.Join(dir, "manifests", mp.Registry, mp.Namespace, mp.Repository, mp.Tag), nil + return filepath.Join(envconfig.Models(), "manifests", mp.Registry, mp.Namespace, mp.Repository, mp.Tag), nil } func (mp ModelPath) BaseURL() *url.URL { @@ -118,9 +116,7 @@ func (mp ModelPath) BaseURL() *url.URL { } func GetManifestPath() (string, error) { - dir := envconfig.ModelsDir - - path := filepath.Join(dir, "manifests") + path := filepath.Join(envconfig.Models(), "manifests") if err := os.MkdirAll(path, 0o755); err != nil { return "", err } @@ -129,8 +125,6 @@ func GetManifestPath() (string, error) { } func GetBlobsPath(digest string) (string, error) { - dir := envconfig.ModelsDir - // only accept actual sha256 digests pattern := "^sha256[:-][0-9a-fA-F]{64}$" re := regexp.MustCompile(pattern) @@ -140,7 +134,7 @@ func GetBlobsPath(digest string) (string, error) { } digest = strings.ReplaceAll(digest, ":", "-") - path := filepath.Join(dir, "blobs", digest) + path := filepath.Join(envconfig.Models(), "blobs", digest) dirPath := filepath.Dir(path) if digest == "" { dirPath = path diff --git a/server/modelpath_test.go b/server/modelpath_test.go index 6c4dfbee..849e0fa7 100644 --- a/server/modelpath_test.go +++ b/server/modelpath_test.go @@ -7,8 +7,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - - "github.com/ollama/ollama/envconfig" ) func TestGetBlobsPath(t *testing.T) { @@ -63,7 +61,6 @@ func TestGetBlobsPath(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { t.Setenv("OLLAMA_MODELS", dir) - envconfig.LoadConfig() got, err := GetBlobsPath(tc.digest) diff --git a/server/routes.go b/server/routes.go index a560f369..adb9ed42 100644 --- a/server/routes.go +++ b/server/routes.go @@ -1053,7 +1053,7 @@ func (s *Server) GenerateRoutes() http.Handler { for _, prop := range openAIProperties { config.AllowHeaders = append(config.AllowHeaders, "x-stainless-"+prop) } - config.AllowOrigins = envconfig.AllowOrigins + config.AllowOrigins = envconfig.Origins() r := gin.Default() r.Use( @@ -1098,7 +1098,7 @@ func (s *Server) GenerateRoutes() http.Handler { func Serve(ln net.Listener) error { level := slog.LevelInfo - if envconfig.Debug { + if envconfig.Debug() { level = slog.LevelDebug } @@ -1126,7 +1126,7 @@ func Serve(ln net.Listener) error { return err } - if !envconfig.NoPrune { + if !envconfig.NoPrune() { // clean up unused layers and manifests if err := PruneLayers(); err != nil { return err diff --git a/server/routes_create_test.go b/server/routes_create_test.go index e801a74f..8c714209 100644 --- a/server/routes_create_test.go +++ b/server/routes_create_test.go @@ -15,7 +15,6 @@ import ( "github.com/gin-gonic/gin" "github.com/ollama/ollama/api" - "github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/llm" ) @@ -89,7 +88,6 @@ func TestCreateFromBin(t *testing.T) { p := t.TempDir() t.Setenv("OLLAMA_MODELS", p) - envconfig.LoadConfig() var s Server w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ @@ -117,7 +115,6 @@ func TestCreateFromModel(t *testing.T) { p := t.TempDir() t.Setenv("OLLAMA_MODELS", p) - envconfig.LoadConfig() var s Server w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ @@ -160,7 +157,6 @@ func TestCreateRemovesLayers(t *testing.T) { p := t.TempDir() t.Setenv("OLLAMA_MODELS", p) - envconfig.LoadConfig() var s Server w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ @@ -209,7 +205,6 @@ func TestCreateUnsetsSystem(t *testing.T) { p := t.TempDir() t.Setenv("OLLAMA_MODELS", p) - envconfig.LoadConfig() var s Server w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ @@ -267,7 +262,6 @@ func TestCreateMergeParameters(t *testing.T) { p := t.TempDir() t.Setenv("OLLAMA_MODELS", p) - envconfig.LoadConfig() var s Server w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ @@ -372,7 +366,6 @@ func TestCreateReplacesMessages(t *testing.T) { p := t.TempDir() t.Setenv("OLLAMA_MODELS", p) - envconfig.LoadConfig() var s Server w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ @@ -450,7 +443,6 @@ func TestCreateTemplateSystem(t *testing.T) { p := t.TempDir() t.Setenv("OLLAMA_MODELS", p) - envconfig.LoadConfig() var s Server w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ @@ -534,7 +526,6 @@ func TestCreateLicenses(t *testing.T) { p := t.TempDir() t.Setenv("OLLAMA_MODELS", p) - envconfig.LoadConfig() var s Server w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ @@ -582,7 +573,6 @@ func TestCreateDetectTemplate(t *testing.T) { p := t.TempDir() t.Setenv("OLLAMA_MODELS", p) - envconfig.LoadConfig() var s Server t.Run("matched", func(t *testing.T) { diff --git a/server/routes_delete_test.go b/server/routes_delete_test.go index 33a97a73..2354d730 100644 --- a/server/routes_delete_test.go +++ b/server/routes_delete_test.go @@ -10,7 +10,6 @@ import ( "github.com/gin-gonic/gin" "github.com/ollama/ollama/api" - "github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/types/model" ) @@ -19,7 +18,6 @@ func TestDelete(t *testing.T) { p := t.TempDir() t.Setenv("OLLAMA_MODELS", p) - envconfig.LoadConfig() var s Server diff --git a/server/routes_list_test.go b/server/routes_list_test.go index c2d9c113..29e3214c 100644 --- a/server/routes_list_test.go +++ b/server/routes_list_test.go @@ -9,14 +9,12 @@ import ( "github.com/gin-gonic/gin" "github.com/ollama/ollama/api" - "github.com/ollama/ollama/envconfig" ) func TestList(t *testing.T) { gin.SetMode(gin.TestMode) t.Setenv("OLLAMA_MODELS", t.TempDir()) - envconfig.LoadConfig() expectNames := []string{ "mistral:7b-instruct-q4_0", diff --git a/server/routes_test.go b/server/routes_test.go index 97786ba2..17da2305 100644 --- a/server/routes_test.go +++ b/server/routes_test.go @@ -19,7 +19,6 @@ import ( "github.com/stretchr/testify/require" "github.com/ollama/ollama/api" - "github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/llm" "github.com/ollama/ollama/openai" "github.com/ollama/ollama/parser" @@ -347,7 +346,6 @@ func Test_Routes(t *testing.T) { } t.Setenv("OLLAMA_MODELS", t.TempDir()) - envconfig.LoadConfig() s := &Server{} router := s.GenerateRoutes() @@ -378,7 +376,6 @@ func Test_Routes(t *testing.T) { func TestCase(t *testing.T) { t.Setenv("OLLAMA_MODELS", t.TempDir()) - envconfig.LoadConfig() cases := []string{ "mistral", @@ -458,7 +455,6 @@ func TestCase(t *testing.T) { func TestShow(t *testing.T) { t.Setenv("OLLAMA_MODELS", t.TempDir()) - envconfig.LoadConfig() var s Server diff --git a/server/sched.go b/server/sched.go index 92b8d508..700642c6 100644 --- a/server/sched.go +++ b/server/sched.go @@ -5,9 +5,11 @@ import ( "errors" "fmt" "log/slog" + "os" "reflect" "runtime" "sort" + "strconv" "strings" "sync" "time" @@ -59,11 +61,12 @@ var defaultParallel = 4 var ErrMaxQueue = fmt.Errorf("server busy, please try again. maximum pending requests exceeded") func InitScheduler(ctx context.Context) *Scheduler { + maxQueue := envconfig.MaxQueue() sched := &Scheduler{ - pendingReqCh: make(chan *LlmRequest, envconfig.MaxQueuedRequests), - finishedReqCh: make(chan *LlmRequest, envconfig.MaxQueuedRequests), - expiredCh: make(chan *runnerRef, envconfig.MaxQueuedRequests), - unloadedCh: make(chan interface{}, envconfig.MaxQueuedRequests), + pendingReqCh: make(chan *LlmRequest, maxQueue), + finishedReqCh: make(chan *LlmRequest, maxQueue), + expiredCh: make(chan *runnerRef, maxQueue), + unloadedCh: make(chan interface{}, maxQueue), loaded: make(map[string]*runnerRef), newServerFn: llm.NewLlamaServer, getGpuFn: gpu.GetGPUInfo, @@ -126,7 +129,7 @@ func (s *Scheduler) processPending(ctx context.Context) { slog.Debug("pending request cancelled or timed out, skipping scheduling") continue } - numParallel := envconfig.NumParallel + numParallel := int(envconfig.NumParallel()) // TODO (jmorganca): multimodal models don't support parallel yet // see https://github.com/ollama/ollama/issues/4165 if len(pending.model.ProjectorPaths) > 0 && numParallel != 1 { @@ -148,7 +151,7 @@ func (s *Scheduler) processPending(ctx context.Context) { pending.useLoadedRunner(runner, s.finishedReqCh) break } - } else if envconfig.MaxRunners > 0 && loadedCount >= envconfig.MaxRunners { + } else if envconfig.MaxRunners() > 0 && loadedCount >= int(envconfig.MaxRunners()) { slog.Debug("max runners achieved, unloading one to make room", "runner_count", loadedCount) runnerToExpire = s.findRunnerToUnload() } else { @@ -161,7 +164,7 @@ func (s *Scheduler) processPending(ctx context.Context) { gpus = s.getGpuFn() } - if envconfig.MaxRunners <= 0 { + if envconfig.MaxRunners() <= 0 { // No user specified MaxRunners, so figure out what automatic setting to use // If all GPUs have reliable free memory reporting, defaultModelsPerGPU * the number of GPUs // if any GPU has unreliable free memory reporting, 1x the number of GPUs @@ -173,11 +176,13 @@ func (s *Scheduler) processPending(ctx context.Context) { } } if allReliable { - envconfig.MaxRunners = defaultModelsPerGPU * len(gpus) + // HACK + os.Setenv("OLLAMA_MAX_LOADED_MODELS", strconv.Itoa(defaultModelsPerGPU*len(gpus))) slog.Debug("updating default concurrency", "OLLAMA_MAX_LOADED_MODELS", envconfig.MaxRunners, "gpu_count", len(gpus)) } else { + // HACK + os.Setenv("OLLAMA_MAX_LOADED_MODELS", strconv.Itoa(len(gpus))) slog.Info("one or more GPUs detected that are unable to accurately report free memory - disabling default concurrency") - envconfig.MaxRunners = len(gpus) } } @@ -404,7 +409,7 @@ func (s *Scheduler) load(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList, if numParallel < 1 { numParallel = 1 } - sessionDuration := envconfig.KeepAlive + sessionDuration := envconfig.KeepAlive() if req.sessionDuration != nil { sessionDuration = req.sessionDuration.Duration } @@ -699,7 +704,7 @@ func pickBestFullFitByLibrary(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoL // First attempt to fit the model into a single GPU for _, p := range numParallelToTry { req.opts.NumCtx = req.origNumCtx * p - if !envconfig.SchedSpread { + if !envconfig.SchedSpread() { for _, g := range sgl { if ok, estimatedVRAM = llm.PredictServerFit([]gpu.GpuInfo{g}, ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts); ok { slog.Info("new model will fit in available VRAM in single GPU, loading", "model", req.model.ModelPath, "gpu", g.ID, "parallel", p, "available", g.FreeMemory, "required", format.HumanBytes2(estimatedVRAM)) diff --git a/server/sched_test.go b/server/sched_test.go index 4f8789fa..6959dace 100644 --- a/server/sched_test.go +++ b/server/sched_test.go @@ -12,7 +12,6 @@ import ( "github.com/ollama/ollama/api" "github.com/ollama/ollama/app/lifecycle" - "github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/format" "github.com/ollama/ollama/gpu" "github.com/ollama/ollama/llm" @@ -272,7 +271,7 @@ func TestRequestsMultipleLoadedModels(t *testing.T) { c.req.opts.NumGPU = 0 // CPU load, will be allowed d := newScenarioRequest(t, ctx, "ollama-model-3c", 30, nil) // Needs prior unloaded - envconfig.MaxRunners = 1 + t.Setenv("OLLAMA_MAX_LOADED_MODELS", "1") s.newServerFn = a.newServer slog.Info("a") s.pendingReqCh <- a.req @@ -291,7 +290,7 @@ func TestRequestsMultipleLoadedModels(t *testing.T) { require.Len(t, s.loaded, 1) s.loadedMu.Unlock() - envconfig.MaxRunners = 0 + t.Setenv("OLLAMA_MAX_LOADED_MODELS", "0") s.newServerFn = b.newServer slog.Info("b") s.pendingReqCh <- b.req @@ -362,7 +361,7 @@ func TestGetRunner(t *testing.T) { a := newScenarioRequest(t, ctx, "ollama-model-1a", 10, &api.Duration{Duration: 2 * time.Millisecond}) b := newScenarioRequest(t, ctx, "ollama-model-1b", 10, &api.Duration{Duration: 2 * time.Millisecond}) c := newScenarioRequest(t, ctx, "ollama-model-1c", 10, &api.Duration{Duration: 2 * time.Millisecond}) - envconfig.MaxQueuedRequests = 1 + t.Setenv("OLLAMA_MAX_QUEUE", "1") s := InitScheduler(ctx) s.getGpuFn = getGpuFn s.getCpuFn = getCpuFn