Merge pull request #3418 from dhiltgen/concurrency

Request and model concurrency
This commit is contained in:
Daniel Hiltgen 2024-04-23 08:31:38 -07:00 committed by GitHub
commit 5690e5ce99
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
30 changed files with 2615 additions and 1387 deletions

View file

@ -91,6 +91,13 @@ func ClientFromEnvironment() (*Client, error) {
}, nil }, nil
} }
func NewClient(base *url.URL, http *http.Client) *Client {
return &Client{
base: base,
http: http,
}
}
func (c *Client) do(ctx context.Context, method, path string, reqData, respData any) error { func (c *Client) do(ctx context.Context, method, path string, reqData, respData any) error {
var reqBody io.Reader var reqBody io.Reader
var data []byte var data []byte

View file

@ -15,6 +15,7 @@ const (
KibiByte = Byte * 1024 KibiByte = Byte * 1024
MebiByte = KibiByte * 1024 MebiByte = KibiByte * 1024
GibiByte = MebiByte * 1024
) )
func HumanBytes(b int64) string { func HumanBytes(b int64) string {

View file

@ -7,7 +7,7 @@ import (
"log/slog" "log/slog"
"os" "os"
"path/filepath" "path/filepath"
"strconv" "runtime"
"strings" "strings"
) )
@ -35,22 +35,64 @@ func GetSupportedGFX(libDir string) ([]string, error) {
return ret, nil return ret, nil
} }
func amdSetVisibleDevices(ids []int, skip map[int]interface{}) { func rocmGetVisibleDevicesEnv(gpuInfo []GpuInfo) (string, string) {
// Set the visible devices if not already set ids := []string{}
// TODO - does sort order matter? for _, info := range gpuInfo {
devices := []string{} if info.Library != "rocm" {
for i := range ids { // TODO shouldn't happen if things are wired correctly...
if _, skipped := skip[i]; skipped { slog.Debug("rocmGetVisibleDevicesEnv skipping over non-rocm device", "library", info.Library)
continue continue
} }
devices = append(devices, strconv.Itoa(i)) ids = append(ids, info.ID)
}
return "HIP_VISIBLE_DEVICES", strings.Join(ids, ",")
}
func commonAMDValidateLibDir() (string, error) {
// We try to favor system paths first, so that we can wire up the subprocess to use
// the system version. Only use our bundled version if the system version doesn't work
// This gives users a more recovery options if versions have subtle problems at runtime
// Prefer explicit HIP env var
hipPath := os.Getenv("HIP_PATH")
if hipPath != "" {
hipLibDir := filepath.Join(hipPath, "bin")
if rocmLibUsable(hipLibDir) {
slog.Debug("detected ROCM via HIP_PATH=" + hipPath)
return hipLibDir, nil
}
} }
val := strings.Join(devices, ",") // Scan the LD_LIBRARY_PATH or PATH
err := os.Setenv("HIP_VISIBLE_DEVICES", val) pathEnv := "LD_LIBRARY_PATH"
if err != nil { if runtime.GOOS == "windows" {
slog.Warn(fmt.Sprintf("failed to set env: %s", err)) pathEnv = "PATH"
} else {
slog.Info("Setting HIP_VISIBLE_DEVICES=" + val)
} }
paths := os.Getenv(pathEnv)
for _, path := range filepath.SplitList(paths) {
d, err := filepath.Abs(path)
if err != nil {
continue
}
if rocmLibUsable(d) {
return d, nil
}
}
// Well known location(s)
if rocmLibUsable(RocmStandardLocation) {
return RocmStandardLocation, nil
}
// Installer payload location if we're running the installed binary
exe, err := os.Executable()
if err == nil {
rocmTargetDir := filepath.Join(filepath.Dir(exe), "rocm")
if rocmLibUsable(rocmTargetDir) {
slog.Debug("detected ROCM next to ollama executable " + rocmTargetDir)
return rocmTargetDir, nil
}
}
return "", fmt.Errorf("no suitable rocm found, falling back to CPU")
} }

View file

@ -69,7 +69,7 @@ func NewHipLib() (*HipLib, error) {
func (hl *HipLib) Release() { func (hl *HipLib) Release() {
err := windows.FreeLibrary(hl.dll) err := windows.FreeLibrary(hl.dll)
if err != nil { if err != nil {
slog.Warn(fmt.Sprintf("failed to unload amdhip64.dll: %s", err)) slog.Warn("failed to unload amdhip64.dll", "error", err)
} }
hl.dll = 0 hl.dll = 0
} }
@ -98,7 +98,7 @@ func (hl *HipLib) HipGetDeviceCount() int {
return 0 return 0
} }
if status != hipSuccess { if status != hipSuccess {
slog.Warn(fmt.Sprintf("failed call to hipGetDeviceCount: %d %s", status, err)) slog.Warn("failed call to hipGetDeviceCount", "status", status, "error", err)
} }
return count return count
} }

View file

@ -11,6 +11,8 @@ import (
"slices" "slices"
"strconv" "strconv"
"strings" "strings"
"github.com/ollama/ollama/format"
) )
// Discovery logic for AMD/ROCm GPUs // Discovery logic for AMD/ROCm GPUs
@ -24,9 +26,6 @@ const (
GPUTotalMemoryFileGlob = "mem_banks/*/properties" // size_in_bytes line GPUTotalMemoryFileGlob = "mem_banks/*/properties" // size_in_bytes line
GPUUsedMemoryFileGlob = "mem_banks/*/used_memory" GPUUsedMemoryFileGlob = "mem_banks/*/used_memory"
RocmStandardLocation = "/opt/rocm/lib" RocmStandardLocation = "/opt/rocm/lib"
// TODO find a better way to detect iGPU instead of minimum memory
IGPUMemLimit = 1024 * 1024 * 1024 // 512G is what they typically report, so anything less than 1G must be iGPU
) )
var ( var (
@ -35,14 +34,11 @@ var (
) )
// Gather GPU information from the amdgpu driver if any supported GPUs are detected // Gather GPU information from the amdgpu driver if any supported GPUs are detected
// HIP_VISIBLE_DEVICES will be set if we detect a mix of unsupported and supported devices func AMDGetGPUInfo() []GpuInfo {
// and the user hasn't already set this variable resp := []GpuInfo{}
func AMDGetGPUInfo(resp *GpuInfo) {
// TODO - DRY this out with windows
if !AMDDetected() { if !AMDDetected() {
return return resp
} }
skip := map[int]interface{}{}
// Opportunistic logging of driver version to aid in troubleshooting // Opportunistic logging of driver version to aid in troubleshooting
ver, err := AMDDriverVersion() ver, err := AMDDriverVersion()
@ -50,160 +46,117 @@ func AMDGetGPUInfo(resp *GpuInfo) {
slog.Info("AMD Driver: " + ver) slog.Info("AMD Driver: " + ver)
} else { } else {
// TODO - if we see users crash and burn with the upstreamed kernel this can be adjusted to hard-fail rocm support and fallback to CPU // TODO - if we see users crash and burn with the upstreamed kernel this can be adjusted to hard-fail rocm support and fallback to CPU
slog.Warn(fmt.Sprintf("ollama recommends running the https://www.amd.com/en/support/linux-drivers: %s", err)) slog.Warn("ollama recommends running the https://www.amd.com/en/support/linux-drivers", "error", err)
} }
// If the user has specified exactly which GPUs to use, look up their memory // Determine if the user has already pre-selected which GPUs to look at, then ignore the others
visibleDevices := os.Getenv("HIP_VISIBLE_DEVICES") var visibleDevices []string
if visibleDevices != "" { hipVD := os.Getenv("HIP_VISIBLE_DEVICES") // zero based index only
ids := []int{} rocrVD := os.Getenv("ROCR_VISIBLE_DEVICES") // zero based index or UUID, but consumer cards seem to not support UUID
for _, idStr := range strings.Split(visibleDevices, ",") { gpuDO := os.Getenv("GPU_DEVICE_ORDINAL") // zero based index
id, err := strconv.Atoi(idStr) switch {
if err != nil { // TODO is this priorty order right?
slog.Warn(fmt.Sprintf("malformed HIP_VISIBLE_DEVICES=%s %s", visibleDevices, err)) case hipVD != "":
} else { visibleDevices = strings.Split(hipVD, ",")
ids = append(ids, id) case rocrVD != "":
} visibleDevices = strings.Split(rocrVD, ",")
} // TODO - since we don't yet support UUIDs, consider detecting and reporting here
amdProcMemLookup(resp, nil, ids) // all our test systems show GPU-XX indicating UUID is not supported
return case gpuDO != "":
visibleDevices = strings.Split(gpuDO, ",")
} }
// Gather GFX version information from all detected cards
gfx := AMDGFXVersions()
verStrings := []string{}
for i, v := range gfx {
verStrings = append(verStrings, v.ToGFXString())
if v.Major == 0 {
// Silently skip CPUs
skip[i] = struct{}{}
continue
}
if v.Major < 9 {
// TODO consider this a build-time setting if we can support 8xx family GPUs
slog.Warn(fmt.Sprintf("amdgpu [%d] too old %s", i, v.ToGFXString()))
skip[i] = struct{}{}
}
}
slog.Info(fmt.Sprintf("detected amdgpu versions %v", verStrings))
// Abort if all GPUs are skipped
if len(skip) >= len(gfx) {
slog.Info("all detected amdgpus are skipped, falling back to CPU")
return
}
// If we got this far, then we have at least 1 GPU that's a ROCm candidate, so make sure we have a lib
libDir, err := AMDValidateLibDir()
if err != nil {
slog.Warn(fmt.Sprintf("unable to verify rocm library, will use cpu: %s", err))
return
}
updateLibPath(libDir)
gfxOverride := os.Getenv("HSA_OVERRIDE_GFX_VERSION") gfxOverride := os.Getenv("HSA_OVERRIDE_GFX_VERSION")
if gfxOverride == "" { var supported []string
supported, err := GetSupportedGFX(libDir) libDir := ""
// The amdgpu driver always exposes the host CPU(s) first, but we have to skip them and subtract
// from the other IDs to get alignment with the HIP libraries expectations (zero is the first GPU, not the CPU)
matches, _ := filepath.Glob(GPUPropertiesFileGlob)
cpuCount := 0
for _, match := range matches {
slog.Debug("evaluating amdgpu node " + match)
fp, err := os.Open(match)
if err != nil { if err != nil {
slog.Warn(fmt.Sprintf("failed to lookup supported GFX types, falling back to CPU mode: %s", err)) slog.Debug("failed to open sysfs node", "file", match, "error", err)
return
}
slog.Debug(fmt.Sprintf("rocm supported GPU types %v", supported))
for i, v := range gfx {
if !slices.Contains[[]string, string](supported, v.ToGFXString()) {
slog.Warn(fmt.Sprintf("amdgpu [%d] %s is not supported by %s %v", i, v.ToGFXString(), libDir, supported))
// TODO - consider discrete markdown just for ROCM troubleshooting?
slog.Warn("See https://github.com/ollama/ollama/blob/main/docs/gpu.md#overrides for HSA_OVERRIDE_GFX_VERSION usage")
skip[i] = struct{}{}
} else {
slog.Info(fmt.Sprintf("amdgpu [%d] %s is supported", i, v.ToGFXString()))
}
}
} else {
slog.Debug("skipping rocm gfx compatibility check with HSA_OVERRIDE_GFX_VERSION=" + gfxOverride)
}
if len(skip) >= len(gfx) {
slog.Info("all detected amdgpus are skipped, falling back to CPU")
return
}
ids := make([]int, len(gfx))
i := 0
for k := range gfx {
ids[i] = k
i++
}
amdProcMemLookup(resp, skip, ids)
if resp.memInfo.DeviceCount == 0 {
return
}
if len(skip) > 0 {
amdSetVisibleDevices(ids, skip)
}
}
func updateLibPath(libDir string) {
ldPaths := []string{}
if val, ok := os.LookupEnv("LD_LIBRARY_PATH"); ok {
ldPaths = strings.Split(val, ":")
}
for _, d := range ldPaths {
if d == libDir {
return
}
}
val := strings.Join(append(ldPaths, libDir), ":")
slog.Debug("updated lib path", "LD_LIBRARY_PATH", val)
os.Setenv("LD_LIBRARY_PATH", val)
}
// Walk the sysfs nodes for the available GPUs and gather information from them
// skipping over any devices in the skip map
func amdProcMemLookup(resp *GpuInfo, skip map[int]interface{}, ids []int) {
resp.memInfo.DeviceCount = 0
resp.memInfo.TotalMemory = 0
resp.memInfo.FreeMemory = 0
slog.Debug("discovering VRAM for amdgpu devices")
if len(ids) == 0 {
entries, err := os.ReadDir(AMDNodesSysfsDir)
if err != nil {
slog.Warn(fmt.Sprintf("failed to read amdgpu sysfs %s - %s", AMDNodesSysfsDir, err))
return
}
for _, node := range entries {
if !node.IsDir() {
continue
}
id, err := strconv.Atoi(node.Name())
if err != nil {
slog.Warn("malformed amdgpu sysfs node id " + node.Name())
continue
}
ids = append(ids, id)
}
}
slog.Debug(fmt.Sprintf("amdgpu devices %v", ids))
for _, id := range ids {
if _, skipped := skip[id]; skipped {
continue continue
} }
defer fp.Close()
nodeID, err := strconv.Atoi(filepath.Base(filepath.Dir(match)))
if err != nil {
slog.Debug("failed to parse node ID", "error", err)
continue
}
scanner := bufio.NewScanner(fp)
isCPU := false
var major, minor, patch uint64
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
// Note: we could also use "cpu_cores_count X" where X is greater than zero to detect CPUs
if strings.HasPrefix(line, "gfx_target_version") {
ver := strings.Fields(line)
// Detect CPUs
if len(ver) == 2 && ver[1] == "0" {
slog.Debug("detected CPU " + match)
isCPU = true
break
}
if len(ver) != 2 || len(ver[1]) < 5 {
slog.Warn("malformed "+match, "gfx_target_version", line)
// If this winds up being a CPU, our offsets may be wrong
continue
}
l := len(ver[1])
var err1, err2, err3 error
patch, err1 = strconv.ParseUint(ver[1][l-2:l], 10, 32)
minor, err2 = strconv.ParseUint(ver[1][l-4:l-2], 10, 32)
major, err3 = strconv.ParseUint(ver[1][:l-4], 10, 32)
if err1 != nil || err2 != nil || err3 != nil {
slog.Debug("malformed int " + line)
continue
}
}
// TODO - any other properties we want to extract and record?
// vendor_id + device_id -> pci lookup for "Name"
// Other metrics that may help us understand relative performance between multiple GPUs
}
if isCPU {
cpuCount++
continue
}
// CPUs are always first in the list
gpuID := nodeID - cpuCount
// Shouldn't happen, but just in case...
if gpuID < 0 {
slog.Error("unexpected amdgpu sysfs data resulted in negative GPU ID, please set OLLAMA_DEBUG=1 and report an issue")
return []GpuInfo{}
}
if int(major) < RocmComputeMin {
slog.Warn(fmt.Sprintf("amdgpu too old gfx%d%d%d", major, minor, patch), "gpu", gpuID)
continue
}
// Look up the memory for the current node
totalMemory := uint64(0) totalMemory := uint64(0)
usedMemory := uint64(0) usedMemory := uint64(0)
// Adjust for sysfs vs HIP ids propGlob := filepath.Join(AMDNodesSysfsDir, strconv.Itoa(nodeID), GPUTotalMemoryFileGlob)
propGlob := filepath.Join(AMDNodesSysfsDir, strconv.Itoa(id+1), GPUTotalMemoryFileGlob)
propFiles, err := filepath.Glob(propGlob) propFiles, err := filepath.Glob(propGlob)
if err != nil { if err != nil {
slog.Warn(fmt.Sprintf("error looking up total GPU memory: %s %s", propGlob, err)) slog.Warn("error looking up total GPU memory", "glob", propGlob, "error", err)
} }
// 1 or more memory banks - sum the values of all of them // 1 or more memory banks - sum the values of all of them
for _, propFile := range propFiles { for _, propFile := range propFiles {
fp, err := os.Open(propFile) fp, err := os.Open(propFile)
if err != nil { if err != nil {
slog.Warn(fmt.Sprintf("failed to open sysfs node file %s: %s", propFile, err)) slog.Warn("failed to open sysfs node", "file", propFile, "erroir", err)
continue continue
} }
defer fp.Close() defer fp.Close()
@ -226,49 +179,113 @@ func amdProcMemLookup(resp *GpuInfo, skip map[int]interface{}, ids []int) {
} }
} }
if totalMemory == 0 { if totalMemory == 0 {
slog.Warn(fmt.Sprintf("amdgpu [%d] reports zero total memory, skipping", id)) slog.Warn("amdgpu reports zero total memory", "gpu", gpuID)
skip[id] = struct{}{}
continue continue
} }
if totalMemory < IGPUMemLimit { usedGlob := filepath.Join(AMDNodesSysfsDir, strconv.Itoa(nodeID), GPUUsedMemoryFileGlob)
slog.Info(fmt.Sprintf("amdgpu [%d] appears to be an iGPU with %dM reported total memory, skipping", id, totalMemory/1024/1024))
skip[id] = struct{}{}
continue
}
usedGlob := filepath.Join(AMDNodesSysfsDir, strconv.Itoa(id), GPUUsedMemoryFileGlob)
usedFiles, err := filepath.Glob(usedGlob) usedFiles, err := filepath.Glob(usedGlob)
if err != nil { if err != nil {
slog.Warn(fmt.Sprintf("error looking up used GPU memory: %s %s", usedGlob, err)) slog.Warn("error looking up used GPU memory", "glob", usedGlob, "error", err)
continue continue
} }
for _, usedFile := range usedFiles { for _, usedFile := range usedFiles {
fp, err := os.Open(usedFile) fp, err := os.Open(usedFile)
if err != nil { if err != nil {
slog.Warn(fmt.Sprintf("failed to open sysfs node file %s: %s", usedFile, err)) slog.Warn("failed to open sysfs node", "file", usedFile, "error", err)
continue continue
} }
defer fp.Close() defer fp.Close()
data, err := io.ReadAll(fp) data, err := io.ReadAll(fp)
if err != nil { if err != nil {
slog.Warn(fmt.Sprintf("failed to read sysfs node file %s: %s", usedFile, err)) slog.Warn("failed to read sysfs node", "file", usedFile, "error", err)
continue continue
} }
used, err := strconv.ParseUint(strings.TrimSpace(string(data)), 10, 64) used, err := strconv.ParseUint(strings.TrimSpace(string(data)), 10, 64)
if err != nil { if err != nil {
slog.Warn(fmt.Sprintf("malformed used memory %s: %s", string(data), err)) slog.Warn("malformed used memory", "data", string(data), "error", err)
continue continue
} }
usedMemory += used usedMemory += used
} }
slog.Info(fmt.Sprintf("[%d] amdgpu totalMemory %dM", id, totalMemory/1024/1024))
slog.Info(fmt.Sprintf("[%d] amdgpu freeMemory %dM", id, (totalMemory-usedMemory)/1024/1024)) // iGPU detection, remove this check once we can support an iGPU variant of the rocm library
resp.memInfo.DeviceCount++ if totalMemory < IGPUMemLimit {
resp.memInfo.TotalMemory += totalMemory slog.Info("amdgpu appears to be an iGPU, skipping", "gpu", gpuID, "total", format.HumanBytes2(totalMemory))
resp.memInfo.FreeMemory += (totalMemory - usedMemory) continue
}
slog.Info("amdgpu memory", "gpu", gpuID, "total", format.HumanBytes2(totalMemory))
slog.Info("amdgpu memory", "gpu", gpuID, "available", format.HumanBytes2(totalMemory-usedMemory))
gpuInfo := GpuInfo{
Library: "rocm",
memInfo: memInfo{
TotalMemory: totalMemory,
FreeMemory: (totalMemory - usedMemory),
},
ID: fmt.Sprintf("%d", gpuID),
// Name: not exposed in sysfs directly, would require pci device id lookup
Major: int(major),
Minor: int(minor),
Patch: int(patch),
MinimumMemory: rocmMinimumMemory,
}
// If the user wants to filter to a subset of devices, filter out if we aren't a match
if len(visibleDevices) > 0 {
include := false
for _, visible := range visibleDevices {
if visible == gpuInfo.ID {
include = true
break
}
}
if !include {
slog.Info("filtering out device per user request", "id", gpuInfo.ID, "visible_devices", visibleDevices)
continue
}
}
// Final validation is gfx compatibility - load the library if we haven't already loaded it
// even if the user overrides, we still need to validate the library
if libDir == "" {
libDir, err = AMDValidateLibDir()
if err != nil {
slog.Warn("unable to verify rocm library, will use cpu", "error", err)
return []GpuInfo{}
}
}
gpuInfo.DependencyPath = libDir
if gfxOverride == "" {
// Only load supported list once
if len(supported) == 0 {
supported, err = GetSupportedGFX(libDir)
if err != nil {
slog.Warn("failed to lookup supported GFX types, falling back to CPU mode", "error", err)
return []GpuInfo{}
}
slog.Debug("rocm supported GPUs", "types", supported)
}
gfx := fmt.Sprintf("gfx%d%d%d", gpuInfo.Major, gpuInfo.Minor, gpuInfo.Patch)
if !slices.Contains[[]string, string](supported, gfx) {
slog.Warn("amdgpu is not supported", "gpu", gpuInfo.ID, "gpu_type", gfx, "library", libDir, "supported_types", supported)
// TODO - consider discrete markdown just for ROCM troubleshooting?
slog.Warn("See https://github.com/ollama/ollama/blob/main/docs/gpu.md#overrides for HSA_OVERRIDE_GFX_VERSION usage")
continue
} else {
slog.Info("amdgpu is supported", "gpu", gpuInfo.ID, "gpu_type", gfx)
}
} else {
slog.Debug("skipping rocm gfx compatibility check with HSA_OVERRIDE_GFX_VERSION=" + gfxOverride)
}
// The GPU has passed all the verification steps and is supported
resp = append(resp, gpuInfo)
} }
if resp.memInfo.DeviceCount > 0 { if len(resp) == 0 {
resp.Library = "rocm" slog.Info("no compatible amdgpu devices detected")
} }
return resp
} }
// Quick check for AMD driver so we can skip amdgpu discovery if not present // Quick check for AMD driver so we can skip amdgpu discovery if not present
@ -280,87 +297,24 @@ func AMDDetected() bool {
slog.Debug("amdgpu driver not detected " + sysfsDir) slog.Debug("amdgpu driver not detected " + sysfsDir)
return false return false
} else if err != nil { } else if err != nil {
slog.Debug(fmt.Sprintf("error looking up amd driver %s %s", sysfsDir, err)) slog.Debug("error looking up amd driver", "path", sysfsDir, "error", err)
return false return false
} }
return true return true
} }
func setupLink(source, target string) error {
if err := os.RemoveAll(target); err != nil {
return fmt.Errorf("failed to remove old rocm directory %s %w", target, err)
}
if err := os.Symlink(source, target); err != nil {
return fmt.Errorf("failed to create link %s => %s %w", source, target, err)
}
slog.Debug(fmt.Sprintf("host rocm linked %s => %s", source, target))
return nil
}
// Ensure the AMD rocm lib dir is wired up
// Prefer to use host installed ROCm, as long as it meets our minimum requirements // Prefer to use host installed ROCm, as long as it meets our minimum requirements
// failing that, tell the user how to download it on their own // failing that, tell the user how to download it on their own
func AMDValidateLibDir() (string, error) { func AMDValidateLibDir() (string, error) {
// We rely on the rpath compiled into our library to find rocm libDir, err := commonAMDValidateLibDir()
// so we establish a symlink to wherever we find it on the system
// to <payloads>/rocm
payloadsDir, err := PayloadsDir()
if err != nil {
return "", err
}
// If we already have a rocm dependency wired, nothing more to do
rocmTargetDir := filepath.Clean(filepath.Join(payloadsDir, "..", "rocm"))
if rocmLibUsable(rocmTargetDir) {
return rocmTargetDir, nil
}
// next to the running binary
exe, err := os.Executable()
if err == nil { if err == nil {
peerDir := filepath.Dir(exe) return libDir, nil
if rocmLibUsable(peerDir) {
slog.Debug("detected ROCM next to ollama executable " + peerDir)
return rocmTargetDir, setupLink(peerDir, rocmTargetDir)
}
peerDir = filepath.Join(filepath.Dir(exe), "rocm")
if rocmLibUsable(peerDir) {
slog.Debug("detected ROCM next to ollama executable " + peerDir)
return rocmTargetDir, setupLink(peerDir, rocmTargetDir)
}
} }
// Well known ollama installer path // Well known ollama installer path
installedRocmDir := "/usr/share/ollama/lib/rocm" installedRocmDir := "/usr/share/ollama/lib/rocm"
if rocmLibUsable(installedRocmDir) { if rocmLibUsable(installedRocmDir) {
return rocmTargetDir, setupLink(installedRocmDir, rocmTargetDir) return installedRocmDir, nil
}
// Prefer explicit HIP env var
hipPath := os.Getenv("HIP_PATH")
if hipPath != "" {
hipLibDir := filepath.Join(hipPath, "lib")
if rocmLibUsable(hipLibDir) {
slog.Debug("detected ROCM via HIP_PATH=" + hipPath)
return rocmTargetDir, setupLink(hipLibDir, rocmTargetDir)
}
}
// Scan the library path for potential matches
ldPaths := strings.Split(os.Getenv("LD_LIBRARY_PATH"), ":")
for _, ldPath := range ldPaths {
d, err := filepath.Abs(ldPath)
if err != nil {
continue
}
if rocmLibUsable(d) {
return rocmTargetDir, setupLink(d, rocmTargetDir)
}
}
// Well known location(s)
if rocmLibUsable("/opt/rocm/lib") {
return rocmTargetDir, setupLink("/opt/rocm/lib", rocmTargetDir)
} }
// If we still haven't found a usable rocm, the user will have to install it on their own // If we still haven't found a usable rocm, the user will have to install it on their own
@ -384,68 +338,3 @@ func AMDDriverVersion() (string, error) {
} }
return strings.TrimSpace(string(verString)), nil return strings.TrimSpace(string(verString)), nil
} }
func AMDGFXVersions() map[int]Version {
// The amdgpu driver always exposes the host CPU as node 0, but we have to skip that and subtract one
// from the other IDs to get alignment with the HIP libraries expectations (zero is the first GPU, not the CPU)
res := map[int]Version{}
matches, _ := filepath.Glob(GPUPropertiesFileGlob)
for _, match := range matches {
fp, err := os.Open(match)
if err != nil {
slog.Debug(fmt.Sprintf("failed to open sysfs node file %s: %s", match, err))
continue
}
defer fp.Close()
i, err := strconv.Atoi(filepath.Base(filepath.Dir(match)))
if err != nil {
slog.Debug(fmt.Sprintf("failed to parse node ID %s", err))
continue
}
if i == 0 {
// Skipping the CPU
continue
}
// Align with HIP IDs (zero is first GPU, not CPU)
i -= 1
scanner := bufio.NewScanner(fp)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if strings.HasPrefix(line, "gfx_target_version") {
ver := strings.Fields(line)
if len(ver) != 2 || len(ver[1]) < 5 {
if ver[1] != "0" {
slog.Debug("malformed " + line)
}
res[i] = Version{
Major: 0,
Minor: 0,
Patch: 0,
}
continue
}
l := len(ver[1])
patch, err1 := strconv.ParseUint(ver[1][l-2:l], 10, 32)
minor, err2 := strconv.ParseUint(ver[1][l-4:l-2], 10, 32)
major, err3 := strconv.ParseUint(ver[1][:l-4], 10, 32)
if err1 != nil || err2 != nil || err3 != nil {
slog.Debug("malformed int " + line)
continue
}
res[i] = Version{
Major: uint(major),
Minor: uint(minor),
Patch: uint(patch),
}
}
}
}
return res
}
func (v Version) ToGFXString() string {
return fmt.Sprintf("gfx%d%d%d", v.Major, v.Minor, v.Patch)
}

