Merge pull request #4154 from dhiltgen/central_config

Centralize server config handling
This commit is contained in:
Daniel Hiltgen 2024-05-05 17:08:26 -07:00 committed by GitHub
commit 840424a2c4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 235 additions and 162 deletions

View file

@ -5,12 +5,14 @@ import (
"log/slog" "log/slog"
"os" "os"
"path/filepath" "path/filepath"
"github.com/ollama/ollama/server/envconfig"
) )
func InitLogging() { func InitLogging() {
level := slog.LevelInfo level := slog.LevelInfo
if debug := os.Getenv("OLLAMA_DEBUG"); debug != "" { if envconfig.Debug {
level = slog.LevelDebug level = slog.LevelDebug
} }

View file

@ -31,16 +31,13 @@ func DoUpgrade(cancel context.CancelFunc, done chan int) error {
"/LOG=" + filepath.Base(UpgradeLogFile), // Only relative seems reliable, so set pwd "/LOG=" + filepath.Base(UpgradeLogFile), // Only relative seems reliable, so set pwd
"/FORCECLOSEAPPLICATIONS", // Force close the tray app - might be needed "/FORCECLOSEAPPLICATIONS", // Force close the tray app - might be needed
} }
// When we're not in debug mode, make the upgrade as quiet as possible (no GUI, no prompts) // make the upgrade as quiet as possible (no GUI, no prompts)
// TODO - temporarily disable since we're pinning in debug mode for the preview
// if debug := os.Getenv("OLLAMA_DEBUG"); debug == "" {
installArgs = append(installArgs, installArgs = append(installArgs,
"/SP", // Skip the "This will install... Do you wish to continue" prompt "/SP", // Skip the "This will install... Do you wish to continue" prompt
"/SUPPRESSMSGBOXES", "/SUPPRESSMSGBOXES",
"/SILENT", "/SILENT",
"/VERYSILENT", "/VERYSILENT",
) )
// }
// Safeguard in case we have requests in flight that need to drain... // Safeguard in case we have requests in flight that need to drain...
slog.Info("Waiting for server to shutdown") slog.Info("Waiting for server to shutdown")

View file

@ -12,6 +12,8 @@ import (
"sync" "sync"
"syscall" "syscall"
"time" "time"
"github.com/ollama/ollama/server/envconfig"
) )
var ( var (
@ -24,45 +26,8 @@ func PayloadsDir() (string, error) {
defer lock.Unlock() defer lock.Unlock()
var err error var err error
if payloadsDir == "" { if payloadsDir == "" {
runnersDir := os.Getenv("OLLAMA_RUNNERS_DIR") runnersDir := envconfig.RunnersDir
// On Windows we do not carry the payloads inside the main executable
if runtime.GOOS == "windows" && runnersDir == "" {
appExe, err := os.Executable()
if err != nil {
slog.Error("failed to lookup executable path", "error", err)
return "", err
}
cwd, err := os.Getwd()
if err != nil {
slog.Error("failed to lookup working directory", "error", err)
return "", err
}
var paths []string
for _, root := range []string{filepath.Dir(appExe), cwd} {
paths = append(paths,
filepath.Join(root),
filepath.Join(root, "windows-"+runtime.GOARCH),
filepath.Join(root, "dist", "windows-"+runtime.GOARCH),
)
}
// Try a few variations to improve developer experience when building from source in the local tree
for _, p := range paths {
candidate := filepath.Join(p, "ollama_runners")
_, err := os.Stat(candidate)
if err == nil {
runnersDir = candidate
break
}
}
if runnersDir == "" {
err = fmt.Errorf("unable to locate llm runner directory. Set OLLAMA_RUNNERS_DIR to the location of 'ollama_runners'")
slog.Error("incomplete distribution", "error", err)
return "", err
}
}
if runnersDir != "" { if runnersDir != "" {
payloadsDir = runnersDir payloadsDir = runnersDir
return payloadsDir, nil return payloadsDir, nil
@ -70,7 +35,7 @@ func PayloadsDir() (string, error) {
// The remainder only applies on non-windows where we still carry payloads in the main executable // The remainder only applies on non-windows where we still carry payloads in the main executable
cleanupTmpDirs() cleanupTmpDirs()
tmpDir := os.Getenv("OLLAMA_TMPDIR") tmpDir := envconfig.TmpDir
if tmpDir == "" { if tmpDir == "" {
tmpDir, err = os.MkdirTemp("", "ollama") tmpDir, err = os.MkdirTemp("", "ollama")
if err != nil { if err != nil {
@ -133,7 +98,7 @@ func cleanupTmpDirs() {
func Cleanup() { func Cleanup() {
lock.Lock() lock.Lock()
defer lock.Unlock() defer lock.Unlock()
runnersDir := os.Getenv("OLLAMA_RUNNERS_DIR") runnersDir := envconfig.RunnersDir
if payloadsDir != "" && runnersDir == "" && runtime.GOOS != "windows" { if payloadsDir != "" && runnersDir == "" && runtime.GOOS != "windows" {
// We want to fully clean up the tmpdir parent of the payloads dir // We want to fully clean up the tmpdir parent of the payloads dir
tmpDir := filepath.Clean(filepath.Join(payloadsDir, "..")) tmpDir := filepath.Clean(filepath.Join(payloadsDir, ".."))

View file

@ -21,6 +21,7 @@ import (
"unsafe" "unsafe"
"github.com/ollama/ollama/format" "github.com/ollama/ollama/format"
"github.com/ollama/ollama/server/envconfig"
) )
type handles struct { type handles struct {
@ -268,7 +269,7 @@ func LoadCUDARTMgmt(cudartLibPaths []string) (int, *C.cudart_handle_t, string) {
} }
func getVerboseState() C.uint16_t { func getVerboseState() C.uint16_t {
if debug := os.Getenv("OLLAMA_DEBUG"); debug != "" { if envconfig.Debug {
return C.uint16_t(1) return C.uint16_t(1)
} }
return C.uint16_t(0) return C.uint16_t(0)

View file

@ -3,12 +3,11 @@ package llm
import ( import (
"fmt" "fmt"
"log/slog" "log/slog"
"os"
"strconv"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/format" "github.com/ollama/ollama/format"
"github.com/ollama/ollama/gpu" "github.com/ollama/ollama/gpu"
"github.com/ollama/ollama/server/envconfig"
) )
// This algorithm looks for a complete fit to determine if we need to unload other models // This algorithm looks for a complete fit to determine if we need to unload other models
@ -50,15 +49,8 @@ func EstimateGPULayers(gpus []gpu.GpuInfo, ggml *GGML, projectors []string, opts
for _, info := range gpus { for _, info := range gpus {
memoryAvailable += info.FreeMemory memoryAvailable += info.FreeMemory
} }
userLimit := os.Getenv("OLLAMA_MAX_VRAM") if envconfig.MaxVRAM > 0 {
if userLimit != "" { memoryAvailable = envconfig.MaxVRAM
avail, err := strconv.ParseUint(userLimit, 10, 64)
if err != nil {
slog.Error("invalid setting, ignoring", "OLLAMA_MAX_VRAM", userLimit, "error", err)
} else {
slog.Info("user override memory limit", "OLLAMA_MAX_VRAM", avail, "actual", memoryAvailable)
memoryAvailable = avail
}
} }
slog.Debug("evaluating", "library", gpus[0].Library, "gpu_count", len(gpus), "available", format.HumanBytes2(memoryAvailable)) slog.Debug("evaluating", "library", gpus[0].Library, "gpu_count", len(gpus), "available", format.HumanBytes2(memoryAvailable))

View file

@ -26,6 +26,7 @@ import (
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/format" "github.com/ollama/ollama/format"
"github.com/ollama/ollama/gpu" "github.com/ollama/ollama/gpu"
"github.com/ollama/ollama/server/envconfig"
) )
type LlamaServer interface { type LlamaServer interface {
@ -124,7 +125,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
} else { } else {
servers = serversForGpu(gpus[0]) // All GPUs in the list are matching Library and Variant servers = serversForGpu(gpus[0]) // All GPUs in the list are matching Library and Variant
} }
demandLib := strings.Trim(os.Getenv("OLLAMA_LLM_LIBRARY"), "\"' ") demandLib := envconfig.LLMLibrary
if demandLib != "" { if demandLib != "" {
serverPath := availableServers[demandLib] serverPath := availableServers[demandLib]
if serverPath == "" { if serverPath == "" {
@ -145,7 +146,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
"--batch-size", fmt.Sprintf("%d", opts.NumBatch), "--batch-size", fmt.Sprintf("%d", opts.NumBatch),
"--embedding", "--embedding",
} }
if debug := os.Getenv("OLLAMA_DEBUG"); debug != "" { if envconfig.Debug {
params = append(params, "--log-format", "json") params = append(params, "--log-format", "json")
} else { } else {
params = append(params, "--log-disable") params = append(params, "--log-disable")
@ -155,7 +156,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
params = append(params, "--n-gpu-layers", fmt.Sprintf("%d", opts.NumGPU)) params = append(params, "--n-gpu-layers", fmt.Sprintf("%d", opts.NumGPU))
} }
if debug := os.Getenv("OLLAMA_DEBUG"); debug != "" { if envconfig.Debug {
params = append(params, "--verbose") params = append(params, "--verbose")
} }
@ -194,15 +195,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
} }
// "--cont-batching", // TODO - doesn't seem to have any noticeable perf change for multiple requests // "--cont-batching", // TODO - doesn't seem to have any noticeable perf change for multiple requests
numParallel := 1 numParallel := envconfig.NumParallel
if onp := os.Getenv("OLLAMA_NUM_PARALLEL"); onp != "" {
numParallel, err = strconv.Atoi(onp)
if err != nil || numParallel <= 0 {
err = fmt.Errorf("invalid OLLAMA_NUM_PARALLEL=%s must be greater than zero - %w", onp, err)
slog.Error("misconfiguration", "error", err)
return nil, err
}
}
params = append(params, "--parallel", fmt.Sprintf("%d", numParallel)) params = append(params, "--parallel", fmt.Sprintf("%d", numParallel))
for i := 0; i < len(servers); i++ { for i := 0; i < len(servers); i++ {

174
server/envconfig/config.go Normal file
View file

@ -0,0 +1,174 @@
package envconfig
import (
"fmt"
"log/slog"
"os"
"path/filepath"
"runtime"
"strconv"
"strings"
)
var (
// Set via OLLAMA_ORIGINS in the environment
AllowOrigins []string
// Set via OLLAMA_DEBUG in the environment
Debug bool
// Set via OLLAMA_LLM_LIBRARY in the environment
LLMLibrary string
// Set via OLLAMA_MAX_LOADED_MODELS in the environment
MaxRunners int
// Set via OLLAMA_MAX_QUEUE in the environment
MaxQueuedRequests int
// Set via OLLAMA_MAX_VRAM in the environment
MaxVRAM uint64
// Set via OLLAMA_NOPRUNE in the environment
NoPrune bool
// Set via OLLAMA_NUM_PARALLEL in the environment
NumParallel int
// Set via OLLAMA_RUNNERS_DIR in the environment
RunnersDir string
// Set via OLLAMA_TMPDIR in the environment
TmpDir string
)
func AsMap() map[string]string {
return map[string]string{
"OLLAMA_ORIGINS": fmt.Sprintf("%v", AllowOrigins),
"OLLAMA_DEBUG": fmt.Sprintf("%v", Debug),
"OLLAMA_LLM_LIBRARY": fmt.Sprintf("%v", LLMLibrary),
"OLLAMA_MAX_LOADED_MODELS": fmt.Sprintf("%v", MaxRunners),
"OLLAMA_MAX_QUEUE": fmt.Sprintf("%v", MaxQueuedRequests),
"OLLAMA_MAX_VRAM": fmt.Sprintf("%v", MaxVRAM),
"OLLAMA_NOPRUNE": fmt.Sprintf("%v", NoPrune),
"OLLAMA_NUM_PARALLEL": fmt.Sprintf("%v", NumParallel),
"OLLAMA_RUNNERS_DIR": fmt.Sprintf("%v", RunnersDir),
"OLLAMA_TMPDIR": fmt.Sprintf("%v", TmpDir),
}
}
var defaultAllowOrigins = []string{
"localhost",
"127.0.0.1",
"0.0.0.0",
}
// Clean quotes and spaces from the value
func clean(key string) string {
return strings.Trim(os.Getenv(key), "\"' ")
}
func init() {
// default values
NumParallel = 1
MaxRunners = 1
MaxQueuedRequests = 512
LoadConfig()
}
func LoadConfig() {
if debug := clean("OLLAMA_DEBUG"); debug != "" {
d, err := strconv.ParseBool(debug)
if err == nil {
Debug = d
} else {
Debug = true
}
}
RunnersDir = clean("OLLAMA_RUNNERS_DIR")
if runtime.GOOS == "windows" && RunnersDir == "" {
// On Windows we do not carry the payloads inside the main executable
appExe, err := os.Executable()
if err != nil {
slog.Error("failed to lookup executable path", "error", err)
}
cwd, err := os.Getwd()
if err != nil {
slog.Error("failed to lookup working directory", "error", err)
}
var paths []string
for _, root := range []string{filepath.Dir(appExe), cwd} {
paths = append(paths,
filepath.Join(root),
filepath.Join(root, "windows-"+runtime.GOARCH),
filepath.Join(root, "dist", "windows-"+runtime.GOARCH),
)
}
// Try a few variations to improve developer experience when building from source in the local tree
for _, p := range paths {
candidate := filepath.Join(p, "ollama_runners")
_, err := os.Stat(candidate)
if err == nil {
RunnersDir = candidate
break
}
}
if RunnersDir == "" {
slog.Error("unable to locate llm runner directory. Set OLLAMA_RUNNERS_DIR to the location of 'ollama_runners'")
}
}
TmpDir = clean("OLLAMA_TMPDIR")
userLimit := clean("OLLAMA_MAX_VRAM")
if userLimit != "" {
avail, err := strconv.ParseUint(userLimit, 10, 64)
if err != nil {
slog.Error("invalid setting, ignoring", "OLLAMA_MAX_VRAM", userLimit, "error", err)
} else {
MaxVRAM = avail
}
}
LLMLibrary = clean("OLLAMA_LLM_LIBRARY")
if onp := clean("OLLAMA_NUM_PARALLEL"); onp != "" {
val, err := strconv.Atoi(onp)
if err != nil || val <= 0 {
slog.Error("invalid setting must be greater than zero", "OLLAMA_NUM_PARALLEL", onp, "error", err)
} else {
NumParallel = val
}
}
if noprune := clean("OLLAMA_NOPRUNE"); noprune != "" {
NoPrune = true
}
if origins := clean("OLLAMA_ORIGINS"); origins != "" {
AllowOrigins = strings.Split(origins, ",")
}
for _, allowOrigin := range defaultAllowOrigins {
AllowOrigins = append(AllowOrigins,
fmt.Sprintf("http://%s", allowOrigin),
fmt.Sprintf("https://%s", allowOrigin),
fmt.Sprintf("http://%s:*", allowOrigin),
fmt.Sprintf("https://%s:*", allowOrigin),
)
}
maxRunners := clean("OLLAMA_MAX_LOADED_MODELS")
if maxRunners != "" {
m, err := strconv.Atoi(maxRunners)
if err != nil {
slog.Error("invalid setting", "OLLAMA_MAX_LOADED_MODELS", maxRunners, "error", err)
} else {
MaxRunners = m
}
}
if onp := os.Getenv("OLLAMA_MAX_QUEUE"); onp != "" {
p, err := strconv.Atoi(onp)
if err != nil || p <= 0 {
slog.Error("invalid setting", "OLLAMA_MAX_QUEUE", onp, "error", err)
} else {
MaxQueuedRequests = p
}
}
}

View file

@ -0,0 +1,20 @@
package envconfig
import (
"os"
"testing"
"github.com/stretchr/testify/require"
)
func TestConfig(t *testing.T) {
os.Setenv("OLLAMA_DEBUG", "")
LoadConfig()
require.False(t, Debug)
os.Setenv("OLLAMA_DEBUG", "false")
LoadConfig()
require.False(t, Debug)
os.Setenv("OLLAMA_DEBUG", "1")
LoadConfig()
require.True(t, Debug)
}

View file

@ -29,6 +29,7 @@ import (
"github.com/ollama/ollama/convert" "github.com/ollama/ollama/convert"
"github.com/ollama/ollama/format" "github.com/ollama/ollama/format"
"github.com/ollama/ollama/llm" "github.com/ollama/ollama/llm"
"github.com/ollama/ollama/server/envconfig"
"github.com/ollama/ollama/types/errtypes" "github.com/ollama/ollama/types/errtypes"
"github.com/ollama/ollama/types/model" "github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version" "github.com/ollama/ollama/version"
@ -695,7 +696,7 @@ func CreateModel(ctx context.Context, name, modelFileDir, quantization string, m
return err return err
} }
if noprune := os.Getenv("OLLAMA_NOPRUNE"); noprune == "" { if !envconfig.NoPrune {
if err := deleteUnusedLayers(nil, deleteMap, false); err != nil { if err := deleteUnusedLayers(nil, deleteMap, false); err != nil {
return err return err
} }
@ -1026,7 +1027,7 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
// build deleteMap to prune unused layers // build deleteMap to prune unused layers
deleteMap := make(map[string]struct{}) deleteMap := make(map[string]struct{})
if noprune = os.Getenv("OLLAMA_NOPRUNE"); noprune == "" { if !envconfig.NoPrune {
manifest, _, err = GetManifest(mp) manifest, _, err = GetManifest(mp)
if err != nil && !errors.Is(err, os.ErrNotExist) { if err != nil && !errors.Is(err, os.ErrNotExist) {
return err return err

View file

@ -29,6 +29,7 @@ import (
"github.com/ollama/ollama/gpu" "github.com/ollama/ollama/gpu"
"github.com/ollama/ollama/llm" "github.com/ollama/ollama/llm"
"github.com/ollama/ollama/openai" "github.com/ollama/ollama/openai"
"github.com/ollama/ollama/server/envconfig"
"github.com/ollama/ollama/types/model" "github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version" "github.com/ollama/ollama/version"
) )
@ -859,12 +860,6 @@ func (s *Server) CreateBlobHandler(c *gin.Context) {
c.Status(http.StatusCreated) c.Status(http.StatusCreated)
} }
var defaultAllowOrigins = []string{
"localhost",
"127.0.0.1",
"0.0.0.0",
}
func isLocalIP(ip netip.Addr) bool { func isLocalIP(ip netip.Addr) bool {
if interfaces, err := net.Interfaces(); err == nil { if interfaces, err := net.Interfaces(); err == nil {
for _, iface := range interfaces { for _, iface := range interfaces {
@ -948,19 +943,7 @@ func (s *Server) GenerateRoutes() http.Handler {
config := cors.DefaultConfig() config := cors.DefaultConfig()
config.AllowWildcard = true config.AllowWildcard = true
config.AllowBrowserExtensions = true config.AllowBrowserExtensions = true
config.AllowOrigins = envconfig.AllowOrigins
if allowedOrigins := strings.Trim(os.Getenv("OLLAMA_ORIGINS"), "\"'"); allowedOrigins != "" {
config.AllowOrigins = strings.Split(allowedOrigins, ",")
}
for _, allowOrigin := range defaultAllowOrigins {
config.AllowOrigins = append(config.AllowOrigins,
fmt.Sprintf("http://%s", allowOrigin),
fmt.Sprintf("https://%s", allowOrigin),
fmt.Sprintf("http://%s:*", allowOrigin),
fmt.Sprintf("https://%s:*", allowOrigin),
)
}
r := gin.Default() r := gin.Default()
r.Use( r.Use(
@ -999,10 +982,11 @@ func (s *Server) GenerateRoutes() http.Handler {
func Serve(ln net.Listener) error { func Serve(ln net.Listener) error {
level := slog.LevelInfo level := slog.LevelInfo
if debug := os.Getenv("OLLAMA_DEBUG"); debug != "" { if envconfig.Debug {
level = slog.LevelDebug level = slog.LevelDebug
} }
slog.Info("server config", "env", envconfig.AsMap())
handler := slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{ handler := slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
Level: level, Level: level,
AddSource: true, AddSource: true,
@ -1026,7 +1010,7 @@ func Serve(ln net.Listener) error {
return err return err
} }
if noprune := os.Getenv("OLLAMA_NOPRUNE"); noprune == "" { if !envconfig.NoPrune {
// clean up unused layers and manifests // clean up unused layers and manifests
if err := PruneLayers(); err != nil { if err := PruneLayers(); err != nil {
return err return err

View file

@ -5,10 +5,8 @@ import (
"errors" "errors"
"fmt" "fmt"
"log/slog" "log/slog"
"os"
"reflect" "reflect"
"sort" "sort"
"strconv"
"strings" "strings"
"sync" "sync"
"time" "time"
@ -17,6 +15,7 @@ import (
"github.com/ollama/ollama/format" "github.com/ollama/ollama/format"
"github.com/ollama/ollama/gpu" "github.com/ollama/ollama/gpu"
"github.com/ollama/ollama/llm" "github.com/ollama/ollama/llm"
"github.com/ollama/ollama/server/envconfig"
"golang.org/x/exp/slices" "golang.org/x/exp/slices"
) )
@ -43,46 +42,14 @@ type Scheduler struct {
getGpuFn func() gpu.GpuInfoList getGpuFn func() gpu.GpuInfoList
} }
var ( var ErrMaxQueue = fmt.Errorf("server busy, please try again. maximum pending requests exceeded")
// TODO set this to zero after a release or two, to enable multiple models by default
loadedMax = 1 // Maximum runners; < 1 maps to as many as will fit in VRAM (unlimited for CPU runners)
maxQueuedRequests = 512
numParallel = 1
ErrMaxQueue = fmt.Errorf("server busy, please try again. maximum pending requests exceeded")
)
func InitScheduler(ctx context.Context) *Scheduler { func InitScheduler(ctx context.Context) *Scheduler {
maxRunners := os.Getenv("OLLAMA_MAX_LOADED_MODELS")
if maxRunners != "" {
m, err := strconv.Atoi(maxRunners)
if err != nil {
slog.Error("invalid setting", "OLLAMA_MAX_LOADED_MODELS", maxRunners, "error", err)
} else {
loadedMax = m
}
}
if onp := os.Getenv("OLLAMA_NUM_PARALLEL"); onp != "" {
p, err := strconv.Atoi(onp)
if err != nil || p <= 0 {
slog.Error("invalid parallel setting, must be greater than zero", "OLLAMA_NUM_PARALLEL", onp, "error", err)
} else {
numParallel = p
}
}
if onp := os.Getenv("OLLAMA_MAX_QUEUE"); onp != "" {
p, err := strconv.Atoi(onp)
if err != nil || p <= 0 {
slog.Error("invalid setting", "OLLAMA_MAX_QUEUE", onp, "error", err)
} else {
maxQueuedRequests = p
}
}
sched := &Scheduler{ sched := &Scheduler{
pendingReqCh: make(chan *LlmRequest, maxQueuedRequests), pendingReqCh: make(chan *LlmRequest, envconfig.MaxQueuedRequests),
finishedReqCh: make(chan *LlmRequest, maxQueuedRequests), finishedReqCh: make(chan *LlmRequest, envconfig.MaxQueuedRequests),
expiredCh: make(chan *runnerRef, maxQueuedRequests), expiredCh: make(chan *runnerRef, envconfig.MaxQueuedRequests),
unloadedCh: make(chan interface{}, maxQueuedRequests), unloadedCh: make(chan interface{}, envconfig.MaxQueuedRequests),
loaded: make(map[string]*runnerRef), loaded: make(map[string]*runnerRef),
newServerFn: llm.NewLlamaServer, newServerFn: llm.NewLlamaServer,
getGpuFn: gpu.GetGPUInfo, getGpuFn: gpu.GetGPUInfo,
@ -94,7 +61,7 @@ func InitScheduler(ctx context.Context) *Scheduler {
// context must be canceled to decrement ref count and release the runner // context must be canceled to decrement ref count and release the runner
func (s *Scheduler) GetRunner(c context.Context, model *Model, opts api.Options, sessionDuration time.Duration) (chan *runnerRef, chan error) { func (s *Scheduler) GetRunner(c context.Context, model *Model, opts api.Options, sessionDuration time.Duration) (chan *runnerRef, chan error) {
// allocate a large enough kv cache for all parallel requests // allocate a large enough kv cache for all parallel requests
opts.NumCtx = opts.NumCtx * numParallel opts.NumCtx = opts.NumCtx * envconfig.NumParallel
req := &LlmRequest{ req := &LlmRequest{
ctx: c, ctx: c,
@ -147,11 +114,11 @@ func (s *Scheduler) processPending(ctx context.Context) {
pending.useLoadedRunner(runner, s.finishedReqCh) pending.useLoadedRunner(runner, s.finishedReqCh)
break break
} }
} else if loadedMax > 0 && loadedCount >= loadedMax { } else if envconfig.MaxRunners > 0 && loadedCount >= envconfig.MaxRunners {
slog.Debug("max runners achieved, unloading one to make room", "runner_count", loadedCount) slog.Debug("max runners achieved, unloading one to make room", "runner_count", loadedCount)
runnerToExpire = s.findRunnerToUnload(pending) runnerToExpire = s.findRunnerToUnload(pending)
} else { } else {
// Either no models are loaded or below loadedMax // Either no models are loaded or below envconfig.MaxRunners
// Get a refreshed GPU list // Get a refreshed GPU list
gpus := s.getGpuFn() gpus := s.getGpuFn()
@ -162,7 +129,7 @@ func (s *Scheduler) processPending(ctx context.Context) {
break break
} }
// If we're CPU only mode, just limit by loadedMax above // If we're CPU only mode, just limit by envconfig.MaxRunners above
// TODO handle system memory exhaustion // TODO handle system memory exhaustion
if (len(gpus) == 1 && gpus[0].Library == "cpu") || pending.opts.NumGPU == 0 { if (len(gpus) == 1 && gpus[0].Library == "cpu") || pending.opts.NumGPU == 0 {
slog.Debug("cpu mode with existing models, loading") slog.Debug("cpu mode with existing models, loading")

View file

@ -15,6 +15,7 @@ import (
"github.com/ollama/ollama/format" "github.com/ollama/ollama/format"
"github.com/ollama/ollama/gpu" "github.com/ollama/ollama/gpu"
"github.com/ollama/ollama/llm" "github.com/ollama/ollama/llm"
"github.com/ollama/ollama/server/envconfig"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -27,34 +28,10 @@ func init() {
func TestInitScheduler(t *testing.T) { func TestInitScheduler(t *testing.T) {
ctx, done := context.WithCancel(context.Background()) ctx, done := context.WithCancel(context.Background())
defer done() defer done()
initialMax := loadedMax
initialParallel := numParallel
s := InitScheduler(ctx) s := InitScheduler(ctx)
require.Equal(t, initialMax, loadedMax)
s.loadedMu.Lock() s.loadedMu.Lock()
require.NotNil(t, s.loaded) require.NotNil(t, s.loaded)
s.loadedMu.Unlock() s.loadedMu.Unlock()
os.Setenv("OLLAMA_MAX_LOADED_MODELS", "blue")
s = InitScheduler(ctx)
require.Equal(t, initialMax, loadedMax)
s.loadedMu.Lock()
require.NotNil(t, s.loaded)
s.loadedMu.Unlock()
os.Setenv("OLLAMA_MAX_LOADED_MODELS", "0")
s = InitScheduler(ctx)
require.Equal(t, 0, loadedMax)
s.loadedMu.Lock()
require.NotNil(t, s.loaded)
s.loadedMu.Unlock()
os.Setenv("OLLAMA_NUM_PARALLEL", "blue")
_ = InitScheduler(ctx)
require.Equal(t, initialParallel, numParallel)
os.Setenv("OLLAMA_NUM_PARALLEL", "10")
_ = InitScheduler(ctx)
require.Equal(t, 10, numParallel)
} }
func TestLoad(t *testing.T) { func TestLoad(t *testing.T) {
@ -249,7 +226,7 @@ func TestRequests(t *testing.T) {
t.Errorf("timeout") t.Errorf("timeout")
} }
loadedMax = 1 envconfig.MaxRunners = 1
s.newServerFn = scenario3a.newServer s.newServerFn = scenario3a.newServer
slog.Info("scenario3a") slog.Info("scenario3a")
s.pendingReqCh <- scenario3a.req s.pendingReqCh <- scenario3a.req
@ -268,7 +245,7 @@ func TestRequests(t *testing.T) {
require.Len(t, s.loaded, 1) require.Len(t, s.loaded, 1)
s.loadedMu.Unlock() s.loadedMu.Unlock()
loadedMax = 0 envconfig.MaxRunners = 0
s.newServerFn = scenario3b.newServer s.newServerFn = scenario3b.newServer
slog.Info("scenario3b") slog.Info("scenario3b")
s.pendingReqCh <- scenario3b.req s.pendingReqCh <- scenario3b.req
@ -339,7 +316,7 @@ func TestGetRunner(t *testing.T) {
scenario1b.req.sessionDuration = 0 scenario1b.req.sessionDuration = 0
scenario1c := newScenario(t, ctx, "ollama-model-1c", 10) scenario1c := newScenario(t, ctx, "ollama-model-1c", 10)
scenario1c.req.sessionDuration = 0 scenario1c.req.sessionDuration = 0
maxQueuedRequests = 1 envconfig.MaxQueuedRequests = 1
s := InitScheduler(ctx) s := InitScheduler(ctx)
s.getGpuFn = func() gpu.GpuInfoList { s.getGpuFn = func() gpu.GpuInfoList {
g := gpu.GpuInfo{Library: "metal"} g := gpu.GpuInfo{Library: "metal"}