Merge pull request #5473 from ollama/mxyng/environ

fix: environ lookup
This commit is contained in:
Michael Yang 2024-07-31 10:18:05 -07:00 committed by GitHub
commit 5c1912769e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
27 changed files with 547 additions and 515 deletions

View file

@ -20,7 +20,6 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"net"
"net/http" "net/http"
"net/url" "net/url"
"runtime" "runtime"
@ -63,13 +62,8 @@ func checkError(resp *http.Response, body []byte) error {
// If the variable is not specified, a default ollama host and port will be // If the variable is not specified, a default ollama host and port will be
// used. // used.
func ClientFromEnvironment() (*Client, error) { func ClientFromEnvironment() (*Client, error) {
ollamaHost := envconfig.Host
return &Client{ return &Client{
base: &url.URL{ base: envconfig.Host(),
Scheme: ollamaHost.Scheme,
Host: net.JoinHostPort(ollamaHost.Host, ollamaHost.Port),
},
http: http.DefaultClient, http: http.DefaultClient,
}, nil }, nil
} }

View file

@ -2,8 +2,6 @@ package api
import ( import (
"testing" "testing"
"github.com/ollama/ollama/envconfig"
) )
func TestClientFromEnvironment(t *testing.T) { func TestClientFromEnvironment(t *testing.T) {
@ -33,7 +31,6 @@ func TestClientFromEnvironment(t *testing.T) {
for k, v := range testCases { for k, v := range testCases {
t.Run(k, func(t *testing.T) { t.Run(k, func(t *testing.T) {
t.Setenv("OLLAMA_HOST", v.value) t.Setenv("OLLAMA_HOST", v.value)
envconfig.LoadConfig()
client, err := ClientFromEnvironment() client, err := ClientFromEnvironment()
if err != v.err { if err != v.err {

View file

@ -14,7 +14,7 @@ import (
func InitLogging() { func InitLogging() {
level := slog.LevelInfo level := slog.LevelInfo
if envconfig.Debug { if envconfig.Debug() {
level = slog.LevelDebug level = slog.LevelDebug
} }

View file

@ -1076,7 +1076,7 @@ func RunServer(cmd *cobra.Command, _ []string) error {
return err return err
} }
ln, err := net.Listen("tcp", net.JoinHostPort(envconfig.Host.Host, envconfig.Host.Port)) ln, err := net.Listen("tcp", envconfig.Host().Host)
if err != nil { if err != nil {
return err return err
} }

View file

@ -160,7 +160,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
return err return err
} }
if envconfig.NoHistory { if envconfig.NoHistory() {
scanner.HistoryDisable() scanner.HistoryDisable()
} }

View file

@ -1,11 +1,11 @@
package envconfig package envconfig
import ( import (
"errors"
"fmt" "fmt"
"log/slog" "log/slog"
"math" "math"
"net" "net"
"net/url"
"os" "os"
"path/filepath" "path/filepath"
"runtime" "runtime"
@ -14,296 +14,16 @@ import (
"time" "time"
) )
type OllamaHost struct { // Host returns the scheme and host. Host can be configured via the OLLAMA_HOST environment variable.
Scheme string // Default is scheme "http" and host "127.0.0.1:11434"
Host string func Host() *url.URL {
Port string
}
func (o OllamaHost) String() string {
return fmt.Sprintf("%s://%s:%s", o.Scheme, o.Host, o.Port)
}
var ErrInvalidHostPort = errors.New("invalid port specified in OLLAMA_HOST")
var (
// Set via OLLAMA_ORIGINS in the environment
AllowOrigins []string
// Set via OLLAMA_DEBUG in the environment
Debug bool
// Experimental flash attention
FlashAttention bool
// Set via OLLAMA_HOST in the environment
Host *OllamaHost
// Set via OLLAMA_KEEP_ALIVE in the environment
KeepAlive time.Duration
// Set via OLLAMA_LLM_LIBRARY in the environment
LLMLibrary string
// Set via OLLAMA_MAX_LOADED_MODELS in the environment
MaxRunners int
// Set via OLLAMA_MAX_QUEUE in the environment
MaxQueuedRequests int
// Set via OLLAMA_MODELS in the environment
ModelsDir string
// Set via OLLAMA_NOHISTORY in the environment
NoHistory bool
// Set via OLLAMA_NOPRUNE in the environment
NoPrune bool
// Set via OLLAMA_NUM_PARALLEL in the environment
NumParallel int
// Set via OLLAMA_RUNNERS_DIR in the environment
RunnersDir string
// Set via OLLAMA_SCHED_SPREAD in the environment
SchedSpread bool
// Set via OLLAMA_TMPDIR in the environment
TmpDir string
// Set via OLLAMA_INTEL_GPU in the environment
IntelGpu bool
// Set via CUDA_VISIBLE_DEVICES in the environment
CudaVisibleDevices string
// Set via HIP_VISIBLE_DEVICES in the environment
HipVisibleDevices string
// Set via ROCR_VISIBLE_DEVICES in the environment
RocrVisibleDevices string
// Set via GPU_DEVICE_ORDINAL in the environment
GpuDeviceOrdinal string
// Set via HSA_OVERRIDE_GFX_VERSION in the environment
HsaOverrideGfxVersion string
)
type EnvVar struct {
Name string
Value any
Description string
}
func AsMap() map[string]EnvVar {
ret := map[string]EnvVar{
"OLLAMA_DEBUG": {"OLLAMA_DEBUG", Debug, "Show additional debug information (e.g. OLLAMA_DEBUG=1)"},
"OLLAMA_FLASH_ATTENTION": {"OLLAMA_FLASH_ATTENTION", FlashAttention, "Enabled flash attention"},
"OLLAMA_HOST": {"OLLAMA_HOST", Host, "IP Address for the ollama server (default 127.0.0.1:11434)"},
"OLLAMA_KEEP_ALIVE": {"OLLAMA_KEEP_ALIVE", KeepAlive, "The duration that models stay loaded in memory (default \"5m\")"},
"OLLAMA_LLM_LIBRARY": {"OLLAMA_LLM_LIBRARY", LLMLibrary, "Set LLM library to bypass autodetection"},
"OLLAMA_MAX_LOADED_MODELS": {"OLLAMA_MAX_LOADED_MODELS", MaxRunners, "Maximum number of loaded models per GPU"},
"OLLAMA_MAX_QUEUE": {"OLLAMA_MAX_QUEUE", MaxQueuedRequests, "Maximum number of queued requests"},
"OLLAMA_MODELS": {"OLLAMA_MODELS", ModelsDir, "The path to the models directory"},
"OLLAMA_NOHISTORY": {"OLLAMA_NOHISTORY", NoHistory, "Do not preserve readline history"},
"OLLAMA_NOPRUNE": {"OLLAMA_NOPRUNE", NoPrune, "Do not prune model blobs on startup"},
"OLLAMA_NUM_PARALLEL": {"OLLAMA_NUM_PARALLEL", NumParallel, "Maximum number of parallel requests"},
"OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", AllowOrigins, "A comma separated list of allowed origins"},
"OLLAMA_RUNNERS_DIR": {"OLLAMA_RUNNERS_DIR", RunnersDir, "Location for runners"},
"OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread, "Always schedule model across all GPUs"},
"OLLAMA_TMPDIR": {"OLLAMA_TMPDIR", TmpDir, "Location for temporary files"},
}
if runtime.GOOS != "darwin" {
ret["CUDA_VISIBLE_DEVICES"] = EnvVar{"CUDA_VISIBLE_DEVICES", CudaVisibleDevices, "Set which NVIDIA devices are visible"}
ret["HIP_VISIBLE_DEVICES"] = EnvVar{"HIP_VISIBLE_DEVICES", HipVisibleDevices, "Set which AMD devices are visible"}
ret["ROCR_VISIBLE_DEVICES"] = EnvVar{"ROCR_VISIBLE_DEVICES", RocrVisibleDevices, "Set which AMD devices are visible"}
ret["GPU_DEVICE_ORDINAL"] = EnvVar{"GPU_DEVICE_ORDINAL", GpuDeviceOrdinal, "Set which AMD devices are visible"}
ret["HSA_OVERRIDE_GFX_VERSION"] = EnvVar{"HSA_OVERRIDE_GFX_VERSION", HsaOverrideGfxVersion, "Override the gfx used for all detected AMD GPUs"}
ret["OLLAMA_INTEL_GPU"] = EnvVar{"OLLAMA_INTEL_GPU", IntelGpu, "Enable experimental Intel GPU detection"}
}
return ret
}
func Values() map[string]string {
vals := make(map[string]string)
for k, v := range AsMap() {
vals[k] = fmt.Sprintf("%v", v.Value)
}
return vals
}
var defaultAllowOrigins = []string{
"localhost",
"127.0.0.1",
"0.0.0.0",
}
// Clean quotes and spaces from the value
func clean(key string) string {
return strings.Trim(os.Getenv(key), "\"' ")
}
func init() {
// default values
NumParallel = 0 // Autoselect
MaxRunners = 0 // Autoselect
MaxQueuedRequests = 512
KeepAlive = 5 * time.Minute
LoadConfig()
}
func LoadConfig() {
if debug := clean("OLLAMA_DEBUG"); debug != "" {
d, err := strconv.ParseBool(debug)
if err == nil {
Debug = d
} else {
Debug = true
}
}
if fa := clean("OLLAMA_FLASH_ATTENTION"); fa != "" {
d, err := strconv.ParseBool(fa)
if err == nil {
FlashAttention = d
}
}
RunnersDir = clean("OLLAMA_RUNNERS_DIR")
if runtime.GOOS == "windows" && RunnersDir == "" {
// On Windows we do not carry the payloads inside the main executable
appExe, err := os.Executable()
if err != nil {
slog.Error("failed to lookup executable path", "error", err)
}
cwd, err := os.Getwd()
if err != nil {
slog.Error("failed to lookup working directory", "error", err)
}
var paths []string
for _, root := range []string{filepath.Dir(appExe), cwd} {
paths = append(paths,
root,
filepath.Join(root, "windows-"+runtime.GOARCH),
filepath.Join(root, "dist", "windows-"+runtime.GOARCH),
)
}
// Try a few variations to improve developer experience when building from source in the local tree
for _, p := range paths {
candidate := filepath.Join(p, "ollama_runners")
_, err := os.Stat(candidate)
if err == nil {
RunnersDir = candidate
break
}
}
if RunnersDir == "" {
slog.Error("unable to locate llm runner directory. Set OLLAMA_RUNNERS_DIR to the location of 'ollama_runners'")
}
}
TmpDir = clean("OLLAMA_TMPDIR")
LLMLibrary = clean("OLLAMA_LLM_LIBRARY")
if onp := clean("OLLAMA_NUM_PARALLEL"); onp != "" {
val, err := strconv.Atoi(onp)
if err != nil {
slog.Error("invalid setting, ignoring", "OLLAMA_NUM_PARALLEL", onp, "error", err)
} else {
NumParallel = val
}
}
if nohistory := clean("OLLAMA_NOHISTORY"); nohistory != "" {
NoHistory = true
}
if spread := clean("OLLAMA_SCHED_SPREAD"); spread != "" {
s, err := strconv.ParseBool(spread)
if err == nil {
SchedSpread = s
} else {
SchedSpread = true
}
}
if noprune := clean("OLLAMA_NOPRUNE"); noprune != "" {
NoPrune = true
}
if origins := clean("OLLAMA_ORIGINS"); origins != "" {
AllowOrigins = strings.Split(origins, ",")
}
for _, allowOrigin := range defaultAllowOrigins {
AllowOrigins = append(AllowOrigins,
fmt.Sprintf("http://%s", allowOrigin),
fmt.Sprintf("https://%s", allowOrigin),
fmt.Sprintf("http://%s", net.JoinHostPort(allowOrigin, "*")),
fmt.Sprintf("https://%s", net.JoinHostPort(allowOrigin, "*")),
)
}
AllowOrigins = append(AllowOrigins,
"app://*",
"file://*",
"tauri://*",
)
maxRunners := clean("OLLAMA_MAX_LOADED_MODELS")
if maxRunners != "" {
m, err := strconv.Atoi(maxRunners)
if err != nil {
slog.Error("invalid setting, ignoring", "OLLAMA_MAX_LOADED_MODELS", maxRunners, "error", err)
} else {
MaxRunners = m
}
}
if onp := os.Getenv("OLLAMA_MAX_QUEUE"); onp != "" {
p, err := strconv.Atoi(onp)
if err != nil || p <= 0 {
slog.Error("invalid setting, ignoring", "OLLAMA_MAX_QUEUE", onp, "error", err)
} else {
MaxQueuedRequests = p
}
}
ka := clean("OLLAMA_KEEP_ALIVE")
if ka != "" {
loadKeepAlive(ka)
}
var err error
ModelsDir, err = getModelsDir()
if err != nil {
slog.Error("invalid setting", "OLLAMA_MODELS", ModelsDir, "error", err)
}
Host, err = getOllamaHost()
if err != nil {
slog.Error("invalid setting", "OLLAMA_HOST", Host, "error", err, "using default port", Host.Port)
}
if set, err := strconv.ParseBool(clean("OLLAMA_INTEL_GPU")); err == nil {
IntelGpu = set
}
CudaVisibleDevices = clean("CUDA_VISIBLE_DEVICES")
HipVisibleDevices = clean("HIP_VISIBLE_DEVICES")
RocrVisibleDevices = clean("ROCR_VISIBLE_DEVICES")
GpuDeviceOrdinal = clean("GPU_DEVICE_ORDINAL")
HsaOverrideGfxVersion = clean("HSA_OVERRIDE_GFX_VERSION")
}
func getModelsDir() (string, error) {
if models, exists := os.LookupEnv("OLLAMA_MODELS"); exists {
return models, nil
}
home, err := os.UserHomeDir()
if err != nil {
return "", err
}
return filepath.Join(home, ".ollama", "models"), nil
}
func getOllamaHost() (*OllamaHost, error) {
defaultPort := "11434" defaultPort := "11434"
hostVar := os.Getenv("OLLAMA_HOST") s := strings.TrimSpace(Var("OLLAMA_HOST"))
hostVar = strings.TrimSpace(strings.Trim(strings.TrimSpace(hostVar), "\"'")) scheme, hostport, ok := strings.Cut(s, "://")
scheme, hostport, ok := strings.Cut(hostVar, "://")
switch { switch {
case !ok: case !ok:
scheme, hostport = "http", hostVar scheme, hostport = "http", s
case scheme == "http": case scheme == "http":
defaultPort = "80" defaultPort = "80"
case scheme == "https": case scheme == "https":
@ -323,38 +43,242 @@ func getOllamaHost() (*OllamaHost, error) {
} }
} }
if portNum, err := strconv.ParseInt(port, 10, 32); err != nil || portNum > 65535 || portNum < 0 { if n, err := strconv.ParseInt(port, 10, 32); err != nil || n > 65535 || n < 0 {
return &OllamaHost{ slog.Warn("invalid port, using default", "port", port, "default", defaultPort)
return &url.URL{
Scheme: scheme, Scheme: scheme,
Host: host, Host: net.JoinHostPort(host, defaultPort),
Port: defaultPort, }
}, ErrInvalidHostPort
} }
return &OllamaHost{ return &url.URL{
Scheme: scheme, Scheme: scheme,
Host: host, Host: net.JoinHostPort(host, port),
Port: port, }
}, nil
} }
func loadKeepAlive(ka string) { // Origins returns a list of allowed origins. Origins can be configured via the OLLAMA_ORIGINS environment variable.
v, err := strconv.Atoi(ka) func Origins() (origins []string) {
if s := Var("OLLAMA_ORIGINS"); s != "" {
origins = strings.Split(s, ",")
}
for _, origin := range []string{"localhost", "127.0.0.1", "0.0.0.0"} {
origins = append(origins,
fmt.Sprintf("http://%s", origin),
fmt.Sprintf("https://%s", origin),
fmt.Sprintf("http://%s", net.JoinHostPort(origin, "*")),
fmt.Sprintf("https://%s", net.JoinHostPort(origin, "*")),
)
}
origins = append(origins,
"app://*",
"file://*",
"tauri://*",
)
return origins
}
// Models returns the path to the models directory. Models directory can be configured via the OLLAMA_MODELS environment variable.
// Default is $HOME/.ollama/models
func Models() string {
if s := Var("OLLAMA_MODELS"); s != "" {
return s
}
home, err := os.UserHomeDir()
if err != nil { if err != nil {
d, err := time.ParseDuration(ka) panic(err)
if err == nil { }
if d < 0 {
KeepAlive = time.Duration(math.MaxInt64) return filepath.Join(home, ".ollama", "models")
}
// KeepAlive returns the duration that models stay loaded in memory. KeepAlive can be configured via the OLLAMA_KEEP_ALIVE environment variable.
// Negative values are treated as infinite. Zero is treated as no keep alive.
// Default is 5 minutes.
func KeepAlive() (keepAlive time.Duration) {
keepAlive = 5 * time.Minute
if s := Var("OLLAMA_KEEP_ALIVE"); s != "" {
if d, err := time.ParseDuration(s); err == nil {
keepAlive = d
} else if n, err := strconv.ParseInt(s, 10, 64); err == nil {
keepAlive = time.Duration(n) * time.Second
}
}
if keepAlive < 0 {
return time.Duration(math.MaxInt64)
}
return keepAlive
}
func Bool(k string) func() bool {
return func() bool {
if s := Var(k); s != "" {
b, err := strconv.ParseBool(s)
if err != nil {
return true
}
return b
}
return false
}
}
var (
// Debug enabled additional debug information.
Debug = Bool("OLLAMA_DEBUG")
// FlashAttention enables the experimental flash attention feature.
FlashAttention = Bool("OLLAMA_FLASH_ATTENTION")
// NoHistory disables readline history.
NoHistory = Bool("OLLAMA_NOHISTORY")
// NoPrune disables pruning of model blobs on startup.
NoPrune = Bool("OLLAMA_NOPRUNE")
// SchedSpread allows scheduling models across all GPUs.
SchedSpread = Bool("OLLAMA_SCHED_SPREAD")
// IntelGPU enables experimental Intel GPU detection.
IntelGPU = Bool("OLLAMA_INTEL_GPU")
)
func String(s string) func() string {
return func() string {
return Var(s)
}
}
var (
LLMLibrary = String("OLLAMA_LLM_LIBRARY")
TmpDir = String("OLLAMA_TMPDIR")
CudaVisibleDevices = String("CUDA_VISIBLE_DEVICES")
HipVisibleDevices = String("HIP_VISIBLE_DEVICES")
RocrVisibleDevices = String("ROCR_VISIBLE_DEVICES")
GpuDeviceOrdinal = String("GPU_DEVICE_ORDINAL")
HsaOverrideGfxVersion = String("HSA_OVERRIDE_GFX_VERSION")
)
func RunnersDir() (p string) {
if p := Var("OLLAMA_RUNNERS_DIR"); p != "" {
return p
}
if runtime.GOOS != "windows" {
return
}
defer func() {
if p == "" {
slog.Error("unable to locate llm runner directory. Set OLLAMA_RUNNERS_DIR to the location of 'ollama_runners'")
}
}()
// On Windows we do not carry the payloads inside the main executable
exe, err := os.Executable()
if err != nil {
return
}
cwd, err := os.Getwd()
if err != nil {
return
}
var paths []string
for _, root := range []string{filepath.Dir(exe), cwd} {
paths = append(paths,
root,
filepath.Join(root, "windows-"+runtime.GOARCH),
filepath.Join(root, "dist", "windows-"+runtime.GOARCH),
)
}
// Try a few variations to improve developer experience when building from source in the local tree
for _, path := range paths {
candidate := filepath.Join(path, "ollama_runners")
if _, err := os.Stat(candidate); err == nil {
p = candidate
break
}
}
return p
}
func Uint(key string, defaultValue uint) func() uint {
return func() uint {
if s := Var(key); s != "" {
if n, err := strconv.ParseUint(s, 10, 64); err != nil {
slog.Warn("invalid environment variable, using default", "key", key, "value", s, "default", defaultValue)
} else { } else {
KeepAlive = d return uint(n)
} }
} }
} else {
d := time.Duration(v) * time.Second return defaultValue
if d < 0 {
KeepAlive = time.Duration(math.MaxInt64)
} else {
KeepAlive = d
}
} }
} }
var (
// NumParallel sets the number of parallel model requests. NumParallel can be configured via the OLLAMA_NUM_PARALLEL environment variable.
NumParallel = Uint("OLLAMA_NUM_PARALLEL", 0)
// MaxRunners sets the maximum number of loaded models. MaxRunners can be configured via the OLLAMA_MAX_LOADED_MODELS environment variable.
MaxRunners = Uint("OLLAMA_MAX_LOADED_MODELS", 0)
// MaxQueue sets the maximum number of queued requests. MaxQueue can be configured via the OLLAMA_MAX_QUEUE environment variable.
MaxQueue = Uint("OLLAMA_MAX_QUEUE", 512)
// MaxVRAM sets a maximum VRAM override in bytes. MaxVRAM can be configured via the OLLAMA_MAX_VRAM environment variable.
MaxVRAM = Uint("OLLAMA_MAX_VRAM", 0)
)
type EnvVar struct {
Name string
Value any
Description string
}
func AsMap() map[string]EnvVar {
ret := map[string]EnvVar{
"OLLAMA_DEBUG": {"OLLAMA_DEBUG", Debug(), "Show additional debug information (e.g. OLLAMA_DEBUG=1)"},
"OLLAMA_FLASH_ATTENTION": {"OLLAMA_FLASH_ATTENTION", FlashAttention(), "Enabled flash attention"},
"OLLAMA_HOST": {"OLLAMA_HOST", Host(), "IP Address for the ollama server (default 127.0.0.1:11434)"},
"OLLAMA_KEEP_ALIVE": {"OLLAMA_KEEP_ALIVE", KeepAlive(), "The duration that models stay loaded in memory (default \"5m\")"},
"OLLAMA_LLM_LIBRARY": {"OLLAMA_LLM_LIBRARY", LLMLibrary(), "Set LLM library to bypass autodetection"},
"OLLAMA_MAX_LOADED_MODELS": {"OLLAMA_MAX_LOADED_MODELS", MaxRunners(), "Maximum number of loaded models per GPU"},
"OLLAMA_MAX_QUEUE": {"OLLAMA_MAX_QUEUE", MaxQueue(), "Maximum number of queued requests"},
"OLLAMA_MODELS": {"OLLAMA_MODELS", Models(), "The path to the models directory"},
"OLLAMA_NOHISTORY": {"OLLAMA_NOHISTORY", NoHistory(), "Do not preserve readline history"},
"OLLAMA_NOPRUNE": {"OLLAMA_NOPRUNE", NoPrune(), "Do not prune model blobs on startup"},
"OLLAMA_NUM_PARALLEL": {"OLLAMA_NUM_PARALLEL", NumParallel(), "Maximum number of parallel requests"},
"OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", Origins(), "A comma separated list of allowed origins"},
"OLLAMA_RUNNERS_DIR": {"OLLAMA_RUNNERS_DIR", RunnersDir(), "Location for runners"},
"OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread(), "Always schedule model across all GPUs"},
"OLLAMA_TMPDIR": {"OLLAMA_TMPDIR", TmpDir(), "Location for temporary files"},
}
if runtime.GOOS != "darwin" {
ret["CUDA_VISIBLE_DEVICES"] = EnvVar{"CUDA_VISIBLE_DEVICES", CudaVisibleDevices(), "Set which NVIDIA devices are visible"}
ret["HIP_VISIBLE_DEVICES"] = EnvVar{"HIP_VISIBLE_DEVICES", HipVisibleDevices(), "Set which AMD devices are visible"}
ret["ROCR_VISIBLE_DEVICES"] = EnvVar{"ROCR_VISIBLE_DEVICES", RocrVisibleDevices(), "Set which AMD devices are visible"}
ret["GPU_DEVICE_ORDINAL"] = EnvVar{"GPU_DEVICE_ORDINAL", GpuDeviceOrdinal(), "Set which AMD devices are visible"}
ret["HSA_OVERRIDE_GFX_VERSION"] = EnvVar{"HSA_OVERRIDE_GFX_VERSION", HsaOverrideGfxVersion(), "Override the gfx used for all detected AMD GPUs"}
ret["OLLAMA_INTEL_GPU"] = EnvVar{"OLLAMA_INTEL_GPU", IntelGPU(), "Enable experimental Intel GPU detection"}
}
return ret
}
func Values() map[string]string {
vals := make(map[string]string)
for k, v := range AsMap() {
vals[k] = fmt.Sprintf("%v", v.Value)
}
return vals
}
// Var returns an environment variable stripped of leading and trailing quotes or spaces
func Var(key string) string {
return strings.Trim(strings.TrimSpace(os.Getenv(key)), "\"'")
}

View file

@ -1,87 +1,234 @@
package envconfig package envconfig
import ( import (
"fmt"
"math" "math"
"net"
"testing" "testing"
"time" "time"
"github.com/stretchr/testify/assert" "github.com/google/go-cmp/cmp"
"github.com/stretchr/testify/require"
) )
func TestConfig(t *testing.T) { func TestHost(t *testing.T) {
Debug = false // Reset whatever was loaded in init() cases := map[string]struct {
t.Setenv("OLLAMA_DEBUG", "")
LoadConfig()
require.False(t, Debug)
t.Setenv("OLLAMA_DEBUG", "false")
LoadConfig()
require.False(t, Debug)
t.Setenv("OLLAMA_DEBUG", "1")
LoadConfig()
require.True(t, Debug)
t.Setenv("OLLAMA_FLASH_ATTENTION", "1")
LoadConfig()
require.True(t, FlashAttention)
t.Setenv("OLLAMA_KEEP_ALIVE", "")
LoadConfig()
require.Equal(t, 5*time.Minute, KeepAlive)
t.Setenv("OLLAMA_KEEP_ALIVE", "3")
LoadConfig()
require.Equal(t, 3*time.Second, KeepAlive)
t.Setenv("OLLAMA_KEEP_ALIVE", "1h")
LoadConfig()
require.Equal(t, 1*time.Hour, KeepAlive)
t.Setenv("OLLAMA_KEEP_ALIVE", "-1s")
LoadConfig()
require.Equal(t, time.Duration(math.MaxInt64), KeepAlive)
t.Setenv("OLLAMA_KEEP_ALIVE", "-1")
LoadConfig()
require.Equal(t, time.Duration(math.MaxInt64), KeepAlive)
}
func TestClientFromEnvironment(t *testing.T) {
type testCase struct {
value string value string
expect string expect string
err error }{
"empty": {"", "127.0.0.1:11434"},
"only address": {"1.2.3.4", "1.2.3.4:11434"},
"only port": {":1234", ":1234"},
"address and port": {"1.2.3.4:1234", "1.2.3.4:1234"},
"hostname": {"example.com", "example.com:11434"},
"hostname and port": {"example.com:1234", "example.com:1234"},
"zero port": {":0", ":0"},
"too large port": {":66000", ":11434"},
"too small port": {":-1", ":11434"},
"ipv6 localhost": {"[::1]", "[::1]:11434"},
"ipv6 world open": {"[::]", "[::]:11434"},
"ipv6 no brackets": {"::1", "[::1]:11434"},
"ipv6 + port": {"[::1]:1337", "[::1]:1337"},
"extra space": {" 1.2.3.4 ", "1.2.3.4:11434"},
"extra quotes": {"\"1.2.3.4\"", "1.2.3.4:11434"},
"extra space+quotes": {" \" 1.2.3.4 \" ", "1.2.3.4:11434"},
"extra single quotes": {"'1.2.3.4'", "1.2.3.4:11434"},
"http": {"http://1.2.3.4", "1.2.3.4:80"},
"http port": {"http://1.2.3.4:4321", "1.2.3.4:4321"},
"https": {"https://1.2.3.4", "1.2.3.4:443"},
"https port": {"https://1.2.3.4:4321", "1.2.3.4:4321"},
} }
hostTestCases := map[string]*testCase{ for name, tt := range cases {
"empty": {value: "", expect: "127.0.0.1:11434"}, t.Run(name, func(t *testing.T) {
"only address": {value: "1.2.3.4", expect: "1.2.3.4:11434"}, t.Setenv("OLLAMA_HOST", tt.value)
"only port": {value: ":1234", expect: ":1234"}, if host := Host(); host.Host != tt.expect {
"address and port": {value: "1.2.3.4:1234", expect: "1.2.3.4:1234"}, t.Errorf("%s: expected %s, got %s", name, tt.expect, host.Host)
"hostname": {value: "example.com", expect: "example.com:11434"}, }
"hostname and port": {value: "example.com:1234", expect: "example.com:1234"}, })
"zero port": {value: ":0", expect: ":0"}, }
"too large port": {value: ":66000", err: ErrInvalidHostPort}, }
"too small port": {value: ":-1", err: ErrInvalidHostPort},
"ipv6 localhost": {value: "[::1]", expect: "[::1]:11434"}, func TestOrigins(t *testing.T) {
"ipv6 world open": {value: "[::]", expect: "[::]:11434"}, cases := []struct {
"ipv6 no brackets": {value: "::1", expect: "[::1]:11434"}, value string
"ipv6 + port": {value: "[::1]:1337", expect: "[::1]:1337"}, expect []string
"extra space": {value: " 1.2.3.4 ", expect: "1.2.3.4:11434"}, }{
"extra quotes": {value: "\"1.2.3.4\"", expect: "1.2.3.4:11434"}, {"", []string{
"extra space+quotes": {value: " \" 1.2.3.4 \" ", expect: "1.2.3.4:11434"}, "http://localhost",
"extra single quotes": {value: "'1.2.3.4'", expect: "1.2.3.4:11434"}, "https://localhost",
} "http://localhost:*",
"https://localhost:*",
for k, v := range hostTestCases { "http://127.0.0.1",
t.Run(k, func(t *testing.T) { "https://127.0.0.1",
t.Setenv("OLLAMA_HOST", v.value) "http://127.0.0.1:*",
LoadConfig() "https://127.0.0.1:*",
"http://0.0.0.0",
oh, err := getOllamaHost() "https://0.0.0.0",
if err != v.err { "http://0.0.0.0:*",
t.Fatalf("expected %s, got %s", v.err, err) "https://0.0.0.0:*",
} "app://*",
"file://*",
if err == nil { "tauri://*",
host := net.JoinHostPort(oh.Host, oh.Port) }},
assert.Equal(t, v.expect, host, fmt.Sprintf("%s: expected %s, got %s", k, v.expect, host)) {"http://10.0.0.1", []string{
"http://10.0.0.1",
"http://localhost",
"https://localhost",
"http://localhost:*",
"https://localhost:*",
"http://127.0.0.1",
"https://127.0.0.1",
"http://127.0.0.1:*",
"https://127.0.0.1:*",
"http://0.0.0.0",
"https://0.0.0.0",
"http://0.0.0.0:*",
"https://0.0.0.0:*",
"app://*",
"file://*",
"tauri://*",
}},
{"http://172.16.0.1,https://192.168.0.1", []string{
"http://172.16.0.1",
"https://192.168.0.1",
"http://localhost",
"https://localhost",
"http://localhost:*",
"https://localhost:*",
"http://127.0.0.1",
"https://127.0.0.1",
"http://127.0.0.1:*",
"https://127.0.0.1:*",
"http://0.0.0.0",
"https://0.0.0.0",
"http://0.0.0.0:*",
"https://0.0.0.0:*",
"app://*",
"file://*",
"tauri://*",
}},
{"http://totally.safe,http://definitely.legit", []string{
"http://totally.safe",
"http://definitely.legit",
"http://localhost",
"https://localhost",
"http://localhost:*",
"https://localhost:*",
"http://127.0.0.1",
"https://127.0.0.1",
"http://127.0.0.1:*",
"https://127.0.0.1:*",
"http://0.0.0.0",
"https://0.0.0.0",
"http://0.0.0.0:*",
"https://0.0.0.0:*",
"app://*",
"file://*",
"tauri://*",
}},
}
for _, tt := range cases {
t.Run(tt.value, func(t *testing.T) {
t.Setenv("OLLAMA_ORIGINS", tt.value)
if diff := cmp.Diff(Origins(), tt.expect); diff != "" {
t.Errorf("%s: mismatch (-want +got):\n%s", tt.value, diff)
}
})
}
}
func TestBool(t *testing.T) {
cases := map[string]bool{
"": false,
"true": true,
"false": false,
"1": true,
"0": false,
// invalid values
"random": true,
"something": true,
}
for k, v := range cases {
t.Run(k, func(t *testing.T) {
t.Setenv("OLLAMA_BOOL", k)
if b := Bool("OLLAMA_BOOL")(); b != v {
t.Errorf("%s: expected %t, got %t", k, v, b)
}
})
}
}
func TestUint(t *testing.T) {
cases := map[string]uint{
"0": 0,
"1": 1,
"1337": 1337,
// default values
"": 11434,
"-1": 11434,
"0o10": 11434,
"0x10": 11434,
"string": 11434,
}
for k, v := range cases {
t.Run(k, func(t *testing.T) {
t.Setenv("OLLAMA_UINT", k)
if i := Uint("OLLAMA_UINT", 11434)(); i != v {
t.Errorf("%s: expected %d, got %d", k, v, i)
}
})
}
}
func TestKeepAlive(t *testing.T) {
cases := map[string]time.Duration{
"": 5 * time.Minute,
"1s": time.Second,
"1m": time.Minute,
"1h": time.Hour,
"5m0s": 5 * time.Minute,
"1h2m3s": 1*time.Hour + 2*time.Minute + 3*time.Second,
"0": time.Duration(0),
"60": 60 * time.Second,
"120": 2 * time.Minute,
"3600": time.Hour,
"-0": time.Duration(0),
"-1": time.Duration(math.MaxInt64),
"-1m": time.Duration(math.MaxInt64),
// invalid values
" ": 5 * time.Minute,
"???": 5 * time.Minute,
"1d": 5 * time.Minute,
"1y": 5 * time.Minute,
"1w": 5 * time.Minute,
}
for tt, expect := range cases {
t.Run(tt, func(t *testing.T) {
t.Setenv("OLLAMA_KEEP_ALIVE", tt)
if actual := KeepAlive(); actual != expect {
t.Errorf("%s: expected %s, got %s", tt, expect, actual)
}
})
}
}
func TestVar(t *testing.T) {
cases := map[string]string{
"value": "value",
" value ": "value",
" 'value' ": "value",
` "value" `: "value",
" ' value ' ": " value ",
` " value " `: " value ",
}
for k, v := range cases {
t.Run(k, func(t *testing.T) {
t.Setenv("OLLAMA_VAR", k)
if s := Var("OLLAMA_VAR"); s != v {
t.Errorf("%s: expected %q, got %q", k, v, s)
} }
}) })
} }

View file

@ -61,9 +61,9 @@ func AMDGetGPUInfo() []RocmGPUInfo {
// Determine if the user has already pre-selected which GPUs to look at, then ignore the others // Determine if the user has already pre-selected which GPUs to look at, then ignore the others
var visibleDevices []string var visibleDevices []string
hipVD := envconfig.HipVisibleDevices // zero based index only hipVD := envconfig.HipVisibleDevices() // zero based index only
rocrVD := envconfig.RocrVisibleDevices // zero based index or UUID, but consumer cards seem to not support UUID rocrVD := envconfig.RocrVisibleDevices() // zero based index or UUID, but consumer cards seem to not support UUID
gpuDO := envconfig.GpuDeviceOrdinal // zero based index gpuDO := envconfig.GpuDeviceOrdinal() // zero based index
switch { switch {
// TODO is this priorty order right? // TODO is this priorty order right?
case hipVD != "": case hipVD != "":
@ -76,7 +76,7 @@ func AMDGetGPUInfo() []RocmGPUInfo {
visibleDevices = strings.Split(gpuDO, ",") visibleDevices = strings.Split(gpuDO, ",")
} }
gfxOverride := envconfig.HsaOverrideGfxVersion gfxOverride := envconfig.HsaOverrideGfxVersion()
var supported []string var supported []string
libDir := "" libDir := ""

View file

@ -53,7 +53,7 @@ func AMDGetGPUInfo() []RocmGPUInfo {
} }
var supported []string var supported []string
gfxOverride := envconfig.HsaOverrideGfxVersion gfxOverride := envconfig.HsaOverrideGfxVersion()
if gfxOverride == "" { if gfxOverride == "" {
supported, err = GetSupportedGFX(libDir) supported, err = GetSupportedGFX(libDir)
if err != nil { if err != nil {

View file

@ -26,7 +26,7 @@ func PayloadsDir() (string, error) {
defer lock.Unlock() defer lock.Unlock()
var err error var err error
if payloadsDir == "" { if payloadsDir == "" {
runnersDir := envconfig.RunnersDir runnersDir := envconfig.RunnersDir()
if runnersDir != "" { if runnersDir != "" {
payloadsDir = runnersDir payloadsDir = runnersDir
@ -35,7 +35,7 @@ func PayloadsDir() (string, error) {
// The remainder only applies on non-windows where we still carry payloads in the main executable // The remainder only applies on non-windows where we still carry payloads in the main executable
cleanupTmpDirs() cleanupTmpDirs()
tmpDir := envconfig.TmpDir tmpDir := envconfig.TmpDir()
if tmpDir == "" { if tmpDir == "" {
tmpDir, err = os.MkdirTemp("", "ollama") tmpDir, err = os.MkdirTemp("", "ollama")
if err != nil { if err != nil {
@ -105,7 +105,7 @@ func cleanupTmpDirs() {
func Cleanup() { func Cleanup() {
lock.Lock() lock.Lock()
defer lock.Unlock() defer lock.Unlock()
runnersDir := envconfig.RunnersDir runnersDir := envconfig.RunnersDir()
if payloadsDir != "" && runnersDir == "" && runtime.GOOS != "windows" { if payloadsDir != "" && runnersDir == "" && runtime.GOOS != "windows" {
// We want to fully clean up the tmpdir parent of the payloads dir // We want to fully clean up the tmpdir parent of the payloads dir
tmpDir := filepath.Clean(filepath.Join(payloadsDir, "..")) tmpDir := filepath.Clean(filepath.Join(payloadsDir, ".."))

View file

@ -230,8 +230,8 @@ func GetGPUInfo() GpuInfoList {
// On windows we bundle the nvidia library one level above the runner dir // On windows we bundle the nvidia library one level above the runner dir
depPath := "" depPath := ""
if runtime.GOOS == "windows" && envconfig.RunnersDir != "" { if runtime.GOOS == "windows" && envconfig.RunnersDir() != "" {
depPath = filepath.Join(filepath.Dir(envconfig.RunnersDir), "cuda") depPath = filepath.Join(filepath.Dir(envconfig.RunnersDir()), "cuda")
} }
// Load ALL libraries // Load ALL libraries
@ -302,12 +302,12 @@ func GetGPUInfo() GpuInfoList {
} }
// Intel // Intel
if envconfig.IntelGpu { if envconfig.IntelGPU() {
oHandles = initOneAPIHandles() oHandles = initOneAPIHandles()
// On windows we bundle the oneapi library one level above the runner dir // On windows we bundle the oneapi library one level above the runner dir
depPath = "" depPath = ""
if runtime.GOOS == "windows" && envconfig.RunnersDir != "" { if runtime.GOOS == "windows" && envconfig.RunnersDir() != "" {
depPath = filepath.Join(filepath.Dir(envconfig.RunnersDir), "oneapi") depPath = filepath.Join(filepath.Dir(envconfig.RunnersDir()), "oneapi")
} }
for d := range oHandles.oneapi.num_drivers { for d := range oHandles.oneapi.num_drivers {
@ -611,7 +611,7 @@ func LoadOneapiMgmt(oneapiLibPaths []string) (int, *C.oneapi_handle_t, string) {
} }
func getVerboseState() C.uint16_t { func getVerboseState() C.uint16_t {
if envconfig.Debug { if envconfig.Debug() {
return C.uint16_t(1) return C.uint16_t(1)
} }
return C.uint16_t(0) return C.uint16_t(0)

View file

@ -45,14 +45,7 @@ func TestUnicodeModelDir(t *testing.T) {
defer os.RemoveAll(modelDir) defer os.RemoveAll(modelDir)
slog.Info("unicode", "OLLAMA_MODELS", modelDir) slog.Info("unicode", "OLLAMA_MODELS", modelDir)
oldModelsDir := os.Getenv("OLLAMA_MODELS") t.Setenv("OLLAMA_MODELS", modelDir)
if oldModelsDir == "" {
defer os.Unsetenv("OLLAMA_MODELS")
} else {
defer os.Setenv("OLLAMA_MODELS", oldModelsDir)
}
err = os.Setenv("OLLAMA_MODELS", modelDir)
require.NoError(t, err)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel() defer cancel()

View file

@ -5,14 +5,16 @@ package integration
import ( import (
"context" "context"
"log/slog" "log/slog"
"os"
"strconv" "strconv"
"sync" "sync"
"testing" "testing"
"time" "time"
"github.com/ollama/ollama/api"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format"
) )
func TestMultiModelConcurrency(t *testing.T) { func TestMultiModelConcurrency(t *testing.T) {
@ -106,13 +108,16 @@ func TestIntegrationConcurrentPredictOrcaMini(t *testing.T) {
// Stress the system if we know how much VRAM it has, and attempt to load more models than will fit // Stress the system if we know how much VRAM it has, and attempt to load more models than will fit
func TestMultiModelStress(t *testing.T) { func TestMultiModelStress(t *testing.T) {
vram := os.Getenv("OLLAMA_MAX_VRAM") // TODO - discover actual VRAM s := os.Getenv("OLLAMA_MAX_VRAM") // TODO - discover actual VRAM
if vram == "" { if s == "" {
t.Skip("OLLAMA_MAX_VRAM not specified, can't pick the right models for the stress test") t.Skip("OLLAMA_MAX_VRAM not specified, can't pick the right models for the stress test")
} }
max, err := strconv.ParseUint(vram, 10, 64)
require.NoError(t, err) maxVram, err := strconv.ParseUint(s, 10, 64)
const MB = uint64(1024 * 1024) if err != nil {
t.Fatal(err)
}
type model struct { type model struct {
name string name string
size uint64 // Approximate amount of VRAM they typically use when fully loaded in VRAM size uint64 // Approximate amount of VRAM they typically use when fully loaded in VRAM
@ -121,83 +126,82 @@ func TestMultiModelStress(t *testing.T) {
smallModels := []model{ smallModels := []model{
{ {
name: "orca-mini", name: "orca-mini",
size: 2992 * MB, size: 2992 * format.MebiByte,
}, },
{ {
name: "phi", name: "phi",
size: 2616 * MB, size: 2616 * format.MebiByte,
}, },
{ {
name: "gemma:2b", name: "gemma:2b",
size: 2364 * MB, size: 2364 * format.MebiByte,
}, },
{ {
name: "stable-code:3b", name: "stable-code:3b",
size: 2608 * MB, size: 2608 * format.MebiByte,
}, },
{ {
name: "starcoder2:3b", name: "starcoder2:3b",
size: 2166 * MB, size: 2166 * format.MebiByte,
}, },
} }
mediumModels := []model{ mediumModels := []model{
{ {
name: "llama2", name: "llama2",
size: 5118 * MB, size: 5118 * format.MebiByte,
}, },
{ {
name: "mistral", name: "mistral",
size: 4620 * MB, size: 4620 * format.MebiByte,
}, },
{ {
name: "orca-mini:7b", name: "orca-mini:7b",
size: 5118 * MB, size: 5118 * format.MebiByte,
}, },
{ {
name: "dolphin-mistral", name: "dolphin-mistral",
size: 4620 * MB, size: 4620 * format.MebiByte,
}, },
{ {
name: "gemma:7b", name: "gemma:7b",
size: 5000 * MB, size: 5000 * format.MebiByte,
},
{
name: "codellama:7b",
size: 5118 * format.MebiByte,
}, },
// TODO - uncomment this once #3565 is merged and this is rebased on it
// {
// name: "codellama:7b",
// size: 5118 * MB,
// },
} }
// These seem to be too slow to be useful... // These seem to be too slow to be useful...
// largeModels := []model{ // largeModels := []model{
// { // {
// name: "llama2:13b", // name: "llama2:13b",
// size: 7400 * MB, // size: 7400 * format.MebiByte,
// }, // },
// { // {
// name: "codellama:13b", // name: "codellama:13b",
// size: 7400 * MB, // size: 7400 * format.MebiByte,
// }, // },
// { // {
// name: "orca-mini:13b", // name: "orca-mini:13b",
// size: 7400 * MB, // size: 7400 * format.MebiByte,
// }, // },
// { // {
// name: "gemma:7b", // name: "gemma:7b",
// size: 5000 * MB, // size: 5000 * format.MebiByte,
// }, // },
// { // {
// name: "starcoder2:15b", // name: "starcoder2:15b",
// size: 9100 * MB, // size: 9100 * format.MebiByte,
// }, // },
// } // }
var chosenModels []model var chosenModels []model
switch { switch {
case max < 10000*MB: case maxVram < 10000*format.MebiByte:
slog.Info("selecting small models") slog.Info("selecting small models")
chosenModels = smallModels chosenModels = smallModels
// case max < 30000*MB: // case maxVram < 30000*format.MebiByte:
default: default:
slog.Info("selecting medium models") slog.Info("selecting medium models")
chosenModels = mediumModels chosenModels = mediumModels
@ -226,15 +230,15 @@ func TestMultiModelStress(t *testing.T) {
} }
var wg sync.WaitGroup var wg sync.WaitGroup
consumed := uint64(256 * MB) // Assume some baseline usage consumed := uint64(256 * format.MebiByte) // Assume some baseline usage
for i := 0; i < len(req); i++ { for i := 0; i < len(req); i++ {
// Always get at least 2 models, but dont' overshoot VRAM too much or we'll take too long // Always get at least 2 models, but dont' overshoot VRAM too much or we'll take too long
if i > 1 && consumed > max { if i > 1 && consumed > vram {
slog.Info("achieved target vram exhaustion", "count", i, "vramMB", max/1024/1024, "modelsMB", consumed/1024/1024) slog.Info("achieved target vram exhaustion", "count", i, "vram", format.HumanBytes2(vram), "models", format.HumanBytes2(consumed))
break break
} }
consumed += chosenModels[i].size consumed += chosenModels[i].size
slog.Info("target vram", "count", i, "vramMB", max/1024/1024, "modelsMB", consumed/1024/1024) slog.Info("target vram", "count", i, "vram", format.HumanBytes2(vram), "models", format.HumanBytes2(consumed))
wg.Add(1) wg.Add(1)
go func(i int) { go func(i int) {

View file

@ -5,7 +5,6 @@ package integration
import ( import (
"context" "context"
"errors" "errors"
"fmt"
"log/slog" "log/slog"
"os" "os"
"strconv" "strconv"
@ -14,8 +13,10 @@ import (
"testing" "testing"
"time" "time"
"github.com/ollama/ollama/api"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
) )
func TestMaxQueue(t *testing.T) { func TestMaxQueue(t *testing.T) {
@ -27,13 +28,10 @@ func TestMaxQueue(t *testing.T) {
// Note: This test can be quite slow when running in CPU mode, so keep the threadCount low unless your on GPU // Note: This test can be quite slow when running in CPU mode, so keep the threadCount low unless your on GPU
// Also note that by default Darwin can't sustain > ~128 connections without adjusting limits // Also note that by default Darwin can't sustain > ~128 connections without adjusting limits
threadCount := 32 threadCount := 32
mq := os.Getenv("OLLAMA_MAX_QUEUE") if maxQueue := envconfig.MaxQueue(); maxQueue != 0 {
if mq != "" { threadCount = maxQueue
var err error
threadCount, err = strconv.Atoi(mq)
require.NoError(t, err)
} else { } else {
os.Setenv("OLLAMA_MAX_QUEUE", fmt.Sprintf("%d", threadCount)) t.Setenv("OLLAMA_MAX_QUEUE", strconv.Itoa(threadCount))
} }
req := api.GenerateRequest{ req := api.GenerateRequest{

View file

@ -8,14 +8,14 @@ import (
"testing" "testing"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/gpu" "github.com/ollama/ollama/gpu"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestEstimateGPULayers(t *testing.T) { func TestEstimateGPULayers(t *testing.T) {
envconfig.Debug = true t.Setenv("OLLAMA_DEBUG", "1")
modelName := "dummy" modelName := "dummy"
f, err := os.CreateTemp(t.TempDir(), modelName) f, err := os.CreateTemp(t.TempDir(), modelName)
require.NoError(t, err) require.NoError(t, err)

View file

@ -163,7 +163,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
} else { } else {
servers = serversForGpu(gpus[0]) // All GPUs in the list are matching Library and Variant servers = serversForGpu(gpus[0]) // All GPUs in the list are matching Library and Variant
} }
demandLib := envconfig.LLMLibrary demandLib := envconfig.LLMLibrary()
if demandLib != "" { if demandLib != "" {
serverPath := availableServers[demandLib] serverPath := availableServers[demandLib]
if serverPath == "" { if serverPath == "" {
@ -195,7 +195,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
params = append(params, "--n-gpu-layers", fmt.Sprintf("%d", opts.NumGPU)) params = append(params, "--n-gpu-layers", fmt.Sprintf("%d", opts.NumGPU))
} }
if envconfig.Debug { if envconfig.Debug() {
params = append(params, "--verbose") params = append(params, "--verbose")
} }
@ -221,7 +221,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
params = append(params, "--memory-f32") params = append(params, "--memory-f32")
} }
flashAttnEnabled := envconfig.FlashAttention flashAttnEnabled := envconfig.FlashAttention()
for _, g := range gpus { for _, g := range gpus {
// only cuda (compute capability 7+) and metal support flash attention // only cuda (compute capability 7+) and metal support flash attention
@ -382,7 +382,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
} }
slog.Info("starting llama server", "cmd", s.cmd.String()) slog.Info("starting llama server", "cmd", s.cmd.String())
if envconfig.Debug { if envconfig.Debug() {
filteredEnv := []string{} filteredEnv := []string{}
for _, ev := range s.cmd.Env { for _, ev := range s.cmd.Env {
if strings.HasPrefix(ev, "CUDA_") || if strings.HasPrefix(ev, "CUDA_") ||

View file

@ -646,7 +646,7 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
return err return err
} }
if !envconfig.NoPrune && old != nil { if !envconfig.NoPrune() && old != nil {
if err := old.RemoveLayers(); err != nil { if err := old.RemoveLayers(); err != nil {
return err return err
} }
@ -885,7 +885,7 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
// build deleteMap to prune unused layers // build deleteMap to prune unused layers
deleteMap := make(map[string]struct{}) deleteMap := make(map[string]struct{})
if !envconfig.NoPrune { if !envconfig.NoPrune() {
manifest, _, err = GetManifest(mp) manifest, _, err = GetManifest(mp)
if err != nil && !errors.Is(err, os.ErrNotExist) { if err != nil && !errors.Is(err, os.ErrNotExist) {
return err return err

View file

@ -7,7 +7,6 @@ import (
"slices" "slices"
"testing" "testing"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/types/model" "github.com/ollama/ollama/types/model"
) )
@ -108,7 +107,6 @@ func TestManifests(t *testing.T) {
t.Run(n, func(t *testing.T) { t.Run(n, func(t *testing.T) {
d := t.TempDir() d := t.TempDir()
t.Setenv("OLLAMA_MODELS", d) t.Setenv("OLLAMA_MODELS", d)
envconfig.LoadConfig()
for _, p := range wants.ps { for _, p := range wants.ps {
createManifest(t, d, p) createManifest(t, d, p)

View file

@ -105,9 +105,7 @@ func (mp ModelPath) GetShortTagname() string {
// GetManifestPath returns the path to the manifest file for the given model path, it is up to the caller to create the directory if it does not exist. // GetManifestPath returns the path to the manifest file for the given model path, it is up to the caller to create the directory if it does not exist.
func (mp ModelPath) GetManifestPath() (string, error) { func (mp ModelPath) GetManifestPath() (string, error) {
dir := envconfig.ModelsDir return filepath.Join(envconfig.Models(), "manifests", mp.Registry, mp.Namespace, mp.Repository, mp.Tag), nil
return filepath.Join(dir, "manifests", mp.Registry, mp.Namespace, mp.Repository, mp.Tag), nil
} }
func (mp ModelPath) BaseURL() *url.URL { func (mp ModelPath) BaseURL() *url.URL {
@ -118,9 +116,7 @@ func (mp ModelPath) BaseURL() *url.URL {
} }
func GetManifestPath() (string, error) { func GetManifestPath() (string, error) {
dir := envconfig.ModelsDir path := filepath.Join(envconfig.Models(), "manifests")
path := filepath.Join(dir, "manifests")
if err := os.MkdirAll(path, 0o755); err != nil { if err := os.MkdirAll(path, 0o755); err != nil {
return "", err return "", err
} }
@ -129,8 +125,6 @@ func GetManifestPath() (string, error) {
} }
func GetBlobsPath(digest string) (string, error) { func GetBlobsPath(digest string) (string, error) {
dir := envconfig.ModelsDir
// only accept actual sha256 digests // only accept actual sha256 digests
pattern := "^sha256[:-][0-9a-fA-F]{64}$" pattern := "^sha256[:-][0-9a-fA-F]{64}$"
re := regexp.MustCompile(pattern) re := regexp.MustCompile(pattern)
@ -140,7 +134,7 @@ func GetBlobsPath(digest string) (string, error) {
} }
digest = strings.ReplaceAll(digest, ":", "-") digest = strings.ReplaceAll(digest, ":", "-")
path := filepath.Join(dir, "blobs", digest) path := filepath.Join(envconfig.Models(), "blobs", digest)
dirPath := filepath.Dir(path) dirPath := filepath.Dir(path)
if digest == "" { if digest == "" {
dirPath = path dirPath = path

View file

@ -7,8 +7,6 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/ollama/ollama/envconfig"
) )
func TestGetBlobsPath(t *testing.T) { func TestGetBlobsPath(t *testing.T) {
@ -63,7 +61,6 @@ func TestGetBlobsPath(t *testing.T) {
for _, tc := range tests { for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
t.Setenv("OLLAMA_MODELS", dir) t.Setenv("OLLAMA_MODELS", dir)
envconfig.LoadConfig()
got, err := GetBlobsPath(tc.digest) got, err := GetBlobsPath(tc.digest)

View file

@ -1053,7 +1053,7 @@ func (s *Server) GenerateRoutes() http.Handler {
for _, prop := range openAIProperties { for _, prop := range openAIProperties {
config.AllowHeaders = append(config.AllowHeaders, "x-stainless-"+prop) config.AllowHeaders = append(config.AllowHeaders, "x-stainless-"+prop)
} }
config.AllowOrigins = envconfig.AllowOrigins config.AllowOrigins = envconfig.Origins()
r := gin.Default() r := gin.Default()
r.Use( r.Use(
@ -1098,7 +1098,7 @@ func (s *Server) GenerateRoutes() http.Handler {
func Serve(ln net.Listener) error { func Serve(ln net.Listener) error {
level := slog.LevelInfo level := slog.LevelInfo
if envconfig.Debug { if envconfig.Debug() {
level = slog.LevelDebug level = slog.LevelDebug
} }
@ -1126,7 +1126,7 @@ func Serve(ln net.Listener) error {
return err return err
} }
if !envconfig.NoPrune { if !envconfig.NoPrune() {
// clean up unused layers and manifests // clean up unused layers and manifests
if err := PruneLayers(); err != nil { if err := PruneLayers(); err != nil {
return err return err

View file

@ -15,7 +15,6 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/llm" "github.com/ollama/ollama/llm"
) )
@ -89,7 +88,6 @@ func TestCreateFromBin(t *testing.T) {
p := t.TempDir() p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p) t.Setenv("OLLAMA_MODELS", p)
envconfig.LoadConfig()
var s Server var s Server
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
@ -117,7 +115,6 @@ func TestCreateFromModel(t *testing.T) {
p := t.TempDir() p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p) t.Setenv("OLLAMA_MODELS", p)
envconfig.LoadConfig()
var s Server var s Server
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
@ -160,7 +157,6 @@ func TestCreateRemovesLayers(t *testing.T) {
p := t.TempDir() p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p) t.Setenv("OLLAMA_MODELS", p)
envconfig.LoadConfig()
var s Server var s Server
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
@ -209,7 +205,6 @@ func TestCreateUnsetsSystem(t *testing.T) {
p := t.TempDir() p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p) t.Setenv("OLLAMA_MODELS", p)
envconfig.LoadConfig()
var s Server var s Server
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
@ -267,7 +262,6 @@ func TestCreateMergeParameters(t *testing.T) {
p := t.TempDir() p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p) t.Setenv("OLLAMA_MODELS", p)
envconfig.LoadConfig()
var s Server var s Server
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
@ -372,7 +366,6 @@ func TestCreateReplacesMessages(t *testing.T) {
p := t.TempDir() p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p) t.Setenv("OLLAMA_MODELS", p)
envconfig.LoadConfig()
var s Server var s Server
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
@ -450,7 +443,6 @@ func TestCreateTemplateSystem(t *testing.T) {
p := t.TempDir() p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p) t.Setenv("OLLAMA_MODELS", p)
envconfig.LoadConfig()
var s Server var s Server
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
@ -534,7 +526,6 @@ func TestCreateLicenses(t *testing.T) {
p := t.TempDir() p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p) t.Setenv("OLLAMA_MODELS", p)
envconfig.LoadConfig()
var s Server var s Server
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
@ -582,7 +573,6 @@ func TestCreateDetectTemplate(t *testing.T) {
p := t.TempDir() p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p) t.Setenv("OLLAMA_MODELS", p)
envconfig.LoadConfig()
var s Server var s Server
t.Run("matched", func(t *testing.T) { t.Run("matched", func(t *testing.T) {

View file

@ -10,7 +10,6 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/types/model" "github.com/ollama/ollama/types/model"
) )
@ -19,7 +18,6 @@ func TestDelete(t *testing.T) {
p := t.TempDir() p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p) t.Setenv("OLLAMA_MODELS", p)
envconfig.LoadConfig()
var s Server var s Server

View file

@ -9,14 +9,12 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
) )
func TestList(t *testing.T) { func TestList(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
t.Setenv("OLLAMA_MODELS", t.TempDir()) t.Setenv("OLLAMA_MODELS", t.TempDir())
envconfig.LoadConfig()
expectNames := []string{ expectNames := []string{
"mistral:7b-instruct-q4_0", "mistral:7b-instruct-q4_0",

View file

@ -19,7 +19,6 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/llm" "github.com/ollama/ollama/llm"
"github.com/ollama/ollama/openai" "github.com/ollama/ollama/openai"
"github.com/ollama/ollama/parser" "github.com/ollama/ollama/parser"
@ -347,7 +346,6 @@ func Test_Routes(t *testing.T) {
} }
t.Setenv("OLLAMA_MODELS", t.TempDir()) t.Setenv("OLLAMA_MODELS", t.TempDir())
envconfig.LoadConfig()
s := &Server{} s := &Server{}
router := s.GenerateRoutes() router := s.GenerateRoutes()
@ -378,7 +376,6 @@ func Test_Routes(t *testing.T) {
func TestCase(t *testing.T) { func TestCase(t *testing.T) {
t.Setenv("OLLAMA_MODELS", t.TempDir()) t.Setenv("OLLAMA_MODELS", t.TempDir())
envconfig.LoadConfig()
cases := []string{ cases := []string{
"mistral", "mistral",
@ -458,7 +455,6 @@ func TestCase(t *testing.T) {
func TestShow(t *testing.T) { func TestShow(t *testing.T) {
t.Setenv("OLLAMA_MODELS", t.TempDir()) t.Setenv("OLLAMA_MODELS", t.TempDir())
envconfig.LoadConfig()
var s Server var s Server

View file

@ -5,9 +5,11 @@ import (
"errors" "errors"
"fmt" "fmt"
"log/slog" "log/slog"
"os"
"reflect" "reflect"
"runtime" "runtime"
"sort" "sort"
"strconv"
"strings" "strings"
"sync" "sync"
"time" "time"
@ -59,11 +61,12 @@ var defaultParallel = 4
var ErrMaxQueue = fmt.Errorf("server busy, please try again. maximum pending requests exceeded") var ErrMaxQueue = fmt.Errorf("server busy, please try again. maximum pending requests exceeded")
func InitScheduler(ctx context.Context) *Scheduler { func InitScheduler(ctx context.Context) *Scheduler {
maxQueue := envconfig.MaxQueue()
sched := &Scheduler{ sched := &Scheduler{
pendingReqCh: make(chan *LlmRequest, envconfig.MaxQueuedRequests), pendingReqCh: make(chan *LlmRequest, maxQueue),
finishedReqCh: make(chan *LlmRequest, envconfig.MaxQueuedRequests), finishedReqCh: make(chan *LlmRequest, maxQueue),
expiredCh: make(chan *runnerRef, envconfig.MaxQueuedRequests), expiredCh: make(chan *runnerRef, maxQueue),
unloadedCh: make(chan interface{}, envconfig.MaxQueuedRequests), unloadedCh: make(chan interface{}, maxQueue),
loaded: make(map[string]*runnerRef), loaded: make(map[string]*runnerRef),
newServerFn: llm.NewLlamaServer, newServerFn: llm.NewLlamaServer,
getGpuFn: gpu.GetGPUInfo, getGpuFn: gpu.GetGPUInfo,
@ -126,7 +129,7 @@ func (s *Scheduler) processPending(ctx context.Context) {
slog.Debug("pending request cancelled or timed out, skipping scheduling") slog.Debug("pending request cancelled or timed out, skipping scheduling")
continue continue
} }
numParallel := envconfig.NumParallel numParallel := int(envconfig.NumParallel())
// TODO (jmorganca): multimodal models don't support parallel yet // TODO (jmorganca): multimodal models don't support parallel yet
// see https://github.com/ollama/ollama/issues/4165 // see https://github.com/ollama/ollama/issues/4165
if len(pending.model.ProjectorPaths) > 0 && numParallel != 1 { if len(pending.model.ProjectorPaths) > 0 && numParallel != 1 {
@ -148,7 +151,7 @@ func (s *Scheduler) processPending(ctx context.Context) {
pending.useLoadedRunner(runner, s.finishedReqCh) pending.useLoadedRunner(runner, s.finishedReqCh)
break break
} }
} else if envconfig.MaxRunners > 0 && loadedCount >= envconfig.MaxRunners { } else if envconfig.MaxRunners() > 0 && loadedCount >= int(envconfig.MaxRunners()) {
slog.Debug("max runners achieved, unloading one to make room", "runner_count", loadedCount) slog.Debug("max runners achieved, unloading one to make room", "runner_count", loadedCount)
runnerToExpire = s.findRunnerToUnload() runnerToExpire = s.findRunnerToUnload()
} else { } else {
@ -161,7 +164,7 @@ func (s *Scheduler) processPending(ctx context.Context) {
gpus = s.getGpuFn() gpus = s.getGpuFn()
} }
if envconfig.MaxRunners <= 0 { if envconfig.MaxRunners() <= 0 {
// No user specified MaxRunners, so figure out what automatic setting to use // No user specified MaxRunners, so figure out what automatic setting to use
// If all GPUs have reliable free memory reporting, defaultModelsPerGPU * the number of GPUs // If all GPUs have reliable free memory reporting, defaultModelsPerGPU * the number of GPUs
// if any GPU has unreliable free memory reporting, 1x the number of GPUs // if any GPU has unreliable free memory reporting, 1x the number of GPUs
@ -173,11 +176,13 @@ func (s *Scheduler) processPending(ctx context.Context) {
} }
} }
if allReliable { if allReliable {
envconfig.MaxRunners = defaultModelsPerGPU * len(gpus) // HACK
os.Setenv("OLLAMA_MAX_LOADED_MODELS", strconv.Itoa(defaultModelsPerGPU*len(gpus)))
slog.Debug("updating default concurrency", "OLLAMA_MAX_LOADED_MODELS", envconfig.MaxRunners, "gpu_count", len(gpus)) slog.Debug("updating default concurrency", "OLLAMA_MAX_LOADED_MODELS", envconfig.MaxRunners, "gpu_count", len(gpus))
} else { } else {
// HACK
os.Setenv("OLLAMA_MAX_LOADED_MODELS", strconv.Itoa(len(gpus)))
slog.Info("one or more GPUs detected that are unable to accurately report free memory - disabling default concurrency") slog.Info("one or more GPUs detected that are unable to accurately report free memory - disabling default concurrency")
envconfig.MaxRunners = len(gpus)
} }
} }
@ -404,7 +409,7 @@ func (s *Scheduler) load(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList,
if numParallel < 1 { if numParallel < 1 {
numParallel = 1 numParallel = 1
} }
sessionDuration := envconfig.KeepAlive sessionDuration := envconfig.KeepAlive()
if req.sessionDuration != nil { if req.sessionDuration != nil {
sessionDuration = req.sessionDuration.Duration sessionDuration = req.sessionDuration.Duration
} }
@ -699,7 +704,7 @@ func pickBestFullFitByLibrary(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoL
// First attempt to fit the model into a single GPU // First attempt to fit the model into a single GPU
for _, p := range numParallelToTry { for _, p := range numParallelToTry {
req.opts.NumCtx = req.origNumCtx * p req.opts.NumCtx = req.origNumCtx * p
if !envconfig.SchedSpread { if !envconfig.SchedSpread() {
for _, g := range sgl { for _, g := range sgl {
if ok, estimatedVRAM = llm.PredictServerFit([]gpu.GpuInfo{g}, ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts); ok { if ok, estimatedVRAM = llm.PredictServerFit([]gpu.GpuInfo{g}, ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts); ok {
slog.Info("new model will fit in available VRAM in single GPU, loading", "model", req.model.ModelPath, "gpu", g.ID, "parallel", p, "available", g.FreeMemory, "required", format.HumanBytes2(estimatedVRAM)) slog.Info("new model will fit in available VRAM in single GPU, loading", "model", req.model.ModelPath, "gpu", g.ID, "parallel", p, "available", g.FreeMemory, "required", format.HumanBytes2(estimatedVRAM))

View file

@ -12,7 +12,6 @@ import (
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/app/lifecycle" "github.com/ollama/ollama/app/lifecycle"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format" "github.com/ollama/ollama/format"
"github.com/ollama/ollama/gpu" "github.com/ollama/ollama/gpu"
"github.com/ollama/ollama/llm" "github.com/ollama/ollama/llm"
@ -272,7 +271,7 @@ func TestRequestsMultipleLoadedModels(t *testing.T) {
c.req.opts.NumGPU = 0 // CPU load, will be allowed c.req.opts.NumGPU = 0 // CPU load, will be allowed
d := newScenarioRequest(t, ctx, "ollama-model-3c", 30, nil) // Needs prior unloaded d := newScenarioRequest(t, ctx, "ollama-model-3c", 30, nil) // Needs prior unloaded
envconfig.MaxRunners = 1 t.Setenv("OLLAMA_MAX_LOADED_MODELS", "1")
s.newServerFn = a.newServer s.newServerFn = a.newServer
slog.Info("a") slog.Info("a")
s.pendingReqCh <- a.req s.pendingReqCh <- a.req
@ -291,7 +290,7 @@ func TestRequestsMultipleLoadedModels(t *testing.T) {
require.Len(t, s.loaded, 1) require.Len(t, s.loaded, 1)
s.loadedMu.Unlock() s.loadedMu.Unlock()
envconfig.MaxRunners = 0 t.Setenv("OLLAMA_MAX_LOADED_MODELS", "0")
s.newServerFn = b.newServer s.newServerFn = b.newServer
slog.Info("b") slog.Info("b")
s.pendingReqCh <- b.req s.pendingReqCh <- b.req
@ -362,7 +361,7 @@ func TestGetRunner(t *testing.T) {
a := newScenarioRequest(t, ctx, "ollama-model-1a", 10, &api.Duration{Duration: 2 * time.Millisecond}) a := newScenarioRequest(t, ctx, "ollama-model-1a", 10, &api.Duration{Duration: 2 * time.Millisecond})
b := newScenarioRequest(t, ctx, "ollama-model-1b", 10, &api.Duration{Duration: 2 * time.Millisecond}) b := newScenarioRequest(t, ctx, "ollama-model-1b", 10, &api.Duration{Duration: 2 * time.Millisecond})
c := newScenarioRequest(t, ctx, "ollama-model-1c", 10, &api.Duration{Duration: 2 * time.Millisecond}) c := newScenarioRequest(t, ctx, "ollama-model-1c", 10, &api.Duration{Duration: 2 * time.Millisecond})
envconfig.MaxQueuedRequests = 1 t.Setenv("OLLAMA_MAX_QUEUE", "1")
s := InitScheduler(ctx) s := InitScheduler(ctx)
s.getGpuFn = getGpuFn s.getGpuFn = getGpuFn
s.getCpuFn = getCpuFn s.getCpuFn = getCpuFn