View file

@ -7,7 +7,10 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"slices" "slices"
"strconv"
"strings" "strings"
"github.com/ollama/ollama/format"
) )
const ( const (
@ -22,36 +25,32 @@ var (
ROCmLibGlobs = []string{"hipblas.dll", "rocblas"} // TODO - probably include more coverage of files here... ROCmLibGlobs = []string{"hipblas.dll", "rocblas"} // TODO - probably include more coverage of files here...
) )
func AMDGetGPUInfo(resp *GpuInfo) { func AMDGetGPUInfo() []GpuInfo {
resp := []GpuInfo{}
hl, err := NewHipLib() hl, err := NewHipLib()
if err != nil { if err != nil {
slog.Debug(err.Error()) slog.Debug(err.Error())
return return nil
} }
defer hl.Release() defer hl.Release()
skip := map[int]interface{}{}
ids := []int{}
resp.memInfo.DeviceCount = 0
resp.memInfo.TotalMemory = 0
resp.memInfo.FreeMemory = 0
ver, err := hl.AMDDriverVersion() ver, err := hl.AMDDriverVersion()
if err == nil { if err == nil {
slog.Info("AMD Driver: " + ver) slog.Info("AMD Driver: " + ver)
} else { } else {
// For now this is benign, but we may eventually need to fail compatibility checks // For now this is benign, but we may eventually need to fail compatibility checks
slog.Debug(fmt.Sprintf("error looking up amd driver version: %s", err)) slog.Debug("error looking up amd driver version", "error", err)
} }
// Note: the HIP library automatically handles HIP_VISIBLE_DEVICES // Note: the HIP library automatically handles subsetting to any HIP_VISIBLE_DEVICES the user specified
count := hl.HipGetDeviceCount() count := hl.HipGetDeviceCount()
if count == 0 { if count == 0 {
return return nil
} }
libDir, err := AMDValidateLibDir() libDir, err := AMDValidateLibDir()
if err != nil { if err != nil {
slog.Warn(fmt.Sprintf("unable to verify rocm library, will use cpu: %s", err)) slog.Warn("unable to verify rocm library, will use cpu", "error", err)
return return nil
} }
var supported []string var supported []string
@ -59,95 +58,117 @@ func AMDGetGPUInfo(resp *GpuInfo) {
if gfxOverride == "" { if gfxOverride == "" {
supported, err = GetSupportedGFX(libDir) supported, err = GetSupportedGFX(libDir)
if err != nil { if err != nil {
slog.Warn(fmt.Sprintf("failed to lookup supported GFX types, falling back to CPU mode: %s", err)) slog.Warn("failed to lookup supported GFX types, falling back to CPU mode", "error", err)
return return nil
} }
} else { } else {
slog.Debug("skipping rocm gfx compatibility check with HSA_OVERRIDE_GFX_VERSION=" + gfxOverride) slog.Debug("skipping rocm gfx compatibility check with HSA_OVERRIDE_GFX_VERSION=" + gfxOverride)
} }
slog.Info(fmt.Sprintf("detected %d hip devices", count)) slog.Info("detected hip devices", "count", count)
// TODO how to determine the underlying device ID when visible devices is causing this to subset?
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
ids = append(ids, i)
err = hl.HipSetDevice(i) err = hl.HipSetDevice(i)
if err != nil { if err != nil {
slog.Warn(fmt.Sprintf("[%d] %s", i, err)) slog.Warn("set device", "id", i, "error", err)
skip[i] = struct{}{}
continue continue
} }
props, err := hl.HipGetDeviceProperties(i) props, err := hl.HipGetDeviceProperties(i)
if err != nil { if err != nil {
slog.Warn(fmt.Sprintf("[%d] %s", i, err)) slog.Warn("get properties", "id", i, "error", err)
skip[i] = struct{}{}
continue continue
} }
n := bytes.IndexByte(props.Name[:], 0) n := bytes.IndexByte(props.Name[:], 0)
name := string(props.Name[:n]) name := string(props.Name[:n])
slog.Info(fmt.Sprintf("[%d] Name: %s", i, name)) // TODO is UUID actually populated on windows?
// Can luid be used on windows for setting visible devices (and is it actually set?)
n = bytes.IndexByte(props.GcnArchName[:], 0) n = bytes.IndexByte(props.GcnArchName[:], 0)
gfx := string(props.GcnArchName[:n]) gfx := string(props.GcnArchName[:n])
slog.Info(fmt.Sprintf("[%d] GcnArchName: %s", i, gfx)) slog.Info("hip device", "id", i, "name", name, "gfx", gfx)
var major, minor, patch string
switch len(gfx) {
case 6:
major, minor, patch = gfx[3:4], gfx[4:5], gfx[5:]
case 7:
major, minor, patch = gfx[3:5], gfx[5:6], gfx[6:]
}
//slog.Info(fmt.Sprintf("[%d] Integrated: %d", i, props.iGPU)) // DOESN'T REPORT CORRECTLY! Always 0 //slog.Info(fmt.Sprintf("[%d] Integrated: %d", i, props.iGPU)) // DOESN'T REPORT CORRECTLY! Always 0
// TODO Why isn't props.iGPU accurate!? // TODO Why isn't props.iGPU accurate!?
if strings.EqualFold(name, iGPUName) { if strings.EqualFold(name, iGPUName) {
slog.Info(fmt.Sprintf("iGPU detected [%d] skipping", i)) slog.Info("iGPU detected skipping", "id", i)
skip[i] = struct{}{}
continue continue
} }
if gfxOverride == "" { if gfxOverride == "" {
if !slices.Contains[[]string, string](supported, gfx) { if !slices.Contains[[]string, string](supported, gfx) {
slog.Warn(fmt.Sprintf("amdgpu [%d] %s is not supported by %s %v", i, gfx, libDir, supported)) slog.Warn("amdgpu is not supported", "gpu", i, "gpu_type", gfx, "library", libDir, "supported_types", supported)
// TODO - consider discrete markdown just for ROCM troubleshooting? // TODO - consider discrete markdown just for ROCM troubleshooting?
slog.Warn("See https://github.com/ollama/ollama/blob/main/docs/troubleshooting.md for HSA_OVERRIDE_GFX_VERSION usage") slog.Warn("See https://github.com/ollama/ollama/blob/main/docs/troubleshooting.md for HSA_OVERRIDE_GFX_VERSION usage")
skip[i] = struct{}{}
continue continue
} else { } else {
slog.Info(fmt.Sprintf("amdgpu [%d] %s is supported", i, gfx)) slog.Info("amdgpu is supported", "gpu", i, "gpu_type", gfx)
} }
} }
totalMemory, freeMemory, err := hl.HipMemGetInfo() freeMemory, totalMemory, err := hl.HipMemGetInfo()
if err != nil { if err != nil {
slog.Warn(fmt.Sprintf("[%d] %s", i, err)) slog.Warn("get mem info", "id", i, "error", err)
continue continue
} }
// TODO according to docs, freeMem may lie on windows! // iGPU detection, remove this check once we can support an iGPU variant of the rocm library
slog.Info(fmt.Sprintf("[%d] Total Mem: %d", i, totalMemory)) if totalMemory < IGPUMemLimit {
slog.Info(fmt.Sprintf("[%d] Free Mem: %d", i, freeMemory)) slog.Info("amdgpu appears to be an iGPU, skipping", "gpu", i, "total", format.HumanBytes2(totalMemory))
resp.memInfo.DeviceCount++ continue
resp.memInfo.TotalMemory += totalMemory }
resp.memInfo.FreeMemory += freeMemory
// TODO revisit this once ROCm v6 is available on windows.
// v5.7 only reports VRAM used by this process, so it's completely wrong and unusable
slog.Info("amdgpu memory", "gpu", i, "total", format.HumanBytes2(totalMemory))
slog.Info("amdgpu memory", "gpu", i, "available", format.HumanBytes2(freeMemory))
gpuInfo := GpuInfo{
Library: "rocm",
memInfo: memInfo{
TotalMemory: totalMemory,
FreeMemory: freeMemory,
},
ID: fmt.Sprintf("%d", i), // TODO this is probably wrong if we specify visible devices
DependencyPath: libDir,
MinimumMemory: rocmMinimumMemory,
}
if major != "" {
gpuInfo.Major, err = strconv.Atoi(major)
if err != nil {
slog.Info("failed to parse version", "version", gfx, "error", err)
}
}
if minor != "" {
gpuInfo.Minor, err = strconv.Atoi(minor)
if err != nil {
slog.Info("failed to parse version", "version", gfx, "error", err)
}
}
if patch != "" {
gpuInfo.Patch, err = strconv.Atoi(patch)
if err != nil {
slog.Info("failed to parse version", "version", gfx, "error", err)
}
}
if gpuInfo.Major < RocmComputeMin {
slog.Warn(fmt.Sprintf("amdgpu [%s] too old gfx%d%d%d", gpuInfo.ID, gpuInfo.Major, gpuInfo.Minor, gpuInfo.Patch))
continue
}
resp = append(resp, gpuInfo)
} }
if resp.memInfo.DeviceCount > 0 {
resp.Library = "rocm" return resp
}
// Abort if all GPUs are skipped
if len(skip) >= count {
slog.Info("all detected amdgpus are skipped, falling back to CPU")
return
}
if len(skip) > 0 {
amdSetVisibleDevices(ids, skip)
}
UpdatePath(libDir)
} }
func AMDValidateLibDir() (string, error) { func AMDValidateLibDir() (string, error) {
// On windows non-admins typically can't create links libDir, err := commonAMDValidateLibDir()
// so instead of trying to rely on rpath and a link in
// $LibDir/rocm, we instead rely on setting PATH to point
// to the location of the ROCm library
// Installer payload location if we're running the installed binary
exe, err := os.Executable()
if err == nil { if err == nil {
rocmTargetDir := filepath.Join(filepath.Dir(exe), "rocm") return libDir, nil
if rocmLibUsable(rocmTargetDir) {
slog.Debug("detected ROCM next to ollama executable " + rocmTargetDir)
return rocmTargetDir, nil
}
} }
// Installer payload (if we're running from some other location) // Installer payload (if we're running from some other location)
@ -159,21 +180,6 @@ func AMDValidateLibDir() (string, error) {
return rocmTargetDir, nil return rocmTargetDir, nil
} }
// Prefer explicit HIP env var
hipPath := os.Getenv("HIP_PATH")
if hipPath != "" {
hipLibDir := filepath.Join(hipPath, "bin")
if rocmLibUsable(hipLibDir) {
slog.Debug("detected ROCM via HIP_PATH=" + hipPath)
return hipLibDir, nil
}
}
// Well known location(s)
if rocmLibUsable(RocmStandardLocation) {
return RocmStandardLocation, nil
}
// Should not happen on windows since we include it in the installer, but stand-alone binary might hit this // Should not happen on windows since we include it in the installer, but stand-alone binary might hit this
slog.Warn("amdgpu detected, but no compatible rocm library found. Please install ROCm") slog.Warn("amdgpu detected, but no compatible rocm library found. Please install ROCm")
return "", fmt.Errorf("no suitable rocm found, falling back to CPU") return "", fmt.Errorf("no suitable rocm found, falling back to CPU")

View file

@ -80,7 +80,7 @@ func cleanupTmpDirs() {
} }
err = os.RemoveAll(d) err = os.RemoveAll(d)
if err != nil { if err != nil {
slog.Debug(fmt.Sprintf("unable to cleanup stale tmpdir %s: %s", d, err)) slog.Debug("unable to cleanup stale tmpdir", "path", d, "error", err)
} }
} }
} }
@ -120,7 +120,7 @@ func UpdatePath(dir string) {
} }
} }
newPath := strings.Join(append([]string{dir}, pathComponents...), ";") newPath := strings.Join(append([]string{dir}, pathComponents...), ";")
slog.Info(fmt.Sprintf("Updating PATH to %s", newPath)) slog.Info("updating", "PATH", newPath)
os.Setenv("PATH", newPath) os.Setenv("PATH", newPath)
} }
// linux and darwin rely on rpath // linux and darwin rely on rpath

22
gpu/cuda_common.go Normal file
View file

@ -0,0 +1,22 @@
//go:build linux || windows
package gpu
import (
"log/slog"
"strings"
)
func cudaGetVisibleDevicesEnv(gpuInfo []GpuInfo) (string, string) {
ids := []string{}
for _, info := range gpuInfo {
if info.Library != "cuda" {
// TODO shouldn't happen if things are wired correctly...
slog.Debug("cudaGetVisibleDevicesEnv skipping over non-cuda device", "library", info.Library)
continue
}
ids = append(ids, info.ID)
}
return "CUDA_VISIBLE_DEVICES", strings.Join(ids, ",")
}

View file

