From f56aa20014efeb383af5380d3de35475d1f08c36 Mon Sep 17 00:00:00 2001 From: Daniel Hiltgen Date: Sat, 4 May 2024 11:46:01 -0700 Subject: [PATCH] Centralize server config handling This moves all the env var reading into one central module and logs the loaded config once at startup which should help in troubleshooting user server logs --- app/lifecycle/logging.go | 4 +- app/lifecycle/updater_windows.go | 5 +- gpu/assets.go | 45 +------- gpu/gpu.go | 3 +- llm/memory.go | 14 +-- llm/server.go | 17 +-- server/envconfig/config.go | 174 +++++++++++++++++++++++++++++++ server/envconfig/config_test.go | 20 ++++ server/images.go | 5 +- server/routes.go | 26 +---- server/sched.go | 53 ++-------- server/sched_test.go | 31 +----- 12 files changed, 235 insertions(+), 162 deletions(-) create mode 100644 server/envconfig/config.go create mode 100644 server/envconfig/config_test.go diff --git a/app/lifecycle/logging.go b/app/lifecycle/logging.go index 98df9b41..4be90648 100644 --- a/app/lifecycle/logging.go +++ b/app/lifecycle/logging.go @@ -5,12 +5,14 @@ import ( "log/slog" "os" "path/filepath" + + "github.com/ollama/ollama/server/envconfig" ) func InitLogging() { level := slog.LevelInfo - if debug := os.Getenv("OLLAMA_DEBUG"); debug != "" { + if envconfig.Debug { level = slog.LevelDebug } diff --git a/app/lifecycle/updater_windows.go b/app/lifecycle/updater_windows.go index f26c43c9..4053671a 100644 --- a/app/lifecycle/updater_windows.go +++ b/app/lifecycle/updater_windows.go @@ -31,16 +31,13 @@ func DoUpgrade(cancel context.CancelFunc, done chan int) error { "/LOG=" + filepath.Base(UpgradeLogFile), // Only relative seems reliable, so set pwd "/FORCECLOSEAPPLICATIONS", // Force close the tray app - might be needed } - // When we're not in debug mode, make the upgrade as quiet as possible (no GUI, no prompts) - // TODO - temporarily disable since we're pinning in debug mode for the preview - // if debug := os.Getenv("OLLAMA_DEBUG"); debug == "" { + // make the upgrade as quiet as possible (no GUI, no prompts) installArgs = append(installArgs, "/SP", // Skip the "This will install... Do you wish to continue" prompt "/SUPPRESSMSGBOXES", "/SILENT", "/VERYSILENT", ) - // } // Safeguard in case we have requests in flight that need to drain... slog.Info("Waiting for server to shutdown") diff --git a/gpu/assets.go b/gpu/assets.go index f9b018cd..911a6977 100644 --- a/gpu/assets.go +++ b/gpu/assets.go @@ -12,6 +12,8 @@ import ( "sync" "syscall" "time" + + "github.com/ollama/ollama/server/envconfig" ) var ( @@ -24,45 +26,8 @@ func PayloadsDir() (string, error) { defer lock.Unlock() var err error if payloadsDir == "" { - runnersDir := os.Getenv("OLLAMA_RUNNERS_DIR") - // On Windows we do not carry the payloads inside the main executable - if runtime.GOOS == "windows" && runnersDir == "" { - appExe, err := os.Executable() - if err != nil { - slog.Error("failed to lookup executable path", "error", err) - return "", err - } + runnersDir := envconfig.RunnersDir - cwd, err := os.Getwd() - if err != nil { - slog.Error("failed to lookup working directory", "error", err) - return "", err - } - - var paths []string - for _, root := range []string{filepath.Dir(appExe), cwd} { - paths = append(paths, - filepath.Join(root), - filepath.Join(root, "windows-"+runtime.GOARCH), - filepath.Join(root, "dist", "windows-"+runtime.GOARCH), - ) - } - - // Try a few variations to improve developer experience when building from source in the local tree - for _, p := range paths { - candidate := filepath.Join(p, "ollama_runners") - _, err := os.Stat(candidate) - if err == nil { - runnersDir = candidate - break - } - } - if runnersDir == "" { - err = fmt.Errorf("unable to locate llm runner directory. Set OLLAMA_RUNNERS_DIR to the location of 'ollama_runners'") - slog.Error("incomplete distribution", "error", err) - return "", err - } - } if runnersDir != "" { payloadsDir = runnersDir return payloadsDir, nil @@ -70,7 +35,7 @@ func PayloadsDir() (string, error) { // The remainder only applies on non-windows where we still carry payloads in the main executable cleanupTmpDirs() - tmpDir := os.Getenv("OLLAMA_TMPDIR") + tmpDir := envconfig.TmpDir if tmpDir == "" { tmpDir, err = os.MkdirTemp("", "ollama") if err != nil { @@ -133,7 +98,7 @@ func cleanupTmpDirs() { func Cleanup() { lock.Lock() defer lock.Unlock() - runnersDir := os.Getenv("OLLAMA_RUNNERS_DIR") + runnersDir := envconfig.RunnersDir if payloadsDir != "" && runnersDir == "" && runtime.GOOS != "windows" { // We want to fully clean up the tmpdir parent of the payloads dir tmpDir := filepath.Clean(filepath.Join(payloadsDir, "..")) diff --git a/gpu/gpu.go b/gpu/gpu.go index 9b915015..a056a90b 100644 --- a/gpu/gpu.go +++ b/gpu/gpu.go @@ -21,6 +21,7 @@ import ( "unsafe" "github.com/ollama/ollama/format" + "github.com/ollama/ollama/server/envconfig" ) type handles struct { @@ -268,7 +269,7 @@ func LoadCUDARTMgmt(cudartLibPaths []string) (int, *C.cudart_handle_t, string) { } func getVerboseState() C.uint16_t { - if debug := os.Getenv("OLLAMA_DEBUG"); debug != "" { + if envconfig.Debug { return C.uint16_t(1) } return C.uint16_t(0) diff --git a/llm/memory.go b/llm/memory.go index b705aefe..661a0c50 100644 --- a/llm/memory.go +++ b/llm/memory.go @@ -3,12 +3,11 @@ package llm import ( "fmt" "log/slog" - "os" - "strconv" "github.com/ollama/ollama/api" "github.com/ollama/ollama/format" "github.com/ollama/ollama/gpu" + "github.com/ollama/ollama/server/envconfig" ) // This algorithm looks for a complete fit to determine if we need to unload other models @@ -50,15 +49,8 @@ func EstimateGPULayers(gpus []gpu.GpuInfo, ggml *GGML, projectors []string, opts for _, info := range gpus { memoryAvailable += info.FreeMemory } - userLimit := os.Getenv("OLLAMA_MAX_VRAM") - if userLimit != "" { - avail, err := strconv.ParseUint(userLimit, 10, 64) - if err != nil { - slog.Error("invalid setting, ignoring", "OLLAMA_MAX_VRAM", userLimit, "error", err) - } else { - slog.Info("user override memory limit", "OLLAMA_MAX_VRAM", avail, "actual", memoryAvailable) - memoryAvailable = avail - } + if envconfig.MaxVRAM > 0 { + memoryAvailable = envconfig.MaxVRAM } slog.Debug("evaluating", "library", gpus[0].Library, "gpu_count", len(gpus), "available", format.HumanBytes2(memoryAvailable)) diff --git a/llm/server.go b/llm/server.go index b41f393d..2272ac83 100644 --- a/llm/server.go +++ b/llm/server.go @@ -26,6 +26,7 @@ import ( "github.com/ollama/ollama/api" "github.com/ollama/ollama/format" "github.com/ollama/ollama/gpu" + "github.com/ollama/ollama/server/envconfig" ) type LlamaServer interface { @@ -124,7 +125,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr } else { servers = serversForGpu(gpus[0]) // All GPUs in the list are matching Library and Variant } - demandLib := strings.Trim(os.Getenv("OLLAMA_LLM_LIBRARY"), "\"' ") + demandLib := envconfig.LLMLibrary if demandLib != "" { serverPath := availableServers[demandLib] if serverPath == "" { @@ -145,7 +146,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr "--batch-size", fmt.Sprintf("%d", opts.NumBatch), "--embedding", } - if debug := os.Getenv("OLLAMA_DEBUG"); debug != "" { + if envconfig.Debug { params = append(params, "--log-format", "json") } else { params = append(params, "--log-disable") @@ -155,7 +156,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr params = append(params, "--n-gpu-layers", fmt.Sprintf("%d", opts.NumGPU)) } - if debug := os.Getenv("OLLAMA_DEBUG"); debug != "" { + if envconfig.Debug { params = append(params, "--verbose") } @@ -194,15 +195,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr } // "--cont-batching", // TODO - doesn't seem to have any noticeable perf change for multiple requests - numParallel := 1 - if onp := os.Getenv("OLLAMA_NUM_PARALLEL"); onp != "" { - numParallel, err = strconv.Atoi(onp) - if err != nil || numParallel <= 0 { - err = fmt.Errorf("invalid OLLAMA_NUM_PARALLEL=%s must be greater than zero - %w", onp, err) - slog.Error("misconfiguration", "error", err) - return nil, err - } - } + numParallel := envconfig.NumParallel params = append(params, "--parallel", fmt.Sprintf("%d", numParallel)) for i := 0; i < len(servers); i++ { diff --git a/server/envconfig/config.go b/server/envconfig/config.go new file mode 100644 index 00000000..9ad68180 --- /dev/null +++ b/server/envconfig/config.go @@ -0,0 +1,174 @@ +package envconfig + +import ( + "fmt" + "log/slog" + "os" + "path/filepath" + "runtime" + "strconv" + "strings" +) + +var ( + // Set via OLLAMA_ORIGINS in the environment + AllowOrigins []string + // Set via OLLAMA_DEBUG in the environment + Debug bool + // Set via OLLAMA_LLM_LIBRARY in the environment + LLMLibrary string + // Set via OLLAMA_MAX_LOADED_MODELS in the environment + MaxRunners int + // Set via OLLAMA_MAX_QUEUE in the environment + MaxQueuedRequests int + // Set via OLLAMA_MAX_VRAM in the environment + MaxVRAM uint64 + // Set via OLLAMA_NOPRUNE in the environment + NoPrune bool + // Set via OLLAMA_NUM_PARALLEL in the environment + NumParallel int + // Set via OLLAMA_RUNNERS_DIR in the environment + RunnersDir string + // Set via OLLAMA_TMPDIR in the environment + TmpDir string +) + +func AsMap() map[string]string { + return map[string]string{ + "OLLAMA_ORIGINS": fmt.Sprintf("%v", AllowOrigins), + "OLLAMA_DEBUG": fmt.Sprintf("%v", Debug), + "OLLAMA_LLM_LIBRARY": fmt.Sprintf("%v", LLMLibrary), + "OLLAMA_MAX_LOADED_MODELS": fmt.Sprintf("%v", MaxRunners), + "OLLAMA_MAX_QUEUE": fmt.Sprintf("%v", MaxQueuedRequests), + "OLLAMA_MAX_VRAM": fmt.Sprintf("%v", MaxVRAM), + "OLLAMA_NOPRUNE": fmt.Sprintf("%v", NoPrune), + "OLLAMA_NUM_PARALLEL": fmt.Sprintf("%v", NumParallel), + "OLLAMA_RUNNERS_DIR": fmt.Sprintf("%v", RunnersDir), + "OLLAMA_TMPDIR": fmt.Sprintf("%v", TmpDir), + } +} + +var defaultAllowOrigins = []string{ + "localhost", + "127.0.0.1", + "0.0.0.0", +} + +// Clean quotes and spaces from the value +func clean(key string) string { + return strings.Trim(os.Getenv(key), "\"' ") +} + +func init() { + // default values + NumParallel = 1 + MaxRunners = 1 + MaxQueuedRequests = 512 + + LoadConfig() +} + +func LoadConfig() { + if debug := clean("OLLAMA_DEBUG"); debug != "" { + d, err := strconv.ParseBool(debug) + if err == nil { + Debug = d + } else { + Debug = true + } + } + + RunnersDir = clean("OLLAMA_RUNNERS_DIR") + if runtime.GOOS == "windows" && RunnersDir == "" { + // On Windows we do not carry the payloads inside the main executable + appExe, err := os.Executable() + if err != nil { + slog.Error("failed to lookup executable path", "error", err) + } + + cwd, err := os.Getwd() + if err != nil { + slog.Error("failed to lookup working directory", "error", err) + } + + var paths []string + for _, root := range []string{filepath.Dir(appExe), cwd} { + paths = append(paths, + filepath.Join(root), + filepath.Join(root, "windows-"+runtime.GOARCH), + filepath.Join(root, "dist", "windows-"+runtime.GOARCH), + ) + } + + // Try a few variations to improve developer experience when building from source in the local tree + for _, p := range paths { + candidate := filepath.Join(p, "ollama_runners") + _, err := os.Stat(candidate) + if err == nil { + RunnersDir = candidate + break + } + } + if RunnersDir == "" { + slog.Error("unable to locate llm runner directory. Set OLLAMA_RUNNERS_DIR to the location of 'ollama_runners'") + } + } + + TmpDir = clean("OLLAMA_TMPDIR") + + userLimit := clean("OLLAMA_MAX_VRAM") + if userLimit != "" { + avail, err := strconv.ParseUint(userLimit, 10, 64) + if err != nil { + slog.Error("invalid setting, ignoring", "OLLAMA_MAX_VRAM", userLimit, "error", err) + } else { + MaxVRAM = avail + } + } + + LLMLibrary = clean("OLLAMA_LLM_LIBRARY") + + if onp := clean("OLLAMA_NUM_PARALLEL"); onp != "" { + val, err := strconv.Atoi(onp) + if err != nil || val <= 0 { + slog.Error("invalid setting must be greater than zero", "OLLAMA_NUM_PARALLEL", onp, "error", err) + } else { + NumParallel = val + } + } + + if noprune := clean("OLLAMA_NOPRUNE"); noprune != "" { + NoPrune = true + } + + if origins := clean("OLLAMA_ORIGINS"); origins != "" { + AllowOrigins = strings.Split(origins, ",") + } + for _, allowOrigin := range defaultAllowOrigins { + AllowOrigins = append(AllowOrigins, + fmt.Sprintf("http://%s", allowOrigin), + fmt.Sprintf("https://%s", allowOrigin), + fmt.Sprintf("http://%s:*", allowOrigin), + fmt.Sprintf("https://%s:*", allowOrigin), + ) + } + + maxRunners := clean("OLLAMA_MAX_LOADED_MODELS") + if maxRunners != "" { + m, err := strconv.Atoi(maxRunners) + if err != nil { + slog.Error("invalid setting", "OLLAMA_MAX_LOADED_MODELS", maxRunners, "error", err) + } else { + MaxRunners = m + } + } + + if onp := os.Getenv("OLLAMA_MAX_QUEUE"); onp != "" { + p, err := strconv.Atoi(onp) + if err != nil || p <= 0 { + slog.Error("invalid setting", "OLLAMA_MAX_QUEUE", onp, "error", err) + } else { + MaxQueuedRequests = p + } + } +} diff --git a/server/envconfig/config_test.go b/server/envconfig/config_test.go new file mode 100644 index 00000000..b2760299 --- /dev/null +++ b/server/envconfig/config_test.go @@ -0,0 +1,20 @@ +package envconfig + +import ( + "os" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestConfig(t *testing.T) { + os.Setenv("OLLAMA_DEBUG", "") + LoadConfig() + require.False(t, Debug) + os.Setenv("OLLAMA_DEBUG", "false") + LoadConfig() + require.False(t, Debug) + os.Setenv("OLLAMA_DEBUG", "1") + LoadConfig() + require.True(t, Debug) +} diff --git a/server/images.go b/server/images.go index 75a41d4a..76205392 100644 --- a/server/images.go +++ b/server/images.go @@ -29,6 +29,7 @@ import ( "github.com/ollama/ollama/convert" "github.com/ollama/ollama/format" "github.com/ollama/ollama/llm" + "github.com/ollama/ollama/server/envconfig" "github.com/ollama/ollama/types/errtypes" "github.com/ollama/ollama/types/model" "github.com/ollama/ollama/version" @@ -695,7 +696,7 @@ func CreateModel(ctx context.Context, name, modelFileDir, quantization string, m return err } - if noprune := os.Getenv("OLLAMA_NOPRUNE"); noprune == "" { + if !envconfig.NoPrune { if err := deleteUnusedLayers(nil, deleteMap, false); err != nil { return err } @@ -1026,7 +1027,7 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu // build deleteMap to prune unused layers deleteMap := make(map[string]struct{}) - if noprune = os.Getenv("OLLAMA_NOPRUNE"); noprune == "" { + if !envconfig.NoPrune { manifest, _, err = GetManifest(mp) if err != nil && !errors.Is(err, os.ErrNotExist) { return err diff --git a/server/routes.go b/server/routes.go index 3b24735f..e878598a 100644 --- a/server/routes.go +++ b/server/routes.go @@ -29,6 +29,7 @@ import ( "github.com/ollama/ollama/gpu" "github.com/ollama/ollama/llm" "github.com/ollama/ollama/openai" + "github.com/ollama/ollama/server/envconfig" "github.com/ollama/ollama/types/model" "github.com/ollama/ollama/version" ) @@ -859,12 +860,6 @@ func (s *Server) CreateBlobHandler(c *gin.Context) { c.Status(http.StatusCreated) } -var defaultAllowOrigins = []string{ - "localhost", - "127.0.0.1", - "0.0.0.0", -} - func isLocalIP(ip netip.Addr) bool { if interfaces, err := net.Interfaces(); err == nil { for _, iface := range interfaces { @@ -948,19 +943,7 @@ func (s *Server) GenerateRoutes() http.Handler { config := cors.DefaultConfig() config.AllowWildcard = true config.AllowBrowserExtensions = true - - if allowedOrigins := strings.Trim(os.Getenv("OLLAMA_ORIGINS"), "\"'"); allowedOrigins != "" { - config.AllowOrigins = strings.Split(allowedOrigins, ",") - } - - for _, allowOrigin := range defaultAllowOrigins { - config.AllowOrigins = append(config.AllowOrigins, - fmt.Sprintf("http://%s", allowOrigin), - fmt.Sprintf("https://%s", allowOrigin), - fmt.Sprintf("http://%s:*", allowOrigin), - fmt.Sprintf("https://%s:*", allowOrigin), - ) - } + config.AllowOrigins = envconfig.AllowOrigins r := gin.Default() r.Use( @@ -999,10 +982,11 @@ func (s *Server) GenerateRoutes() http.Handler { func Serve(ln net.Listener) error { level := slog.LevelInfo - if debug := os.Getenv("OLLAMA_DEBUG"); debug != "" { + if envconfig.Debug { level = slog.LevelDebug } + slog.Info("server config", "env", envconfig.AsMap()) handler := slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{ Level: level, AddSource: true, @@ -1026,7 +1010,7 @@ func Serve(ln net.Listener) error { return err } - if noprune := os.Getenv("OLLAMA_NOPRUNE"); noprune == "" { + if !envconfig.NoPrune { // clean up unused layers and manifests if err := PruneLayers(); err != nil { return err diff --git a/server/sched.go b/server/sched.go index f3d5c276..9d97c632 100644 --- a/server/sched.go +++ b/server/sched.go @@ -5,10 +5,8 @@ import ( "errors" "fmt" "log/slog" - "os" "reflect" "sort" - "strconv" "strings" "sync" "time" @@ -17,6 +15,7 @@ import ( "github.com/ollama/ollama/format" "github.com/ollama/ollama/gpu" "github.com/ollama/ollama/llm" + "github.com/ollama/ollama/server/envconfig" "golang.org/x/exp/slices" ) @@ -43,46 +42,14 @@ type Scheduler struct { getGpuFn func() gpu.GpuInfoList } -var ( - // TODO set this to zero after a release or two, to enable multiple models by default - loadedMax = 1 // Maximum runners; < 1 maps to as many as will fit in VRAM (unlimited for CPU runners) - maxQueuedRequests = 512 - numParallel = 1 - ErrMaxQueue = fmt.Errorf("server busy, please try again. maximum pending requests exceeded") -) +var ErrMaxQueue = fmt.Errorf("server busy, please try again. maximum pending requests exceeded") func InitScheduler(ctx context.Context) *Scheduler { - maxRunners := os.Getenv("OLLAMA_MAX_LOADED_MODELS") - if maxRunners != "" { - m, err := strconv.Atoi(maxRunners) - if err != nil { - slog.Error("invalid setting", "OLLAMA_MAX_LOADED_MODELS", maxRunners, "error", err) - } else { - loadedMax = m - } - } - if onp := os.Getenv("OLLAMA_NUM_PARALLEL"); onp != "" { - p, err := strconv.Atoi(onp) - if err != nil || p <= 0 { - slog.Error("invalid parallel setting, must be greater than zero", "OLLAMA_NUM_PARALLEL", onp, "error", err) - } else { - numParallel = p - } - } - if onp := os.Getenv("OLLAMA_MAX_QUEUE"); onp != "" { - p, err := strconv.Atoi(onp) - if err != nil || p <= 0 { - slog.Error("invalid setting", "OLLAMA_MAX_QUEUE", onp, "error", err) - } else { - maxQueuedRequests = p - } - } - sched := &Scheduler{ - pendingReqCh: make(chan *LlmRequest, maxQueuedRequests), - finishedReqCh: make(chan *LlmRequest, maxQueuedRequests), - expiredCh: make(chan *runnerRef, maxQueuedRequests), - unloadedCh: make(chan interface{}, maxQueuedRequests), + pendingReqCh: make(chan *LlmRequest, envconfig.MaxQueuedRequests), + finishedReqCh: make(chan *LlmRequest, envconfig.MaxQueuedRequests), + expiredCh: make(chan *runnerRef, envconfig.MaxQueuedRequests), + unloadedCh: make(chan interface{}, envconfig.MaxQueuedRequests), loaded: make(map[string]*runnerRef), newServerFn: llm.NewLlamaServer, getGpuFn: gpu.GetGPUInfo, @@ -94,7 +61,7 @@ func InitScheduler(ctx context.Context) *Scheduler { // context must be canceled to decrement ref count and release the runner func (s *Scheduler) GetRunner(c context.Context, model *Model, opts api.Options, sessionDuration time.Duration) (chan *runnerRef, chan error) { // allocate a large enough kv cache for all parallel requests - opts.NumCtx = opts.NumCtx * numParallel + opts.NumCtx = opts.NumCtx * envconfig.NumParallel req := &LlmRequest{ ctx: c, @@ -147,11 +114,11 @@ func (s *Scheduler) processPending(ctx context.Context) { pending.useLoadedRunner(runner, s.finishedReqCh) break } - } else if loadedMax > 0 && loadedCount >= loadedMax { + } else if envconfig.MaxRunners > 0 && loadedCount >= envconfig.MaxRunners { slog.Debug("max runners achieved, unloading one to make room", "runner_count", loadedCount) runnerToExpire = s.findRunnerToUnload(pending) } else { - // Either no models are loaded or below loadedMax + // Either no models are loaded or below envconfig.MaxRunners // Get a refreshed GPU list gpus := s.getGpuFn() @@ -162,7 +129,7 @@ func (s *Scheduler) processPending(ctx context.Context) { break } - // If we're CPU only mode, just limit by loadedMax above + // If we're CPU only mode, just limit by envconfig.MaxRunners above // TODO handle system memory exhaustion if (len(gpus) == 1 && gpus[0].Library == "cpu") || pending.opts.NumGPU == 0 { slog.Debug("cpu mode with existing models, loading") diff --git a/server/sched_test.go b/server/sched_test.go index ff1421ed..0e70b843 100644 --- a/server/sched_test.go +++ b/server/sched_test.go @@ -15,6 +15,7 @@ import ( "github.com/ollama/ollama/format" "github.com/ollama/ollama/gpu" "github.com/ollama/ollama/llm" + "github.com/ollama/ollama/server/envconfig" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -27,34 +28,10 @@ func init() { func TestInitScheduler(t *testing.T) { ctx, done := context.WithCancel(context.Background()) defer done() - initialMax := loadedMax - initialParallel := numParallel s := InitScheduler(ctx) - require.Equal(t, initialMax, loadedMax) s.loadedMu.Lock() require.NotNil(t, s.loaded) s.loadedMu.Unlock() - - os.Setenv("OLLAMA_MAX_LOADED_MODELS", "blue") - s = InitScheduler(ctx) - require.Equal(t, initialMax, loadedMax) - s.loadedMu.Lock() - require.NotNil(t, s.loaded) - s.loadedMu.Unlock() - - os.Setenv("OLLAMA_MAX_LOADED_MODELS", "0") - s = InitScheduler(ctx) - require.Equal(t, 0, loadedMax) - s.loadedMu.Lock() - require.NotNil(t, s.loaded) - s.loadedMu.Unlock() - - os.Setenv("OLLAMA_NUM_PARALLEL", "blue") - _ = InitScheduler(ctx) - require.Equal(t, initialParallel, numParallel) - os.Setenv("OLLAMA_NUM_PARALLEL", "10") - _ = InitScheduler(ctx) - require.Equal(t, 10, numParallel) } func TestLoad(t *testing.T) { @@ -249,7 +226,7 @@ func TestRequests(t *testing.T) { t.Errorf("timeout") } - loadedMax = 1 + envconfig.MaxRunners = 1 s.newServerFn = scenario3a.newServer slog.Info("scenario3a") s.pendingReqCh <- scenario3a.req @@ -268,7 +245,7 @@ func TestRequests(t *testing.T) { require.Len(t, s.loaded, 1) s.loadedMu.Unlock() - loadedMax = 0 + envconfig.MaxRunners = 0 s.newServerFn = scenario3b.newServer slog.Info("scenario3b") s.pendingReqCh <- scenario3b.req @@ -339,7 +316,7 @@ func TestGetRunner(t *testing.T) { scenario1b.req.sessionDuration = 0 scenario1c := newScenario(t, ctx, "ollama-model-1c", 10) scenario1c.req.sessionDuration = 0 - maxQueuedRequests = 1 + envconfig.MaxQueuedRequests = 1 s := InitScheduler(ctx) s.getGpuFn = func() gpu.GpuInfoList { g := gpu.GpuInfo{Library: "metal"}