@ -16,7 +16,6 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"runtime" "runtime"
"strconv"
"strings" "strings"
"sync" "sync"
"unsafe" "unsafe"
@ -25,8 +24,8 @@ import (
) )
type handles struct { type handles struct {
nvml *C.nvml_handle_t deviceCount int
cudart *C.cudart_handle_t cudart *C.cudart_handle_t
} }
const ( const (
@ -39,26 +38,10 @@ var gpuMutex sync.Mutex
// With our current CUDA compile flags, older than 5.0 will not work properly // With our current CUDA compile flags, older than 5.0 will not work properly
var CudaComputeMin = [2]C.int{5, 0} var CudaComputeMin = [2]C.int{5, 0}
// Possible locations for the nvidia-ml library var RocmComputeMin = 9
var NvmlLinuxGlobs = []string{
"/usr/local/cuda/lib64/libnvidia-ml.so*",
"/usr/lib/x86_64-linux-gnu/nvidia/current/libnvidia-ml.so*",
"/usr/lib/x86_64-linux-gnu/libnvidia-ml.so*",
"/usr/lib/wsl/lib/libnvidia-ml.so*",
"/usr/lib/wsl/drivers/*/libnvidia-ml.so*",
"/opt/cuda/lib64/libnvidia-ml.so*",
"/usr/lib*/libnvidia-ml.so*",
"/usr/lib/aarch64-linux-gnu/nvidia/current/libnvidia-ml.so*",
"/usr/lib/aarch64-linux-gnu/libnvidia-ml.so*",
"/usr/local/lib*/libnvidia-ml.so*",
// TODO: are these stubs ever valid? // TODO find a better way to detect iGPU instead of minimum memory
"/opt/cuda/targets/x86_64-linux/lib/stubs/libnvidia-ml.so*", const IGPUMemLimit = 1 * format.GibiByte // 512G is what they typically report, so anything less than 1G must be iGPU
}
var NvmlWindowsGlobs = []string{
"c:\\Windows\\System32\\nvml.dll",
}
var CudartLinuxGlobs = []string{ var CudartLinuxGlobs = []string{
"/usr/local/cuda/lib64/libcudart.so*", "/usr/local/cuda/lib64/libcudart.so*",
@ -88,26 +71,18 @@ func initGPUHandles() *handles {
// TODO - if the ollama build is CPU only, don't do these checks as they're irrelevant and confusing // TODO - if the ollama build is CPU only, don't do these checks as they're irrelevant and confusing
gpuHandles := &handles{nil, nil} gpuHandles := &handles{}
var nvmlMgmtName string
var nvmlMgmtPatterns []string
var cudartMgmtName string var cudartMgmtName string
var cudartMgmtPatterns []string var cudartMgmtPatterns []string
tmpDir, _ := PayloadsDir() tmpDir, _ := PayloadsDir()
switch runtime.GOOS { switch runtime.GOOS {
case "windows": case "windows":
nvmlMgmtName = "nvml.dll"
nvmlMgmtPatterns = make([]string, len(NvmlWindowsGlobs))
copy(nvmlMgmtPatterns, NvmlWindowsGlobs)
cudartMgmtName = "cudart64_*.dll" cudartMgmtName = "cudart64_*.dll"
localAppData := os.Getenv("LOCALAPPDATA") localAppData := os.Getenv("LOCALAPPDATA")
cudartMgmtPatterns = []string{filepath.Join(localAppData, "Programs", "Ollama", cudartMgmtName)} cudartMgmtPatterns = []string{filepath.Join(localAppData, "Programs", "Ollama", cudartMgmtName)}
cudartMgmtPatterns = append(cudartMgmtPatterns, CudartWindowsGlobs...) cudartMgmtPatterns = append(cudartMgmtPatterns, CudartWindowsGlobs...)
case "linux": case "linux":
nvmlMgmtName = "libnvidia-ml.so"
nvmlMgmtPatterns = make([]string, len(NvmlLinuxGlobs))
copy(nvmlMgmtPatterns, NvmlLinuxGlobs)
cudartMgmtName = "libcudart.so*" cudartMgmtName = "libcudart.so*"
if tmpDir != "" { if tmpDir != "" {
// TODO - add "payloads" for subprocess // TODO - add "payloads" for subprocess
@ -118,31 +93,21 @@ func initGPUHandles() *handles {
return gpuHandles return gpuHandles
} }
slog.Info("Detecting GPU type") slog.Info("Detecting GPUs")
cudartLibPaths := FindGPULibs(cudartMgmtName, cudartMgmtPatterns) cudartLibPaths := FindGPULibs(cudartMgmtName, cudartMgmtPatterns)
if len(cudartLibPaths) > 0 { if len(cudartLibPaths) > 0 {
cudart := LoadCUDARTMgmt(cudartLibPaths) deviceCount, cudart, libPath := LoadCUDARTMgmt(cudartLibPaths)
if cudart != nil { if cudart != nil {
slog.Info("Nvidia GPU detected via cudart") slog.Info("detected GPUs", "library", libPath, "count", deviceCount)
gpuHandles.cudart = cudart gpuHandles.cudart = cudart
return gpuHandles gpuHandles.deviceCount = deviceCount
}
}
// TODO once we build confidence, remove this and the gpu_info_nvml.[ch] files
nvmlLibPaths := FindGPULibs(nvmlMgmtName, nvmlMgmtPatterns)
if len(nvmlLibPaths) > 0 {
nvml := LoadNVMLMgmt(nvmlLibPaths)
if nvml != nil {
slog.Info("Nvidia GPU detected via nvidia-ml")
gpuHandles.nvml = nvml
return gpuHandles return gpuHandles
} }
} }
return gpuHandles return gpuHandles
} }
func GetGPUInfo() GpuInfo { func GetGPUInfo() GpuInfoList {
// TODO - consider exploring lspci (and equivalent on windows) to check for // TODO - consider exploring lspci (and equivalent on windows) to check for
// GPUs so we can report warnings if we see Nvidia/AMD but fail to load the libraries // GPUs so we can report warnings if we see Nvidia/AMD but fail to load the libraries
gpuMutex.Lock() gpuMutex.Lock()
@ -150,9 +115,6 @@ func GetGPUInfo() GpuInfo {
gpuHandles := initGPUHandles() gpuHandles := initGPUHandles()
defer func() { defer func() {
if gpuHandles.nvml != nil {
C.nvml_release(*gpuHandles.nvml)
}
if gpuHandles.cudart != nil { if gpuHandles.cudart != nil {
C.cudart_release(*gpuHandles.cudart) C.cudart_release(*gpuHandles.cudart)
} }
@ -165,72 +127,63 @@ func GetGPUInfo() GpuInfo {
} }
var memInfo C.mem_info_t var memInfo C.mem_info_t
resp := GpuInfo{} resp := []GpuInfo{}
if gpuHandles.nvml != nil && (cpuVariant != "" || runtime.GOARCH != "amd64") {
C.nvml_check_vram(*gpuHandles.nvml, &memInfo) // NVIDIA first
for i := 0; i < gpuHandles.deviceCount; i++ {
// TODO once we support CPU compilation variants of GPU libraries refine this...
if cpuVariant == "" && runtime.GOARCH == "amd64" {
continue
}
gpuInfo := GpuInfo{
Library: "cuda",
}
C.cudart_check_vram(*gpuHandles.cudart, C.int(i), &memInfo)
if memInfo.err != nil { if memInfo.err != nil {
slog.Info(fmt.Sprintf("[nvidia-ml] error looking up NVML GPU memory: %s", C.GoString(memInfo.err))) slog.Info("error looking up nvidia GPU memory", "error", C.GoString(memInfo.err))
C.free(unsafe.Pointer(memInfo.err)) C.free(unsafe.Pointer(memInfo.err))
} else if memInfo.count > 0 { continue
// Verify minimum compute capability
var cc C.nvml_compute_capability_t
C.nvml_compute_capability(*gpuHandles.nvml, &cc)
if cc.err != nil {
slog.Info(fmt.Sprintf("[nvidia-ml] error looking up NVML GPU compute capability: %s", C.GoString(cc.err)))
C.free(unsafe.Pointer(cc.err))
} else if cc.major > CudaComputeMin[0] || (cc.major == CudaComputeMin[0] && cc.minor >= CudaComputeMin[1]) {
slog.Info(fmt.Sprintf("[nvidia-ml] NVML CUDA Compute Capability detected: %d.%d", cc.major, cc.minor))
resp.Library = "cuda"
resp.MinimumMemory = cudaMinimumMemory
} else {
slog.Info(fmt.Sprintf("[nvidia-ml] CUDA GPU is too old. Falling back to CPU mode. Compute Capability detected: %d.%d", cc.major, cc.minor))
}
} }
} else if gpuHandles.cudart != nil && (cpuVariant != "" || runtime.GOARCH != "amd64") { if memInfo.major < CudaComputeMin[0] || (memInfo.major == CudaComputeMin[0] && memInfo.minor < CudaComputeMin[1]) {
C.cudart_check_vram(*gpuHandles.cudart, &memInfo) slog.Info(fmt.Sprintf("[%d] CUDA GPU is too old. Compute Capability detected: %d.%d", i, memInfo.major, memInfo.minor))
if memInfo.err != nil { continue
slog.Info(fmt.Sprintf("[cudart] error looking up CUDART GPU memory: %s", C.GoString(memInfo.err)))
C.free(unsafe.Pointer(memInfo.err))
} else if memInfo.count > 0 {
// Verify minimum compute capability
var cc C.cudart_compute_capability_t
C.cudart_compute_capability(*gpuHandles.cudart, &cc)
if cc.err != nil {
slog.Info(fmt.Sprintf("[cudart] error looking up CUDA compute capability: %s", C.GoString(cc.err)))
C.free(unsafe.Pointer(cc.err))
} else if cc.major > CudaComputeMin[0] || (cc.major == CudaComputeMin[0] && cc.minor >= CudaComputeMin[1]) {
slog.Info(fmt.Sprintf("[cudart] CUDART CUDA Compute Capability detected: %d.%d", cc.major, cc.minor))
resp.Library = "cuda"
resp.MinimumMemory = cudaMinimumMemory
} else {
slog.Info(fmt.Sprintf("[cudart] CUDA GPU is too old. Falling back to CPU mode. Compute Capability detected: %d.%d", cc.major, cc.minor))
}
} }
} else { gpuInfo.TotalMemory = uint64(memInfo.total)
AMDGetGPUInfo(&resp) gpuInfo.FreeMemory = uint64(memInfo.free)
if resp.Library != "" { gpuInfo.ID = C.GoString(&memInfo.gpu_id[0])
resp.MinimumMemory = rocmMinimumMemory gpuInfo.Major = int(memInfo.major)
return resp gpuInfo.Minor = int(memInfo.minor)
} gpuInfo.MinimumMemory = cudaMinimumMemory
}
if resp.Library == "" { // TODO potentially sort on our own algorithm instead of what the underlying GPU library does...
C.cpu_check_ram(&memInfo) resp = append(resp, gpuInfo)
resp.Library = "cpu" }
resp.Variant = cpuVariant
} // Then AMD
if memInfo.err != nil { resp = append(resp, AMDGetGPUInfo()...)
slog.Info(fmt.Sprintf("error looking up CPU memory: %s", C.GoString(memInfo.err)))
C.free(unsafe.Pointer(memInfo.err)) if len(resp) == 0 {
return resp C.cpu_check_ram(&memInfo)
if memInfo.err != nil {
slog.Info("error looking up CPU memory", "error", C.GoString(memInfo.err))
C.free(unsafe.Pointer(memInfo.err))
return resp
}
gpuInfo := GpuInfo{
Library: "cpu",
Variant: cpuVariant,
}
gpuInfo.TotalMemory = uint64(memInfo.total)
gpuInfo.FreeMemory = uint64(memInfo.free)
gpuInfo.ID = C.GoString(&memInfo.gpu_id[0])
resp = append(resp, gpuInfo)
} }
resp.DeviceCount = uint32(memInfo.count)
resp.FreeMemory = uint64(memInfo.free)
resp.TotalMemory = uint64(memInfo.total)
return resp return resp
} }
func getCPUMem() (memInfo, error) { func GetCPUMem() (memInfo, error) {
var ret memInfo var ret memInfo
var info C.mem_info_t var info C.mem_info_t
C.cpu_check_ram(&info) C.cpu_check_ram(&info)
@ -243,29 +196,11 @@ func getCPUMem() (memInfo, error) {
return ret, nil return ret, nil
} }
func CheckVRAM() (uint64, error) {
userLimit := os.Getenv("OLLAMA_MAX_VRAM")
if userLimit != "" {
avail, err := strconv.ParseInt(userLimit, 10, 64)
if err != nil {
return 0, fmt.Errorf("Invalid OLLAMA_MAX_VRAM setting %s: %s", userLimit, err)
}
slog.Info(fmt.Sprintf("user override OLLAMA_MAX_VRAM=%d", avail))
return uint64(avail), nil
}
gpuInfo := GetGPUInfo()
if gpuInfo.FreeMemory > 0 && (gpuInfo.Library == "cuda" || gpuInfo.Library == "rocm") {
return gpuInfo.FreeMemory, nil
}
return 0, fmt.Errorf("no GPU detected") // TODO - better handling of CPU based memory determiniation
}
func FindGPULibs(baseLibName string, patterns []string) []string { func FindGPULibs(baseLibName string, patterns []string) []string {
// Multiple GPU libraries may exist, and some may not work, so keep trying until we exhaust them // Multiple GPU libraries may exist, and some may not work, so keep trying until we exhaust them
var ldPaths []string var ldPaths []string
gpuLibPaths := []string{} gpuLibPaths := []string{}
slog.Info(fmt.Sprintf("Searching for GPU management library %s", baseLibName)) slog.Debug("Searching for GPU library", "name", baseLibName)
switch runtime.GOOS { switch runtime.GOOS {
case "windows": case "windows":
@ -283,7 +218,7 @@ func FindGPULibs(baseLibName string, patterns []string) []string {
} }
patterns = append(patterns, filepath.Join(d, baseLibName+"*")) patterns = append(patterns, filepath.Join(d, baseLibName+"*"))
} }
slog.Debug(fmt.Sprintf("gpu management search paths: %v", patterns)) slog.Debug("gpu library search", "globs", patterns)
for _, pattern := range patterns { for _, pattern := range patterns {
// Ignore glob discovery errors // Ignore glob discovery errors
matches, _ := filepath.Glob(pattern) matches, _ := filepath.Glob(pattern)
@ -311,28 +246,11 @@ func FindGPULibs(baseLibName string, patterns []string) []string {
} }
} }
} }
slog.Info(fmt.Sprintf("Discovered GPU libraries: %v", gpuLibPaths)) slog.Debug("discovered GPU libraries", "paths", gpuLibPaths)
return gpuLibPaths return gpuLibPaths
} }
func LoadNVMLMgmt(nvmlLibPaths []string) *C.nvml_handle_t { func LoadCUDARTMgmt(cudartLibPaths []string) (int, *C.cudart_handle_t, string) {
var resp C.nvml_init_resp_t
resp.ch.verbose = getVerboseState()
for _, libPath := range nvmlLibPaths {
lib := C.CString(libPath)
defer C.free(unsafe.Pointer(lib))
C.nvml_init(lib, &resp)
if resp.err != nil {
slog.Info(fmt.Sprintf("Unable to load NVML management library %s: %s", libPath, C.GoString(resp.err)))
C.free(unsafe.Pointer(resp.err))
} else {
return &resp.ch
}
}
return nil
}
func LoadCUDARTMgmt(cudartLibPaths []string) *C.cudart_handle_t {
var resp C.cudart_init_resp_t var resp C.cudart_init_resp_t
resp.ch.verbose = getVerboseState() resp.ch.verbose = getVerboseState()
for _, libPath := range cudartLibPaths { for _, libPath := range cudartLibPaths {
@ -340,13 +258,13 @@ func LoadCUDARTMgmt(cudartLibPaths []string) *C.cudart_handle_t {
defer C.free(unsafe.Pointer(lib)) defer C.free(unsafe.Pointer(lib))
C.cudart_init(lib, &resp) C.cudart_init(lib, &resp)
if resp.err != nil { if resp.err != nil {
slog.Info(fmt.Sprintf("Unable to load cudart CUDA management library %s: %s", libPath, C.GoString(resp.err))) slog.Debug("Unable to load cudart", "library", libPath, "error", C.GoString(resp.err))
C.free(unsafe.Pointer(resp.err)) C.free(unsafe.Pointer(resp.err))
} else { } else {
return &resp.ch return int(resp.num_devices), &resp.ch, libPath
} }
} }
return nil return 0, nil, ""
} }
func getVerboseState() C.uint16_t { func getVerboseState() C.uint16_t {
@ -355,3 +273,22 @@ func getVerboseState() C.uint16_t {
} }
return C.uint16_t(0) return C.uint16_t(0)
} }
// Given the list of GPUs this instantiation is targeted for,
// figure out the visible devices environment variable
//
// If different libraries are detected, the first one is what we use
func (l GpuInfoList) GetVisibleDevicesEnv() (string, string) {
if len(l) == 0 {
return "", ""
}
switch l[0].Library {
case "cuda":
return cudaGetVisibleDevicesEnv(l)
case "rocm":
return rocmGetVisibleDevicesEnv(l)
default:
slog.Debug("no filter required for library " + l[0].Library)
return "", ""
}
}

View file

@ -9,52 +9,41 @@ package gpu
*/ */
import "C" import "C"
import ( import (
"fmt"
"log/slog"
"os"
"runtime" "runtime"
"strconv"
) )
// CheckVRAM returns the free VRAM in bytes on Linux machines with NVIDIA GPUs func GetGPUInfo() GpuInfoList {
func CheckVRAM() (uint64, error) { mem, _ := GetCPUMem()
userLimit := os.Getenv("OLLAMA_MAX_VRAM")
if userLimit != "" {
avail, err := strconv.ParseInt(userLimit, 10, 64)
if err != nil {
return 0, fmt.Errorf("Invalid OLLAMA_MAX_VRAM setting %s: %s", userLimit, err)
}
slog.Info(fmt.Sprintf("user override OLLAMA_MAX_VRAM=%d", avail))
return uint64(avail), nil
}
if runtime.GOARCH == "amd64" { if runtime.GOARCH == "amd64" {
// gpu not supported, this may not be metal return []GpuInfo{
return 0, nil {
} Library: "cpu",
Variant: GetCPUVariant(),
return uint64(C.getRecommendedMaxVRAM()), nil memInfo: mem,
} },
func GetGPUInfo() GpuInfo {
mem, _ := getCPUMem()
if runtime.GOARCH == "amd64" {
return GpuInfo{
Library: "cpu",
Variant: GetCPUVariant(),
memInfo: mem,
} }
} }
return GpuInfo{ info := GpuInfo{
Library: "metal", Library: "metal",
memInfo: mem, ID: "0",
} }
info.TotalMemory = uint64(C.getRecommendedMaxVRAM())
// TODO is there a way to gather actual allocated video memory? (currentAllocatedSize doesn't work)
info.FreeMemory = info.TotalMemory
info.MinimumMemory = 0
return []GpuInfo{info}
} }
func getCPUMem() (memInfo, error) { func GetCPUMem() (memInfo, error) {
return memInfo{ return memInfo{
TotalMemory: uint64(C.getPhysicalMemory()), TotalMemory: uint64(C.getPhysicalMemory()),
FreeMemory: 0, FreeMemory: 0,
DeviceCount: 1,
}, nil }, nil
} }
func (l GpuInfoList) GetVisibleDevicesEnv() (string, string) {
// No-op on darwin
return "", ""
}

View file

@ -38,12 +38,17 @@
extern "C" { extern "C" {
#endif #endif
#define GPU_ID_LEN 64
typedef struct mem_info { typedef struct mem_info {
char *err; // If non-nill, caller responsible for freeing
char gpu_id[GPU_ID_LEN];
uint64_t total; uint64_t total;
uint64_t free; uint64_t free;
unsigned int count;
int igpu_index; // If >= 0, we detected an integrated GPU to ignore // Compute Capability
char *err; // If non-nill, caller responsible for freeing int major;
int minor;
} mem_info_t; } mem_info_t;
void cpu_check_ram(mem_info_t *resp); void cpu_check_ram(mem_info_t *resp);
@ -52,7 +57,6 @@ void cpu_check_ram(mem_info_t *resp);
} }
#endif #endif
#include "gpu_info_nvml.h"
#include "gpu_info_cudart.h" #include "gpu_info_cudart.h"
#endif // __GPU_INFO_H__ #endif // __GPU_INFO_H__

View file

@ -8,9 +8,11 @@ void cpu_check_ram(mem_info_t *resp) {
MEMORYSTATUSEX info; MEMORYSTATUSEX info;
info.dwLength = sizeof(info); info.dwLength = sizeof(info);
if (GlobalMemoryStatusEx(&info) != 0) { if (GlobalMemoryStatusEx(&info) != 0) {
resp->count = 1;
resp->total = info.ullTotalPhys; resp->total = info.ullTotalPhys;
resp->free = info.ullAvailPhys; resp->free = info.ullAvailPhys;
resp->major = 0;
resp->minor = 0;
snprintf(&resp->gpu_id[0], GPU_ID_LEN, "0");
} else { } else {
resp->err = LOAD_ERR(); resp->err = LOAD_ERR();
} }
@ -27,9 +29,11 @@ void cpu_check_ram(mem_info_t *resp) {
if (sysinfo(&info) != 0) { if (sysinfo(&info) != 0) {
resp->err = strdup(strerror(errno)); resp->err = strdup(strerror(errno));
} else { } else {
resp->count = 1;
resp->total = info.totalram * info.mem_unit; resp->total = info.totalram * info.mem_unit;
resp->free = info.freeram * info.mem_unit; resp->free = info.freeram * info.mem_unit;
resp->major = 0;
resp->minor = 0;
snprintf(&resp->gpu_id[0], GPU_ID_LEN, "0");
} }
return; return;
} }

View file

@ -6,6 +6,7 @@
void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) { void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) {
cudartReturn_t ret; cudartReturn_t ret;
resp->err = NULL; resp->err = NULL;
resp->num_devices = 0;
const int buflen = 256; const int buflen = 256;
char buf[buflen + 1]; char buf[buflen + 1];
int i; int i;
@ -21,6 +22,7 @@ void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) {
{"cudaGetDeviceCount", (void *)&resp->ch.cudaGetDeviceCount}, {"cudaGetDeviceCount", (void *)&resp->ch.cudaGetDeviceCount},
{"cudaDeviceGetAttribute", (void *)&resp->ch.cudaDeviceGetAttribute}, {"cudaDeviceGetAttribute", (void *)&resp->ch.cudaDeviceGetAttribute},
{"cudaDriverGetVersion", (void *)&resp->ch.cudaDriverGetVersion}, {"cudaDriverGetVersion", (void *)&resp->ch.cudaDriverGetVersion},
{"cudaGetDeviceProperties", (void *)&resp->ch.cudaGetDeviceProperties},
{NULL, NULL}, {NULL, NULL},
}; };
@ -36,13 +38,7 @@ void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) {
return; return;
} }
// TODO once we've squashed the remaining corner cases remove this log
LOG(resp->ch.verbose, "wiring cudart library functions in %s\n", cudart_lib_path);
for (i = 0; l[i].s != NULL; i++) { for (i = 0; l[i].s != NULL; i++) {
// TODO once we've squashed the remaining corner cases remove this log
LOG(resp->ch.verbose, "dlsym: %s\n", l[i].s);
*l[i].p = LOAD_SYMBOL(resp->ch.handle, l[i].s); *l[i].p = LOAD_SYMBOL(resp->ch.handle, l[i].s);
if (!l[i].p) { if (!l[i].p) {
char *msg = LOAD_ERR(); char *msg = LOAD_ERR();
@ -63,7 +59,7 @@ void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) {
UNLOAD_LIBRARY(resp->ch.handle); UNLOAD_LIBRARY(resp->ch.handle);
resp->ch.handle = NULL; resp->ch.handle = NULL;
if (ret == CUDA_ERROR_INSUFFICIENT_DRIVER) { if (ret == CUDA_ERROR_INSUFFICIENT_DRIVER) {
resp->err = strdup("your nvidia driver is too old or missing, please upgrade to run ollama"); resp->err = strdup("your nvidia driver is too old or missing. If you have a CUDA GPU please upgrade to run ollama");
return; return;
} }
snprintf(buf, buflen, "cudart init failure: %d", ret); snprintf(buf, buflen, "cudart init failure: %d", ret);
@ -85,110 +81,95 @@ void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) {
driverVersion.minor = (version - (driverVersion.major * 1000)) / 10; driverVersion.minor = (version - (driverVersion.major * 1000)) / 10;
LOG(resp->ch.verbose, "CUDA driver version: %d-%d\n", driverVersion.major, driverVersion.minor); LOG(resp->ch.verbose, "CUDA driver version: %d-%d\n", driverVersion.major, driverVersion.minor);
} }
ret = (*resp->ch.cudaGetDeviceCount)(&resp->num_devices);
if (ret != CUDART_SUCCESS) {
LOG(resp->ch.verbose, "cudaGetDeviceCount err: %d\n", ret);
UNLOAD_LIBRARY(resp->ch.handle);
resp->ch.handle = NULL;
snprintf(buf, buflen, "unable to get device count: %d", ret);
resp->err = strdup(buf);
return;
}
} }
void cudart_check_vram(cudart_handle_t h, mem_info_t *resp) { void cudart_check_vram(cudart_handle_t h, int i, mem_info_t *resp) {
resp->err = NULL; resp->err = NULL;
cudartMemory_t memInfo = {0,0,0}; cudartMemory_t memInfo = {0,0,0};
cudartReturn_t ret; cudartReturn_t ret;
const int buflen = 256; const int buflen = 256;
char buf[buflen + 1]; char buf[buflen + 1];
int i;
if (h.handle == NULL) { if (h.handle == NULL) {
resp->err = strdup("cudart handle isn't initialized"); resp->err = strdup("cudart handle isn't initialized");
return; return;
} }
// cudaGetDeviceCount takes int type, resp-> count is uint ret = (*h.cudaSetDevice)(i);
int deviceCount;
ret = (*h.cudaGetDeviceCount)(&deviceCount);
if (ret != CUDART_SUCCESS) { if (ret != CUDART_SUCCESS) {
snprintf(buf, buflen, "unable to get device count: %d", ret); snprintf(buf, buflen, "cudart device failed to initialize");
resp->err = strdup(buf); resp->err = strdup(buf);
return; return;
}
cudaDeviceProp_t props;
ret = (*h.cudaGetDeviceProperties)(&props, i);
if (ret != CUDART_SUCCESS) {
LOG(h.verbose, "[%d] device properties lookup failure: %d\n", i, ret);
snprintf(&resp->gpu_id[0], GPU_ID_LEN, "%d", i);
resp->major = 0;
resp->minor = 0;
} else { } else {
resp->count = (unsigned int)deviceCount; int allNull = 1;
} for (int j = 0; j < 16; j++) {
if (props.uuid.bytes[j] != 0) {
resp->total = 0; allNull = 0;
resp->free = 0; break;
for (i = 0; i < resp-> count; i++) { }
ret = (*h.cudaSetDevice)(i);
if (ret != CUDART_SUCCESS) {
snprintf(buf, buflen, "cudart device failed to initialize");
resp->err = strdup(buf);
return;
} }
ret = (*h.cudaMemGetInfo)(&memInfo.free, &memInfo.total); if (allNull != 0) {
if (ret != CUDART_SUCCESS) { snprintf(&resp->gpu_id[0], GPU_ID_LEN, "%d", i);
snprintf(buf, buflen, "cudart device memory info lookup failure %d", ret); } else {
resp->err = strdup(buf); // GPU-d110a105-ac29-1d54-7b49-9c90440f215b
return; snprintf(&resp->gpu_id[0], GPU_ID_LEN,
"GPU-%02x%02x%02x%02x-%02x%02x-%02x%02x-%02x%02x-%02x%02x%02x%02x%02x%02x",
props.uuid.bytes[0],
props.uuid.bytes[1],
props.uuid.bytes[2],
props.uuid.bytes[3],
props.uuid.bytes[4],
props.uuid.bytes[5],
props.uuid.bytes[6],
props.uuid.bytes[7],
props.uuid.bytes[8],
props.uuid.bytes[9],
props.uuid.bytes[10],
props.uuid.bytes[11],
props.uuid.bytes[12],
props.uuid.bytes[13],
props.uuid.bytes[14],
props.uuid.bytes[15]
);
} }
resp->major = props.major;
resp->minor = props.minor;
LOG(h.verbose, "[%d] CUDA totalMem %lu\n", i, memInfo.total); // TODO add other useful properties from props
LOG(h.verbose, "[%d] CUDA freeMem %lu\n", i, memInfo.free);
resp->total += memInfo.total;
resp->free += memInfo.free;
} }
} ret = (*h.cudaMemGetInfo)(&memInfo.free, &memInfo.total);
void cudart_compute_capability(cudart_handle_t h, cudart_compute_capability_t *resp) {
resp->err = NULL;
resp->major = 0;
resp->minor = 0;
int major = 0;
int minor = 0;
cudartReturn_t ret;
const int buflen = 256;
char buf[buflen + 1];
int i;
if (h.handle == NULL) {
resp->err = strdup("cudart handle not initialized");
return;
}
int devices;
ret = (*h.cudaGetDeviceCount)(&devices);
if (ret != CUDART_SUCCESS) { if (ret != CUDART_SUCCESS) {
snprintf(buf, buflen, "unable to get cudart device count: %d", ret); snprintf(buf, buflen, "cudart device memory info lookup failure %d", ret);
resp->err = strdup(buf); resp->err = strdup(buf);
return; return;
} }
for (i = 0; i < devices; i++) { resp->total = memInfo.total;
ret = (*h.cudaSetDevice)(i); resp->free = memInfo.free;
if (ret != CUDART_SUCCESS) {
snprintf(buf, buflen, "cudart device failed to initialize");
resp->err = strdup(buf);
return;
}
ret = (*h.cudaDeviceGetAttribute)(&major, cudartDevAttrComputeCapabilityMajor, i); LOG(h.verbose, "[%s] CUDA totalMem %lu\n", resp->gpu_id, resp->total);
if (ret != CUDART_SUCCESS) { LOG(h.verbose, "[%s] CUDA freeMem %lu\n", resp->gpu_id, resp->free);
snprintf(buf, buflen, "device compute capability lookup failure %d: %d", i, ret); LOG(h.verbose, "[%s] Compute Capability %d.%d\n", resp->gpu_id, resp->major, resp->minor);
resp->err = strdup(buf);
return;
}
ret = (*h.cudaDeviceGetAttribute)(&minor, cudartDevAttrComputeCapabilityMinor, i);
if (ret != CUDART_SUCCESS) {
snprintf(buf, buflen, "device compute capability lookup failure %d: %d", i, ret);
resp->err = strdup(buf);
return;
}
// Report the lowest major.minor we detect as that limits our compatibility
if (resp->major == 0 || resp->major > major ) {
resp->major = major;
resp->minor = minor;
} else if ( resp->major == major && resp->minor > minor ) {
resp->minor = minor;
}
}
} }
void cudart_release(cudart_handle_t h) { void cudart_release(cudart_handle_t h) {

View file

@ -6,7 +6,8 @@
// Just enough typedef's to dlopen/dlsym for memory information // Just enough typedef's to dlopen/dlsym for memory information
typedef enum cudartReturn_enum { typedef enum cudartReturn_enum {
CUDART_SUCCESS = 0, CUDART_SUCCESS = 0,
CUDART_UNSUPPORTED = 1, CUDA_ERROR_INVALID_VALUE = 1,
CUDA_ERROR_MEMORY_ALLOCATION = 2,
CUDA_ERROR_INSUFFICIENT_DRIVER = 35, CUDA_ERROR_INSUFFICIENT_DRIVER = 35,
// Other values omitted for now... // Other values omitted for now...
} cudartReturn_t; } cudartReturn_t;
@ -14,6 +15,11 @@ typedef enum cudartReturn_enum {
typedef enum cudartDeviceAttr_enum { typedef enum cudartDeviceAttr_enum {
cudartDevAttrComputeCapabilityMajor = 75, cudartDevAttrComputeCapabilityMajor = 75,
cudartDevAttrComputeCapabilityMinor = 76, cudartDevAttrComputeCapabilityMinor = 76,
// TODO - not yet wired up but may be useful for Jetson or other
// integrated GPU scenarios with shared memory
cudaDevAttrIntegrated = 18
} cudartDeviceAttr_t; } cudartDeviceAttr_t;
typedef void *cudartDevice_t; // Opaque is sufficient typedef void *cudartDevice_t; // Opaque is sufficient
@ -28,6 +34,92 @@ typedef struct cudartDriverVersion {
int minor; int minor;
} cudartDriverVersion_t; } cudartDriverVersion_t;
typedef struct cudaUUID {
unsigned char bytes[16];
} cudaUUID_t;
typedef struct cudaDeviceProp {
char name[256]; /**< ASCII string identifying device */
cudaUUID_t uuid; /**< 16-byte unique identifier */
char luid[8]; /**< 8-byte locally unique identifier. Value is undefined on TCC and non-Windows platforms */
unsigned int luidDeviceNodeMask; /**< LUID device node mask. Value is undefined on TCC and non-Windows platforms */
size_t totalGlobalMem; /**< Global memory available on device in bytes */
size_t sharedMemPerBlock; /**< Shared memory available per block in bytes */
int regsPerBlock; /**< 32-bit registers available per block */
int warpSize; /**< Warp size in threads */
size_t memPitch; /**< Maximum pitch in bytes allowed by memory copies */
int maxThreadsPerBlock; /**< Maximum number of threads per block */
int maxThreadsDim[3]; /**< Maximum size of each dimension of a block */
int maxGridSize[3]; /**< Maximum size of each dimension of a grid */
int clockRate; /**< Clock frequency in kilohertz */
size_t totalConstMem; /**< Constant memory available on device in bytes */
int major; /**< Major compute capability */
int minor; /**< Minor compute capability */
size_t textureAlignment; /**< Alignment requirement for textures */
size_t texturePitchAlignment; /**< Pitch alignment requirement for texture references bound to pitched memory */
int deviceOverlap; /**< Device can concurrently copy memory and execute a kernel. Deprecated. Use instead asyncEngineCount. */
int multiProcessorCount; /**< Number of multiprocessors on device */
int kernelExecTimeoutEnabled; /**< Specified whether there is a run time limit on kernels */
int integrated; /**< Device is integrated as opposed to discrete */
int canMapHostMemory; /**< Device can map host memory with cudaHostAlloc/cudaHostGetDevicePointer */
int computeMode; /**< Compute mode (See ::cudaComputeMode) */
int maxTexture1D; /**< Maximum 1D texture size */
int maxTexture1DMipmap; /**< Maximum 1D mipmapped texture size */
int maxTexture1DLinear; /**< Deprecated, do not use. Use cudaDeviceGetTexture1DLinearMaxWidth() or cuDeviceGetTexture1DLinearMaxWidth() instead. */
int maxTexture2D[2]; /**< Maximum 2D texture dimensions */
int maxTexture2DMipmap[2]; /**< Maximum 2D mipmapped texture dimensions */
int maxTexture2DLinear[3]; /**< Maximum dimensions (width, height, pitch) for 2D textures bound to pitched memory */
int maxTexture2DGather[2]; /**< Maximum 2D texture dimensions if texture gather operations have to be performed */
int maxTexture3D[3]; /**< Maximum 3D texture dimensions */
int maxTexture3DAlt[3]; /**< Maximum alternate 3D texture dimensions */
int maxTextureCubemap; /**< Maximum Cubemap texture dimensions */
int maxTexture1DLayered[2]; /**< Maximum 1D layered texture dimensions */
int maxTexture2DLayered[3]; /**< Maximum 2D layered texture dimensions */
int maxTextureCubemapLayered[2];/**< Maximum Cubemap layered texture dimensions */
int maxSurface1D; /**< Maximum 1D surface size */
int maxSurface2D[2]; /**< Maximum 2D surface dimensions */
int maxSurface3D[3]; /**< Maximum 3D surface dimensions */
int maxSurface1DLayered[2]; /**< Maximum 1D layered surface dimensions */
int maxSurface2DLayered[3]; /**< Maximum 2D layered surface dimensions */
int maxSurfaceCubemap; /**< Maximum Cubemap surface dimensions */
int maxSurfaceCubemapLayered[2];/**< Maximum Cubemap layered surface dimensions */
size_t surfaceAlignment; /**< Alignment requirements for surfaces */
int concurrentKernels; /**< Device can possibly execute multiple kernels concurrently */
int ECCEnabled; /**< Device has ECC support enabled */
int pciBusID; /**< PCI bus ID of the device */
int pciDeviceID; /**< PCI device ID of the device */
int pciDomainID; /**< PCI domain ID of the device */
int tccDriver; /**< 1 if device is a Tesla device using TCC driver, 0 otherwise */
int asyncEngineCount; /**< Number of asynchronous engines */
int unifiedAddressing; /**< Device shares a unified address space with the host */
int memoryClockRate; /**< Peak memory clock frequency in kilohertz */
int memoryBusWidth; /**< Global memory bus width in bits */
int l2CacheSize; /**< Size of L2 cache in bytes */
int persistingL2CacheMaxSize; /**< Device's maximum l2 persisting lines capacity setting in bytes */
int maxThreadsPerMultiProcessor;/**< Maximum resident threads per multiprocessor */
int streamPrioritiesSupported; /**< Device supports stream priorities */
int globalL1CacheSupported; /**< Device supports caching globals in L1 */
int localL1CacheSupported; /**< Device supports caching locals in L1 */
size_t sharedMemPerMultiprocessor; /**< Shared memory available per multiprocessor in bytes */
int regsPerMultiprocessor; /**< 32-bit registers available per multiprocessor */
int managedMemory; /**< Device supports allocating managed memory on this system */
int isMultiGpuBoard; /**< Device is on a multi-GPU board */
int multiGpuBoardGroupID; /**< Unique identifier for a group of devices on the same multi-GPU board */
int hostNativeAtomicSupported; /**< Link between the device and the host supports native atomic operations */
int singleToDoublePrecisionPerfRatio; /**< Ratio of single precision performance (in floating-point operations per second) to double precision performance */
int pageableMemoryAccess; /**< Device supports coherently accessing pageable memory without calling cudaHostRegister on it */
int concurrentManagedAccess; /**< Device can coherently access managed memory concurrently with the CPU */
int computePreemptionSupported; /**< Device supports Compute Preemption */
int canUseHostPointerForRegisteredMem; /**< Device can access host registered memory at the same virtual address as the CPU */
int cooperativeLaunch; /**< Device supports launching cooperative kernels via ::cudaLaunchCooperativeKernel */
int cooperativeMultiDeviceLaunch; /**< Deprecated, cudaLaunchCooperativeKernelMultiDevice is deprecated. */
size_t sharedMemPerBlockOptin; /**< Per device maximum shared memory per block usable by special opt in */
int pageableMemoryAccessUsesHostPageTables; /**< Device accesses pageable memory via the host's page tables */
int directManagedMemAccessFromHost; /**< Host can directly access managed memory on the device without migration. */
int maxBlocksPerMultiProcessor; /**< Maximum number of resident blocks per multiprocessor */
int accessPolicyMaxWindowSize; /**< The maximum value of ::cudaAccessPolicyWindow::num_bytes. */
size_t reservedSharedMemPerBlock; /**< Shared memory reserved by CUDA driver per block in bytes */
} cudaDeviceProp_t;
typedef struct cudart_handle { typedef struct cudart_handle {
void *handle; void *handle;
uint16_t verbose; uint16_t verbose;
@ -38,23 +130,17 @@ typedef struct cudart_handle {
cudartReturn_t (*cudaGetDeviceCount)(int *); cudartReturn_t (*cudaGetDeviceCount)(int *);
cudartReturn_t (*cudaDeviceGetAttribute)(int* value, cudartDeviceAttr_t attr, int device); cudartReturn_t (*cudaDeviceGetAttribute)(int* value, cudartDeviceAttr_t attr, int device);
cudartReturn_t (*cudaDriverGetVersion) (int *driverVersion); cudartReturn_t (*cudaDriverGetVersion) (int *driverVersion);
cudartReturn_t (*cudaGetDeviceProperties) (cudaDeviceProp_t* prop, int device);
} cudart_handle_t; } cudart_handle_t;
typedef struct cudart_init_resp { typedef struct cudart_init_resp {
char *err; // If err is non-null handle is invalid char *err; // If err is non-null handle is invalid
cudart_handle_t ch; cudart_handle_t ch;
int num_devices;
} cudart_init_resp_t; } cudart_init_resp_t;
typedef struct cudart_compute_capability {
char *err;
int major;
int minor;
} cudart_compute_capability_t;
void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp); void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp);
void cudart_check_vram(cudart_handle_t ch, mem_info_t *resp); void cudart_check_vram(cudart_handle_t ch, int device_id, mem_info_t *resp);
void cudart_compute_capability(cudart_handle_t th, cudart_compute_capability_t *cc);
void cudart_release(cudart_handle_t ch); void cudart_release(cudart_handle_t ch);
#endif // __GPU_INFO_CUDART_H__ #endif // __GPU_INFO_CUDART_H__

View file

@ -1,221 +0,0 @@
#ifndef __APPLE__ // TODO - maybe consider nvidia support on intel macs?
#include <string.h>
#include "gpu_info_nvml.h"
void nvml_init(char *nvml_lib_path, nvml_init_resp_t *resp) {
nvmlReturn_t ret;
resp->err = NULL;
const int buflen = 256;
char buf[buflen + 1];
int i;
struct lookup {
char *s;
void **p;
} l[] = {
{"nvmlInit_v2", (void *)&resp->ch.nvmlInit_v2},
{"nvmlShutdown", (void *)&resp->ch.nvmlShutdown},
{"nvmlDeviceGetHandleByIndex", (void *)&resp->ch.nvmlDeviceGetHandleByIndex},
{"nvmlDeviceGetMemoryInfo", (void *)&resp->ch.nvmlDeviceGetMemoryInfo},
{"nvmlDeviceGetCount_v2", (void *)&resp->ch.nvmlDeviceGetCount_v2},
{"nvmlDeviceGetCudaComputeCapability", (void *)&resp->ch.nvmlDeviceGetCudaComputeCapability},
{"nvmlSystemGetDriverVersion", (void *)&resp->ch.nvmlSystemGetDriverVersion},
{"nvmlDeviceGetName", (void *)&resp->ch.nvmlDeviceGetName},
{"nvmlDeviceGetSerial", (void *)&resp->ch.nvmlDeviceGetSerial},
{"nvmlDeviceGetVbiosVersion", (void *)&resp->ch.nvmlDeviceGetVbiosVersion},
{"nvmlDeviceGetBoardPartNumber", (void *)&resp->ch.nvmlDeviceGetBoardPartNumber},
{"nvmlDeviceGetBrand", (void *)&resp->ch.nvmlDeviceGetBrand},
{NULL, NULL},
};
resp->ch.handle = LOAD_LIBRARY(nvml_lib_path, RTLD_LAZY);
if (!resp->ch.handle) {
char *msg = LOAD_ERR();
LOG(resp->ch.verbose, "library %s load err: %s\n", nvml_lib_path, msg);
snprintf(buf, buflen,
"Unable to load %s library to query for Nvidia GPUs: %s",
nvml_lib_path, msg);
free(msg);
resp->err = strdup(buf);
return;
}
// TODO once we've squashed the remaining corner cases remove this log
LOG(resp->ch.verbose, "wiring nvidia management library functions in %s\n", nvml_lib_path);
for (i = 0; l[i].s != NULL; i++) {
// TODO once we've squashed the remaining corner cases remove this log
LOG(resp->ch.verbose, "dlsym: %s\n", l[i].s);
*l[i].p = LOAD_SYMBOL(resp->ch.handle, l[i].s);
if (!l[i].p) {
resp->ch.handle = NULL;
char *msg = LOAD_ERR();
LOG(resp->ch.verbose, "dlerr: %s\n", msg);
UNLOAD_LIBRARY(resp->ch.handle);
snprintf(buf, buflen, "symbol lookup for %s failed: %s", l[i].s,
msg);
free(msg);
resp->err = strdup(buf);
return;
}
}
ret = (*resp->ch.nvmlInit_v2)();
if (ret != NVML_SUCCESS) {
LOG(resp->ch.verbose, "nvmlInit_v2 err: %d\n", ret);
UNLOAD_LIBRARY(resp->ch.handle);
resp->ch.handle = NULL;
snprintf(buf, buflen, "nvml vram init failure: %d", ret);
resp->err = strdup(buf);
return;
}
// Report driver version if we're in verbose mode, ignore errors
ret = (*resp->ch.nvmlSystemGetDriverVersion)(buf, buflen);
if (ret != NVML_SUCCESS) {
LOG(resp->ch.verbose, "nvmlSystemGetDriverVersion failed: %d\n", ret);
} else {
LOG(resp->ch.verbose, "CUDA driver version: %s\n", buf);
}
}
void nvml_check_vram(nvml_handle_t h, mem_info_t *resp) {
resp->err = NULL;
nvmlDevice_t device;
nvmlMemory_t memInfo = {0};
nvmlReturn_t ret;
const int buflen = 256;
char buf[buflen + 1];
int i;
if (h.handle == NULL) {
resp->err = strdup("nvml handle isn't initialized");
return;
}
ret = (*h.nvmlDeviceGetCount_v2)(&resp->count);
if (ret != NVML_SUCCESS) {
snprintf(buf, buflen, "unable to get device count: %d", ret);
resp->err = strdup(buf);
return;
}
resp->total = 0;
resp->free = 0;
for (i = 0; i < resp->count; i++) {
ret = (*h.nvmlDeviceGetHandleByIndex)(i, &device);
if (ret != NVML_SUCCESS) {
snprintf(buf, buflen, "unable to get device handle %d: %d", i, ret);
resp->err = strdup(buf);
return;
}
ret = (*h.nvmlDeviceGetMemoryInfo)(device, &memInfo);
if (ret != NVML_SUCCESS) {
snprintf(buf, buflen, "device memory info lookup failure %d: %d", i, ret);
resp->err = strdup(buf);
return;
}
if (h.verbose) {
nvmlBrandType_t brand = 0;
// When in verbose mode, report more information about
// the card we discover, but don't fail on error
ret = (*h.nvmlDeviceGetName)(device, buf, buflen);
if (ret != NVML_SUCCESS) {
LOG(h.verbose, "nvmlDeviceGetName failed: %d\n", ret);
} else {
LOG(h.verbose, "[%d] CUDA device name: %s\n", i, buf);
}
ret = (*h.nvmlDeviceGetBoardPartNumber)(device, buf, buflen);
if (ret != NVML_SUCCESS) {
LOG(h.verbose, "nvmlDeviceGetBoardPartNumber failed: %d\n", ret);
} else {
LOG(h.verbose, "[%d] CUDA part number: %s\n", i, buf);
}
ret = (*h.nvmlDeviceGetSerial)(device, buf, buflen);
if (ret != NVML_SUCCESS) {
LOG(h.verbose, "nvmlDeviceGetSerial failed: %d\n", ret);
} else {
LOG(h.verbose, "[%d] CUDA S/N: %s\n", i, buf);
}
ret = (*h.nvmlDeviceGetVbiosVersion)(device, buf, buflen);
if (ret != NVML_SUCCESS) {
LOG(h.verbose, "nvmlDeviceGetVbiosVersion failed: %d\n", ret);
} else {
LOG(h.verbose, "[%d] CUDA vbios version: %s\n", i, buf);
}
ret = (*h.nvmlDeviceGetBrand)(device, &brand);
if (ret != NVML_SUCCESS) {
LOG(h.verbose, "nvmlDeviceGetBrand failed: %d\n", ret);
} else {
LOG(h.verbose, "[%d] CUDA brand: %d\n", i, brand);
}
}
LOG(h.verbose, "[%d] CUDA totalMem %ld\n", i, memInfo.total);
LOG(h.verbose, "[%d] CUDA freeMem %ld\n", i, memInfo.free);
resp->total += memInfo.total;
resp->free += memInfo.free;
}
}
void nvml_compute_capability(nvml_handle_t h, nvml_compute_capability_t *resp) {
resp->err = NULL;
resp->major = 0;
resp->minor = 0;
nvmlDevice_t device;
int major = 0;
int minor = 0;
nvmlReturn_t ret;
const int buflen = 256;
char buf[buflen + 1];
int i;
if (h.handle == NULL) {
resp->err = strdup("nvml handle not initialized");
return;
}
unsigned int devices;
ret = (*h.nvmlDeviceGetCount_v2)(&devices);
if (ret != NVML_SUCCESS) {
snprintf(buf, buflen, "unable to get device count: %d", ret);
resp->err = strdup(buf);
return;
}
for (i = 0; i < devices; i++) {
ret = (*h.nvmlDeviceGetHandleByIndex)(i, &device);
if (ret != NVML_SUCCESS) {
snprintf(buf, buflen, "unable to get device handle %d: %d", i, ret);
resp->err = strdup(buf);
return;
}
ret = (*h.nvmlDeviceGetCudaComputeCapability)(device, &major, &minor);
if (ret != NVML_SUCCESS) {
snprintf(buf, buflen, "device compute capability lookup failure %d: %d", i, ret);
resp->err = strdup(buf);
return;
}
// Report the lowest major.minor we detect as that limits our compatibility
if (resp->major == 0 || resp->major > major ) {
resp->major = major;
resp->minor = minor;
} else if ( resp->major == major && resp->minor > minor ) {
resp->minor = minor;
}
}
}
void nvml_release(nvml_handle_t h) {
LOG(h.verbose, "releasing nvml library\n");
UNLOAD_LIBRARY(h.handle);
h.handle = NULL;
}
#endif // __APPLE__

View file

@ -1,57 +0,0 @@
#ifndef __APPLE__
#ifndef __GPU_INFO_NVML_H__
#define __GPU_INFO_NVML_H__
#include "gpu_info.h"
// Just enough typedef's to dlopen/dlsym for memory information
typedef enum nvmlReturn_enum {
NVML_SUCCESS = 0,
// Other values omitted for now...
} nvmlReturn_t;
typedef void *nvmlDevice_t; // Opaque is sufficient
typedef struct nvmlMemory_st {
unsigned long long total;
unsigned long long free;
unsigned long long used;
} nvmlMemory_t;
typedef enum nvmlBrandType_enum
{
NVML_BRAND_UNKNOWN = 0,
} nvmlBrandType_t;
typedef struct nvml_handle {
void *handle;
uint16_t verbose;
nvmlReturn_t (*nvmlInit_v2)(void);
nvmlReturn_t (*nvmlShutdown)(void);
nvmlReturn_t (*nvmlDeviceGetHandleByIndex)(unsigned int, nvmlDevice_t *);
nvmlReturn_t (*nvmlDeviceGetMemoryInfo)(nvmlDevice_t, nvmlMemory_t *);
nvmlReturn_t (*nvmlDeviceGetCount_v2)(unsigned int *);
nvmlReturn_t (*nvmlDeviceGetCudaComputeCapability)(nvmlDevice_t, int* major, int* minor);
nvmlReturn_t (*nvmlSystemGetDriverVersion) (char* version, unsigned int length);
nvmlReturn_t (*nvmlDeviceGetName) (nvmlDevice_t device, char* name, unsigned int length);
nvmlReturn_t (*nvmlDeviceGetSerial) (nvmlDevice_t device, char* serial, unsigned int length);
nvmlReturn_t (*nvmlDeviceGetVbiosVersion) (nvmlDevice_t device, char* version, unsigned int length);
nvmlReturn_t (*nvmlDeviceGetBoardPartNumber) (nvmlDevice_t device, char* partNumber, unsigned int length);
nvmlReturn_t (*nvmlDeviceGetBrand) (nvmlDevice_t device, nvmlBrandType_t* type);
} nvml_handle_t;
typedef struct nvml_init_resp {
char *err; // If err is non-null handle is invalid
nvml_handle_t ch;
} nvml_init_resp_t;
typedef struct nvml_compute_capability {
char *err;
int major;
int minor;
} nvml_compute_capability_t;
void nvml_init(char *nvml_lib_path, nvml_init_resp_t *resp);
void nvml_check_vram(nvml_handle_t ch, mem_info_t *resp);
void nvml_compute_capability(nvml_handle_t ch, nvml_compute_capability_t *cc);
void nvml_release(nvml_handle_t ch);
#endif // __GPU_INFO_NVML_H__
#endif // __APPLE__

View file

@ -9,23 +9,16 @@ import (
func TestBasicGetGPUInfo(t *testing.T) { func TestBasicGetGPUInfo(t *testing.T) {
info := GetGPUInfo() info := GetGPUInfo()
assert.Contains(t, "cuda rocm cpu metal", info.Library) assert.Greater(t, len(info), 0)
assert.Contains(t, "cuda rocm cpu metal", info[0].Library)
switch runtime.GOOS { if info[0].Library != "cpu" {
case "darwin": assert.Greater(t, info[0].TotalMemory, uint64(0))
// TODO - remove this once MacOS returns some size for CPU assert.Greater(t, info[0].FreeMemory, uint64(0))
return
case "linux", "windows":
assert.Greater(t, info.TotalMemory, uint64(0))
assert.Greater(t, info.FreeMemory, uint64(0))
assert.Greater(t, info.DeviceCount, uint32(0))
default:
return
} }
} }
func TestCPUMemInfo(t *testing.T) { func TestCPUMemInfo(t *testing.T) {
info, err := getCPUMem() info, err := GetCPUMem()
assert.NoError(t, err) assert.NoError(t, err)
switch runtime.GOOS { switch runtime.GOOS {
case "darwin": case "darwin":

View file

@ -3,7 +3,6 @@ package gpu
type memInfo struct { type memInfo struct {
TotalMemory uint64 `json:"total_memory,omitempty"` TotalMemory uint64 `json:"total_memory,omitempty"`
FreeMemory uint64 `json:"free_memory,omitempty"` FreeMemory uint64 `json:"free_memory,omitempty"`
DeviceCount uint32 `json:"device_count,omitempty"`
} }
// Beginning of an `ollama info` command // Beginning of an `ollama info` command
@ -17,11 +16,49 @@ type GpuInfo struct {
// MinimumMemory represents the minimum memory required to use the GPU // MinimumMemory represents the minimum memory required to use the GPU
MinimumMemory uint64 `json:"-"` MinimumMemory uint64 `json:"-"`
// TODO add other useful attributes about the card here for discovery information // Any extra PATH/LD_LIBRARY_PATH dependencies required for the Library to operate properly
DependencyPath string `json:"lib_path,omitempty"`
// GPU information
ID string `json:"gpu_id"` // string to use for selection of this specific GPU
Name string `json:"name"` // user friendly name if available
Major int `json:"major,omitempty"` // Major compatibility version (CC or gfx)
Minor int `json:"minor,omitempty"` // Minor compatibility version (CC or gfx)
Patch int `json:"patch,omitempty"` // Patch compatibility only matters on AMD
// TODO other performance capability info to help in scheduling decisions
} }
type Version struct { type GpuInfoList []GpuInfo
Major uint
Minor uint // Split up the set of gpu info's by Library and variant
Patch uint func (l GpuInfoList) ByLibrary() []GpuInfoList {
resp := []GpuInfoList{}
libs := []string{}
for _, info := range l {
found := false
requested := info.Library
if info.Variant != "" {
requested += "_" + info.Variant
}
for i, lib := range libs {
if lib == requested {
resp[i] = append(resp[i], info)
found = true
break
}
}
if !found {
libs = append(libs, info.Library)
resp = append(resp, []GpuInfo{info})
}
}
return resp
} }
// Sort by Free Space
type ByFreeMemory []GpuInfo
func (a ByFreeMemory) Len() int { return len(a) }
func (a ByFreeMemory) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
func (a ByFreeMemory) Less(i, j int) bool { return a[i].FreeMemory < a[j].FreeMemory }

View file

@ -4,11 +4,14 @@ package integration
import ( import (
"context" "context"
"net/http" "log/slog"
"os"
"runtime"
"testing" "testing"
"time" "time"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/stretchr/testify/require"
) )
func TestOrcaMiniBlueSky(t *testing.T) { func TestOrcaMiniBlueSky(t *testing.T) {
@ -24,5 +27,44 @@ func TestOrcaMiniBlueSky(t *testing.T) {
"seed": 123, "seed": 123,
}, },
} }
GenerateTestHelper(ctx, t, &http.Client{}, req, []string{"rayleigh", "scattering"}) GenerateTestHelper(ctx, t, req, []string{"rayleigh", "scattering"})
}
func TestUnicodeModelDir(t *testing.T) {
// This is only useful for Windows with utf-16 characters, so skip this test for other platforms
if runtime.GOOS != "windows" {
t.Skip("Unicode test only applicable to windows")
}
// Only works for local testing
if os.Getenv("OLLAMA_TEST_EXISTING") != "" {
t.Skip("TestUnicodeModelDir only works for local testing, skipping")
}
modelDir, err := os.MkdirTemp("", "ollama_埃")
require.NoError(t, err)
defer os.RemoveAll(modelDir)
slog.Info("unicode", "OLLAMA_MODELS", modelDir)
oldModelsDir := os.Getenv("OLLAMA_MODELS")
if oldModelsDir == "" {
defer os.Unsetenv("OLLAMA_MODELS")
} else {
defer os.Setenv("OLLAMA_MODELS", oldModelsDir)
}
err = os.Setenv("OLLAMA_MODELS", modelDir)
require.NoError(t, err)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
req := api.GenerateRequest{
Model: "orca-mini",
Prompt: "why is the sky blue?",
Stream: &stream,
Options: map[string]interface{}{
"temperature": 0,
"seed": 123,
},
}
GenerateTestHelper(ctx, t, req, []string{"rayleigh", "scattering"})
} }

View file

@ -0,0 +1,225 @@
//go:build integration
package integration
import (
"context"
"log/slog"
"os"
"strconv"
"sync"
"testing"
"time"
"github.com/ollama/ollama/api"
"github.com/stretchr/testify/require"
)
func TestMultiModelConcurrency(t *testing.T) {
var (
req = [2]api.GenerateRequest{
{
Model: "orca-mini",
Prompt: "why is the ocean blue?",
Stream: &stream,
Options: map[string]interface{}{
"seed": 42,
"temperature": 0.0,
},
}, {
Model: "tinydolphin",
Prompt: "what is the origin of the us thanksgiving holiday?",
Stream: &stream,
Options: map[string]interface{}{
"seed": 42,
"temperature": 0.0,
},
},
}
resp = [2][]string{
[]string{"sunlight"},
[]string{"england", "english", "massachusetts", "pilgrims"},
}
)
var wg sync.WaitGroup
wg.Add(len(req))
ctx, cancel := context.WithTimeout(context.Background(), time.Second*120)
defer cancel()
for i := 0; i < len(req); i++ {
go func(i int) {
defer wg.Done()
GenerateTestHelper(ctx, t, req[i], resp[i])
}(i)
}
wg.Wait()
}
func TestIntegrationConcurrentPredictOrcaMini(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute) // GTX 750 2G card takes ~9 minutes
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
req, resp := GenerateRequests()
// Get the server running (if applicable) warm the model up with a single initial request
DoGenerate(ctx, t, client, req[0], resp[0], 60*time.Second, 5*time.Second)
var wg sync.WaitGroup
wg.Add(len(req))
for i := 0; i < len(req); i++ {
go func(i int) {
defer wg.Done()
for j := 0; j < 5; j++ {
slog.Info("Starting", "req", i, "iter", j)
// On slower GPUs it can take a while to process the 4 concurrent requests
// so we allow a much longer initial timeout
DoGenerate(ctx, t, client, req[i], resp[i], 90*time.Second, 5*time.Second)
}
}(i)
}
wg.Wait()
}
// Stress the system if we know how much VRAM it has, and attempt to load more models than will fit
func TestMultiModelStress(t *testing.T) {
vram := os.Getenv("OLLAMA_MAX_VRAM")
if vram == "" {
t.Skip("OLLAMA_MAX_VRAM not specified, can't pick the right models for the stress test")
}
max, err := strconv.ParseUint(vram, 10, 64)
require.NoError(t, err)
const MB = uint64(1024 * 1024)
type model struct {
name string
size uint64 // Approximate amount of VRAM they typically use when fully loaded in VRAM
}
smallModels := []model{
{
name: "orca-mini",
size: 2992 * MB,
},
{
name: "phi",
size: 2616 * MB,
},
{
name: "gemma:2b",
size: 2364 * MB,
},
{
name: "stable-code:3b",
size: 2608 * MB,
},
{
name: "starcoder2:3b",
size: 2166 * MB,
},
}
mediumModels := []model{
{
name: "llama2",
size: 5118 * MB,
},
{
name: "mistral",
size: 4620 * MB,
},
{
name: "orca-mini:7b",
size: 5118 * MB,
},
{
name: "dolphin-mistral",
size: 4620 * MB,
},
{
name: "gemma:7b",
size: 5000 * MB,
},
// TODO - uncomment this once #3565 is merged and this is rebased on it
// {
// name: "codellama:7b",
// size: 5118 * MB,
// },
}
// These seem to be too slow to be useful...
// largeModels := []model{
// {
// name: "llama2:13b",
// size: 7400 * MB,
// },
// {
// name: "codellama:13b",
// size: 7400 * MB,
// },
// {
// name: "orca-mini:13b",
// size: 7400 * MB,
// },
// {
// name: "gemma:7b",
// size: 5000 * MB,
// },
// {
// name: "starcoder2:15b",
// size: 9100 * MB,
// },
// }
var chosenModels []model
switch {
case max < 10000*MB:
slog.Info("selecting small models")
chosenModels = smallModels
// case max < 30000*MB:
default:
slog.Info("selecting medium models")
chosenModels = mediumModels
// default:
// slog.Info("selecting large models")
// chosenModels = largModels
}
req, resp := GenerateRequests()
for i := range req {
if i > len(chosenModels) {
break
}
req[i].Model = chosenModels[i].name
}
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Minute) // TODO baseline -- 10m too short
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
// Make sure all the models are pulled before we get started
for _, r := range req {
require.NoError(t, PullIfMissing(ctx, client, r.Model))
}
var wg sync.WaitGroup
consumed := uint64(256 * MB) // Assume some baseline usage
for i := 0; i < len(req); i++ {
// Always get at least 2 models, but dont' overshoot VRAM too much or we'll take too long
if i > 1 && consumed > max {
slog.Info("achieved target vram exhaustion", "count", i, "vramMB", max/1024/1024, "modelsMB", consumed/1024/1024)
break
}
consumed += chosenModels[i].size
slog.Info("target vram", "count", i, "vramMB", max/1024/1024, "modelsMB", consumed/1024/1024)
wg.Add(1)
go func(i int) {
defer wg.Done()
for j := 0; j < 3; j++ {
slog.Info("Starting", "req", i, "iter", j, "model", req[i].Model)
DoGenerate(ctx, t, client, req[i], resp[i], 90*time.Second, 5*time.Second)
}
}(i)
}
wg.Wait()
}

View file

@ -4,7 +4,6 @@ package integration
import ( import (
"context" "context"
"net/http"
"testing" "testing"
"time" "time"
@ -25,5 +24,5 @@ func TestContextExhaustion(t *testing.T) {
"num_ctx": 128, "num_ctx": 128,
}, },
} }
GenerateTestHelper(ctx, t, &http.Client{}, req, []string{"once", "upon", "lived"}) GenerateTestHelper(ctx, t, req, []string{"once", "upon", "lived"})
} }

View file

@ -5,7 +5,6 @@ package integration
import ( import (
"context" "context"
"encoding/base64" "encoding/base64"
"net/http"
"testing" "testing"
"time" "time"
@ -29,10 +28,11 @@ func TestIntegrationMultimodal(t *testing.T) {
}, },
} }
resp := "the ollamas" // Note: sometimes it returns "the ollamas" sometimes "the ollams"
resp := "the ollam"
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute) ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
defer cancel() defer cancel()
GenerateTestHelper(ctx, t, &http.Client{}, req, []string{resp}) GenerateTestHelper(ctx, t, req, []string{resp})
} }
const imageEncoding = `iVBORw0KGgoAAAANSUhEUgAAANIAAAB4CAYAAACHHqzKAAAAAXNSR0IArs4c6QAAAIRlWElmTU0AKgAAAAgABQESAAMAAAABAAEAAAEaAAUAAAABAAAASgEb const imageEncoding = `iVBORw0KGgoAAAANSUhEUgAAANIAAAB4CAYAAACHHqzKAAAAAXNSR0IArs4c6QAAAIRlWElmTU0AKgAAAAgABQESAAMAAAABAAEAAAEaAAUAAAABAAAASgEb

View file

@ -4,8 +4,6 @@ package integration
import ( import (
"context" "context"
"net/http"
"sync"
"testing" "testing"
"time" "time"
@ -45,25 +43,5 @@ var (
func TestIntegrationSimpleOrcaMini(t *testing.T) { func TestIntegrationSimpleOrcaMini(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*120) ctx, cancel := context.WithTimeout(context.Background(), time.Second*120)
defer cancel() defer cancel()
GenerateTestHelper(ctx, t, &http.Client{}, req[0], resp[0]) GenerateTestHelper(ctx, t, req[0], resp[0])
} }
// TODO
// The server always loads a new runner and closes the old one, which forces serial execution
// At present this test case fails with concurrency problems. Eventually we should try to
// get true concurrency working with n_parallel support in the backend
func TestIntegrationConcurrentPredictOrcaMini(t *testing.T) {
var wg sync.WaitGroup
wg.Add(len(req))
ctx, cancel := context.WithTimeout(context.Background(), time.Second*120)
defer cancel()
for i := 0; i < len(req); i++ {
go func(i int) {
defer wg.Done()
GenerateTestHelper(ctx, t, &http.Client{}, req[i], resp[i])
}(i)
}
wg.Wait()
}
// TODO - create a parallel test with 2 different models once we support concurrency

View file

@ -5,13 +5,14 @@ package integration
import ( import (
"bytes" "bytes"
"context" "context"
"encoding/json" "errors"
"fmt" "fmt"
"io" "io"
"log/slog" "log/slog"
"math/rand" "math/rand"
"net" "net"
"net/http" "net/http"
"net/url"
"os" "os"
"path/filepath" "path/filepath"
"runtime" "runtime"
@ -23,9 +24,13 @@ import (
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/app/lifecycle" "github.com/ollama/ollama/app/lifecycle"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/require"
) )
func Init() {
lifecycle.InitLogging()
}
func FindPort() string { func FindPort() string {
port := 0 port := 0
if a, err := net.ResolveTCPAddr("tcp", "localhost:0"); err == nil { if a, err := net.ResolveTCPAddr("tcp", "localhost:0"); err == nil {
@ -41,7 +46,7 @@ func FindPort() string {
return strconv.Itoa(port) return strconv.Itoa(port)
} }
func GetTestEndpoint() (string, string) { func GetTestEndpoint() (*api.Client, string) {
defaultPort := "11434" defaultPort := "11434"
ollamaHost := os.Getenv("OLLAMA_HOST") ollamaHost := os.Getenv("OLLAMA_HOST")
@ -67,16 +72,20 @@ func GetTestEndpoint() (string, string) {
port = FindPort() port = FindPort()
} }
url := fmt.Sprintf("%s:%s", host, port) slog.Info("server connection", "host", host, "port", port)
slog.Info("server connection", "url", url)
return scheme, url return api.NewClient(
&url.URL{
Scheme: scheme,
Host: net.JoinHostPort(host, port),
},
http.DefaultClient), fmt.Sprintf("%s:%s", host, port)
} }
// TODO make fanicier, grab logs, etc.
var serverMutex sync.Mutex var serverMutex sync.Mutex
var serverReady bool var serverReady bool
func StartServer(ctx context.Context, ollamaHost string) error { func startServer(ctx context.Context, ollamaHost string) error {
// Make sure the server has been built // Make sure the server has been built
CLIName, err := filepath.Abs("../ollama") CLIName, err := filepath.Abs("../ollama")
if err != nil { if err != nil {
@ -125,67 +134,76 @@ func StartServer(ctx context.Context, ollamaHost string) error {
return nil return nil
} }
func PullIfMissing(ctx context.Context, client *http.Client, scheme, testEndpoint, modelName string) error { func PullIfMissing(ctx context.Context, client *api.Client, modelName string) error {
slog.Info("checking status of model", "model", modelName) slog.Info("checking status of model", "model", modelName)
showReq := &api.ShowRequest{Name: modelName} showReq := &api.ShowRequest{Name: modelName}
requestJSON, err := json.Marshal(showReq)
if err != nil {
return err
}
req, err := http.NewRequest("POST", scheme+"://"+testEndpoint+"/api/show", bytes.NewReader(requestJSON)) showCtx, cancel := context.WithDeadlineCause(
if err != nil { ctx,
time.Now().Add(5*time.Second),
fmt.Errorf("show for existing model %s took too long", modelName),
)
defer cancel()
_, err := client.Show(showCtx, showReq)
var statusError api.StatusError
switch {
case errors.As(err, &statusError) && statusError.StatusCode == http.StatusNotFound:
break
case err != nil:
return err return err
} default:
// Make the request with the HTTP client
response, err := client.Do(req.WithContext(ctx))
if err != nil {
return err
}
defer response.Body.Close()
if response.StatusCode == 200 {
slog.Info("model already present", "model", modelName) slog.Info("model already present", "model", modelName)
return nil return nil
} }
slog.Info("model missing", "status", response.StatusCode) slog.Info("model missing", "model", modelName)
stallDuration := 30 * time.Second // This includes checksum verification, which can take a while on larger models
stallTimer := time.NewTimer(stallDuration)
fn := func(resp api.ProgressResponse) error {
// fmt.Print(".")
if !stallTimer.Reset(stallDuration) {
return fmt.Errorf("stall was detected, aborting status reporting")
}
return nil
}
stream := true
pullReq := &api.PullRequest{Name: modelName, Stream: &stream} pullReq := &api.PullRequest{Name: modelName, Stream: &stream}
requestJSON, err = json.Marshal(pullReq)
if err != nil {
return err
}
req, err = http.NewRequest("POST", scheme+"://"+testEndpoint+"/api/pull", bytes.NewReader(requestJSON)) var pullError error
if err != nil {
return err
}
slog.Info("pulling", "model", modelName)
response, err = client.Do(req.WithContext(ctx)) done := make(chan int)
if err != nil { go func() {
return err pullError = client.Pull(ctx, pullReq, fn)
done <- 0
}()
select {
case <-stallTimer.C:
return fmt.Errorf("download stalled")
case <-done:
return pullError
} }
defer response.Body.Close()
if response.StatusCode != 200 {
return fmt.Errorf("failed to pull model") // TODO more details perhaps
}
slog.Info("model pulled", "model", modelName)
return nil
} }
var serverProcMutex sync.Mutex var serverProcMutex sync.Mutex
func GenerateTestHelper(ctx context.Context, t *testing.T, client *http.Client, genReq api.GenerateRequest, anyResp []string) { // Returns an Client, the testEndpoint, and a cleanup function, fails the test on errors
// Starts the server if needed
// TODO maybe stuff in an init routine? func InitServerConnection(ctx context.Context, t *testing.T) (*api.Client, string, func()) {
lifecycle.InitLogging() client, testEndpoint := GetTestEndpoint()
if os.Getenv("OLLAMA_TEST_EXISTING") == "" {
requestJSON, err := json.Marshal(genReq) serverProcMutex.Lock()
if err != nil { fp, err := os.CreateTemp("", "ollama-server-*.log")
t.Fatalf("Error serializing request: %v", err) if err != nil {
t.Fatalf("failed to generate log file: %s", err)
}
lifecycle.ServerLogFile = fp.Name()
fp.Close()
require.NoError(t, startServer(ctx, testEndpoint))
} }
defer func() {
return client, testEndpoint, func() {
if os.Getenv("OLLAMA_TEST_EXISTING") == "" { if os.Getenv("OLLAMA_TEST_EXISTING") == "" {
defer serverProcMutex.Unlock() defer serverProcMutex.Unlock()
if t.Failed() { if t.Failed() {
@ -203,63 +221,118 @@ func GenerateTestHelper(ctx context.Context, t *testing.T, client *http.Client,
os.Stderr.Write(data) os.Stderr.Write(data)
slog.Warn("END OF SERVER") slog.Warn("END OF SERVER")
} }
err = os.Remove(lifecycle.ServerLogFile) err := os.Remove(lifecycle.ServerLogFile)
if err != nil && !os.IsNotExist(err) { if err != nil && !os.IsNotExist(err) {
slog.Warn("failed to cleanup", "logfile", lifecycle.ServerLogFile, "error", err) slog.Warn("failed to cleanup", "logfile", lifecycle.ServerLogFile, "error", err)
} }
} }
}()
scheme, testEndpoint := GetTestEndpoint()
if os.Getenv("OLLAMA_TEST_EXISTING") == "" {
serverProcMutex.Lock()
fp, err := os.CreateTemp("", "ollama-server-*.log")
if err != nil {
t.Fatalf("failed to generate log file: %s", err)
}
lifecycle.ServerLogFile = fp.Name()
fp.Close()
assert.NoError(t, StartServer(ctx, testEndpoint))
} }
}
err = PullIfMissing(ctx, client, scheme, testEndpoint, genReq.Model)
if err != nil { func GenerateTestHelper(ctx context.Context, t *testing.T, genReq api.GenerateRequest, anyResp []string) {
t.Fatalf("Error pulling model: %v", err) client, _, cleanup := InitServerConnection(ctx, t)
} defer cleanup()
require.NoError(t, PullIfMissing(ctx, client, genReq.Model))
// Make the request and get the response DoGenerate(ctx, t, client, genReq, anyResp, 30*time.Second, 10*time.Second)
req, err := http.NewRequest("POST", scheme+"://"+testEndpoint+"/api/generate", bytes.NewReader(requestJSON)) }
if err != nil {
t.Fatalf("Error creating request: %v", err) func DoGenerate(ctx context.Context, t *testing.T, client *api.Client, genReq api.GenerateRequest, anyResp []string, initialTimeout, streamTimeout time.Duration) {
} stallTimer := time.NewTimer(initialTimeout)
var buf bytes.Buffer
// Set the content type for the request fn := func(response api.GenerateResponse) error {
req.Header.Set("Content-Type", "application/json") // fmt.Print(".")
buf.Write([]byte(response.Response))
// Make the request with the HTTP client if !stallTimer.Reset(streamTimeout) {
response, err := client.Do(req.WithContext(ctx)) return fmt.Errorf("stall was detected while streaming response, aborting")
if err != nil { }
t.Fatalf("Error making request: %v", err) return nil
} }
defer response.Body.Close()
body, err := io.ReadAll(response.Body) stream := true
assert.NoError(t, err) genReq.Stream = &stream
assert.Equal(t, response.StatusCode, 200, string(body)) done := make(chan int)
var genErr error
// Verify the response is valid JSON go func() {
var payload api.GenerateResponse genErr = client.Generate(ctx, &genReq, fn)
err = json.Unmarshal(body, &payload) done <- 0
if err != nil { }()
assert.NoError(t, err, body)
} select {
case <-stallTimer.C:
// Verify the response contains the expected data if buf.Len() == 0 {
atLeastOne := false t.Errorf("generate never started. Timed out after :%s", initialTimeout.String())
for _, resp := range anyResp { } else {
if strings.Contains(strings.ToLower(payload.Response), resp) { t.Errorf("generate stalled. Response so far:%s", buf.String())
atLeastOne = true }
break case <-done:
} require.NoError(t, genErr, "failed with %s request prompt %s ", genReq.Model, genReq.Prompt)
} // Verify the response contains the expected data
assert.True(t, atLeastOne, "none of %v found in %s", anyResp, payload.Response) response := buf.String()
atLeastOne := false
for _, resp := range anyResp {
if strings.Contains(strings.ToLower(response), resp) {
atLeastOne = true
break
}
}
require.True(t, atLeastOne, "none of %v found in %s", anyResp, response)
slog.Info("test pass", "model", genReq.Model, "prompt", genReq.Prompt, "contains", anyResp, "response", response)
case <-ctx.Done():
t.Error("outer test context done while waiting for generate")
}
}
// Generate a set of requests
// By default each request uses orca-mini as the model
func GenerateRequests() ([]api.GenerateRequest, [][]string) {
return []api.GenerateRequest{
{
Model: "orca-mini",
Prompt: "why is the ocean blue?",
Stream: &stream,
Options: map[string]interface{}{
"seed": 42,
"temperature": 0.0,
},
}, {
Model: "orca-mini",
Prompt: "why is the color of dirt brown?",
Stream: &stream,
Options: map[string]interface{}{
"seed": 42,
"temperature": 0.0,
},
}, {
Model: "orca-mini",
Prompt: "what is the origin of the us thanksgiving holiday?",
Stream: &stream,
Options: map[string]interface{}{
"seed": 42,
"temperature": 0.0,
},
}, {
Model: "orca-mini",
Prompt: "what is the origin of independence day?",
Stream: &stream,
Options: map[string]interface{}{
"seed": 42,
"temperature": 0.0,
},
}, {
Model: "orca-mini",
Prompt: "what is the composition of air?",
Stream: &stream,
Options: map[string]interface{}{
"seed": 42,
"temperature": 0.0,
},
},
},
[][]string{
[]string{"sunlight"},
[]string{"soil", "organic", "earth", "black", "tan"},
[]string{"england", "english", "massachusetts", "pilgrims"},
[]string{"fourth", "july", "declaration", "independence"},
[]string{"nitrogen", "oxygen", "carbon", "dioxide"},
}
} }

162
llm/memory.go Normal file
View file

@ -0,0 +1,162 @@
package llm
import (
"fmt"
"log/slog"
"strings"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/gpu"
)
// This algorithm looks for a complete fit to determine if we need to unload other models
func PredictServerFit(allGpus gpu.GpuInfoList, ggml *GGML, adapters, projectors []string, opts api.Options) (bool, uint64) {
var estimatedVRAM uint64
if opts.NumCtx > int(ggml.KV().ContextLength()) {
slog.Warn("requested context length is greater than model max context length", "requested", opts.NumCtx, "model", ggml.KV().ContextLength())
opts.NumCtx = int(ggml.KV().ContextLength())
}
if opts.NumCtx < 4 {
opts.NumCtx = 4
}
// Split up the GPUs by type and try them
for _, gpus := range allGpus.ByLibrary() {
var layerCount int
layerCount, estimatedVRAM = EstimateGPULayers(gpus, ggml, projectors, opts)
if opts.NumGPU < 0 {
if layerCount > 0 && layerCount >= int(ggml.KV().BlockCount()+1) {
return true, estimatedVRAM
}
} else {
if layerCount > 0 && layerCount >= opts.NumGPU {
return true, estimatedVRAM
}
}
}
return false, estimatedVRAM
}
// Given a model and one or more GPU targets, predict how many layers and bytes we can load
// The GPUs provided must all be the same Library
func EstimateGPULayers(gpus []gpu.GpuInfo, ggml *GGML, projectors []string, opts api.Options) (int, uint64) {
if gpus[0].Library == "cpu" {
return 0, 0
}
var memoryAvailable uint64
for _, info := range gpus {
memoryAvailable += info.FreeMemory
}
slog.Debug("evaluating", "library", gpus[0].Library, "gpu_count", len(gpus), "available", format.HumanBytes2(memoryAvailable))
// TODO - this is probably wrong, first GPU vs secondaries will have different overheads
memoryMinimum := gpus[0].MinimumMemory
for _, projector := range projectors {
memoryMinimum += projectorMemoryRequirements(projector)
// multimodal models require at least 2048 context
opts.NumCtx = max(opts.NumCtx, 2048)
}
// fp16 k,v = (1 (k) + 1 (v)) * sizeof(float16) * n_ctx * n_layer * n_embd / n_head * n_head_kv
var kv uint64 = 2 * 2 * uint64(opts.NumCtx) * ggml.KV().BlockCount() * ggml.KV().EmbeddingLength() / ggml.KV().HeadCount() * ggml.KV().HeadCountKV()
graphPartialOffload, graphFullOffload := ggml.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch)))
if graphPartialOffload == 0 {
graphPartialOffload = ggml.KV().GQA() * kv / 6
}
if graphFullOffload == 0 {
graphFullOffload = graphPartialOffload
}
graphFullOffload *= uint64(len(gpus))
graphPartialOffload *= uint64(len(gpus))
// memoryRequiredTotal represents the memory required for full GPU offloading (all layers)
memoryRequiredTotal := memoryMinimum + graphFullOffload
// memoryRequiredPartial represents the memory required for partial GPU offloading (n > 0, n < layers)
memoryRequiredPartial := memoryMinimum + graphPartialOffload
if memoryRequiredPartial > memoryAvailable {
slog.Debug("insufficient VRAM to load any model layers")
return 0, 0
}
var layerCount int
layers := ggml.Tensors().Layers()
for i := 0; i < int(ggml.KV().BlockCount()); i++ {
memoryLayer := layers[fmt.Sprintf("blk.%d", i)].size()
// KV is proportional to the number of layers
memoryLayer += kv / ggml.KV().BlockCount()
memoryRequiredTotal += memoryLayer
if memoryAvailable > memoryRequiredPartial+memoryLayer {
memoryRequiredPartial += memoryLayer
layerCount++
}
}
var memoryLayerOutput uint64
for k, v := range layers {
if !strings.HasPrefix(k, "blk.") {
memoryLayerOutput += v.size()
}
}
memoryRequiredTotal += memoryLayerOutput
if memoryAvailable > memoryRequiredTotal {
layerCount = int(ggml.KV().BlockCount()) + 1
memoryRequiredPartial = memoryRequiredTotal
}
memoryWeights := memoryRequiredTotal - memoryMinimum - graphFullOffload - kv
slog.Info(
"offload to gpu",
slog.Group(
"layers",
// actual number of layers offloaded
"real", opts.NumGPU,
// estimated number of layers that can be offloaded
"estimate", layerCount,
),
slog.Group(
"memory",
// memory available for offloading
"available", format.HumanBytes2(memoryAvailable),
slog.Group(
"required",
// memory required for full offloading
"full", format.HumanBytes2(memoryRequiredTotal),
// memory required to offload layers.estimate layers
"partial", format.HumanBytes2(memoryRequiredPartial),
// memory of KV cache
"kv", format.HumanBytes2(kv),
),
slog.Group(
"weights",
// memory of the weights
"total", format.HumanBytes2(memoryWeights),
// memory of repeating layers
"repeating", format.HumanBytes2(memoryWeights-memoryLayerOutput),
// memory of non-repeating layers
"nonrepeating", format.HumanBytes2(memoryLayerOutput),
),
slog.Group(
"graph",
// memory of graph when fully offloaded
"full", format.HumanBytes2(graphFullOffload),
// memory of graph when not fully offloaded
"partial", format.HumanBytes2(graphPartialOffload),
),
),
)
return layerCount, uint64(memoryRequiredPartial)
}

View file

@ -9,6 +9,7 @@ import (
"log/slog" "log/slog"
"os" "os"
"path/filepath" "path/filepath"
"runtime"
"strings" "strings"
"golang.org/x/exp/slices" "golang.org/x/exp/slices"
@ -138,6 +139,23 @@ func serversForGpu(info gpu.GpuInfo) []string {
return servers return servers
} }
// Return the optimal server for this CPU architecture
func serverForCpu() string {
if runtime.GOOS == "darwin" && runtime.GOARCH == "arm64" {
return "metal"
}
variant := gpu.GetCPUVariant()
availableServers := availableServers()
if variant != "" {
for cmp := range availableServers {
if cmp == "cpu_"+variant {
return cmp
}
}
}
return "cpu"
}
// extract extracts the embedded files to the target directory // extract extracts the embedded files to the target directory
func extractFiles(targetDir string, glob string) error { func extractFiles(targetDir string, glob string) error {
files, err := fs.Glob(libEmbed, glob) files, err := fs.Glob(libEmbed, glob)

View file

@ -21,21 +21,43 @@ import (
"strings" "strings"
"time" "time"
"golang.org/x/sync/semaphore"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/format" "github.com/ollama/ollama/format"
"github.com/ollama/ollama/gpu" "github.com/ollama/ollama/gpu"
) )
// LlamaServer is an instance of the llama.cpp server type LlamaServer interface {
type LlamaServer struct { Ping(ctx context.Context) error
WaitUntilRunning(ctx context.Context) error
Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error
Embedding(ctx context.Context, prompt string) ([]float64, error)
Tokenize(ctx context.Context, content string) ([]int, error)
Detokenize(ctx context.Context, tokens []int) (string, error)
Close() error
EstimatedVRAM() uint64
}
// llmServer is an instance of the llama.cpp server
type llmServer struct {
port int port int
cmd *exec.Cmd cmd *exec.Cmd
done chan error // Channel to signal when the process exits done chan error // Channel to signal when the process exits
status *StatusWriter status *StatusWriter
options api.Options options api.Options
// TODO - this should be broken down by GPU
estimatedVRAM uint64 // Estimated usage of VRAM by the loaded model
sem *semaphore.Weighted
} }
func NewLlamaServer(model string, adapters, projectors []string, opts api.Options) (*LlamaServer, error) { func LoadModel(model string) (*GGML, error) {
if _, err := os.Stat(model); err != nil {
return nil, err
}
f, err := os.Open(model) f, err := os.Open(model)
if err != nil { if err != nil {
return nil, err return nil, err
@ -43,10 +65,13 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
defer f.Close() defer f.Close()
ggml, _, err := DecodeGGML(f) ggml, _, err := DecodeGGML(f)
if err != nil { return ggml, err
return nil, err }
}
// NewLlamaServer will run a server for the given GPUs
// The gpu list must be a single family.
func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, projectors []string, opts api.Options) (LlamaServer, error) {
var err error
if opts.NumCtx > int(ggml.KV().ContextLength()) { if opts.NumCtx > int(ggml.KV().ContextLength()) {
slog.Warn("requested context length is greater than model max context length", "requested", opts.NumCtx, "model", ggml.KV().ContextLength()) slog.Warn("requested context length is greater than model max context length", "requested", opts.NumCtx, "model", ggml.KV().ContextLength())
opts.NumCtx = int(ggml.KV().ContextLength()) opts.NumCtx = int(ggml.KV().ContextLength())
@ -56,130 +81,50 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
opts.NumCtx = 4 opts.NumCtx = 4
} }
memoryAvailable, _ := gpu.CheckVRAM() cpuRunner := ""
info := gpu.GetGPUInfo() var estimatedVRAM uint64
var systemMemory uint64
if (len(gpus) == 1 && gpus[0].Library == "cpu") || opts.NumGPU == 0 {
memoryMinimum := info.MinimumMemory // TODO evaluate system memory to see if we should block the load, or force an unload of another CPU runner
for _, projector := range projectors {
memoryMinimum += projectorMemoryRequirements(projector)
// multimodal models require at least 2048 context cpuRunner = serverForCpu()
opts.NumCtx = max(opts.NumCtx, 2048) } else {
} if gpus[0].Library == "metal" {
memInfo, err := gpu.GetCPUMem()
if err != nil {
slog.Error("failed to lookup system memory", "error", err)
} else {
systemMemory = memInfo.TotalMemory
slog.Debug("system memory", "total", format.HumanBytes2(systemMemory))
}
}
var layers int
layers, estimatedVRAM = EstimateGPULayers(gpus, ggml, projectors, opts)
// fp16 k,v = (1 (k) + 1 (v)) * sizeof(float16) * n_ctx * n_layer * n_embd / n_head * n_head_kv if gpus[0].Library == "metal" && estimatedVRAM > systemMemory {
var kv uint64 = 2 * 2 * uint64(opts.NumCtx) * ggml.KV().BlockCount() * ggml.KV().EmbeddingLength() / ggml.KV().HeadCount() * ggml.KV().HeadCountKV() // disable partial offloading when model is greater than total system memory as this
// can lead to locking up the system
graphPartialOffload, graphFullOffload := ggml.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch))) opts.NumGPU = 0
if graphPartialOffload == 0 { } else if opts.NumGPU < 0 && layers > 0 && gpus[0].Library != "cpu" {
graphPartialOffload = ggml.KV().GQA() * kv / 6 opts.NumGPU = layers
}
if graphFullOffload == 0 {
graphFullOffload = graphPartialOffload
}
graphFullOffload *= uint64(info.DeviceCount)
graphPartialOffload *= uint64(info.DeviceCount)
// memoryRequiredTotal represents the memory required for full GPU offloading (all layers)
memoryRequiredTotal := memoryMinimum + graphFullOffload
// memoryRequiredPartial represents the memory required for partial GPU offloading (n > 0, n < layers)
memoryRequiredPartial := memoryMinimum + graphPartialOffload
if info.Library != "metal" {
if memoryRequiredPartial > memoryAvailable {
info.Library = "cpu"
} }
} }
var layerCount int // Loop through potential servers
layers := ggml.Tensors().Layers() finalErr := fmt.Errorf("no suitable llama servers found")
for i := 0; i < int(ggml.KV().BlockCount()); i++ {
memoryLayer := layers[fmt.Sprintf("blk.%d", i)].size()
// KV is proportional to the number of layers
memoryLayer += kv / ggml.KV().BlockCount()
memoryRequiredTotal += memoryLayer
if memoryAvailable > memoryRequiredPartial+memoryLayer {
memoryRequiredPartial += memoryLayer
layerCount++
}
}
var memoryLayerOutput uint64
for k, v := range layers {
if !strings.HasPrefix(k, "blk.") {
memoryLayerOutput += v.size()
}
}
memoryRequiredTotal += memoryLayerOutput
if info.Library == "metal" && memoryRequiredTotal > info.TotalMemory {
// disable partial offloading when model is greater than total system memory
opts.NumGPU = 0
} else if memoryAvailable > memoryRequiredTotal {
layerCount = int(ggml.KV().BlockCount()) + 1
memoryRequiredPartial = memoryRequiredTotal
}
if opts.NumGPU < 0 {
opts.NumGPU = layerCount
}
memoryWeights := memoryRequiredTotal - memoryMinimum - graphFullOffload - kv
slog.Info(
"offload to gpu",
slog.Group(
"layers",
// actual number of layers offloaded
"real", opts.NumGPU,
// estimated number of layers that can be offloaded
"estimate", layerCount,
),
slog.Group(
"memory",
// memory available for offloading
"available", format.HumanBytes2(memoryAvailable),
slog.Group(
"required",
// memory required for full offloading
"full", format.HumanBytes2(memoryRequiredTotal),
// memory required to offload layers.estimate layers
"partial", format.HumanBytes2(memoryRequiredPartial),
// memory of KV cache
"kv", format.HumanBytes2(kv),
),
slog.Group(
"weights",
// memory of the weights
"total", format.HumanBytes2(memoryWeights),
// memory of repeating layers
"repeating", format.HumanBytes2(memoryWeights-memoryLayerOutput),
// memory of non-repeating layers
"nonrepeating", format.HumanBytes2(memoryLayerOutput),
),
slog.Group(
"graph",
// memory of graph when fully offloaded
"full", format.HumanBytes2(graphFullOffload),
// memory of graph when not fully offloaded
"partial", format.HumanBytes2(graphPartialOffload),
),
),
)
if len(adapters) > 1 { if len(adapters) > 1 {
return nil, errors.New("ollama supports only one lora adapter, but multiple were provided") return nil, errors.New("ollama supports only one lora adapter, but multiple were provided")
} }
availableServers := availableServers() availableServers := availableServers()
servers := serversForGpu(info) var servers []string
if cpuRunner != "" {
servers = []string{cpuRunner}
} else {
servers = serversForGpu(gpus[0]) // All GPUs in the list are matching Library and Variant
}
demandLib := strings.Trim(os.Getenv("OLLAMA_LLM_LIBRARY"), "\"' ") demandLib := strings.Trim(os.Getenv("OLLAMA_LLM_LIBRARY"), "\"' ")
if demandLib != "" { if demandLib != "" {
serverPath := availableServers[demandLib] serverPath := availableServers[demandLib]
@ -192,7 +137,7 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
} }
if len(servers) == 0 { if len(servers) == 0 {
return nil, fmt.Errorf("no servers found for %v", info) return nil, fmt.Errorf("no servers found for %v", gpus)
} }
params := []string{ params := []string{
@ -249,8 +194,18 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
params = append(params, "--numa") params = append(params, "--numa")
} }
// Loop through potential servers // "--cont-batching", // TODO - doesn't seem to have any noticeable perf change for multiple requests
var finalErr error numParallel := 1
if onp := os.Getenv("OLLAMA_NUM_PARALLEL"); onp != "" {
numParallel, err = strconv.Atoi(onp)
if err != nil || numParallel <= 0 {
err = fmt.Errorf("invalid OLLAMA_NUM_PARALLEL=%s must be greater than zero - %w", onp, err)
slog.Error("misconfiguration", "error", err)
return nil, err
}
}
params = append(params, "--parallel", fmt.Sprintf("%d", numParallel))
for i := 0; i < len(servers); i++ { for i := 0; i < len(servers); i++ {
dir := availableServers[servers[i]] dir := availableServers[servers[i]]
@ -275,30 +230,49 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
} }
// append the server directory to LD_LIBRARY_PATH/PATH // append the server directory to LD_LIBRARY_PATH/PATH
libraryPaths := []string{dir} libraryPaths := []string{dir}
if libraryPath, ok := os.LookupEnv(pathEnv); ok { if libraryPath, ok := os.LookupEnv(pathEnv); ok {
// Append our runner directory to the path // Append our runner directory to the path
// This will favor system libraries over our bundled library dependencies // This will favor system libraries over our bundled library dependencies
libraryPaths = append(filepath.SplitList(libraryPath), libraryPaths...) libraryPaths = append(filepath.SplitList(libraryPath), libraryPaths...)
} }
// Note: we always put the dependency path first
// since this was the exact version we verified for AMD GPUs
// and we favor what the user had in their path
if gpus[0].DependencyPath != "" {
// TODO refine for multi-gpu support
libraryPaths = append([]string{gpus[0].DependencyPath}, libraryPaths...)
}
server := filepath.Join(dir, "ollama_llama_server") server := filepath.Join(dir, "ollama_llama_server")
if runtime.GOOS == "windows" { if runtime.GOOS == "windows" {
server = server + ".exe" server = server + ".exe"
} }
s := &LlamaServer{ s := &llmServer{
port: port, port: port,
cmd: exec.Command(server, finalParams...), cmd: exec.Command(server, finalParams...),
status: NewStatusWriter(os.Stderr), status: NewStatusWriter(os.Stderr),
options: opts, options: opts,
estimatedVRAM: estimatedVRAM,
sem: semaphore.NewWeighted(int64(numParallel)),
} }
libEnv := fmt.Sprintf("%s=%s", pathEnv, strings.Join(libraryPaths, string(filepath.ListSeparator))) libEnv := fmt.Sprintf("%s=%s", pathEnv, strings.Join(libraryPaths, string(filepath.ListSeparator)))
slog.Debug(libEnv)
s.cmd.Env = append(os.Environ(), libEnv) s.cmd.Env = append(os.Environ(), libEnv)
s.cmd.Stdout = os.Stdout s.cmd.Stdout = os.Stdout
s.cmd.Stderr = s.status s.cmd.Stderr = s.status
// TODO - multiple GPU selection logic...
key, val := gpu.GpuInfoList(gpus).GetVisibleDevicesEnv()
if key != "" {
s.cmd.Env = append(s.cmd.Env, key+"="+val)
}
slog.Info("starting llama server", "cmd", s.cmd.String()) slog.Info("starting llama server", "cmd", s.cmd.String())
// Log at debug as the environment is inherited and might contain sensitive information
slog.Debug("subprocess", "environment", s.cmd.Env)
if err = s.cmd.Start(); err != nil { if err = s.cmd.Start(); err != nil {
msg := "" msg := ""
@ -316,6 +290,13 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
_ = s.cmd.Wait() _ = s.cmd.Wait()
}() }()
// TODO - make sure this is all wired up correctly
// if err = s.WaitUntilRunning(); err != nil {
// slog.Error("error starting llama server", "server", servers[i], "error", err)
// s.Close()
// finalErr = err
// continue
// }
return s, nil return s, nil
} }
@ -353,6 +334,21 @@ const ( // iota is reset to 0
ServerStatusError ServerStatusError
) )
func (s ServerStatus) ToString() string {
switch s {
case ServerStatusReady:
return "llm server ready"
case ServerStatusNoSlotsAvaialble:
return "llm busy - no slots available"
case ServerStatusLoadingModel:
return "llm server loading model"
case ServerStatusNotResponding:
return "llm server not responding"
default:
return "llm server error"
}
}
type ServerStatusResp struct { type ServerStatusResp struct {
Status string `json:"status"` Status string `json:"status"`
SlotsIdle int `json:"slots_idle"` SlotsIdle int `json:"slots_idle"`
@ -360,7 +356,7 @@ type ServerStatusResp struct {
Error string `json:"error"` Error string `json:"error"`
} }
func (s *LlamaServer) getServerStatus(ctx context.Context) (ServerStatus, error) { func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) {
// Fail fast if its exited // Fail fast if its exited
if s.cmd.ProcessState != nil { if s.cmd.ProcessState != nil {
msg := "" msg := ""
@ -407,7 +403,7 @@ func (s *LlamaServer) getServerStatus(ctx context.Context) (ServerStatus, error)
} }
} }
func (s *LlamaServer) Ping(ctx context.Context) error { func (s *llmServer) Ping(ctx context.Context) error {
_, err := s.getServerStatus(ctx) _, err := s.getServerStatus(ctx)
if err != nil { if err != nil {
slog.Debug("server unhealthy", "error", err) slog.Debug("server unhealthy", "error", err)
@ -416,7 +412,7 @@ func (s *LlamaServer) Ping(ctx context.Context) error {
return nil return nil
} }
func (s *LlamaServer) WaitUntilRunning() error { func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
start := time.Now() start := time.Now()
// TODO we need to wire up a better way to detect hangs during model load and startup of the server // TODO we need to wire up a better way to detect hangs during model load and startup of the server
expiresAt := time.Now().Add(10 * time.Minute) // be generous with timeout, large models can take a while to load expiresAt := time.Now().Add(10 * time.Minute) // be generous with timeout, large models can take a while to load
@ -427,6 +423,9 @@ func (s *LlamaServer) WaitUntilRunning() error {
var lastStatus ServerStatus = -1 var lastStatus ServerStatus = -1
for { for {
select { select {
case <-ctx.Done():
slog.Info("context expired before server started")
return fmt.Errorf("timed out waiting for llama runner to start")
case err := <-s.done: case err := <-s.done:
msg := "" msg := ""
if s.status != nil && s.status.LastErrMsg != "" { if s.status != nil && s.status.LastErrMsg != "" {
@ -450,9 +449,9 @@ func (s *LlamaServer) WaitUntilRunning() error {
return fmt.Errorf("llama runner process no longer running: %d %s", s.cmd.ProcessState.ExitCode(), msg) return fmt.Errorf("llama runner process no longer running: %d %s", s.cmd.ProcessState.ExitCode(), msg)
} }
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) c, cancel := context.WithTimeout(ctx, 200*time.Millisecond)
defer cancel() defer cancel()
status, err := s.getServerStatus(ctx) status, err := s.getServerStatus(c)
if err != nil && lastStatus != status { if err != nil && lastStatus != status {
slog.Debug("server not yet available", "error", err) slog.Debug("server not yet available", "error", err)
lastStatus = status lastStatus = status
@ -538,7 +537,12 @@ type CompletionResponse struct {
EvalDuration time.Duration EvalDuration time.Duration
} }
func (s *LlamaServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error { func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
if err := s.sem.Acquire(ctx, 1); err != nil {
slog.Error("Failed to acquire semaphore", "error", err)
return err
}
defer s.sem.Release(1)
request := map[string]any{ request := map[string]any{
"prompt": req.Prompt, "prompt": req.Prompt,
"stream": true, "stream": true,
@ -569,7 +573,7 @@ func (s *LlamaServer) Completion(ctx context.Context, req CompletionRequest, fn
if err != nil { if err != nil {
return err return err
} else if status != ServerStatusReady { } else if status != ServerStatusReady {
return fmt.Errorf("unexpected server status: %d", status) return fmt.Errorf("unexpected server status: %s", status.ToString())
} }
if req.Format == "json" { if req.Format == "json" {
@ -716,13 +720,18 @@ type EmbeddingResponse struct {
Embedding []float64 `json:"embedding"` Embedding []float64 `json:"embedding"`
} }
func (s *LlamaServer) Embedding(ctx context.Context, prompt string) ([]float64, error) { func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, error) {
if err := s.sem.Acquire(ctx, 1); err != nil {
slog.Error("Failed to acquire semaphore", "error", err)
return nil, err
}
defer s.sem.Release(1)
// Make sure the server is ready // Make sure the server is ready
status, err := s.getServerStatus(ctx) status, err := s.getServerStatus(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} else if status != ServerStatusReady { } else if status != ServerStatusReady {
return nil, fmt.Errorf("unexpected server status: %d", status) return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
} }
data, err := json.Marshal(TokenizeRequest{Content: prompt}) data, err := json.Marshal(TokenizeRequest{Content: prompt})
@ -768,13 +777,13 @@ type TokenizeResponse struct {
Tokens []int `json:"tokens"` Tokens []int `json:"tokens"`
} }
func (s *LlamaServer) Tokenize(ctx context.Context, content string) ([]int, error) { func (s *llmServer) Tokenize(ctx context.Context, content string) ([]int, error) {
// Make sure the server is ready // Make sure the server is ready
status, err := s.getServerStatus(ctx) status, err := s.getServerStatus(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} else if status != ServerStatusReady { } else if status != ServerStatusReady && status != ServerStatusNoSlotsAvaialble {
return nil, fmt.Errorf("unexpected server status: %d", status) return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
} }
data, err := json.Marshal(TokenizeRequest{Content: content}) data, err := json.Marshal(TokenizeRequest{Content: content})
@ -820,13 +829,13 @@ type DetokenizeResponse struct {
Content string `json:"content"` Content string `json:"content"`
} }
func (s *LlamaServer) Detokenize(ctx context.Context, tokens []int) (string, error) { func (s *llmServer) Detokenize(ctx context.Context, tokens []int) (string, error) {
// Make sure the server is ready // Make sure the server is ready
status, err := s.getServerStatus(ctx) status, err := s.getServerStatus(ctx)
if err != nil { if err != nil {
return "", err return "", err
} else if status != ServerStatusReady { } else if status != ServerStatusReady && status != ServerStatusNoSlotsAvaialble {
return "", fmt.Errorf("unexpected server status: %d", status) return "", fmt.Errorf("unexpected server status: %s", status.ToString())
} }
data, err := json.Marshal(DetokenizeRequest{Tokens: tokens}) data, err := json.Marshal(DetokenizeRequest{Tokens: tokens})
@ -864,7 +873,7 @@ func (s *LlamaServer) Detokenize(ctx context.Context, tokens []int) (string, err
return decoded.Content, nil return decoded.Content, nil
} }
func (s *LlamaServer) Close() error { func (s *llmServer) Close() error {
if s.cmd != nil { if s.cmd != nil {
slog.Debug("stopping llama server") slog.Debug("stopping llama server")
return s.cmd.Process.Kill() return s.cmd.Process.Kill()
@ -873,6 +882,10 @@ func (s *LlamaServer) Close() error {
return nil return nil
} }
func (s *llmServer) EstimatedVRAM() uint64 {
return s.estimatedVRAM
}
func parseDurationMs(ms float64) time.Duration { func parseDurationMs(ms float64) time.Duration {
dur, err := time.ParseDuration(fmt.Sprintf("%fms", ms)) dur, err := time.ParseDuration(fmt.Sprintf("%fms", ms))
if err != nil { if err != nil {

View file

@ -15,11 +15,8 @@ import (
"os" "os"
"os/signal" "os/signal"
"path/filepath" "path/filepath"
"reflect"
"runtime"
"strconv" "strconv"
"strings" "strings"
"sync"
"syscall" "syscall"
"time" "time"
@ -38,7 +35,8 @@ import (
var mode string = gin.DebugMode var mode string = gin.DebugMode
type Server struct { type Server struct {
addr net.Addr addr net.Addr
sched *Scheduler
} }
func init() { func init() {
@ -53,88 +51,8 @@ func init() {
gin.SetMode(mode) gin.SetMode(mode)
} }
var loaded struct {
mu sync.Mutex
llama *llm.LlamaServer
expireTimer *time.Timer
model string
adapters []string
projectors []string
*api.Options
}
var defaultSessionDuration = 5 * time.Minute var defaultSessionDuration = 5 * time.Minute
func unload() {
if loaded.llama != nil {
loaded.llama.Close()
}
loaded.llama = nil
loaded.model = ""
loaded.adapters = nil
loaded.projectors = nil
loaded.Options = nil
}
// load a model into memory if it is not already loaded, it is up to the caller to lock loaded.mu before calling this function
func load(c *gin.Context, model *Model, opts api.Options, sessionDuration time.Duration) error {
ctx, cancel := context.WithTimeout(c, 10*time.Second)
defer cancel()
needLoad := loaded.llama == nil || // is there a model loaded?
loaded.model != model.ModelPath || // has the base model changed?
!reflect.DeepEqual(loaded.adapters, model.AdapterPaths) || // have the adapters changed?
!reflect.DeepEqual(loaded.projectors, model.ProjectorPaths) || // have the adapters changed?
!reflect.DeepEqual(loaded.Options.Runner, opts.Runner) || // have the runner options changed?
loaded.llama.Ping(ctx) != nil
if needLoad {
if loaded.llama != nil {
slog.Info("changing loaded model")
unload()
}
llama, err := llm.NewLlamaServer(model.ModelPath, model.AdapterPaths, model.ProjectorPaths, opts)
if err != nil {
// some older models are not compatible with newer versions of llama.cpp
// show a generalized compatibility error until there is a better way to
// check for model compatibility
if errors.Is(llm.ErrUnsupportedFormat, err) || strings.Contains(err.Error(), "failed to load model") {
err = fmt.Errorf("%v: this model may be incompatible with your version of Ollama. If you previously pulled this model, try updating it by running `ollama pull %s`", err, model.ShortName)
}
return err
}
loaded.model = model.ModelPath
loaded.adapters = model.AdapterPaths
loaded.projectors = model.ProjectorPaths
loaded.llama = llama
loaded.Options = &opts
if err = llama.WaitUntilRunning(); err != nil {
slog.Error("error loading llama server", "error", err)
unload()
return err
}
}
if loaded.expireTimer == nil {
loaded.expireTimer = time.AfterFunc(sessionDuration, func() {
loaded.mu.Lock()
defer loaded.mu.Unlock()
unload()
})
}
loaded.expireTimer.Reset(sessionDuration)
return nil
}
func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) { func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) {
opts := api.DefaultOptions() opts := api.DefaultOptions()
if err := opts.FromMap(model.Options); err != nil { if err := opts.FromMap(model.Options); err != nil {
@ -154,9 +72,7 @@ func isSupportedImageType(image []byte) bool {
return slices.Contains(allowedTypes, contentType) return slices.Contains(allowedTypes, contentType)
} }
func GenerateHandler(c *gin.Context) { func (s *Server) GenerateHandler(c *gin.Context) {
loaded.mu.Lock()
defer loaded.mu.Unlock()
checkpointStart := time.Now() checkpointStart := time.Now()
var req api.GenerateRequest var req api.GenerateRequest
@ -224,7 +140,11 @@ func GenerateHandler(c *gin.Context) {
sessionDuration = req.KeepAlive.Duration sessionDuration = req.KeepAlive.Duration
} }
if err := load(c, model, opts, sessionDuration); err != nil { rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration)
var runner *runnerRef
select {
case runner = <-rCh:
case err = <-eCh:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
} }
@ -275,7 +195,7 @@ func GenerateHandler(c *gin.Context) {
sb.Reset() sb.Reset()
if req.Context != nil { if req.Context != nil {
prev, err := loaded.llama.Detokenize(c.Request.Context(), req.Context) prev, err := runner.llama.Detokenize(c.Request.Context(), req.Context)
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
@ -297,9 +217,6 @@ func GenerateHandler(c *gin.Context) {
defer close(ch) defer close(ch)
fn := func(r llm.CompletionResponse) { fn := func(r llm.CompletionResponse) {
// Update model expiration
loaded.expireTimer.Reset(sessionDuration)
// Build up the full response // Build up the full response
if _, err := generated.WriteString(r.Content); err != nil { if _, err := generated.WriteString(r.Content); err != nil {
ch <- gin.H{"error": err.Error()} ch <- gin.H{"error": err.Error()}
@ -331,7 +248,7 @@ func GenerateHandler(c *gin.Context) {
} }
// TODO (jmorganca): encode() should not strip special tokens // TODO (jmorganca): encode() should not strip special tokens
tokens, err := loaded.llama.Tokenize(c.Request.Context(), p) tokens, err := runner.llama.Tokenize(c.Request.Context(), p)
if err != nil { if err != nil {
ch <- gin.H{"error": err.Error()} ch <- gin.H{"error": err.Error()}
return return
@ -359,7 +276,7 @@ func GenerateHandler(c *gin.Context) {
Images: images, Images: images,
Options: opts, Options: opts,
} }
if err := loaded.llama.Completion(c.Request.Context(), req, fn); err != nil { if err := runner.llama.Completion(c.Request.Context(), req, fn); err != nil {
ch <- gin.H{"error": err.Error()} ch <- gin.H{"error": err.Error()}
} }
}() }()
@ -421,10 +338,7 @@ func getDefaultSessionDuration() time.Duration {
return defaultSessionDuration return defaultSessionDuration
} }
func EmbeddingsHandler(c *gin.Context) { func (s *Server) EmbeddingsHandler(c *gin.Context) {
loaded.mu.Lock()
defer loaded.mu.Unlock()
var req api.EmbeddingRequest var req api.EmbeddingRequest
err := c.ShouldBindJSON(&req) err := c.ShouldBindJSON(&req)
switch { switch {
@ -469,7 +383,11 @@ func EmbeddingsHandler(c *gin.Context) {
sessionDuration = req.KeepAlive.Duration sessionDuration = req.KeepAlive.Duration
} }
if err := load(c, model, opts, sessionDuration); err != nil { rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration)
var runner *runnerRef
select {
case runner = <-rCh:
case err = <-eCh:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
} }
@ -480,7 +398,7 @@ func EmbeddingsHandler(c *gin.Context) {
return return
} }
embedding, err := loaded.llama.Embedding(c.Request.Context(), req.Prompt) embedding, err := runner.llama.Embedding(c.Request.Context(), req.Prompt)
if err != nil { if err != nil {
slog.Info(fmt.Sprintf("embedding generation failed: %v", err)) slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
@ -493,7 +411,7 @@ func EmbeddingsHandler(c *gin.Context) {
c.JSON(http.StatusOK, resp) c.JSON(http.StatusOK, resp)
} }
func PullModelHandler(c *gin.Context) { func (s *Server) PullModelHandler(c *gin.Context) {
var req api.PullRequest var req api.PullRequest
err := c.ShouldBindJSON(&req) err := c.ShouldBindJSON(&req)
switch { switch {
@ -542,7 +460,7 @@ func PullModelHandler(c *gin.Context) {
streamResponse(c, ch) streamResponse(c, ch)
} }
func PushModelHandler(c *gin.Context) { func (s *Server) PushModelHandler(c *gin.Context) {
var req api.PushRequest var req api.PushRequest
err := c.ShouldBindJSON(&req) err := c.ShouldBindJSON(&req)
switch { switch {
@ -591,7 +509,7 @@ func PushModelHandler(c *gin.Context) {
streamResponse(c, ch) streamResponse(c, ch)
} }
func CreateModelHandler(c *gin.Context) { func (s *Server) CreateModelHandler(c *gin.Context) {
var req api.CreateRequest var req api.CreateRequest
err := c.ShouldBindJSON(&req) err := c.ShouldBindJSON(&req)
switch { switch {
@ -664,7 +582,7 @@ func CreateModelHandler(c *gin.Context) {
streamResponse(c, ch) streamResponse(c, ch)
} }
func DeleteModelHandler(c *gin.Context) { func (s *Server) DeleteModelHandler(c *gin.Context) {
var req api.DeleteRequest var req api.DeleteRequest
err := c.ShouldBindJSON(&req) err := c.ShouldBindJSON(&req)
switch { switch {
@ -709,7 +627,7 @@ func DeleteModelHandler(c *gin.Context) {
c.JSON(http.StatusOK, nil) c.JSON(http.StatusOK, nil)
} }
func ShowModelHandler(c *gin.Context) { func (s *Server) ShowModelHandler(c *gin.Context) {
var req api.ShowRequest var req api.ShowRequest
err := c.ShouldBindJSON(&req) err := c.ShouldBindJSON(&req)
switch { switch {
@ -809,7 +727,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
return resp, nil return resp, nil
} }
func ListModelsHandler(c *gin.Context) { func (s *Server) ListModelsHandler(c *gin.Context) {
models := make([]api.ModelResponse, 0) models := make([]api.ModelResponse, 0)
manifestsPath, err := GetManifestPath() manifestsPath, err := GetManifestPath()
if err != nil { if err != nil {
@ -869,7 +787,7 @@ func ListModelsHandler(c *gin.Context) {
c.JSON(http.StatusOK, api.ListResponse{Models: models}) c.JSON(http.StatusOK, api.ListResponse{Models: models})
} }
func CopyModelHandler(c *gin.Context) { func (s *Server) CopyModelHandler(c *gin.Context) {
var req api.CopyRequest var req api.CopyRequest
err := c.ShouldBindJSON(&req) err := c.ShouldBindJSON(&req)
switch { switch {
@ -901,7 +819,7 @@ func CopyModelHandler(c *gin.Context) {
} }
} }
func HeadBlobHandler(c *gin.Context) { func (s *Server) HeadBlobHandler(c *gin.Context) {
path, err := GetBlobsPath(c.Param("digest")) path, err := GetBlobsPath(c.Param("digest"))
if err != nil { if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
@ -916,7 +834,7 @@ func HeadBlobHandler(c *gin.Context) {
c.Status(http.StatusOK) c.Status(http.StatusOK)
} }
func CreateBlobHandler(c *gin.Context) { func (s *Server) CreateBlobHandler(c *gin.Context) {
path, err := GetBlobsPath(c.Param("digest")) path, err := GetBlobsPath(c.Param("digest"))
if err != nil { if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
@ -1063,27 +981,27 @@ func (s *Server) GenerateRoutes() http.Handler {
allowedHostsMiddleware(s.addr), allowedHostsMiddleware(s.addr),
) )
r.POST("/api/pull", PullModelHandler) r.POST("/api/pull", s.PullModelHandler)
r.POST("/api/generate", GenerateHandler) r.POST("/api/generate", s.GenerateHandler)
r.POST("/api/chat", ChatHandler) r.POST("/api/chat", s.ChatHandler)
r.POST("/api/embeddings", EmbeddingsHandler) r.POST("/api/embeddings", s.EmbeddingsHandler)
r.POST("/api/create", CreateModelHandler) r.POST("/api/create", s.CreateModelHandler)
r.POST("/api/push", PushModelHandler) r.POST("/api/push", s.PushModelHandler)
r.POST("/api/copy", CopyModelHandler) r.POST("/api/copy", s.CopyModelHandler)
r.DELETE("/api/delete", DeleteModelHandler) r.DELETE("/api/delete", s.DeleteModelHandler)
r.POST("/api/show", ShowModelHandler) r.POST("/api/show", s.ShowModelHandler)
r.POST("/api/blobs/:digest", CreateBlobHandler) r.POST("/api/blobs/:digest", s.CreateBlobHandler)
r.HEAD("/api/blobs/:digest", HeadBlobHandler) r.HEAD("/api/blobs/:digest", s.HeadBlobHandler)
// Compatibility endpoints // Compatibility endpoints
r.POST("/v1/chat/completions", openai.Middleware(), ChatHandler) r.POST("/v1/chat/completions", openai.Middleware(), s.ChatHandler)
for _, method := range []string{http.MethodGet, http.MethodHead} { for _, method := range []string{http.MethodGet, http.MethodHead} {
r.Handle(method, "/", func(c *gin.Context) { r.Handle(method, "/", func(c *gin.Context) {
c.String(http.StatusOK, "Ollama is running") c.String(http.StatusOK, "Ollama is running")
}) })
r.Handle(method, "/api/tags", ListModelsHandler) r.Handle(method, "/api/tags", s.ListModelsHandler)
r.Handle(method, "/api/version", func(c *gin.Context) { r.Handle(method, "/api/version", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"version": version.Version}) c.JSON(http.StatusOK, gin.H{"version": version.Version})
}) })
@ -1137,7 +1055,9 @@ func Serve(ln net.Listener) error {
} }
} }
s := &Server{addr: ln.Addr()} ctx, done := context.WithCancel(context.Background())
sched := InitScheduler(ctx)
s := &Server{addr: ln.Addr(), sched: sched}
r := s.GenerateRoutes() r := s.GenerateRoutes()
slog.Info(fmt.Sprintf("Listening on %s (version %s)", ln.Addr(), version.Version)) slog.Info(fmt.Sprintf("Listening on %s (version %s)", ln.Addr(), version.Version))
@ -1150,7 +1070,8 @@ func Serve(ln net.Listener) error {
signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM) signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM)
go func() { go func() {
<-signals <-signals
unload() done()
sched.unloadAllRunners()
gpu.Cleanup() gpu.Cleanup()
os.Exit(0) os.Exit(0)
}() }()
@ -1158,12 +1079,12 @@ func Serve(ln net.Listener) error {
if err := llm.Init(); err != nil { if err := llm.Init(); err != nil {
return fmt.Errorf("unable to initialize llm library %w", err) return fmt.Errorf("unable to initialize llm library %w", err)
} }
if runtime.GOOS == "linux" { // TODO - windows too
// check compatibility to log warnings s.sched.Run(ctx)
if _, err := gpu.CheckVRAM(); err != nil {
slog.Info(err.Error()) // At startup we retrieve GPU information so we can get log messages before loading a model
} // This will log warnings to the log in case we have problems with detected GPUs
} _ = gpu.GetGPUInfo()
return srvr.Serve(ln) return srvr.Serve(ln)
} }
@ -1219,9 +1140,9 @@ func streamResponse(c *gin.Context, ch chan any) {
} }
// ChatPrompt builds up a prompt from a series of messages for the currently `loaded` model // ChatPrompt builds up a prompt from a series of messages for the currently `loaded` model
func chatPrompt(ctx context.Context, template string, messages []api.Message, numCtx int) (string, error) { func chatPrompt(ctx context.Context, runner *runnerRef, template string, messages []api.Message, numCtx int) (string, error) {
encode := func(s string) ([]int, error) { encode := func(s string) ([]int, error) {
return loaded.llama.Tokenize(ctx, s) return runner.llama.Tokenize(ctx, s)
} }
prompt, err := ChatPrompt(template, messages, numCtx, encode) prompt, err := ChatPrompt(template, messages, numCtx, encode)
@ -1232,10 +1153,7 @@ func chatPrompt(ctx context.Context, template string, messages []api.Message, nu
return prompt, nil return prompt, nil
} }
func ChatHandler(c *gin.Context) { func (s *Server) ChatHandler(c *gin.Context) {
loaded.mu.Lock()
defer loaded.mu.Unlock()
checkpointStart := time.Now() checkpointStart := time.Now()
var req api.ChatRequest var req api.ChatRequest
@ -1292,7 +1210,11 @@ func ChatHandler(c *gin.Context) {
sessionDuration = req.KeepAlive.Duration sessionDuration = req.KeepAlive.Duration
} }
if err := load(c, model, opts, sessionDuration); err != nil { rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration)
var runner *runnerRef
select {
case runner = <-rCh:
case err = <-eCh:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
} }
@ -1309,7 +1231,7 @@ func ChatHandler(c *gin.Context) {
}, req.Messages...) }, req.Messages...)
} }
prompt, err := chatPrompt(c.Request.Context(), model.Template, req.Messages, opts.NumCtx) prompt, err := chatPrompt(c.Request.Context(), runner, model.Template, req.Messages, opts.NumCtx)
if err != nil { if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return return
@ -1352,8 +1274,6 @@ func ChatHandler(c *gin.Context) {
defer close(ch) defer close(ch)
fn := func(r llm.CompletionResponse) { fn := func(r llm.CompletionResponse) {
// Update model expiration
loaded.expireTimer.Reset(sessionDuration)
resp := api.ChatResponse{ resp := api.ChatResponse{
Model: req.Model, Model: req.Model,
@ -1376,7 +1296,7 @@ func ChatHandler(c *gin.Context) {
ch <- resp ch <- resp
} }
if err := loaded.llama.Completion(c.Request.Context(), llm.CompletionRequest{ if err := runner.llama.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: prompt, Prompt: prompt,
Format: req.Format, Format: req.Format,
Images: images, Images: images,

525
server/sched.go Normal file
View file

@ -0,0 +1,525 @@
package server
import (
"context"
"errors"
"fmt"
"log/slog"
"os"
"reflect"
"sort"
"strconv"
"strings"
"sync"
"time"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/gpu"
"github.com/ollama/ollama/llm"
"golang.org/x/exp/slices"
)
type LlmRequest struct {
ctx context.Context //nolint:containedctx
model *Model
ggml *llm.GGML // TODO - how large is this, and do we need to free it after we've finished loading?
opts api.Options
sessionDuration time.Duration
successCh chan *runnerRef
errCh chan error
}
type Scheduler struct {
pendingReqCh chan *LlmRequest
finishedReqCh chan *LlmRequest
expiredCh chan *runnerRef
unloadedCh chan interface{}
loaded map[string]*runnerRef
loadedMu sync.Mutex
loadFn func(req *LlmRequest, gpus gpu.GpuInfoList)
newServerFn func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options) (llm.LlamaServer, error)
getGpuFn func() gpu.GpuInfoList
}
// TODO set this to zero after a release or two, to enable multiple models by default
var loadedMax = 1 // Maximum runners; < 1 maps to as many as will fit in VRAM (unlimited for CPU runners)
var maxQueuedRequests = 10 // TODO configurable
func InitScheduler(ctx context.Context) *Scheduler {
maxRunners := os.Getenv("OLLAMA_MAX_LOADED_MODELS")
if maxRunners != "" {
m, err := strconv.Atoi(maxRunners)
if err != nil {
slog.Error("invalid setting", "OLLAMA_MAX_LOADED_MODELS", maxRunners, "error", err)
} else {
loadedMax = m
}
}
sched := &Scheduler{
pendingReqCh: make(chan *LlmRequest, maxQueuedRequests),
finishedReqCh: make(chan *LlmRequest, maxQueuedRequests),
expiredCh: make(chan *runnerRef, maxQueuedRequests),
unloadedCh: make(chan interface{}, maxQueuedRequests),
loaded: make(map[string]*runnerRef),
newServerFn: llm.NewLlamaServer,
getGpuFn: gpu.GetGPUInfo,
}
sched.loadFn = sched.load
return sched
}
// context must be canceled to decrement ref count and release the runner
func (s *Scheduler) GetRunner(c context.Context, model *Model, opts api.Options, sessionDuration time.Duration) (chan *runnerRef, chan error) {
ggml, err := llm.LoadModel(model.ModelPath)
req := &LlmRequest{
ctx: c,
model: model,
ggml: ggml,
opts: opts,
sessionDuration: sessionDuration,
successCh: make(chan *runnerRef),
errCh: make(chan error, 1),
}
if err != nil {
req.errCh <- err
return req.successCh, req.errCh
}
select {
case s.pendingReqCh <- req:
default:
req.errCh <- fmt.Errorf("server busy, please try again. maximum pending requests exceeded")
}
return req.successCh, req.errCh
}
// Returns immediately, spawns go routines for the scheduler which will shutdown when ctx is done
func (s *Scheduler) Run(ctx context.Context) {
slog.Debug("starting llm scheduler")
go func() {
s.processPending(ctx)
}()
go func() {
s.processCompleted(ctx)
}()
}
func (s *Scheduler) processPending(ctx context.Context) {
for {
select {
case <-ctx.Done():
slog.Debug("shutting down scheduler pending loop")
return
case pending := <-s.pendingReqCh:
// Block other requests until we get this pending request running
for {
var runnerToExpire *runnerRef
s.loadedMu.Lock()
runner := s.loaded[pending.model.ModelPath]
loadedCount := len(s.loaded)
s.loadedMu.Unlock()
if runner != nil {
if runner.needsReload(ctx, pending) {
runnerToExpire = runner
} else {
// Runner is usable, return it
pending.useLoadedRunner(runner, s.finishedReqCh)
break
}
} else if loadedCount == 0 {
slog.Debug("loading first model", "model", pending.model.ModelPath)
gpus := s.getGpuFn()
g := pickBestFitGPUs(pending, gpus)
if g != nil {
gpus = g
}
s.loadFn(pending, gpus)
break
} else if loadedMax > 0 && loadedCount >= loadedMax {
slog.Debug("max runners achieved, unloading one to make room", "runner_count", loadedCount)
runnerToExpire = s.findRunnerToUnload(pending)
} else {
// More than one loaded model, so we have to see if the new one fits
// Get a refreshed GPU list
gpus := s.getGpuFn()
// Update free memory from currently loaded models
s.updateFreeSpace(gpus)
gpus = pickBestFitGPUs(pending, gpus)
if gpus != nil {
slog.Debug("new model fits with existing models, loading")
s.loadFn(pending, gpus)
break
}
runnerToExpire = s.findRunnerToUnload(pending)
}
if runnerToExpire == nil {
// Shouildn't happen
slog.Error("runner to expire was nil!")
continue
}
// Trigger an expiration to unload once it's done
runnerToExpire.refMu.Lock()
slog.Debug("resetting model to expire immediately to make room", "model", runnerToExpire.model, "refCount", runnerToExpire.refCount)
if runnerToExpire.expireTimer != nil {
runnerToExpire.expireTimer.Stop()
runnerToExpire.expireTimer = nil
}
runnerToExpire.sessionDuration = 0
if runnerToExpire.refCount <= 0 {
s.expiredCh <- runnerToExpire
}
runnerToExpire.refMu.Unlock()
// Wait for the unload to happen
// Note: at this point we're queueing up all incoming requests, even if they were for
// a different model that's loaded and not scheduled to be removed.
slog.Debug("waiting for pending requests to complete and unload to occur", "model", runnerToExpire.model)
select {
case <-ctx.Done():
slog.Debug("shutting down scheduler pending loop")
return
case <-s.unloadedCh:
slog.Debug("unload completed", "model", runnerToExpire.model)
continue
}
}
case <-s.unloadedCh:
// An unload request when there are no pending request can be ignored
slog.Debug("ignoring unload event with no pending requests")
}
}
}
func (s *Scheduler) processCompleted(ctx context.Context) {
// Process completed requests, expired timers, and unloading models
for {
select {
case <-ctx.Done():
slog.Debug("shutting down scheduler completed loop")
return
case finished := <-s.finishedReqCh:
s.loadedMu.Lock()
runner := s.loaded[finished.model.ModelPath]
s.loadedMu.Unlock()
if runner == nil {
slog.Error("finished requeset signal received after model unloaded", "model", finished.model.ModelPath)
continue
}
runner.refMu.Lock()
runner.refCount--
if runner.refCount <= 0 {
if runner.sessionDuration <= 0 {
slog.Debug("runner with zero duration has gone idle, expiring to unload", "model", runner.model)
if runner.expireTimer != nil {
runner.expireTimer.Stop()
runner.expireTimer = nil
}
s.expiredCh <- runner
} else if runner.expireTimer == nil {
slog.Debug("runner with non-zero duration has gone idle, adding timer", "model", runner.model, "duration", runner.sessionDuration)
runner.expireTimer = time.AfterFunc(runner.sessionDuration, func() {
slog.Debug("timer expired, expiring to unload", "model", runner.model)
runner.refMu.Lock()
defer runner.refMu.Unlock()
if runner.expireTimer != nil {
runner.expireTimer.Stop()
}
s.expiredCh <- runner
})
} else {
slog.Debug("runner with non-zero duration has gone idle, resetting timer", "model", runner.model, "duration", runner.sessionDuration)
runner.expireTimer.Reset(runner.sessionDuration)
}
}
slog.Debug("after processing request finished event", "model", runner.model, "refCount", runner.refCount)
runner.refMu.Unlock()
case runner := <-s.expiredCh:
slog.Debug("runner expired event received", "model", runner.model)
runner.refMu.Lock()
if runner.refCount > 0 {
// Shouldn't happen, but safeguard to ensure no leaked runners
slog.Debug("expired event with positive ref count, retrying", "model", runner.model, "refCount", runner.refCount)
go func(runner *runnerRef) {
// We can't unload yet, but want to as soon as the current request completes
// So queue up another expired event
time.Sleep(10 * time.Millisecond)
s.expiredCh <- runner
}(runner)
runner.refMu.Unlock()
continue
}
slog.Debug("got lock to unload", "model", runner.model)
runner.unload()
s.loadedMu.Lock()
delete(s.loaded, runner.model)
s.loadedMu.Unlock()
slog.Debug("runner released", "model", runner.model)
runner.refMu.Unlock()
slog.Debug("sending an unloaded event", "model", runner.model)
s.unloadedCh <- struct{}{}
}
}
}
// Complete the pending request and send the runner back to the requester
// Wires up a finished event after the request context is completed
// Updates session duration, and resets expiration timer
func (pending *LlmRequest) useLoadedRunner(runner *runnerRef, finished chan *LlmRequest) {
runner.refMu.Lock()
defer runner.refMu.Unlock()
runner.refCount++
runner.sessionDuration = pending.sessionDuration
pending.successCh <- runner
go func() {
<-pending.ctx.Done()
slog.Debug("context for request finished")
finished <- pending
}()
}
func (s *Scheduler) load(req *LlmRequest, gpus gpu.GpuInfoList) {
llama, err := s.newServerFn(gpus, req.model.ModelPath, req.ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts)
if err != nil {
// some older models are not compatible with newer versions of llama.cpp
// show a generalized compatibility error until there is a better way to
// check for model compatibility
if errors.Is(llm.ErrUnsupportedFormat, err) || strings.Contains(err.Error(), "failed to load model") {
err = fmt.Errorf("%v: this model may be incompatible with your version of Ollama. If you previously pulled this model, try updating it by running `ollama pull %s`", err, req.model.ShortName)
}
slog.Info("NewLlamaServer failed", "model", req.model.ModelPath, "error", err)
req.errCh <- err
return
}
runner := &runnerRef{}
runner.model = req.model.ModelPath
runner.adapters = req.model.AdapterPaths
runner.projectors = req.model.ProjectorPaths
runner.llama = llama
runner.Options = &req.opts
runner.sessionDuration = req.sessionDuration
runner.gpus = gpus
runner.estimatedVRAM = llama.EstimatedVRAM()
runner.loading = true
runner.refCount = 1
runner.refMu.Lock()
s.loadedMu.Lock()
s.loaded[req.model.ModelPath] = runner
slog.Info("loaded runners", "count", len(s.loaded))
s.loadedMu.Unlock()
go func() {
defer runner.refMu.Unlock()
if err = llama.WaitUntilRunning(req.ctx); err != nil {
slog.Error("error loading llama server", "error", err)
runner.refCount--
req.errCh <- err
slog.Debug("triggering expiration for failed load", "model", runner.model)
s.expiredCh <- runner
return
}
slog.Debug("finished setting up runner", "model", req.model.ModelPath)
runner.loading = false
go func() {
<-req.ctx.Done()
slog.Debug("context for request finished")
s.finishedReqCh <- req
}()
req.successCh <- runner
}()
}
func (s *Scheduler) updateFreeSpace(allGpus gpu.GpuInfoList) {
type predKey struct {
Library string
ID string
}
predMap := map[predKey]uint64{} // Sum up the total predicted usage per GPU for all runners
s.loadedMu.Lock()
for _, r := range s.loaded {
r.refMu.Lock()
gpuIDs := make([]string, 0, len(r.gpus))
if r.llama != nil {
// TODO this should be broken down by GPU instead of assuming uniform spread
estimatedVRAMPerGPU := r.llama.EstimatedVRAM() / uint64(len(r.gpus))
for _, gpu := range r.gpus {
gpuIDs = append(gpuIDs, gpu.ID)
}
for _, gpu := range allGpus {
if slices.Contains(gpuIDs, gpu.ID) {
predMap[predKey{gpu.Library, gpu.ID}] += estimatedVRAMPerGPU
}
}
} else {
slog.Warn("unexpected nil runner reference, memory prediction may be incorrect")
}
r.refMu.Unlock()
}
s.loadedMu.Unlock()
// Now that we've summed up all the GPU usage predictions across all the loaded runners, update the gpu list
for i := range allGpus {
if p, ok := predMap[predKey{allGpus[i].Library, allGpus[i].ID}]; ok {
slog.Debug("gpu reported", "gpu", allGpus[i].ID, "library", allGpus[i].Library, "available", format.HumanBytes2(allGpus[i].FreeMemory))
if p > allGpus[i].TotalMemory {
// Shouldn't happen
slog.Warn("predicted usage exceeds VRAM", "gpu", allGpus[i].ID, "totalMemory", allGpus[i].TotalMemory, "predicted", p)
allGpus[i].FreeMemory = 0
} else if (allGpus[i].TotalMemory - p) < allGpus[i].FreeMemory { // predicted free is smaller than reported free, use it
// TODO maybe we should just always trust our numbers, since cuda's free memory reporting is laggy
// and we might unload models we didn't actually need to. The risk is if some other GPU intensive app is loaded
// after we start our first runner, then we'll never acount for that, so picking the smallest free value seems prudent.
allGpus[i].FreeMemory = allGpus[i].TotalMemory - p
}
slog.Info("updated VRAM", "gpu", allGpus[i].ID, "library", allGpus[i].Library, "total", format.HumanBytes2(allGpus[i].TotalMemory), "available", format.HumanBytes2(allGpus[i].FreeMemory))
}
}
}
type runnerRef struct {
refMu sync.Mutex
// refCond sync.Cond // Signaled on transition from 1 -> 0 refCount
refCount uint // prevent unloading if > 0
// unloading bool // set to true when we are trying to unload the runner
llama llm.LlamaServer
loading bool // True only during initial load, then false forever
gpus gpu.GpuInfoList // Recorded at time of provisioning
estimatedVRAM uint64
sessionDuration time.Duration
expireTimer *time.Timer
model string
adapters []string
projectors []string
*api.Options
}
// The refMu must already be held when calling unload
func (runner *runnerRef) unload() {
if runner.llama != nil {
runner.llama.Close()
}
runner.llama = nil
runner.adapters = nil
runner.projectors = nil
runner.Options = nil
runner.gpus = nil
}
func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool {
slog.Debug("evaluating already loaded", "model", req.model.ModelPath)
runner.refMu.Lock()
defer runner.refMu.Unlock()
// Ignore the NumGPU settings for comparison
optsExisting := runner.Options.Runner
optsExisting.NumGPU = -1
optsNew := req.opts.Runner
optsNew.NumGPU = -1
timeout := 10 * time.Second
if runner.loading {
timeout = 2 * time.Minute // Initial load can take a long time for big models on slow systems...
}
ctx, cancel := context.WithTimeout(ctx, timeout) // BUG -
defer cancel()
if !reflect.DeepEqual(runner.adapters, req.model.AdapterPaths) || // have the adapters changed?
!reflect.DeepEqual(runner.projectors, req.model.ProjectorPaths) || // have the projectors changed?
!reflect.DeepEqual(optsExisting, optsNew) || // have the runner options changed?
runner.llama.Ping(ctx) != nil {
return true
}
return false
}
type ByDuration []*runnerRef
func (a ByDuration) Len() int { return len(a) }
func (a ByDuration) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
func (a ByDuration) Less(i, j int) bool {
// uint64 to turn negative time (never unload) to largest
return uint64(a[i].sessionDuration) < uint64(a[j].sessionDuration)
}
// TODO - future consideration to pick runners based on size
// type BySize []*runnerRef
// func (a BySize) Len() int { return len(a) }
// func (a BySize) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
// func (a BySize) Less(i, j int) bool { return a[i].estimatedVRAM < a[j].estimatedVRAM }
// pickBestFitGPUs will try to find the optimal placement of the model in the available GPUs where the model fully fits
// If the model can not be fit fully within the available GPU(s) nil is returned
func pickBestFitGPUs(req *LlmRequest, gpus gpu.GpuInfoList) gpu.GpuInfoList {
var estimatedVRAM uint64
for _, gl := range gpus.ByLibrary() {
var ok bool
sgl := append(make(gpu.GpuInfoList, 0, len(gl)), gl...)
// TODO - potentially sort by performance capability, existing models loaded, etc.
// Note: at present, this will favor more VRAM over faster GPU speed in mixed setups
sort.Sort(sort.Reverse(gpu.ByFreeMemory(sgl)))
// First attempt to fit the model into a single GPU
for _, g := range sgl {
if ok, estimatedVRAM = llm.PredictServerFit([]gpu.GpuInfo{g}, req.ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts); ok {
slog.Debug("new model will fit in available VRAM in single GPU, loading", "model", req.model.ModelPath, "gpu", g.ID, "available", g.FreeMemory, "required", format.HumanBytes2(estimatedVRAM))
return []gpu.GpuInfo{g}
}
}
// TODO future refinements
// - if multiple Libraries, see if any single GPU in any Library will fit
// - try subsets of GPUs instead of just falling back to 1 or all in a family
// Now try all the GPUs
if ok, estimatedVRAM = llm.PredictServerFit(gl, req.ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts); ok {
slog.Debug("new model will fit in available VRAM, loading", "model", req.model.ModelPath, "library", gl[0].Library, "required", format.HumanBytes2(estimatedVRAM))
return gl
}
}
return nil
}
// findRunnerToUnload finds a runner to unload to make room for a new model
func (s *Scheduler) findRunnerToUnload(req *LlmRequest) *runnerRef {
s.loadedMu.Lock()
runnerList := make([]*runnerRef, 0, len(s.loaded))
for _, r := range s.loaded {
runnerList = append(runnerList, r)
}
s.loadedMu.Unlock()
// In the future we can enhance the algorithm to be smarter about picking the optimal runner to unload
// e.g., if we have multiple options, will one make room for the request?
sort.Sort(ByDuration(runnerList))
// First try to find a runner that's already idle
for _, runner := range runnerList {
runner.refMu.Lock()
rc := runner.refCount
runner.refMu.Unlock()
if rc == 0 {
slog.Debug("found an idle runner to unload")
return runner
}
}
// None appear idle, just wait for the one with the shortest duration
slog.Debug("no idle runners, picking the shortest duration", "count", len(runnerList))
return runnerList[0]
}
func (s *Scheduler) unloadAllRunners() {
s.loadedMu.Lock()
defer s.loadedMu.Unlock()
for model, runner := range s.loaded {
if runner.llama != nil {
slog.Debug("shutting down runner", "model", model)
runner.llama.Close()
}
}
}

553
server/sched_test.go Normal file
View file

@ -0,0 +1,553 @@
package server
import (
"bytes"
"context"
"encoding/binary"
"fmt"
"log/slog"
"os"
"testing"
"time"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/app/lifecycle"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/gpu"
"github.com/ollama/ollama/llm"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func init() {
os.Setenv("OLLAMA_DEBUG", "1")
lifecycle.InitLogging()
}
func TestInitScheduler(t *testing.T) {
ctx, done := context.WithCancel(context.Background())
defer done()
initialMax := loadedMax
s := InitScheduler(ctx)
require.Equal(t, initialMax, loadedMax)
require.NotNil(t, s.loaded)
os.Setenv("OLLAMA_MAX_LOADED_MODELS", "blue")
s = InitScheduler(ctx)
require.Equal(t, initialMax, loadedMax)
require.NotNil(t, s.loaded)
os.Setenv("OLLAMA_MAX_LOADED_MODELS", "0")
s = InitScheduler(ctx)
require.Equal(t, 0, loadedMax)
require.NotNil(t, s.loaded)
}
func TestLoad(t *testing.T) {
ctx, done := context.WithTimeout(context.Background(), 5*time.Millisecond)
defer done()
s := InitScheduler(ctx)
req := &LlmRequest{
ctx: ctx,
model: &Model{ModelPath: "foo"},
successCh: make(chan *runnerRef, 1),
errCh: make(chan error, 1),
sessionDuration: 2,
}
// Fail to load model first
s.newServerFn = func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options) (llm.LlamaServer, error) {
return nil, fmt.Errorf("something failed to load model blah")
}
gpus := gpu.GpuInfoList{}
s.load(req, gpus)
require.Len(t, req.successCh, 0)
require.Len(t, req.errCh, 1)
require.Len(t, s.loaded, 0)
err := <-req.errCh
require.Contains(t, err.Error(), "this model may be incompatible")
server := &mockLlm{estimatedVRAM: 10}
s.newServerFn = func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options) (llm.LlamaServer, error) {
return server, nil
}
s.load(req, gpus)
select {
case err := <-req.errCh:
require.NoError(t, err)
case resp := <-req.successCh:
require.Equal(t, uint64(10), resp.estimatedVRAM)
require.Equal(t, uint(1), resp.refCount)
require.Len(t, s.loaded, 1)
}
req.model.ModelPath = "dummy_model_path"
server.waitResp = fmt.Errorf("wait failure")
s.load(req, gpus)
select {
case err := <-req.errCh:
require.Contains(t, err.Error(), "wait failure")
case resp := <-req.successCh:
t.Errorf("unexpected success %v", resp)
}
runner := s.loaded["dummy_model_path"]
require.NotNil(t, runner)
require.Equal(t, uint(0), runner.refCount)
require.Len(t, s.expiredCh, 1)
}
type bundle struct {
ctx context.Context //nolint:containedctx
ctxDone func()
srv *mockLlm
req *LlmRequest
}
func (scenario *bundle) newServer(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options) (llm.LlamaServer, error) {
return scenario.srv, nil
}
func newScenario(t *testing.T, ctx context.Context, modelName string, estimatedVRAM uint64) *bundle {
scenario := &bundle{}
scenario.ctx, scenario.ctxDone = context.WithCancel(ctx)
t.Helper()
f, err := os.CreateTemp(t.TempDir(), modelName)
assert.Nil(t, err)
defer f.Close()
gguf := llm.NewGGUFV3(binary.LittleEndian)
err = gguf.Encode(f, llm.KV{
"general.architecture": "llama",
"general.name": "name",
"llama.context_length": uint32(32),
"llama.embedding_length": uint32(4096),
"llama.block_count": uint32(1),
"llama.attention.head_count": uint32(32),
"llama.attention.head_count_kv": uint32(32),
"tokenizer.ggml.tokens": []string{" "},
"tokenizer.ggml.scores": []float32{0},
"tokenizer.ggml.token_type": []int32{0},
}, []llm.Tensor{
{Name: "blk.0.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: &bytes.Reader{}},
})
assert.Nil(t, err)
fname := f.Name()
model := &Model{Name: modelName, ModelPath: fname}
ggml, err := llm.LoadModel(model.ModelPath)
require.NoError(t, err)
scenario.req = &LlmRequest{
ctx: scenario.ctx,
model: model,
ggml: ggml,
sessionDuration: 5 * time.Millisecond,
successCh: make(chan *runnerRef, 1),
errCh: make(chan error, 1),
}
scenario.srv = &mockLlm{estimatedVRAM: estimatedVRAM}
return scenario
}
func TestRequests(t *testing.T) {
ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer done()
// Same model, same request
scenario1a := newScenario(t, ctx, "ollama-model-1", 10)
scenario1a.req.sessionDuration = 0
scenario1b := newScenario(t, ctx, "ollama-model-1", 11)
scenario1b.req.model = scenario1a.req.model
scenario1b.req.ggml = scenario1a.req.ggml
scenario1b.req.sessionDuration = 0
// simple reload of same model
scenario2a := newScenario(t, ctx, "ollama-model-1", 20)
scenario2a.req.model = scenario1a.req.model
scenario2a.req.ggml = scenario1a.req.ggml
// Multiple loaded models
scenario3a := newScenario(t, ctx, "ollama-model-3a", 1*format.GigaByte)
scenario3b := newScenario(t, ctx, "ollama-model-3b", 24*format.GigaByte)
scenario3c := newScenario(t, ctx, "ollama-model-3c", 30) // Needs prior unloaded
s := InitScheduler(ctx)
s.getGpuFn = func() gpu.GpuInfoList {
g := gpu.GpuInfo{Library: "metal"}
g.TotalMemory = 24 * format.GigaByte
g.FreeMemory = 12 * format.GigaByte
return []gpu.GpuInfo{g}
}
s.newServerFn = scenario1a.newServer
slog.Info("scenario1a")
s.pendingReqCh <- scenario1a.req
require.Len(t, s.pendingReqCh, 1)
s.Run(ctx)
select {
case resp := <-scenario1a.req.successCh:
require.Equal(t, resp.llama, scenario1a.srv)
require.Len(t, s.pendingReqCh, 0)
require.Len(t, scenario1a.req.errCh, 0)
case <-ctx.Done():
t.Errorf("timeout")
}
// Same runner as first request due to not needing a reload
s.newServerFn = scenario1b.newServer
slog.Info("scenario1b")
s.pendingReqCh <- scenario1b.req
select {
case resp := <-scenario1b.req.successCh:
require.Equal(t, resp.llama, scenario1a.srv)
require.Len(t, s.pendingReqCh, 0)
require.Len(t, scenario1b.req.errCh, 0)
case <-ctx.Done():
t.Errorf("timeout")
}
// Trigger a reload
s.newServerFn = scenario2a.newServer
scenario2a.req.model.AdapterPaths = []string{"new"}
slog.Info("scenario2a")
s.pendingReqCh <- scenario2a.req
// finish first two requests, so model can reload
time.Sleep(1 * time.Millisecond)
scenario1a.ctxDone()
scenario1b.ctxDone()
select {
case resp := <-scenario2a.req.successCh:
require.Equal(t, resp.llama, scenario2a.srv)
require.Len(t, s.pendingReqCh, 0)
require.Len(t, scenario2a.req.errCh, 0)
case <-ctx.Done():
t.Errorf("timeout")
}
loadedMax = 1
s.newServerFn = scenario3a.newServer
slog.Info("scenario3a")
s.pendingReqCh <- scenario3a.req
// finish prior request, so new model can load
time.Sleep(1 * time.Millisecond)
scenario2a.ctxDone()
select {
case resp := <-scenario3a.req.successCh:
require.Equal(t, resp.llama, scenario3a.srv)
require.Len(t, s.pendingReqCh, 0)
require.Len(t, scenario3a.req.errCh, 0)
case <-ctx.Done():
t.Errorf("timeout")
}
require.Len(t, s.loaded, 1)
loadedMax = 0
s.newServerFn = scenario3b.newServer
slog.Info("scenario3b")
s.pendingReqCh <- scenario3b.req
select {
case resp := <-scenario3b.req.successCh:
require.Equal(t, resp.llama, scenario3b.srv)
require.Len(t, s.pendingReqCh, 0)
require.Len(t, scenario3b.req.errCh, 0)
case <-ctx.Done():
t.Errorf("timeout")
}
require.Len(t, s.loaded, 2)
// Try to load a model that wont fit
s.newServerFn = scenario3c.newServer
slog.Info("scenario3c")
require.Len(t, s.loaded, 2)
scenario3a.ctxDone() // Won't help since this one isn't big enough to make room
time.Sleep(2 * time.Millisecond)
s.pendingReqCh <- scenario3c.req
// finish prior request, so new model can load
time.Sleep(6 * time.Millisecond)
require.Len(t, s.loaded, 1)
scenario3b.ctxDone()
select {
case resp := <-scenario3c.req.successCh:
require.Equal(t, resp.llama, scenario3c.srv)
require.Len(t, s.pendingReqCh, 0)
require.Len(t, scenario3c.req.errCh, 0)
case <-ctx.Done():
t.Errorf("timeout")
}
require.Len(t, s.loaded, 1)
}
func TestGetRunner(t *testing.T) {
ctx, done := context.WithTimeout(context.Background(), 20*time.Millisecond)
defer done()
// Same model, same request
scenario1a := newScenario(t, ctx, "ollama-model-1a", 10)
scenario1a.req.sessionDuration = 0
scenario1b := newScenario(t, ctx, "ollama-model-1b", 10)
scenario1b.req.sessionDuration = 0
scenario1c := newScenario(t, ctx, "ollama-model-1c", 10)
scenario1c.req.sessionDuration = 0
maxQueuedRequests = 1
s := InitScheduler(ctx)
s.getGpuFn = func() gpu.GpuInfoList {
g := gpu.GpuInfo{Library: "metal"}
g.TotalMemory = 24 * format.GigaByte
g.FreeMemory = 12 * format.GigaByte
return []gpu.GpuInfo{g}
}
s.newServerFn = scenario1a.newServer
slog.Info("scenario1a")
successCh1a, errCh1a := s.GetRunner(scenario1a.ctx, scenario1a.req.model, scenario1a.req.opts, scenario1a.req.sessionDuration)
require.Len(t, s.pendingReqCh, 1)
slog.Info("scenario1b")
successCh1b, errCh1b := s.GetRunner(scenario1b.ctx, scenario1b.req.model, scenario1b.req.opts, scenario1b.req.sessionDuration)
require.Len(t, s.pendingReqCh, 1)
require.Len(t, successCh1b, 0)
require.Len(t, errCh1b, 1)
err := <-errCh1b
require.Contains(t, err.Error(), "server busy")
s.Run(ctx)
select {
case resp := <-successCh1a:
require.Equal(t, resp.llama, scenario1a.srv)
require.Len(t, s.pendingReqCh, 0)
require.Len(t, errCh1a, 0)
case <-ctx.Done():
t.Errorf("timeout")
}
scenario1a.ctxDone()
require.Len(t, s.loaded, 1)
scenario1c.req.model.ModelPath = "bad path"
slog.Info("scenario1c")
successCh1c, errCh1c := s.GetRunner(scenario1c.ctx, scenario1c.req.model, scenario1c.req.opts, scenario1c.req.sessionDuration)
require.Len(t, s.pendingReqCh, 0)
require.Len(t, successCh1c, 0)
require.Len(t, errCh1c, 1)
err = <-errCh1c
require.Contains(t, err.Error(), "bad path")
scenario1b.ctxDone()
time.Sleep(5 * time.Millisecond)
require.Len(t, s.loaded, 0)
}
// TODO - add one scenario that triggers the bogus finished event with positive ref count
func TestPrematureExpired(t *testing.T) {
ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer done()
// Same model, same request
scenario1a := newScenario(t, ctx, "ollama-model-1a", 10)
s := InitScheduler(ctx)
s.getGpuFn = func() gpu.GpuInfoList {
g := gpu.GpuInfo{Library: "metal"}
g.TotalMemory = 24 * format.GigaByte
g.FreeMemory = 12 * format.GigaByte
return []gpu.GpuInfo{g}
}
s.newServerFn = scenario1a.newServer
successCh1a, errCh1a := s.GetRunner(scenario1a.ctx, scenario1a.req.model, scenario1a.req.opts, scenario1a.req.sessionDuration)
require.Len(t, s.pendingReqCh, 1)
s.Run(ctx)
select {
case resp := <-successCh1a:
require.Equal(t, resp.llama, scenario1a.srv)
require.Len(t, s.pendingReqCh, 0)
require.Len(t, errCh1a, 0)
require.Len(t, s.loaded, 1)
slog.Info("sending premature expired event now")
s.expiredCh <- resp // Shouldn't happen in real life, but make sure its safe
case <-ctx.Done():
t.Errorf("timeout")
}
time.Sleep(scenario1a.req.sessionDuration)
scenario1a.ctxDone()
time.Sleep(20 * time.Millisecond)
require.LessOrEqual(t, len(s.finishedReqCh), 1)
time.Sleep(10 * time.Millisecond)
require.Len(t, s.finishedReqCh, 0)
require.Len(t, s.loaded, 0)
// also shouldn't happen in real life
s.finishedReqCh <- scenario1a.req
time.Sleep(5 * time.Millisecond)
}
func TestUseLoadedRunner(t *testing.T) {
ctx, done := context.WithTimeout(context.Background(), 5*time.Millisecond)
req := &LlmRequest{
ctx: ctx,
successCh: make(chan *runnerRef, 1),
sessionDuration: 2,
}
finished := make(chan *LlmRequest)
llm1 := &mockLlm{}
r1 := &runnerRef{llama: llm1, sessionDuration: 1}
req.useLoadedRunner(r1, finished)
require.Equal(t, uint(1), r1.refCount)
require.Equal(t, time.Duration(2), r1.sessionDuration)
select {
case success := <-req.successCh:
require.Equal(t, r1, success)
case <-ctx.Done():
t.Errorf("timeout")
}
done()
fin := <-finished
require.Equal(t, req, fin)
}
func TestUpdateFreeSpace(t *testing.T) {
ctx, done := context.WithTimeout(context.Background(), 5*time.Millisecond)
defer done()
gpus := gpu.GpuInfoList{
{
Library: "a",
ID: "1",
},
{
Library: "a",
ID: "2",
},
}
gpus[0].TotalMemory = 1000
gpus[0].FreeMemory = 900
gpus[1].TotalMemory = 2000
gpus[1].FreeMemory = 1900
llm1 := &mockLlm{estimatedVRAM: 100}
llm2 := &mockLlm{estimatedVRAM: 200}
r1 := &runnerRef{llama: llm1, gpus: gpus}
r2 := &runnerRef{llama: llm2, gpus: gpus}
s := InitScheduler(ctx)
s.loaded["a"] = r1
s.loaded["b"] = r2
s.updateFreeSpace(gpus)
require.Equal(t, uint64(850), gpus[0].FreeMemory)
require.Equal(t, uint64(1850), gpus[1].FreeMemory)
}
func TestFindRunnerToUnload(t *testing.T) {
ctx, done := context.WithTimeout(context.Background(), 5*time.Millisecond)
defer done()
req := &LlmRequest{ctx: ctx}
r1 := &runnerRef{refCount: 1, sessionDuration: 1}
r2 := &runnerRef{sessionDuration: 2}
s := InitScheduler(ctx)
s.loaded["a"] = r1
s.loaded["b"] = r2
resp := s.findRunnerToUnload(req)
require.Equal(t, r2, resp)
r2.refCount = 1
resp = s.findRunnerToUnload(req)
require.Equal(t, r1, resp)
}
func TestNeedsReload(t *testing.T) {
ctx, done := context.WithTimeout(context.Background(), 5*time.Millisecond)
defer done()
llm := &mockLlm{}
runner := &runnerRef{
adapters: []string{"adapter1"},
projectors: []string{"projector1"},
Options: &api.Options{},
llama: llm,
}
req := &LlmRequest{
model: &Model{
AdapterPaths: []string{"adapter2"},
ProjectorPaths: []string{"projector2"},
},
opts: api.Options{},
}
resp := runner.needsReload(ctx, req)
require.True(t, resp)
req.model.AdapterPaths = runner.adapters
resp = runner.needsReload(ctx, req)
require.True(t, resp)
req.model.ProjectorPaths = runner.projectors
runner.loading = true
req.opts.NumBatch = 1234
resp = runner.needsReload(ctx, req)
require.True(t, resp)
req.opts.NumBatch = runner.Options.NumBatch
llm.pingResp = fmt.Errorf("foo")
resp = runner.needsReload(ctx, req)
require.True(t, resp)
llm.pingResp = nil
resp = runner.needsReload(ctx, req)
require.False(t, resp)
req.opts.NumGPU = 99
resp = runner.needsReload(ctx, req)
require.False(t, resp)
}
func TestUnloadAllRunners(t *testing.T) {
ctx, done := context.WithTimeout(context.Background(), 5*time.Millisecond)
defer done()
llm1 := &mockLlm{}
llm2 := &mockLlm{}
s := InitScheduler(ctx)
s.unloadAllRunners()
r1 := &runnerRef{llama: llm1}
r2 := &runnerRef{llama: llm2}
s.loaded["a"] = r1
s.loaded["b"] = r2
s.unloadAllRunners()
require.True(t, llm1.closeCalled)
require.True(t, llm2.closeCalled)
}
func TestUnload(t *testing.T) {
llm1 := &mockLlm{}
r1 := &runnerRef{llama: llm1}
r2 := &runnerRef{adapters: []string{"A"}}
r1.unload()
require.True(t, llm1.closeCalled)
r2.unload()
require.Nil(t, r2.adapters)
}
type mockLlm struct {
pingResp error
waitResp error
completionResp error
embeddingResp []float64
embeddingRespErr error
tokenizeResp []int
tokenizeRespErr error
detokenizeResp string
detonekizeRespErr error
closeResp error
closeCalled bool
estimatedVRAM uint64
}
func (s *mockLlm) Ping(ctx context.Context) error { return s.pingResp }
func (s *mockLlm) WaitUntilRunning(ctx context.Context) error { return s.waitResp }
func (s *mockLlm) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
return s.completionResp
}
func (s *mockLlm) Embedding(ctx context.Context, prompt string) ([]float64, error) {
return s.embeddingResp, s.embeddingRespErr
}
func (s *mockLlm) Tokenize(ctx context.Context, content string) ([]int, error) {
return s.tokenizeResp, s.tokenizeRespErr
}
func (s *mockLlm) Detokenize(ctx context.Context, tokens []int) (string, error) {
return s.detokenizeResp, s.detonekizeRespErr
}
func (s *mockLlm) Close() error {
s.closeCalled = true
return s.closeResp
}
func (s *mockLlm) EstimatedVRAM() uint64 { return s.estimatedVRAM }