diff --git a/api/client.go b/api/client.go index a1ebdcd4..101382ca 100644 --- a/api/client.go +++ b/api/client.go @@ -91,6 +91,13 @@ func ClientFromEnvironment() (*Client, error) { }, nil } +func NewClient(base *url.URL, http *http.Client) *Client { + return &Client{ + base: base, + http: http, + } +} + func (c *Client) do(ctx context.Context, method, path string, reqData, respData any) error { var reqBody io.Reader var data []byte diff --git a/format/bytes.go b/format/bytes.go index f4bcc8c5..9fdc8bcf 100644 --- a/format/bytes.go +++ b/format/bytes.go @@ -15,6 +15,7 @@ const ( KibiByte = Byte * 1024 MebiByte = KibiByte * 1024 + GibiByte = MebiByte * 1024 ) func HumanBytes(b int64) string { diff --git a/gpu/amd_common.go b/gpu/amd_common.go index cf3348a8..6fa4fce4 100644 --- a/gpu/amd_common.go +++ b/gpu/amd_common.go @@ -7,7 +7,7 @@ import ( "log/slog" "os" "path/filepath" - "strconv" + "runtime" "strings" ) @@ -35,22 +35,64 @@ func GetSupportedGFX(libDir string) ([]string, error) { return ret, nil } -func amdSetVisibleDevices(ids []int, skip map[int]interface{}) { - // Set the visible devices if not already set - // TODO - does sort order matter? - devices := []string{} - for i := range ids { - if _, skipped := skip[i]; skipped { +func rocmGetVisibleDevicesEnv(gpuInfo []GpuInfo) (string, string) { + ids := []string{} + for _, info := range gpuInfo { + if info.Library != "rocm" { + // TODO shouldn't happen if things are wired correctly... + slog.Debug("rocmGetVisibleDevicesEnv skipping over non-rocm device", "library", info.Library) continue } - devices = append(devices, strconv.Itoa(i)) + ids = append(ids, info.ID) + } + return "HIP_VISIBLE_DEVICES", strings.Join(ids, ",") +} + +func commonAMDValidateLibDir() (string, error) { + // We try to favor system paths first, so that we can wire up the subprocess to use + // the system version. Only use our bundled version if the system version doesn't work + // This gives users a more recovery options if versions have subtle problems at runtime + + // Prefer explicit HIP env var + hipPath := os.Getenv("HIP_PATH") + if hipPath != "" { + hipLibDir := filepath.Join(hipPath, "bin") + if rocmLibUsable(hipLibDir) { + slog.Debug("detected ROCM via HIP_PATH=" + hipPath) + return hipLibDir, nil + } } - val := strings.Join(devices, ",") - err := os.Setenv("HIP_VISIBLE_DEVICES", val) - if err != nil { - slog.Warn(fmt.Sprintf("failed to set env: %s", err)) - } else { - slog.Info("Setting HIP_VISIBLE_DEVICES=" + val) + // Scan the LD_LIBRARY_PATH or PATH + pathEnv := "LD_LIBRARY_PATH" + if runtime.GOOS == "windows" { + pathEnv = "PATH" } + + paths := os.Getenv(pathEnv) + for _, path := range filepath.SplitList(paths) { + d, err := filepath.Abs(path) + if err != nil { + continue + } + if rocmLibUsable(d) { + return d, nil + } + } + + // Well known location(s) + if rocmLibUsable(RocmStandardLocation) { + return RocmStandardLocation, nil + } + + // Installer payload location if we're running the installed binary + exe, err := os.Executable() + if err == nil { + rocmTargetDir := filepath.Join(filepath.Dir(exe), "rocm") + if rocmLibUsable(rocmTargetDir) { + slog.Debug("detected ROCM next to ollama executable " + rocmTargetDir) + return rocmTargetDir, nil + } + } + return "", fmt.Errorf("no suitable rocm found, falling back to CPU") } diff --git a/gpu/amd_hip_windows.go b/gpu/amd_hip_windows.go index 14a6c7d6..4e216132 100644 --- a/gpu/amd_hip_windows.go +++ b/gpu/amd_hip_windows.go @@ -69,7 +69,7 @@ func NewHipLib() (*HipLib, error) { func (hl *HipLib) Release() { err := windows.FreeLibrary(hl.dll) if err != nil { - slog.Warn(fmt.Sprintf("failed to unload amdhip64.dll: %s", err)) + slog.Warn("failed to unload amdhip64.dll", "error", err) } hl.dll = 0 } @@ -98,7 +98,7 @@ func (hl *HipLib) HipGetDeviceCount() int { return 0 } if status != hipSuccess { - slog.Warn(fmt.Sprintf("failed call to hipGetDeviceCount: %d %s", status, err)) + slog.Warn("failed call to hipGetDeviceCount", "status", status, "error", err) } return count } diff --git a/gpu/amd_linux.go b/gpu/amd_linux.go index 529fb8db..b049de0c 100644 --- a/gpu/amd_linux.go +++ b/gpu/amd_linux.go @@ -11,6 +11,8 @@ import ( "slices" "strconv" "strings" + + "github.com/ollama/ollama/format" ) // Discovery logic for AMD/ROCm GPUs @@ -24,9 +26,6 @@ const ( GPUTotalMemoryFileGlob = "mem_banks/*/properties" // size_in_bytes line GPUUsedMemoryFileGlob = "mem_banks/*/used_memory" RocmStandardLocation = "/opt/rocm/lib" - - // TODO find a better way to detect iGPU instead of minimum memory - IGPUMemLimit = 1024 * 1024 * 1024 // 512G is what they typically report, so anything less than 1G must be iGPU ) var ( @@ -35,14 +34,11 @@ var ( ) // Gather GPU information from the amdgpu driver if any supported GPUs are detected -// HIP_VISIBLE_DEVICES will be set if we detect a mix of unsupported and supported devices -// and the user hasn't already set this variable -func AMDGetGPUInfo(resp *GpuInfo) { - // TODO - DRY this out with windows +func AMDGetGPUInfo() []GpuInfo { + resp := []GpuInfo{} if !AMDDetected() { - return + return resp } - skip := map[int]interface{}{} // Opportunistic logging of driver version to aid in troubleshooting ver, err := AMDDriverVersion() @@ -50,160 +46,117 @@ func AMDGetGPUInfo(resp *GpuInfo) { slog.Info("AMD Driver: " + ver) } else { // TODO - if we see users crash and burn with the upstreamed kernel this can be adjusted to hard-fail rocm support and fallback to CPU - slog.Warn(fmt.Sprintf("ollama recommends running the https://www.amd.com/en/support/linux-drivers: %s", err)) + slog.Warn("ollama recommends running the https://www.amd.com/en/support/linux-drivers", "error", err) } - // If the user has specified exactly which GPUs to use, look up their memory - visibleDevices := os.Getenv("HIP_VISIBLE_DEVICES") - if visibleDevices != "" { - ids := []int{} - for _, idStr := range strings.Split(visibleDevices, ",") { - id, err := strconv.Atoi(idStr) - if err != nil { - slog.Warn(fmt.Sprintf("malformed HIP_VISIBLE_DEVICES=%s %s", visibleDevices, err)) - } else { - ids = append(ids, id) - } - } - amdProcMemLookup(resp, nil, ids) - return + // Determine if the user has already pre-selected which GPUs to look at, then ignore the others + var visibleDevices []string + hipVD := os.Getenv("HIP_VISIBLE_DEVICES") // zero based index only + rocrVD := os.Getenv("ROCR_VISIBLE_DEVICES") // zero based index or UUID, but consumer cards seem to not support UUID + gpuDO := os.Getenv("GPU_DEVICE_ORDINAL") // zero based index + switch { + // TODO is this priorty order right? + case hipVD != "": + visibleDevices = strings.Split(hipVD, ",") + case rocrVD != "": + visibleDevices = strings.Split(rocrVD, ",") + // TODO - since we don't yet support UUIDs, consider detecting and reporting here + // all our test systems show GPU-XX indicating UUID is not supported + case gpuDO != "": + visibleDevices = strings.Split(gpuDO, ",") } - // Gather GFX version information from all detected cards - gfx := AMDGFXVersions() - verStrings := []string{} - for i, v := range gfx { - verStrings = append(verStrings, v.ToGFXString()) - if v.Major == 0 { - // Silently skip CPUs - skip[i] = struct{}{} - continue - } - if v.Major < 9 { - // TODO consider this a build-time setting if we can support 8xx family GPUs - slog.Warn(fmt.Sprintf("amdgpu [%d] too old %s", i, v.ToGFXString())) - skip[i] = struct{}{} - } - } - slog.Info(fmt.Sprintf("detected amdgpu versions %v", verStrings)) - - // Abort if all GPUs are skipped - if len(skip) >= len(gfx) { - slog.Info("all detected amdgpus are skipped, falling back to CPU") - return - } - - // If we got this far, then we have at least 1 GPU that's a ROCm candidate, so make sure we have a lib - libDir, err := AMDValidateLibDir() - if err != nil { - slog.Warn(fmt.Sprintf("unable to verify rocm library, will use cpu: %s", err)) - return - } - - updateLibPath(libDir) - gfxOverride := os.Getenv("HSA_OVERRIDE_GFX_VERSION") - if gfxOverride == "" { - supported, err := GetSupportedGFX(libDir) + var supported []string + libDir := "" + + // The amdgpu driver always exposes the host CPU(s) first, but we have to skip them and subtract + // from the other IDs to get alignment with the HIP libraries expectations (zero is the first GPU, not the CPU) + matches, _ := filepath.Glob(GPUPropertiesFileGlob) + cpuCount := 0 + for _, match := range matches { + slog.Debug("evaluating amdgpu node " + match) + fp, err := os.Open(match) if err != nil { - slog.Warn(fmt.Sprintf("failed to lookup supported GFX types, falling back to CPU mode: %s", err)) - return - } - slog.Debug(fmt.Sprintf("rocm supported GPU types %v", supported)) - - for i, v := range gfx { - if !slices.Contains[[]string, string](supported, v.ToGFXString()) { - slog.Warn(fmt.Sprintf("amdgpu [%d] %s is not supported by %s %v", i, v.ToGFXString(), libDir, supported)) - // TODO - consider discrete markdown just for ROCM troubleshooting? - slog.Warn("See https://github.com/ollama/ollama/blob/main/docs/gpu.md#overrides for HSA_OVERRIDE_GFX_VERSION usage") - skip[i] = struct{}{} - } else { - slog.Info(fmt.Sprintf("amdgpu [%d] %s is supported", i, v.ToGFXString())) - } - } - } else { - slog.Debug("skipping rocm gfx compatibility check with HSA_OVERRIDE_GFX_VERSION=" + gfxOverride) - } - - if len(skip) >= len(gfx) { - slog.Info("all detected amdgpus are skipped, falling back to CPU") - return - } - - ids := make([]int, len(gfx)) - i := 0 - for k := range gfx { - ids[i] = k - i++ - } - amdProcMemLookup(resp, skip, ids) - if resp.memInfo.DeviceCount == 0 { - return - } - if len(skip) > 0 { - amdSetVisibleDevices(ids, skip) - } -} - -func updateLibPath(libDir string) { - ldPaths := []string{} - if val, ok := os.LookupEnv("LD_LIBRARY_PATH"); ok { - ldPaths = strings.Split(val, ":") - } - for _, d := range ldPaths { - if d == libDir { - return - } - } - val := strings.Join(append(ldPaths, libDir), ":") - slog.Debug("updated lib path", "LD_LIBRARY_PATH", val) - os.Setenv("LD_LIBRARY_PATH", val) -} - -// Walk the sysfs nodes for the available GPUs and gather information from them -// skipping over any devices in the skip map -func amdProcMemLookup(resp *GpuInfo, skip map[int]interface{}, ids []int) { - resp.memInfo.DeviceCount = 0 - resp.memInfo.TotalMemory = 0 - resp.memInfo.FreeMemory = 0 - slog.Debug("discovering VRAM for amdgpu devices") - if len(ids) == 0 { - entries, err := os.ReadDir(AMDNodesSysfsDir) - if err != nil { - slog.Warn(fmt.Sprintf("failed to read amdgpu sysfs %s - %s", AMDNodesSysfsDir, err)) - return - } - for _, node := range entries { - if !node.IsDir() { - continue - } - id, err := strconv.Atoi(node.Name()) - if err != nil { - slog.Warn("malformed amdgpu sysfs node id " + node.Name()) - continue - } - ids = append(ids, id) - } - } - slog.Debug(fmt.Sprintf("amdgpu devices %v", ids)) - - for _, id := range ids { - if _, skipped := skip[id]; skipped { + slog.Debug("failed to open sysfs node", "file", match, "error", err) continue } + defer fp.Close() + nodeID, err := strconv.Atoi(filepath.Base(filepath.Dir(match))) + if err != nil { + slog.Debug("failed to parse node ID", "error", err) + continue + } + + scanner := bufio.NewScanner(fp) + isCPU := false + var major, minor, patch uint64 + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + // Note: we could also use "cpu_cores_count X" where X is greater than zero to detect CPUs + if strings.HasPrefix(line, "gfx_target_version") { + ver := strings.Fields(line) + + // Detect CPUs + if len(ver) == 2 && ver[1] == "0" { + slog.Debug("detected CPU " + match) + isCPU = true + break + } + + if len(ver) != 2 || len(ver[1]) < 5 { + slog.Warn("malformed "+match, "gfx_target_version", line) + // If this winds up being a CPU, our offsets may be wrong + continue + } + l := len(ver[1]) + var err1, err2, err3 error + patch, err1 = strconv.ParseUint(ver[1][l-2:l], 10, 32) + minor, err2 = strconv.ParseUint(ver[1][l-4:l-2], 10, 32) + major, err3 = strconv.ParseUint(ver[1][:l-4], 10, 32) + if err1 != nil || err2 != nil || err3 != nil { + slog.Debug("malformed int " + line) + continue + } + } + + // TODO - any other properties we want to extract and record? + // vendor_id + device_id -> pci lookup for "Name" + // Other metrics that may help us understand relative performance between multiple GPUs + } + + if isCPU { + cpuCount++ + continue + } + + // CPUs are always first in the list + gpuID := nodeID - cpuCount + + // Shouldn't happen, but just in case... + if gpuID < 0 { + slog.Error("unexpected amdgpu sysfs data resulted in negative GPU ID, please set OLLAMA_DEBUG=1 and report an issue") + return []GpuInfo{} + } + + if int(major) < RocmComputeMin { + slog.Warn(fmt.Sprintf("amdgpu too old gfx%d%d%d", major, minor, patch), "gpu", gpuID) + continue + } + + // Look up the memory for the current node totalMemory := uint64(0) usedMemory := uint64(0) - // Adjust for sysfs vs HIP ids - propGlob := filepath.Join(AMDNodesSysfsDir, strconv.Itoa(id+1), GPUTotalMemoryFileGlob) + propGlob := filepath.Join(AMDNodesSysfsDir, strconv.Itoa(nodeID), GPUTotalMemoryFileGlob) propFiles, err := filepath.Glob(propGlob) if err != nil { - slog.Warn(fmt.Sprintf("error looking up total GPU memory: %s %s", propGlob, err)) + slog.Warn("error looking up total GPU memory", "glob", propGlob, "error", err) } // 1 or more memory banks - sum the values of all of them for _, propFile := range propFiles { fp, err := os.Open(propFile) if err != nil { - slog.Warn(fmt.Sprintf("failed to open sysfs node file %s: %s", propFile, err)) + slog.Warn("failed to open sysfs node", "file", propFile, "erroir", err) continue } defer fp.Close() @@ -226,49 +179,113 @@ func amdProcMemLookup(resp *GpuInfo, skip map[int]interface{}, ids []int) { } } if totalMemory == 0 { - slog.Warn(fmt.Sprintf("amdgpu [%d] reports zero total memory, skipping", id)) - skip[id] = struct{}{} + slog.Warn("amdgpu reports zero total memory", "gpu", gpuID) continue } - if totalMemory < IGPUMemLimit { - slog.Info(fmt.Sprintf("amdgpu [%d] appears to be an iGPU with %dM reported total memory, skipping", id, totalMemory/1024/1024)) - skip[id] = struct{}{} - continue - } - usedGlob := filepath.Join(AMDNodesSysfsDir, strconv.Itoa(id), GPUUsedMemoryFileGlob) + usedGlob := filepath.Join(AMDNodesSysfsDir, strconv.Itoa(nodeID), GPUUsedMemoryFileGlob) usedFiles, err := filepath.Glob(usedGlob) if err != nil { - slog.Warn(fmt.Sprintf("error looking up used GPU memory: %s %s", usedGlob, err)) + slog.Warn("error looking up used GPU memory", "glob", usedGlob, "error", err) continue } for _, usedFile := range usedFiles { fp, err := os.Open(usedFile) if err != nil { - slog.Warn(fmt.Sprintf("failed to open sysfs node file %s: %s", usedFile, err)) + slog.Warn("failed to open sysfs node", "file", usedFile, "error", err) continue } defer fp.Close() data, err := io.ReadAll(fp) if err != nil { - slog.Warn(fmt.Sprintf("failed to read sysfs node file %s: %s", usedFile, err)) + slog.Warn("failed to read sysfs node", "file", usedFile, "error", err) continue } used, err := strconv.ParseUint(strings.TrimSpace(string(data)), 10, 64) if err != nil { - slog.Warn(fmt.Sprintf("malformed used memory %s: %s", string(data), err)) + slog.Warn("malformed used memory", "data", string(data), "error", err) continue } usedMemory += used } - slog.Info(fmt.Sprintf("[%d] amdgpu totalMemory %dM", id, totalMemory/1024/1024)) - slog.Info(fmt.Sprintf("[%d] amdgpu freeMemory %dM", id, (totalMemory-usedMemory)/1024/1024)) - resp.memInfo.DeviceCount++ - resp.memInfo.TotalMemory += totalMemory - resp.memInfo.FreeMemory += (totalMemory - usedMemory) + + // iGPU detection, remove this check once we can support an iGPU variant of the rocm library + if totalMemory < IGPUMemLimit { + slog.Info("amdgpu appears to be an iGPU, skipping", "gpu", gpuID, "total", format.HumanBytes2(totalMemory)) + continue + } + + slog.Info("amdgpu memory", "gpu", gpuID, "total", format.HumanBytes2(totalMemory)) + slog.Info("amdgpu memory", "gpu", gpuID, "available", format.HumanBytes2(totalMemory-usedMemory)) + gpuInfo := GpuInfo{ + Library: "rocm", + memInfo: memInfo{ + TotalMemory: totalMemory, + FreeMemory: (totalMemory - usedMemory), + }, + ID: fmt.Sprintf("%d", gpuID), + // Name: not exposed in sysfs directly, would require pci device id lookup + Major: int(major), + Minor: int(minor), + Patch: int(patch), + MinimumMemory: rocmMinimumMemory, + } + + // If the user wants to filter to a subset of devices, filter out if we aren't a match + if len(visibleDevices) > 0 { + include := false + for _, visible := range visibleDevices { + if visible == gpuInfo.ID { + include = true + break + } + } + if !include { + slog.Info("filtering out device per user request", "id", gpuInfo.ID, "visible_devices", visibleDevices) + continue + } + } + + // Final validation is gfx compatibility - load the library if we haven't already loaded it + // even if the user overrides, we still need to validate the library + if libDir == "" { + libDir, err = AMDValidateLibDir() + if err != nil { + slog.Warn("unable to verify rocm library, will use cpu", "error", err) + return []GpuInfo{} + } + } + gpuInfo.DependencyPath = libDir + + if gfxOverride == "" { + // Only load supported list once + if len(supported) == 0 { + supported, err = GetSupportedGFX(libDir) + if err != nil { + slog.Warn("failed to lookup supported GFX types, falling back to CPU mode", "error", err) + return []GpuInfo{} + } + slog.Debug("rocm supported GPUs", "types", supported) + } + gfx := fmt.Sprintf("gfx%d%d%d", gpuInfo.Major, gpuInfo.Minor, gpuInfo.Patch) + if !slices.Contains[[]string, string](supported, gfx) { + slog.Warn("amdgpu is not supported", "gpu", gpuInfo.ID, "gpu_type", gfx, "library", libDir, "supported_types", supported) + // TODO - consider discrete markdown just for ROCM troubleshooting? + slog.Warn("See https://github.com/ollama/ollama/blob/main/docs/gpu.md#overrides for HSA_OVERRIDE_GFX_VERSION usage") + continue + } else { + slog.Info("amdgpu is supported", "gpu", gpuInfo.ID, "gpu_type", gfx) + } + } else { + slog.Debug("skipping rocm gfx compatibility check with HSA_OVERRIDE_GFX_VERSION=" + gfxOverride) + } + + // The GPU has passed all the verification steps and is supported + resp = append(resp, gpuInfo) } - if resp.memInfo.DeviceCount > 0 { - resp.Library = "rocm" + if len(resp) == 0 { + slog.Info("no compatible amdgpu devices detected") } + return resp } // Quick check for AMD driver so we can skip amdgpu discovery if not present @@ -280,87 +297,24 @@ func AMDDetected() bool { slog.Debug("amdgpu driver not detected " + sysfsDir) return false } else if err != nil { - slog.Debug(fmt.Sprintf("error looking up amd driver %s %s", sysfsDir, err)) + slog.Debug("error looking up amd driver", "path", sysfsDir, "error", err) return false } return true } -func setupLink(source, target string) error { - if err := os.RemoveAll(target); err != nil { - return fmt.Errorf("failed to remove old rocm directory %s %w", target, err) - } - if err := os.Symlink(source, target); err != nil { - return fmt.Errorf("failed to create link %s => %s %w", source, target, err) - } - slog.Debug(fmt.Sprintf("host rocm linked %s => %s", source, target)) - return nil -} - -// Ensure the AMD rocm lib dir is wired up // Prefer to use host installed ROCm, as long as it meets our minimum requirements // failing that, tell the user how to download it on their own func AMDValidateLibDir() (string, error) { - // We rely on the rpath compiled into our library to find rocm - // so we establish a symlink to wherever we find it on the system - // to /rocm - payloadsDir, err := PayloadsDir() - if err != nil { - return "", err - } - - // If we already have a rocm dependency wired, nothing more to do - rocmTargetDir := filepath.Clean(filepath.Join(payloadsDir, "..", "rocm")) - if rocmLibUsable(rocmTargetDir) { - return rocmTargetDir, nil - } - - // next to the running binary - exe, err := os.Executable() + libDir, err := commonAMDValidateLibDir() if err == nil { - peerDir := filepath.Dir(exe) - if rocmLibUsable(peerDir) { - slog.Debug("detected ROCM next to ollama executable " + peerDir) - return rocmTargetDir, setupLink(peerDir, rocmTargetDir) - } - peerDir = filepath.Join(filepath.Dir(exe), "rocm") - if rocmLibUsable(peerDir) { - slog.Debug("detected ROCM next to ollama executable " + peerDir) - return rocmTargetDir, setupLink(peerDir, rocmTargetDir) - } + return libDir, nil } // Well known ollama installer path installedRocmDir := "/usr/share/ollama/lib/rocm" if rocmLibUsable(installedRocmDir) { - return rocmTargetDir, setupLink(installedRocmDir, rocmTargetDir) - } - - // Prefer explicit HIP env var - hipPath := os.Getenv("HIP_PATH") - if hipPath != "" { - hipLibDir := filepath.Join(hipPath, "lib") - if rocmLibUsable(hipLibDir) { - slog.Debug("detected ROCM via HIP_PATH=" + hipPath) - return rocmTargetDir, setupLink(hipLibDir, rocmTargetDir) - } - } - - // Scan the library path for potential matches - ldPaths := strings.Split(os.Getenv("LD_LIBRARY_PATH"), ":") - for _, ldPath := range ldPaths { - d, err := filepath.Abs(ldPath) - if err != nil { - continue - } - if rocmLibUsable(d) { - return rocmTargetDir, setupLink(d, rocmTargetDir) - } - } - - // Well known location(s) - if rocmLibUsable("/opt/rocm/lib") { - return rocmTargetDir, setupLink("/opt/rocm/lib", rocmTargetDir) + return installedRocmDir, nil } // If we still haven't found a usable rocm, the user will have to install it on their own @@ -384,68 +338,3 @@ func AMDDriverVersion() (string, error) { } return strings.TrimSpace(string(verString)), nil } - -func AMDGFXVersions() map[int]Version { - // The amdgpu driver always exposes the host CPU as node 0, but we have to skip that and subtract one - // from the other IDs to get alignment with the HIP libraries expectations (zero is the first GPU, not the CPU) - res := map[int]Version{} - matches, _ := filepath.Glob(GPUPropertiesFileGlob) - for _, match := range matches { - fp, err := os.Open(match) - if err != nil { - slog.Debug(fmt.Sprintf("failed to open sysfs node file %s: %s", match, err)) - continue - } - defer fp.Close() - i, err := strconv.Atoi(filepath.Base(filepath.Dir(match))) - if err != nil { - slog.Debug(fmt.Sprintf("failed to parse node ID %s", err)) - continue - } - - if i == 0 { - // Skipping the CPU - continue - } - // Align with HIP IDs (zero is first GPU, not CPU) - i -= 1 - - scanner := bufio.NewScanner(fp) - for scanner.Scan() { - line := strings.TrimSpace(scanner.Text()) - if strings.HasPrefix(line, "gfx_target_version") { - ver := strings.Fields(line) - if len(ver) != 2 || len(ver[1]) < 5 { - if ver[1] != "0" { - slog.Debug("malformed " + line) - } - res[i] = Version{ - Major: 0, - Minor: 0, - Patch: 0, - } - continue - } - l := len(ver[1]) - patch, err1 := strconv.ParseUint(ver[1][l-2:l], 10, 32) - minor, err2 := strconv.ParseUint(ver[1][l-4:l-2], 10, 32) - major, err3 := strconv.ParseUint(ver[1][:l-4], 10, 32) - if err1 != nil || err2 != nil || err3 != nil { - slog.Debug("malformed int " + line) - continue - } - - res[i] = Version{ - Major: uint(major), - Minor: uint(minor), - Patch: uint(patch), - } - } - } - } - return res -} - -func (v Version) ToGFXString() string { - return fmt.Sprintf("gfx%d%d%d", v.Major, v.Minor, v.Patch) -} diff --git a/gpu/amd_windows.go b/gpu/amd_windows.go index be1be567..1bca3e10 100644 --- a/gpu/amd_windows.go +++ b/gpu/amd_windows.go @@ -7,7 +7,10 @@ import ( "os" "path/filepath" "slices" + "strconv" "strings" + + "github.com/ollama/ollama/format" ) const ( @@ -22,36 +25,32 @@ var ( ROCmLibGlobs = []string{"hipblas.dll", "rocblas"} // TODO - probably include more coverage of files here... ) -func AMDGetGPUInfo(resp *GpuInfo) { +func AMDGetGPUInfo() []GpuInfo { + resp := []GpuInfo{} hl, err := NewHipLib() if err != nil { slog.Debug(err.Error()) - return + return nil } defer hl.Release() - skip := map[int]interface{}{} - ids := []int{} - resp.memInfo.DeviceCount = 0 - resp.memInfo.TotalMemory = 0 - resp.memInfo.FreeMemory = 0 ver, err := hl.AMDDriverVersion() if err == nil { slog.Info("AMD Driver: " + ver) } else { // For now this is benign, but we may eventually need to fail compatibility checks - slog.Debug(fmt.Sprintf("error looking up amd driver version: %s", err)) + slog.Debug("error looking up amd driver version", "error", err) } - // Note: the HIP library automatically handles HIP_VISIBLE_DEVICES + // Note: the HIP library automatically handles subsetting to any HIP_VISIBLE_DEVICES the user specified count := hl.HipGetDeviceCount() if count == 0 { - return + return nil } libDir, err := AMDValidateLibDir() if err != nil { - slog.Warn(fmt.Sprintf("unable to verify rocm library, will use cpu: %s", err)) - return + slog.Warn("unable to verify rocm library, will use cpu", "error", err) + return nil } var supported []string @@ -59,95 +58,117 @@ func AMDGetGPUInfo(resp *GpuInfo) { if gfxOverride == "" { supported, err = GetSupportedGFX(libDir) if err != nil { - slog.Warn(fmt.Sprintf("failed to lookup supported GFX types, falling back to CPU mode: %s", err)) - return + slog.Warn("failed to lookup supported GFX types, falling back to CPU mode", "error", err) + return nil } } else { slog.Debug("skipping rocm gfx compatibility check with HSA_OVERRIDE_GFX_VERSION=" + gfxOverride) } - slog.Info(fmt.Sprintf("detected %d hip devices", count)) + slog.Info("detected hip devices", "count", count) + // TODO how to determine the underlying device ID when visible devices is causing this to subset? for i := 0; i < count; i++ { - ids = append(ids, i) err = hl.HipSetDevice(i) if err != nil { - slog.Warn(fmt.Sprintf("[%d] %s", i, err)) - skip[i] = struct{}{} + slog.Warn("set device", "id", i, "error", err) continue } props, err := hl.HipGetDeviceProperties(i) if err != nil { - slog.Warn(fmt.Sprintf("[%d] %s", i, err)) - skip[i] = struct{}{} + slog.Warn("get properties", "id", i, "error", err) continue } n := bytes.IndexByte(props.Name[:], 0) name := string(props.Name[:n]) - slog.Info(fmt.Sprintf("[%d] Name: %s", i, name)) + // TODO is UUID actually populated on windows? + // Can luid be used on windows for setting visible devices (and is it actually set?) n = bytes.IndexByte(props.GcnArchName[:], 0) gfx := string(props.GcnArchName[:n]) - slog.Info(fmt.Sprintf("[%d] GcnArchName: %s", i, gfx)) + slog.Info("hip device", "id", i, "name", name, "gfx", gfx) + var major, minor, patch string + switch len(gfx) { + case 6: + major, minor, patch = gfx[3:4], gfx[4:5], gfx[5:] + case 7: + major, minor, patch = gfx[3:5], gfx[5:6], gfx[6:] + } //slog.Info(fmt.Sprintf("[%d] Integrated: %d", i, props.iGPU)) // DOESN'T REPORT CORRECTLY! Always 0 // TODO Why isn't props.iGPU accurate!? if strings.EqualFold(name, iGPUName) { - slog.Info(fmt.Sprintf("iGPU detected [%d] skipping", i)) - skip[i] = struct{}{} + slog.Info("iGPU detected skipping", "id", i) continue } if gfxOverride == "" { if !slices.Contains[[]string, string](supported, gfx) { - slog.Warn(fmt.Sprintf("amdgpu [%d] %s is not supported by %s %v", i, gfx, libDir, supported)) + slog.Warn("amdgpu is not supported", "gpu", i, "gpu_type", gfx, "library", libDir, "supported_types", supported) // TODO - consider discrete markdown just for ROCM troubleshooting? slog.Warn("See https://github.com/ollama/ollama/blob/main/docs/troubleshooting.md for HSA_OVERRIDE_GFX_VERSION usage") - skip[i] = struct{}{} continue } else { - slog.Info(fmt.Sprintf("amdgpu [%d] %s is supported", i, gfx)) + slog.Info("amdgpu is supported", "gpu", i, "gpu_type", gfx) } } - totalMemory, freeMemory, err := hl.HipMemGetInfo() + freeMemory, totalMemory, err := hl.HipMemGetInfo() if err != nil { - slog.Warn(fmt.Sprintf("[%d] %s", i, err)) + slog.Warn("get mem info", "id", i, "error", err) continue } - // TODO according to docs, freeMem may lie on windows! - slog.Info(fmt.Sprintf("[%d] Total Mem: %d", i, totalMemory)) - slog.Info(fmt.Sprintf("[%d] Free Mem: %d", i, freeMemory)) - resp.memInfo.DeviceCount++ - resp.memInfo.TotalMemory += totalMemory - resp.memInfo.FreeMemory += freeMemory + // iGPU detection, remove this check once we can support an iGPU variant of the rocm library + if totalMemory < IGPUMemLimit { + slog.Info("amdgpu appears to be an iGPU, skipping", "gpu", i, "total", format.HumanBytes2(totalMemory)) + continue + } + + // TODO revisit this once ROCm v6 is available on windows. + // v5.7 only reports VRAM used by this process, so it's completely wrong and unusable + slog.Info("amdgpu memory", "gpu", i, "total", format.HumanBytes2(totalMemory)) + slog.Info("amdgpu memory", "gpu", i, "available", format.HumanBytes2(freeMemory)) + gpuInfo := GpuInfo{ + Library: "rocm", + memInfo: memInfo{ + TotalMemory: totalMemory, + FreeMemory: freeMemory, + }, + ID: fmt.Sprintf("%d", i), // TODO this is probably wrong if we specify visible devices + DependencyPath: libDir, + MinimumMemory: rocmMinimumMemory, + } + if major != "" { + gpuInfo.Major, err = strconv.Atoi(major) + if err != nil { + slog.Info("failed to parse version", "version", gfx, "error", err) + } + } + if minor != "" { + gpuInfo.Minor, err = strconv.Atoi(minor) + if err != nil { + slog.Info("failed to parse version", "version", gfx, "error", err) + } + } + if patch != "" { + gpuInfo.Patch, err = strconv.Atoi(patch) + if err != nil { + slog.Info("failed to parse version", "version", gfx, "error", err) + } + } + if gpuInfo.Major < RocmComputeMin { + slog.Warn(fmt.Sprintf("amdgpu [%s] too old gfx%d%d%d", gpuInfo.ID, gpuInfo.Major, gpuInfo.Minor, gpuInfo.Patch)) + continue + } + + resp = append(resp, gpuInfo) } - if resp.memInfo.DeviceCount > 0 { - resp.Library = "rocm" - } - // Abort if all GPUs are skipped - if len(skip) >= count { - slog.Info("all detected amdgpus are skipped, falling back to CPU") - return - } - if len(skip) > 0 { - amdSetVisibleDevices(ids, skip) - } - UpdatePath(libDir) + + return resp } func AMDValidateLibDir() (string, error) { - // On windows non-admins typically can't create links - // so instead of trying to rely on rpath and a link in - // $LibDir/rocm, we instead rely on setting PATH to point - // to the location of the ROCm library - - // Installer payload location if we're running the installed binary - exe, err := os.Executable() + libDir, err := commonAMDValidateLibDir() if err == nil { - rocmTargetDir := filepath.Join(filepath.Dir(exe), "rocm") - if rocmLibUsable(rocmTargetDir) { - slog.Debug("detected ROCM next to ollama executable " + rocmTargetDir) - return rocmTargetDir, nil - } + return libDir, nil } // Installer payload (if we're running from some other location) @@ -159,21 +180,6 @@ func AMDValidateLibDir() (string, error) { return rocmTargetDir, nil } - // Prefer explicit HIP env var - hipPath := os.Getenv("HIP_PATH") - if hipPath != "" { - hipLibDir := filepath.Join(hipPath, "bin") - if rocmLibUsable(hipLibDir) { - slog.Debug("detected ROCM via HIP_PATH=" + hipPath) - return hipLibDir, nil - } - } - - // Well known location(s) - if rocmLibUsable(RocmStandardLocation) { - return RocmStandardLocation, nil - } - // Should not happen on windows since we include it in the installer, but stand-alone binary might hit this slog.Warn("amdgpu detected, but no compatible rocm library found. Please install ROCm") return "", fmt.Errorf("no suitable rocm found, falling back to CPU") diff --git a/gpu/assets.go b/gpu/assets.go index 085c05bc..4915471b 100644 --- a/gpu/assets.go +++ b/gpu/assets.go @@ -80,7 +80,7 @@ func cleanupTmpDirs() { } err = os.RemoveAll(d) if err != nil { - slog.Debug(fmt.Sprintf("unable to cleanup stale tmpdir %s: %s", d, err)) + slog.Debug("unable to cleanup stale tmpdir", "path", d, "error", err) } } } @@ -120,7 +120,7 @@ func UpdatePath(dir string) { } } newPath := strings.Join(append([]string{dir}, pathComponents...), ";") - slog.Info(fmt.Sprintf("Updating PATH to %s", newPath)) + slog.Info("updating", "PATH", newPath) os.Setenv("PATH", newPath) } // linux and darwin rely on rpath diff --git a/gpu/cuda_common.go b/gpu/cuda_common.go new file mode 100644 index 00000000..03c1a25b --- /dev/null +++ b/gpu/cuda_common.go @@ -0,0 +1,22 @@ +//go:build linux || windows + +package gpu + +import ( + "log/slog" + "strings" +) + +func cudaGetVisibleDevicesEnv(gpuInfo []GpuInfo) (string, string) { + ids := []string{} + for _, info := range gpuInfo { + if info.Library != "cuda" { + // TODO shouldn't happen if things are wired correctly... + slog.Debug("cudaGetVisibleDevicesEnv skipping over non-cuda device", "library", info.Library) + continue + } + ids = append(ids, info.ID) + } + return "CUDA_VISIBLE_DEVICES", strings.Join(ids, ",") + +} diff --git a/gpu/gpu.go b/gpu/gpu.go index 47d70ed0..9b915015 100644 --- a/gpu/gpu.go +++ b/gpu/gpu.go @@ -16,7 +16,6 @@ import ( "os" "path/filepath" "runtime" - "strconv" "strings" "sync" "unsafe" @@ -25,8 +24,8 @@ import ( ) type handles struct { - nvml *C.nvml_handle_t - cudart *C.cudart_handle_t + deviceCount int + cudart *C.cudart_handle_t } const ( @@ -39,26 +38,10 @@ var gpuMutex sync.Mutex // With our current CUDA compile flags, older than 5.0 will not work properly var CudaComputeMin = [2]C.int{5, 0} -// Possible locations for the nvidia-ml library -var NvmlLinuxGlobs = []string{ - "/usr/local/cuda/lib64/libnvidia-ml.so*", - "/usr/lib/x86_64-linux-gnu/nvidia/current/libnvidia-ml.so*", - "/usr/lib/x86_64-linux-gnu/libnvidia-ml.so*", - "/usr/lib/wsl/lib/libnvidia-ml.so*", - "/usr/lib/wsl/drivers/*/libnvidia-ml.so*", - "/opt/cuda/lib64/libnvidia-ml.so*", - "/usr/lib*/libnvidia-ml.so*", - "/usr/lib/aarch64-linux-gnu/nvidia/current/libnvidia-ml.so*", - "/usr/lib/aarch64-linux-gnu/libnvidia-ml.so*", - "/usr/local/lib*/libnvidia-ml.so*", +var RocmComputeMin = 9 - // TODO: are these stubs ever valid? - "/opt/cuda/targets/x86_64-linux/lib/stubs/libnvidia-ml.so*", -} - -var NvmlWindowsGlobs = []string{ - "c:\\Windows\\System32\\nvml.dll", -} +// TODO find a better way to detect iGPU instead of minimum memory +const IGPUMemLimit = 1 * format.GibiByte // 512G is what they typically report, so anything less than 1G must be iGPU var CudartLinuxGlobs = []string{ "/usr/local/cuda/lib64/libcudart.so*", @@ -88,26 +71,18 @@ func initGPUHandles() *handles { // TODO - if the ollama build is CPU only, don't do these checks as they're irrelevant and confusing - gpuHandles := &handles{nil, nil} - var nvmlMgmtName string - var nvmlMgmtPatterns []string + gpuHandles := &handles{} var cudartMgmtName string var cudartMgmtPatterns []string tmpDir, _ := PayloadsDir() switch runtime.GOOS { case "windows": - nvmlMgmtName = "nvml.dll" - nvmlMgmtPatterns = make([]string, len(NvmlWindowsGlobs)) - copy(nvmlMgmtPatterns, NvmlWindowsGlobs) cudartMgmtName = "cudart64_*.dll" localAppData := os.Getenv("LOCALAPPDATA") cudartMgmtPatterns = []string{filepath.Join(localAppData, "Programs", "Ollama", cudartMgmtName)} cudartMgmtPatterns = append(cudartMgmtPatterns, CudartWindowsGlobs...) case "linux": - nvmlMgmtName = "libnvidia-ml.so" - nvmlMgmtPatterns = make([]string, len(NvmlLinuxGlobs)) - copy(nvmlMgmtPatterns, NvmlLinuxGlobs) cudartMgmtName = "libcudart.so*" if tmpDir != "" { // TODO - add "payloads" for subprocess @@ -118,31 +93,21 @@ func initGPUHandles() *handles { return gpuHandles } - slog.Info("Detecting GPU type") + slog.Info("Detecting GPUs") cudartLibPaths := FindGPULibs(cudartMgmtName, cudartMgmtPatterns) if len(cudartLibPaths) > 0 { - cudart := LoadCUDARTMgmt(cudartLibPaths) + deviceCount, cudart, libPath := LoadCUDARTMgmt(cudartLibPaths) if cudart != nil { - slog.Info("Nvidia GPU detected via cudart") + slog.Info("detected GPUs", "library", libPath, "count", deviceCount) gpuHandles.cudart = cudart - return gpuHandles - } - } - - // TODO once we build confidence, remove this and the gpu_info_nvml.[ch] files - nvmlLibPaths := FindGPULibs(nvmlMgmtName, nvmlMgmtPatterns) - if len(nvmlLibPaths) > 0 { - nvml := LoadNVMLMgmt(nvmlLibPaths) - if nvml != nil { - slog.Info("Nvidia GPU detected via nvidia-ml") - gpuHandles.nvml = nvml + gpuHandles.deviceCount = deviceCount return gpuHandles } } return gpuHandles } -func GetGPUInfo() GpuInfo { +func GetGPUInfo() GpuInfoList { // TODO - consider exploring lspci (and equivalent on windows) to check for // GPUs so we can report warnings if we see Nvidia/AMD but fail to load the libraries gpuMutex.Lock() @@ -150,9 +115,6 @@ func GetGPUInfo() GpuInfo { gpuHandles := initGPUHandles() defer func() { - if gpuHandles.nvml != nil { - C.nvml_release(*gpuHandles.nvml) - } if gpuHandles.cudart != nil { C.cudart_release(*gpuHandles.cudart) } @@ -165,72 +127,63 @@ func GetGPUInfo() GpuInfo { } var memInfo C.mem_info_t - resp := GpuInfo{} - if gpuHandles.nvml != nil && (cpuVariant != "" || runtime.GOARCH != "amd64") { - C.nvml_check_vram(*gpuHandles.nvml, &memInfo) + resp := []GpuInfo{} + + // NVIDIA first + for i := 0; i < gpuHandles.deviceCount; i++ { + // TODO once we support CPU compilation variants of GPU libraries refine this... + if cpuVariant == "" && runtime.GOARCH == "amd64" { + continue + } + gpuInfo := GpuInfo{ + Library: "cuda", + } + C.cudart_check_vram(*gpuHandles.cudart, C.int(i), &memInfo) if memInfo.err != nil { - slog.Info(fmt.Sprintf("[nvidia-ml] error looking up NVML GPU memory: %s", C.GoString(memInfo.err))) + slog.Info("error looking up nvidia GPU memory", "error", C.GoString(memInfo.err)) C.free(unsafe.Pointer(memInfo.err)) - } else if memInfo.count > 0 { - // Verify minimum compute capability - var cc C.nvml_compute_capability_t - C.nvml_compute_capability(*gpuHandles.nvml, &cc) - if cc.err != nil { - slog.Info(fmt.Sprintf("[nvidia-ml] error looking up NVML GPU compute capability: %s", C.GoString(cc.err))) - C.free(unsafe.Pointer(cc.err)) - } else if cc.major > CudaComputeMin[0] || (cc.major == CudaComputeMin[0] && cc.minor >= CudaComputeMin[1]) { - slog.Info(fmt.Sprintf("[nvidia-ml] NVML CUDA Compute Capability detected: %d.%d", cc.major, cc.minor)) - resp.Library = "cuda" - resp.MinimumMemory = cudaMinimumMemory - } else { - slog.Info(fmt.Sprintf("[nvidia-ml] CUDA GPU is too old. Falling back to CPU mode. Compute Capability detected: %d.%d", cc.major, cc.minor)) - } + continue } - } else if gpuHandles.cudart != nil && (cpuVariant != "" || runtime.GOARCH != "amd64") { - C.cudart_check_vram(*gpuHandles.cudart, &memInfo) - if memInfo.err != nil { - slog.Info(fmt.Sprintf("[cudart] error looking up CUDART GPU memory: %s", C.GoString(memInfo.err))) - C.free(unsafe.Pointer(memInfo.err)) - } else if memInfo.count > 0 { - // Verify minimum compute capability - var cc C.cudart_compute_capability_t - C.cudart_compute_capability(*gpuHandles.cudart, &cc) - if cc.err != nil { - slog.Info(fmt.Sprintf("[cudart] error looking up CUDA compute capability: %s", C.GoString(cc.err))) - C.free(unsafe.Pointer(cc.err)) - } else if cc.major > CudaComputeMin[0] || (cc.major == CudaComputeMin[0] && cc.minor >= CudaComputeMin[1]) { - slog.Info(fmt.Sprintf("[cudart] CUDART CUDA Compute Capability detected: %d.%d", cc.major, cc.minor)) - resp.Library = "cuda" - resp.MinimumMemory = cudaMinimumMemory - } else { - slog.Info(fmt.Sprintf("[cudart] CUDA GPU is too old. Falling back to CPU mode. Compute Capability detected: %d.%d", cc.major, cc.minor)) - } + if memInfo.major < CudaComputeMin[0] || (memInfo.major == CudaComputeMin[0] && memInfo.minor < CudaComputeMin[1]) { + slog.Info(fmt.Sprintf("[%d] CUDA GPU is too old. Compute Capability detected: %d.%d", i, memInfo.major, memInfo.minor)) + continue } - } else { - AMDGetGPUInfo(&resp) - if resp.Library != "" { - resp.MinimumMemory = rocmMinimumMemory - return resp - } - } - if resp.Library == "" { - C.cpu_check_ram(&memInfo) - resp.Library = "cpu" - resp.Variant = cpuVariant - } - if memInfo.err != nil { - slog.Info(fmt.Sprintf("error looking up CPU memory: %s", C.GoString(memInfo.err))) - C.free(unsafe.Pointer(memInfo.err)) - return resp + gpuInfo.TotalMemory = uint64(memInfo.total) + gpuInfo.FreeMemory = uint64(memInfo.free) + gpuInfo.ID = C.GoString(&memInfo.gpu_id[0]) + gpuInfo.Major = int(memInfo.major) + gpuInfo.Minor = int(memInfo.minor) + gpuInfo.MinimumMemory = cudaMinimumMemory + + // TODO potentially sort on our own algorithm instead of what the underlying GPU library does... + resp = append(resp, gpuInfo) + } + + // Then AMD + resp = append(resp, AMDGetGPUInfo()...) + + if len(resp) == 0 { + C.cpu_check_ram(&memInfo) + if memInfo.err != nil { + slog.Info("error looking up CPU memory", "error", C.GoString(memInfo.err)) + C.free(unsafe.Pointer(memInfo.err)) + return resp + } + gpuInfo := GpuInfo{ + Library: "cpu", + Variant: cpuVariant, + } + gpuInfo.TotalMemory = uint64(memInfo.total) + gpuInfo.FreeMemory = uint64(memInfo.free) + gpuInfo.ID = C.GoString(&memInfo.gpu_id[0]) + + resp = append(resp, gpuInfo) } - resp.DeviceCount = uint32(memInfo.count) - resp.FreeMemory = uint64(memInfo.free) - resp.TotalMemory = uint64(memInfo.total) return resp } -func getCPUMem() (memInfo, error) { +func GetCPUMem() (memInfo, error) { var ret memInfo var info C.mem_info_t C.cpu_check_ram(&info) @@ -243,29 +196,11 @@ func getCPUMem() (memInfo, error) { return ret, nil } -func CheckVRAM() (uint64, error) { - userLimit := os.Getenv("OLLAMA_MAX_VRAM") - if userLimit != "" { - avail, err := strconv.ParseInt(userLimit, 10, 64) - if err != nil { - return 0, fmt.Errorf("Invalid OLLAMA_MAX_VRAM setting %s: %s", userLimit, err) - } - slog.Info(fmt.Sprintf("user override OLLAMA_MAX_VRAM=%d", avail)) - return uint64(avail), nil - } - gpuInfo := GetGPUInfo() - if gpuInfo.FreeMemory > 0 && (gpuInfo.Library == "cuda" || gpuInfo.Library == "rocm") { - return gpuInfo.FreeMemory, nil - } - - return 0, fmt.Errorf("no GPU detected") // TODO - better handling of CPU based memory determiniation -} - func FindGPULibs(baseLibName string, patterns []string) []string { // Multiple GPU libraries may exist, and some may not work, so keep trying until we exhaust them var ldPaths []string gpuLibPaths := []string{} - slog.Info(fmt.Sprintf("Searching for GPU management library %s", baseLibName)) + slog.Debug("Searching for GPU library", "name", baseLibName) switch runtime.GOOS { case "windows": @@ -283,7 +218,7 @@ func FindGPULibs(baseLibName string, patterns []string) []string { } patterns = append(patterns, filepath.Join(d, baseLibName+"*")) } - slog.Debug(fmt.Sprintf("gpu management search paths: %v", patterns)) + slog.Debug("gpu library search", "globs", patterns) for _, pattern := range patterns { // Ignore glob discovery errors matches, _ := filepath.Glob(pattern) @@ -311,28 +246,11 @@ func FindGPULibs(baseLibName string, patterns []string) []string { } } } - slog.Info(fmt.Sprintf("Discovered GPU libraries: %v", gpuLibPaths)) + slog.Debug("discovered GPU libraries", "paths", gpuLibPaths) return gpuLibPaths } -func LoadNVMLMgmt(nvmlLibPaths []string) *C.nvml_handle_t { - var resp C.nvml_init_resp_t - resp.ch.verbose = getVerboseState() - for _, libPath := range nvmlLibPaths { - lib := C.CString(libPath) - defer C.free(unsafe.Pointer(lib)) - C.nvml_init(lib, &resp) - if resp.err != nil { - slog.Info(fmt.Sprintf("Unable to load NVML management library %s: %s", libPath, C.GoString(resp.err))) - C.free(unsafe.Pointer(resp.err)) - } else { - return &resp.ch - } - } - return nil -} - -func LoadCUDARTMgmt(cudartLibPaths []string) *C.cudart_handle_t { +func LoadCUDARTMgmt(cudartLibPaths []string) (int, *C.cudart_handle_t, string) { var resp C.cudart_init_resp_t resp.ch.verbose = getVerboseState() for _, libPath := range cudartLibPaths { @@ -340,13 +258,13 @@ func LoadCUDARTMgmt(cudartLibPaths []string) *C.cudart_handle_t { defer C.free(unsafe.Pointer(lib)) C.cudart_init(lib, &resp) if resp.err != nil { - slog.Info(fmt.Sprintf("Unable to load cudart CUDA management library %s: %s", libPath, C.GoString(resp.err))) + slog.Debug("Unable to load cudart", "library", libPath, "error", C.GoString(resp.err)) C.free(unsafe.Pointer(resp.err)) } else { - return &resp.ch + return int(resp.num_devices), &resp.ch, libPath } } - return nil + return 0, nil, "" } func getVerboseState() C.uint16_t { @@ -355,3 +273,22 @@ func getVerboseState() C.uint16_t { } return C.uint16_t(0) } + +// Given the list of GPUs this instantiation is targeted for, +// figure out the visible devices environment variable +// +// If different libraries are detected, the first one is what we use +func (l GpuInfoList) GetVisibleDevicesEnv() (string, string) { + if len(l) == 0 { + return "", "" + } + switch l[0].Library { + case "cuda": + return cudaGetVisibleDevicesEnv(l) + case "rocm": + return rocmGetVisibleDevicesEnv(l) + default: + slog.Debug("no filter required for library " + l[0].Library) + return "", "" + } +} diff --git a/gpu/gpu_darwin.go b/gpu/gpu_darwin.go index bf764ce6..2ff6b351 100644 --- a/gpu/gpu_darwin.go +++ b/gpu/gpu_darwin.go @@ -9,52 +9,41 @@ package gpu */ import "C" import ( - "fmt" - "log/slog" - "os" "runtime" - "strconv" ) -// CheckVRAM returns the free VRAM in bytes on Linux machines with NVIDIA GPUs -func CheckVRAM() (uint64, error) { - userLimit := os.Getenv("OLLAMA_MAX_VRAM") - if userLimit != "" { - avail, err := strconv.ParseInt(userLimit, 10, 64) - if err != nil { - return 0, fmt.Errorf("Invalid OLLAMA_MAX_VRAM setting %s: %s", userLimit, err) - } - slog.Info(fmt.Sprintf("user override OLLAMA_MAX_VRAM=%d", avail)) - return uint64(avail), nil - } - +func GetGPUInfo() GpuInfoList { + mem, _ := GetCPUMem() if runtime.GOARCH == "amd64" { - // gpu not supported, this may not be metal - return 0, nil - } - - return uint64(C.getRecommendedMaxVRAM()), nil -} - -func GetGPUInfo() GpuInfo { - mem, _ := getCPUMem() - if runtime.GOARCH == "amd64" { - return GpuInfo{ - Library: "cpu", - Variant: GetCPUVariant(), - memInfo: mem, + return []GpuInfo{ + { + Library: "cpu", + Variant: GetCPUVariant(), + memInfo: mem, + }, } } - return GpuInfo{ + info := GpuInfo{ Library: "metal", - memInfo: mem, + ID: "0", } + info.TotalMemory = uint64(C.getRecommendedMaxVRAM()) + + // TODO is there a way to gather actual allocated video memory? (currentAllocatedSize doesn't work) + info.FreeMemory = info.TotalMemory + + info.MinimumMemory = 0 + return []GpuInfo{info} } -func getCPUMem() (memInfo, error) { +func GetCPUMem() (memInfo, error) { return memInfo{ TotalMemory: uint64(C.getPhysicalMemory()), FreeMemory: 0, - DeviceCount: 1, }, nil } + +func (l GpuInfoList) GetVisibleDevicesEnv() (string, string) { + // No-op on darwin + return "", "" +} diff --git a/gpu/gpu_info.h b/gpu/gpu_info.h index 4c449a60..0f67442f 100644 --- a/gpu/gpu_info.h +++ b/gpu/gpu_info.h @@ -38,12 +38,17 @@ extern "C" { #endif +#define GPU_ID_LEN 64 + typedef struct mem_info { + char *err; // If non-nill, caller responsible for freeing + char gpu_id[GPU_ID_LEN]; uint64_t total; uint64_t free; - unsigned int count; - int igpu_index; // If >= 0, we detected an integrated GPU to ignore - char *err; // If non-nill, caller responsible for freeing + + // Compute Capability + int major; + int minor; } mem_info_t; void cpu_check_ram(mem_info_t *resp); @@ -52,7 +57,6 @@ void cpu_check_ram(mem_info_t *resp); } #endif -#include "gpu_info_nvml.h" #include "gpu_info_cudart.h" #endif // __GPU_INFO_H__ diff --git a/gpu/gpu_info_cpu.c b/gpu/gpu_info_cpu.c index 0c4d62c5..81ba3de4 100644 --- a/gpu/gpu_info_cpu.c +++ b/gpu/gpu_info_cpu.c @@ -8,9 +8,11 @@ void cpu_check_ram(mem_info_t *resp) { MEMORYSTATUSEX info; info.dwLength = sizeof(info); if (GlobalMemoryStatusEx(&info) != 0) { - resp->count = 1; resp->total = info.ullTotalPhys; resp->free = info.ullAvailPhys; + resp->major = 0; + resp->minor = 0; + snprintf(&resp->gpu_id[0], GPU_ID_LEN, "0"); } else { resp->err = LOAD_ERR(); } @@ -27,9 +29,11 @@ void cpu_check_ram(mem_info_t *resp) { if (sysinfo(&info) != 0) { resp->err = strdup(strerror(errno)); } else { - resp->count = 1; resp->total = info.totalram * info.mem_unit; resp->free = info.freeram * info.mem_unit; + resp->major = 0; + resp->minor = 0; + snprintf(&resp->gpu_id[0], GPU_ID_LEN, "0"); } return; } diff --git a/gpu/gpu_info_cudart.c b/gpu/gpu_info_cudart.c index 27cd2342..8e9204ea 100644 --- a/gpu/gpu_info_cudart.c +++ b/gpu/gpu_info_cudart.c @@ -6,6 +6,7 @@ void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) { cudartReturn_t ret; resp->err = NULL; + resp->num_devices = 0; const int buflen = 256; char buf[buflen + 1]; int i; @@ -21,6 +22,7 @@ void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) { {"cudaGetDeviceCount", (void *)&resp->ch.cudaGetDeviceCount}, {"cudaDeviceGetAttribute", (void *)&resp->ch.cudaDeviceGetAttribute}, {"cudaDriverGetVersion", (void *)&resp->ch.cudaDriverGetVersion}, + {"cudaGetDeviceProperties", (void *)&resp->ch.cudaGetDeviceProperties}, {NULL, NULL}, }; @@ -36,13 +38,7 @@ void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) { return; } - // TODO once we've squashed the remaining corner cases remove this log - LOG(resp->ch.verbose, "wiring cudart library functions in %s\n", cudart_lib_path); - for (i = 0; l[i].s != NULL; i++) { - // TODO once we've squashed the remaining corner cases remove this log - LOG(resp->ch.verbose, "dlsym: %s\n", l[i].s); - *l[i].p = LOAD_SYMBOL(resp->ch.handle, l[i].s); if (!l[i].p) { char *msg = LOAD_ERR(); @@ -63,7 +59,7 @@ void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) { UNLOAD_LIBRARY(resp->ch.handle); resp->ch.handle = NULL; if (ret == CUDA_ERROR_INSUFFICIENT_DRIVER) { - resp->err = strdup("your nvidia driver is too old or missing, please upgrade to run ollama"); + resp->err = strdup("your nvidia driver is too old or missing. If you have a CUDA GPU please upgrade to run ollama"); return; } snprintf(buf, buflen, "cudart init failure: %d", ret); @@ -85,110 +81,95 @@ void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) { driverVersion.minor = (version - (driverVersion.major * 1000)) / 10; LOG(resp->ch.verbose, "CUDA driver version: %d-%d\n", driverVersion.major, driverVersion.minor); } + + ret = (*resp->ch.cudaGetDeviceCount)(&resp->num_devices); + if (ret != CUDART_SUCCESS) { + LOG(resp->ch.verbose, "cudaGetDeviceCount err: %d\n", ret); + UNLOAD_LIBRARY(resp->ch.handle); + resp->ch.handle = NULL; + snprintf(buf, buflen, "unable to get device count: %d", ret); + resp->err = strdup(buf); + return; + } } -void cudart_check_vram(cudart_handle_t h, mem_info_t *resp) { +void cudart_check_vram(cudart_handle_t h, int i, mem_info_t *resp) { resp->err = NULL; cudartMemory_t memInfo = {0,0,0}; cudartReturn_t ret; const int buflen = 256; char buf[buflen + 1]; - int i; if (h.handle == NULL) { resp->err = strdup("cudart handle isn't initialized"); return; } - // cudaGetDeviceCount takes int type, resp-> count is uint - int deviceCount; - ret = (*h.cudaGetDeviceCount)(&deviceCount); + ret = (*h.cudaSetDevice)(i); if (ret != CUDART_SUCCESS) { - snprintf(buf, buflen, "unable to get device count: %d", ret); + snprintf(buf, buflen, "cudart device failed to initialize"); resp->err = strdup(buf); return; + } + + cudaDeviceProp_t props; + ret = (*h.cudaGetDeviceProperties)(&props, i); + if (ret != CUDART_SUCCESS) { + LOG(h.verbose, "[%d] device properties lookup failure: %d\n", i, ret); + snprintf(&resp->gpu_id[0], GPU_ID_LEN, "%d", i); + resp->major = 0; + resp->minor = 0; } else { - resp->count = (unsigned int)deviceCount; - } - - resp->total = 0; - resp->free = 0; - for (i = 0; i < resp-> count; i++) { - ret = (*h.cudaSetDevice)(i); - if (ret != CUDART_SUCCESS) { - snprintf(buf, buflen, "cudart device failed to initialize"); - resp->err = strdup(buf); - return; + int allNull = 1; + for (int j = 0; j < 16; j++) { + if (props.uuid.bytes[j] != 0) { + allNull = 0; + break; + } } - ret = (*h.cudaMemGetInfo)(&memInfo.free, &memInfo.total); - if (ret != CUDART_SUCCESS) { - snprintf(buf, buflen, "cudart device memory info lookup failure %d", ret); - resp->err = strdup(buf); - return; + if (allNull != 0) { + snprintf(&resp->gpu_id[0], GPU_ID_LEN, "%d", i); + } else { + // GPU-d110a105-ac29-1d54-7b49-9c90440f215b + snprintf(&resp->gpu_id[0], GPU_ID_LEN, + "GPU-%02x%02x%02x%02x-%02x%02x-%02x%02x-%02x%02x-%02x%02x%02x%02x%02x%02x", + props.uuid.bytes[0], + props.uuid.bytes[1], + props.uuid.bytes[2], + props.uuid.bytes[3], + props.uuid.bytes[4], + props.uuid.bytes[5], + props.uuid.bytes[6], + props.uuid.bytes[7], + props.uuid.bytes[8], + props.uuid.bytes[9], + props.uuid.bytes[10], + props.uuid.bytes[11], + props.uuid.bytes[12], + props.uuid.bytes[13], + props.uuid.bytes[14], + props.uuid.bytes[15] + ); } + resp->major = props.major; + resp->minor = props.minor; - LOG(h.verbose, "[%d] CUDA totalMem %lu\n", i, memInfo.total); - LOG(h.verbose, "[%d] CUDA freeMem %lu\n", i, memInfo.free); - - resp->total += memInfo.total; - resp->free += memInfo.free; + // TODO add other useful properties from props } -} - -void cudart_compute_capability(cudart_handle_t h, cudart_compute_capability_t *resp) { - resp->err = NULL; - resp->major = 0; - resp->minor = 0; - int major = 0; - int minor = 0; - cudartReturn_t ret; - const int buflen = 256; - char buf[buflen + 1]; - int i; - - if (h.handle == NULL) { - resp->err = strdup("cudart handle not initialized"); - return; - } - - int devices; - ret = (*h.cudaGetDeviceCount)(&devices); + ret = (*h.cudaMemGetInfo)(&memInfo.free, &memInfo.total); if (ret != CUDART_SUCCESS) { - snprintf(buf, buflen, "unable to get cudart device count: %d", ret); + snprintf(buf, buflen, "cudart device memory info lookup failure %d", ret); resp->err = strdup(buf); return; } - for (i = 0; i < devices; i++) { - ret = (*h.cudaSetDevice)(i); - if (ret != CUDART_SUCCESS) { - snprintf(buf, buflen, "cudart device failed to initialize"); - resp->err = strdup(buf); - return; - } + resp->total = memInfo.total; + resp->free = memInfo.free; - ret = (*h.cudaDeviceGetAttribute)(&major, cudartDevAttrComputeCapabilityMajor, i); - if (ret != CUDART_SUCCESS) { - snprintf(buf, buflen, "device compute capability lookup failure %d: %d", i, ret); - resp->err = strdup(buf); - return; - } - ret = (*h.cudaDeviceGetAttribute)(&minor, cudartDevAttrComputeCapabilityMinor, i); - if (ret != CUDART_SUCCESS) { - snprintf(buf, buflen, "device compute capability lookup failure %d: %d", i, ret); - resp->err = strdup(buf); - return; - } - - // Report the lowest major.minor we detect as that limits our compatibility - if (resp->major == 0 || resp->major > major ) { - resp->major = major; - resp->minor = minor; - } else if ( resp->major == major && resp->minor > minor ) { - resp->minor = minor; - } - } + LOG(h.verbose, "[%s] CUDA totalMem %lu\n", resp->gpu_id, resp->total); + LOG(h.verbose, "[%s] CUDA freeMem %lu\n", resp->gpu_id, resp->free); + LOG(h.verbose, "[%s] Compute Capability %d.%d\n", resp->gpu_id, resp->major, resp->minor); } void cudart_release(cudart_handle_t h) { diff --git a/gpu/gpu_info_cudart.h b/gpu/gpu_info_cudart.h index eb9336ec..ae2579a2 100644 --- a/gpu/gpu_info_cudart.h +++ b/gpu/gpu_info_cudart.h @@ -6,7 +6,8 @@ // Just enough typedef's to dlopen/dlsym for memory information typedef enum cudartReturn_enum { CUDART_SUCCESS = 0, - CUDART_UNSUPPORTED = 1, + CUDA_ERROR_INVALID_VALUE = 1, + CUDA_ERROR_MEMORY_ALLOCATION = 2, CUDA_ERROR_INSUFFICIENT_DRIVER = 35, // Other values omitted for now... } cudartReturn_t; @@ -14,6 +15,11 @@ typedef enum cudartReturn_enum { typedef enum cudartDeviceAttr_enum { cudartDevAttrComputeCapabilityMajor = 75, cudartDevAttrComputeCapabilityMinor = 76, + + // TODO - not yet wired up but may be useful for Jetson or other + // integrated GPU scenarios with shared memory + cudaDevAttrIntegrated = 18 + } cudartDeviceAttr_t; typedef void *cudartDevice_t; // Opaque is sufficient @@ -28,6 +34,92 @@ typedef struct cudartDriverVersion { int minor; } cudartDriverVersion_t; +typedef struct cudaUUID { + unsigned char bytes[16]; +} cudaUUID_t; +typedef struct cudaDeviceProp { + char name[256]; /**< ASCII string identifying device */ + cudaUUID_t uuid; /**< 16-byte unique identifier */ + char luid[8]; /**< 8-byte locally unique identifier. Value is undefined on TCC and non-Windows platforms */ + unsigned int luidDeviceNodeMask; /**< LUID device node mask. Value is undefined on TCC and non-Windows platforms */ + size_t totalGlobalMem; /**< Global memory available on device in bytes */ + size_t sharedMemPerBlock; /**< Shared memory available per block in bytes */ + int regsPerBlock; /**< 32-bit registers available per block */ + int warpSize; /**< Warp size in threads */ + size_t memPitch; /**< Maximum pitch in bytes allowed by memory copies */ + int maxThreadsPerBlock; /**< Maximum number of threads per block */ + int maxThreadsDim[3]; /**< Maximum size of each dimension of a block */ + int maxGridSize[3]; /**< Maximum size of each dimension of a grid */ + int clockRate; /**< Clock frequency in kilohertz */ + size_t totalConstMem; /**< Constant memory available on device in bytes */ + int major; /**< Major compute capability */ + int minor; /**< Minor compute capability */ + size_t textureAlignment; /**< Alignment requirement for textures */ + size_t texturePitchAlignment; /**< Pitch alignment requirement for texture references bound to pitched memory */ + int deviceOverlap; /**< Device can concurrently copy memory and execute a kernel. Deprecated. Use instead asyncEngineCount. */ + int multiProcessorCount; /**< Number of multiprocessors on device */ + int kernelExecTimeoutEnabled; /**< Specified whether there is a run time limit on kernels */ + int integrated; /**< Device is integrated as opposed to discrete */ + int canMapHostMemory; /**< Device can map host memory with cudaHostAlloc/cudaHostGetDevicePointer */ + int computeMode; /**< Compute mode (See ::cudaComputeMode) */ + int maxTexture1D; /**< Maximum 1D texture size */ + int maxTexture1DMipmap; /**< Maximum 1D mipmapped texture size */ + int maxTexture1DLinear; /**< Deprecated, do not use. Use cudaDeviceGetTexture1DLinearMaxWidth() or cuDeviceGetTexture1DLinearMaxWidth() instead. */ + int maxTexture2D[2]; /**< Maximum 2D texture dimensions */ + int maxTexture2DMipmap[2]; /**< Maximum 2D mipmapped texture dimensions */ + int maxTexture2DLinear[3]; /**< Maximum dimensions (width, height, pitch) for 2D textures bound to pitched memory */ + int maxTexture2DGather[2]; /**< Maximum 2D texture dimensions if texture gather operations have to be performed */ + int maxTexture3D[3]; /**< Maximum 3D texture dimensions */ + int maxTexture3DAlt[3]; /**< Maximum alternate 3D texture dimensions */ + int maxTextureCubemap; /**< Maximum Cubemap texture dimensions */ + int maxTexture1DLayered[2]; /**< Maximum 1D layered texture dimensions */ + int maxTexture2DLayered[3]; /**< Maximum 2D layered texture dimensions */ + int maxTextureCubemapLayered[2];/**< Maximum Cubemap layered texture dimensions */ + int maxSurface1D; /**< Maximum 1D surface size */ + int maxSurface2D[2]; /**< Maximum 2D surface dimensions */ + int maxSurface3D[3]; /**< Maximum 3D surface dimensions */ + int maxSurface1DLayered[2]; /**< Maximum 1D layered surface dimensions */ + int maxSurface2DLayered[3]; /**< Maximum 2D layered surface dimensions */ + int maxSurfaceCubemap; /**< Maximum Cubemap surface dimensions */ + int maxSurfaceCubemapLayered[2];/**< Maximum Cubemap layered surface dimensions */ + size_t surfaceAlignment; /**< Alignment requirements for surfaces */ + int concurrentKernels; /**< Device can possibly execute multiple kernels concurrently */ + int ECCEnabled; /**< Device has ECC support enabled */ + int pciBusID; /**< PCI bus ID of the device */ + int pciDeviceID; /**< PCI device ID of the device */ + int pciDomainID; /**< PCI domain ID of the device */ + int tccDriver; /**< 1 if device is a Tesla device using TCC driver, 0 otherwise */ + int asyncEngineCount; /**< Number of asynchronous engines */ + int unifiedAddressing; /**< Device shares a unified address space with the host */ + int memoryClockRate; /**< Peak memory clock frequency in kilohertz */ + int memoryBusWidth; /**< Global memory bus width in bits */ + int l2CacheSize; /**< Size of L2 cache in bytes */ + int persistingL2CacheMaxSize; /**< Device's maximum l2 persisting lines capacity setting in bytes */ + int maxThreadsPerMultiProcessor;/**< Maximum resident threads per multiprocessor */ + int streamPrioritiesSupported; /**< Device supports stream priorities */ + int globalL1CacheSupported; /**< Device supports caching globals in L1 */ + int localL1CacheSupported; /**< Device supports caching locals in L1 */ + size_t sharedMemPerMultiprocessor; /**< Shared memory available per multiprocessor in bytes */ + int regsPerMultiprocessor; /**< 32-bit registers available per multiprocessor */ + int managedMemory; /**< Device supports allocating managed memory on this system */ + int isMultiGpuBoard; /**< Device is on a multi-GPU board */ + int multiGpuBoardGroupID; /**< Unique identifier for a group of devices on the same multi-GPU board */ + int hostNativeAtomicSupported; /**< Link between the device and the host supports native atomic operations */ + int singleToDoublePrecisionPerfRatio; /**< Ratio of single precision performance (in floating-point operations per second) to double precision performance */ + int pageableMemoryAccess; /**< Device supports coherently accessing pageable memory without calling cudaHostRegister on it */ + int concurrentManagedAccess; /**< Device can coherently access managed memory concurrently with the CPU */ + int computePreemptionSupported; /**< Device supports Compute Preemption */ + int canUseHostPointerForRegisteredMem; /**< Device can access host registered memory at the same virtual address as the CPU */ + int cooperativeLaunch; /**< Device supports launching cooperative kernels via ::cudaLaunchCooperativeKernel */ + int cooperativeMultiDeviceLaunch; /**< Deprecated, cudaLaunchCooperativeKernelMultiDevice is deprecated. */ + size_t sharedMemPerBlockOptin; /**< Per device maximum shared memory per block usable by special opt in */ + int pageableMemoryAccessUsesHostPageTables; /**< Device accesses pageable memory via the host's page tables */ + int directManagedMemAccessFromHost; /**< Host can directly access managed memory on the device without migration. */ + int maxBlocksPerMultiProcessor; /**< Maximum number of resident blocks per multiprocessor */ + int accessPolicyMaxWindowSize; /**< The maximum value of ::cudaAccessPolicyWindow::num_bytes. */ + size_t reservedSharedMemPerBlock; /**< Shared memory reserved by CUDA driver per block in bytes */ + } cudaDeviceProp_t; + typedef struct cudart_handle { void *handle; uint16_t verbose; @@ -38,23 +130,17 @@ typedef struct cudart_handle { cudartReturn_t (*cudaGetDeviceCount)(int *); cudartReturn_t (*cudaDeviceGetAttribute)(int* value, cudartDeviceAttr_t attr, int device); cudartReturn_t (*cudaDriverGetVersion) (int *driverVersion); + cudartReturn_t (*cudaGetDeviceProperties) (cudaDeviceProp_t* prop, int device); } cudart_handle_t; typedef struct cudart_init_resp { char *err; // If err is non-null handle is invalid cudart_handle_t ch; + int num_devices; } cudart_init_resp_t; -typedef struct cudart_compute_capability { - char *err; - int major; - int minor; -} cudart_compute_capability_t; - - void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp); -void cudart_check_vram(cudart_handle_t ch, mem_info_t *resp); -void cudart_compute_capability(cudart_handle_t th, cudart_compute_capability_t *cc); +void cudart_check_vram(cudart_handle_t ch, int device_id, mem_info_t *resp); void cudart_release(cudart_handle_t ch); #endif // __GPU_INFO_CUDART_H__ diff --git a/gpu/gpu_info_nvml.c b/gpu/gpu_info_nvml.c deleted file mode 100644 index 67c80b0f..00000000 --- a/gpu/gpu_info_nvml.c +++ /dev/null @@ -1,221 +0,0 @@ -#ifndef __APPLE__ // TODO - maybe consider nvidia support on intel macs? - -#include - -#include "gpu_info_nvml.h" - -void nvml_init(char *nvml_lib_path, nvml_init_resp_t *resp) { - nvmlReturn_t ret; - resp->err = NULL; - const int buflen = 256; - char buf[buflen + 1]; - int i; - - struct lookup { - char *s; - void **p; - } l[] = { - {"nvmlInit_v2", (void *)&resp->ch.nvmlInit_v2}, - {"nvmlShutdown", (void *)&resp->ch.nvmlShutdown}, - {"nvmlDeviceGetHandleByIndex", (void *)&resp->ch.nvmlDeviceGetHandleByIndex}, - {"nvmlDeviceGetMemoryInfo", (void *)&resp->ch.nvmlDeviceGetMemoryInfo}, - {"nvmlDeviceGetCount_v2", (void *)&resp->ch.nvmlDeviceGetCount_v2}, - {"nvmlDeviceGetCudaComputeCapability", (void *)&resp->ch.nvmlDeviceGetCudaComputeCapability}, - {"nvmlSystemGetDriverVersion", (void *)&resp->ch.nvmlSystemGetDriverVersion}, - {"nvmlDeviceGetName", (void *)&resp->ch.nvmlDeviceGetName}, - {"nvmlDeviceGetSerial", (void *)&resp->ch.nvmlDeviceGetSerial}, - {"nvmlDeviceGetVbiosVersion", (void *)&resp->ch.nvmlDeviceGetVbiosVersion}, - {"nvmlDeviceGetBoardPartNumber", (void *)&resp->ch.nvmlDeviceGetBoardPartNumber}, - {"nvmlDeviceGetBrand", (void *)&resp->ch.nvmlDeviceGetBrand}, - {NULL, NULL}, - }; - - resp->ch.handle = LOAD_LIBRARY(nvml_lib_path, RTLD_LAZY); - if (!resp->ch.handle) { - char *msg = LOAD_ERR(); - LOG(resp->ch.verbose, "library %s load err: %s\n", nvml_lib_path, msg); - snprintf(buf, buflen, - "Unable to load %s library to query for Nvidia GPUs: %s", - nvml_lib_path, msg); - free(msg); - resp->err = strdup(buf); - return; - } - - // TODO once we've squashed the remaining corner cases remove this log - LOG(resp->ch.verbose, "wiring nvidia management library functions in %s\n", nvml_lib_path); - - for (i = 0; l[i].s != NULL; i++) { - // TODO once we've squashed the remaining corner cases remove this log - LOG(resp->ch.verbose, "dlsym: %s\n", l[i].s); - - *l[i].p = LOAD_SYMBOL(resp->ch.handle, l[i].s); - if (!l[i].p) { - resp->ch.handle = NULL; - char *msg = LOAD_ERR(); - LOG(resp->ch.verbose, "dlerr: %s\n", msg); - UNLOAD_LIBRARY(resp->ch.handle); - snprintf(buf, buflen, "symbol lookup for %s failed: %s", l[i].s, - msg); - free(msg); - resp->err = strdup(buf); - return; - } - } - - ret = (*resp->ch.nvmlInit_v2)(); - if (ret != NVML_SUCCESS) { - LOG(resp->ch.verbose, "nvmlInit_v2 err: %d\n", ret); - UNLOAD_LIBRARY(resp->ch.handle); - resp->ch.handle = NULL; - snprintf(buf, buflen, "nvml vram init failure: %d", ret); - resp->err = strdup(buf); - return; - } - - // Report driver version if we're in verbose mode, ignore errors - ret = (*resp->ch.nvmlSystemGetDriverVersion)(buf, buflen); - if (ret != NVML_SUCCESS) { - LOG(resp->ch.verbose, "nvmlSystemGetDriverVersion failed: %d\n", ret); - } else { - LOG(resp->ch.verbose, "CUDA driver version: %s\n", buf); - } -} - -void nvml_check_vram(nvml_handle_t h, mem_info_t *resp) { - resp->err = NULL; - nvmlDevice_t device; - nvmlMemory_t memInfo = {0}; - nvmlReturn_t ret; - const int buflen = 256; - char buf[buflen + 1]; - int i; - - if (h.handle == NULL) { - resp->err = strdup("nvml handle isn't initialized"); - return; - } - - ret = (*h.nvmlDeviceGetCount_v2)(&resp->count); - if (ret != NVML_SUCCESS) { - snprintf(buf, buflen, "unable to get device count: %d", ret); - resp->err = strdup(buf); - return; - } - - resp->total = 0; - resp->free = 0; - for (i = 0; i < resp->count; i++) { - ret = (*h.nvmlDeviceGetHandleByIndex)(i, &device); - if (ret != NVML_SUCCESS) { - snprintf(buf, buflen, "unable to get device handle %d: %d", i, ret); - resp->err = strdup(buf); - return; - } - - ret = (*h.nvmlDeviceGetMemoryInfo)(device, &memInfo); - if (ret != NVML_SUCCESS) { - snprintf(buf, buflen, "device memory info lookup failure %d: %d", i, ret); - resp->err = strdup(buf); - return; - } - if (h.verbose) { - nvmlBrandType_t brand = 0; - // When in verbose mode, report more information about - // the card we discover, but don't fail on error - ret = (*h.nvmlDeviceGetName)(device, buf, buflen); - if (ret != NVML_SUCCESS) { - LOG(h.verbose, "nvmlDeviceGetName failed: %d\n", ret); - } else { - LOG(h.verbose, "[%d] CUDA device name: %s\n", i, buf); - } - ret = (*h.nvmlDeviceGetBoardPartNumber)(device, buf, buflen); - if (ret != NVML_SUCCESS) { - LOG(h.verbose, "nvmlDeviceGetBoardPartNumber failed: %d\n", ret); - } else { - LOG(h.verbose, "[%d] CUDA part number: %s\n", i, buf); - } - ret = (*h.nvmlDeviceGetSerial)(device, buf, buflen); - if (ret != NVML_SUCCESS) { - LOG(h.verbose, "nvmlDeviceGetSerial failed: %d\n", ret); - } else { - LOG(h.verbose, "[%d] CUDA S/N: %s\n", i, buf); - } - ret = (*h.nvmlDeviceGetVbiosVersion)(device, buf, buflen); - if (ret != NVML_SUCCESS) { - LOG(h.verbose, "nvmlDeviceGetVbiosVersion failed: %d\n", ret); - } else { - LOG(h.verbose, "[%d] CUDA vbios version: %s\n", i, buf); - } - ret = (*h.nvmlDeviceGetBrand)(device, &brand); - if (ret != NVML_SUCCESS) { - LOG(h.verbose, "nvmlDeviceGetBrand failed: %d\n", ret); - } else { - LOG(h.verbose, "[%d] CUDA brand: %d\n", i, brand); - } - } - - LOG(h.verbose, "[%d] CUDA totalMem %ld\n", i, memInfo.total); - LOG(h.verbose, "[%d] CUDA freeMem %ld\n", i, memInfo.free); - - resp->total += memInfo.total; - resp->free += memInfo.free; - } -} - -void nvml_compute_capability(nvml_handle_t h, nvml_compute_capability_t *resp) { - resp->err = NULL; - resp->major = 0; - resp->minor = 0; - nvmlDevice_t device; - int major = 0; - int minor = 0; - nvmlReturn_t ret; - const int buflen = 256; - char buf[buflen + 1]; - int i; - - if (h.handle == NULL) { - resp->err = strdup("nvml handle not initialized"); - return; - } - - unsigned int devices; - ret = (*h.nvmlDeviceGetCount_v2)(&devices); - if (ret != NVML_SUCCESS) { - snprintf(buf, buflen, "unable to get device count: %d", ret); - resp->err = strdup(buf); - return; - } - - for (i = 0; i < devices; i++) { - ret = (*h.nvmlDeviceGetHandleByIndex)(i, &device); - if (ret != NVML_SUCCESS) { - snprintf(buf, buflen, "unable to get device handle %d: %d", i, ret); - resp->err = strdup(buf); - return; - } - - ret = (*h.nvmlDeviceGetCudaComputeCapability)(device, &major, &minor); - if (ret != NVML_SUCCESS) { - snprintf(buf, buflen, "device compute capability lookup failure %d: %d", i, ret); - resp->err = strdup(buf); - return; - } - // Report the lowest major.minor we detect as that limits our compatibility - if (resp->major == 0 || resp->major > major ) { - resp->major = major; - resp->minor = minor; - } else if ( resp->major == major && resp->minor > minor ) { - resp->minor = minor; - } - } -} - -void nvml_release(nvml_handle_t h) { - LOG(h.verbose, "releasing nvml library\n"); - UNLOAD_LIBRARY(h.handle); - h.handle = NULL; -} - -#endif // __APPLE__ \ No newline at end of file diff --git a/gpu/gpu_info_nvml.h b/gpu/gpu_info_nvml.h deleted file mode 100644 index bd1d6001..00000000 --- a/gpu/gpu_info_nvml.h +++ /dev/null @@ -1,57 +0,0 @@ -#ifndef __APPLE__ -#ifndef __GPU_INFO_NVML_H__ -#define __GPU_INFO_NVML_H__ -#include "gpu_info.h" - -// Just enough typedef's to dlopen/dlsym for memory information -typedef enum nvmlReturn_enum { - NVML_SUCCESS = 0, - // Other values omitted for now... -} nvmlReturn_t; -typedef void *nvmlDevice_t; // Opaque is sufficient -typedef struct nvmlMemory_st { - unsigned long long total; - unsigned long long free; - unsigned long long used; -} nvmlMemory_t; - -typedef enum nvmlBrandType_enum -{ - NVML_BRAND_UNKNOWN = 0, -} nvmlBrandType_t; - -typedef struct nvml_handle { - void *handle; - uint16_t verbose; - nvmlReturn_t (*nvmlInit_v2)(void); - nvmlReturn_t (*nvmlShutdown)(void); - nvmlReturn_t (*nvmlDeviceGetHandleByIndex)(unsigned int, nvmlDevice_t *); - nvmlReturn_t (*nvmlDeviceGetMemoryInfo)(nvmlDevice_t, nvmlMemory_t *); - nvmlReturn_t (*nvmlDeviceGetCount_v2)(unsigned int *); - nvmlReturn_t (*nvmlDeviceGetCudaComputeCapability)(nvmlDevice_t, int* major, int* minor); - nvmlReturn_t (*nvmlSystemGetDriverVersion) (char* version, unsigned int length); - nvmlReturn_t (*nvmlDeviceGetName) (nvmlDevice_t device, char* name, unsigned int length); - nvmlReturn_t (*nvmlDeviceGetSerial) (nvmlDevice_t device, char* serial, unsigned int length); - nvmlReturn_t (*nvmlDeviceGetVbiosVersion) (nvmlDevice_t device, char* version, unsigned int length); - nvmlReturn_t (*nvmlDeviceGetBoardPartNumber) (nvmlDevice_t device, char* partNumber, unsigned int length); - nvmlReturn_t (*nvmlDeviceGetBrand) (nvmlDevice_t device, nvmlBrandType_t* type); -} nvml_handle_t; - -typedef struct nvml_init_resp { - char *err; // If err is non-null handle is invalid - nvml_handle_t ch; -} nvml_init_resp_t; - -typedef struct nvml_compute_capability { - char *err; - int major; - int minor; -} nvml_compute_capability_t; - -void nvml_init(char *nvml_lib_path, nvml_init_resp_t *resp); -void nvml_check_vram(nvml_handle_t ch, mem_info_t *resp); -void nvml_compute_capability(nvml_handle_t ch, nvml_compute_capability_t *cc); -void nvml_release(nvml_handle_t ch); - -#endif // __GPU_INFO_NVML_H__ -#endif // __APPLE__ \ No newline at end of file diff --git a/gpu/gpu_test.go b/gpu/gpu_test.go index f57597b5..a28cbe8c 100644 --- a/gpu/gpu_test.go +++ b/gpu/gpu_test.go @@ -9,23 +9,16 @@ import ( func TestBasicGetGPUInfo(t *testing.T) { info := GetGPUInfo() - assert.Contains(t, "cuda rocm cpu metal", info.Library) - - switch runtime.GOOS { - case "darwin": - // TODO - remove this once MacOS returns some size for CPU - return - case "linux", "windows": - assert.Greater(t, info.TotalMemory, uint64(0)) - assert.Greater(t, info.FreeMemory, uint64(0)) - assert.Greater(t, info.DeviceCount, uint32(0)) - default: - return + assert.Greater(t, len(info), 0) + assert.Contains(t, "cuda rocm cpu metal", info[0].Library) + if info[0].Library != "cpu" { + assert.Greater(t, info[0].TotalMemory, uint64(0)) + assert.Greater(t, info[0].FreeMemory, uint64(0)) } } func TestCPUMemInfo(t *testing.T) { - info, err := getCPUMem() + info, err := GetCPUMem() assert.NoError(t, err) switch runtime.GOOS { case "darwin": diff --git a/gpu/types.go b/gpu/types.go index 7fe6c40c..7a5d5ba7 100644 --- a/gpu/types.go +++ b/gpu/types.go @@ -3,7 +3,6 @@ package gpu type memInfo struct { TotalMemory uint64 `json:"total_memory,omitempty"` FreeMemory uint64 `json:"free_memory,omitempty"` - DeviceCount uint32 `json:"device_count,omitempty"` } // Beginning of an `ollama info` command @@ -17,11 +16,49 @@ type GpuInfo struct { // MinimumMemory represents the minimum memory required to use the GPU MinimumMemory uint64 `json:"-"` - // TODO add other useful attributes about the card here for discovery information + // Any extra PATH/LD_LIBRARY_PATH dependencies required for the Library to operate properly + DependencyPath string `json:"lib_path,omitempty"` + + // GPU information + ID string `json:"gpu_id"` // string to use for selection of this specific GPU + Name string `json:"name"` // user friendly name if available + Major int `json:"major,omitempty"` // Major compatibility version (CC or gfx) + Minor int `json:"minor,omitempty"` // Minor compatibility version (CC or gfx) + Patch int `json:"patch,omitempty"` // Patch compatibility only matters on AMD + + // TODO other performance capability info to help in scheduling decisions } -type Version struct { - Major uint - Minor uint - Patch uint +type GpuInfoList []GpuInfo + +// Split up the set of gpu info's by Library and variant +func (l GpuInfoList) ByLibrary() []GpuInfoList { + resp := []GpuInfoList{} + libs := []string{} + for _, info := range l { + found := false + requested := info.Library + if info.Variant != "" { + requested += "_" + info.Variant + } + for i, lib := range libs { + if lib == requested { + resp[i] = append(resp[i], info) + found = true + break + } + } + if !found { + libs = append(libs, info.Library) + resp = append(resp, []GpuInfo{info}) + } + } + return resp } + +// Sort by Free Space +type ByFreeMemory []GpuInfo + +func (a ByFreeMemory) Len() int { return len(a) } +func (a ByFreeMemory) Swap(i, j int) { a[i], a[j] = a[j], a[i] } +func (a ByFreeMemory) Less(i, j int) bool { return a[i].FreeMemory < a[j].FreeMemory } diff --git a/integration/basic_test.go b/integration/basic_test.go index 40bde03c..6e632a1c 100644 --- a/integration/basic_test.go +++ b/integration/basic_test.go @@ -4,11 +4,14 @@ package integration import ( "context" - "net/http" + "log/slog" + "os" + "runtime" "testing" "time" "github.com/ollama/ollama/api" + "github.com/stretchr/testify/require" ) func TestOrcaMiniBlueSky(t *testing.T) { @@ -24,5 +27,44 @@ func TestOrcaMiniBlueSky(t *testing.T) { "seed": 123, }, } - GenerateTestHelper(ctx, t, &http.Client{}, req, []string{"rayleigh", "scattering"}) + GenerateTestHelper(ctx, t, req, []string{"rayleigh", "scattering"}) +} + +func TestUnicodeModelDir(t *testing.T) { + // This is only useful for Windows with utf-16 characters, so skip this test for other platforms + if runtime.GOOS != "windows" { + t.Skip("Unicode test only applicable to windows") + } + // Only works for local testing + if os.Getenv("OLLAMA_TEST_EXISTING") != "" { + t.Skip("TestUnicodeModelDir only works for local testing, skipping") + } + + modelDir, err := os.MkdirTemp("", "ollama_埃") + require.NoError(t, err) + defer os.RemoveAll(modelDir) + slog.Info("unicode", "OLLAMA_MODELS", modelDir) + + oldModelsDir := os.Getenv("OLLAMA_MODELS") + if oldModelsDir == "" { + defer os.Unsetenv("OLLAMA_MODELS") + } else { + defer os.Setenv("OLLAMA_MODELS", oldModelsDir) + } + err = os.Setenv("OLLAMA_MODELS", modelDir) + require.NoError(t, err) + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) + defer cancel() + + req := api.GenerateRequest{ + Model: "orca-mini", + Prompt: "why is the sky blue?", + Stream: &stream, + Options: map[string]interface{}{ + "temperature": 0, + "seed": 123, + }, + } + GenerateTestHelper(ctx, t, req, []string{"rayleigh", "scattering"}) } diff --git a/integration/concurrency_test.go b/integration/concurrency_test.go new file mode 100644 index 00000000..110301ab --- /dev/null +++ b/integration/concurrency_test.go @@ -0,0 +1,225 @@ +//go:build integration + +package integration + +import ( + "context" + "log/slog" + "os" + "strconv" + "sync" + "testing" + "time" + + "github.com/ollama/ollama/api" + "github.com/stretchr/testify/require" +) + +func TestMultiModelConcurrency(t *testing.T) { + var ( + req = [2]api.GenerateRequest{ + { + Model: "orca-mini", + Prompt: "why is the ocean blue?", + Stream: &stream, + Options: map[string]interface{}{ + "seed": 42, + "temperature": 0.0, + }, + }, { + Model: "tinydolphin", + Prompt: "what is the origin of the us thanksgiving holiday?", + Stream: &stream, + Options: map[string]interface{}{ + "seed": 42, + "temperature": 0.0, + }, + }, + } + resp = [2][]string{ + []string{"sunlight"}, + []string{"england", "english", "massachusetts", "pilgrims"}, + } + ) + var wg sync.WaitGroup + wg.Add(len(req)) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*120) + defer cancel() + for i := 0; i < len(req); i++ { + go func(i int) { + defer wg.Done() + GenerateTestHelper(ctx, t, req[i], resp[i]) + }(i) + } + wg.Wait() +} + +func TestIntegrationConcurrentPredictOrcaMini(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute) // GTX 750 2G card takes ~9 minutes + defer cancel() + client, _, cleanup := InitServerConnection(ctx, t) + defer cleanup() + + req, resp := GenerateRequests() + // Get the server running (if applicable) warm the model up with a single initial request + DoGenerate(ctx, t, client, req[0], resp[0], 60*time.Second, 5*time.Second) + + var wg sync.WaitGroup + wg.Add(len(req)) + for i := 0; i < len(req); i++ { + go func(i int) { + defer wg.Done() + for j := 0; j < 5; j++ { + slog.Info("Starting", "req", i, "iter", j) + // On slower GPUs it can take a while to process the 4 concurrent requests + // so we allow a much longer initial timeout + DoGenerate(ctx, t, client, req[i], resp[i], 90*time.Second, 5*time.Second) + } + }(i) + } + wg.Wait() +} + +// Stress the system if we know how much VRAM it has, and attempt to load more models than will fit +func TestMultiModelStress(t *testing.T) { + vram := os.Getenv("OLLAMA_MAX_VRAM") + if vram == "" { + t.Skip("OLLAMA_MAX_VRAM not specified, can't pick the right models for the stress test") + } + max, err := strconv.ParseUint(vram, 10, 64) + require.NoError(t, err) + const MB = uint64(1024 * 1024) + type model struct { + name string + size uint64 // Approximate amount of VRAM they typically use when fully loaded in VRAM + } + + smallModels := []model{ + { + name: "orca-mini", + size: 2992 * MB, + }, + { + name: "phi", + size: 2616 * MB, + }, + { + name: "gemma:2b", + size: 2364 * MB, + }, + { + name: "stable-code:3b", + size: 2608 * MB, + }, + { + name: "starcoder2:3b", + size: 2166 * MB, + }, + } + mediumModels := []model{ + { + name: "llama2", + size: 5118 * MB, + }, + { + name: "mistral", + size: 4620 * MB, + }, + { + name: "orca-mini:7b", + size: 5118 * MB, + }, + { + name: "dolphin-mistral", + size: 4620 * MB, + }, + { + name: "gemma:7b", + size: 5000 * MB, + }, + // TODO - uncomment this once #3565 is merged and this is rebased on it + // { + // name: "codellama:7b", + // size: 5118 * MB, + // }, + } + + // These seem to be too slow to be useful... + // largeModels := []model{ + // { + // name: "llama2:13b", + // size: 7400 * MB, + // }, + // { + // name: "codellama:13b", + // size: 7400 * MB, + // }, + // { + // name: "orca-mini:13b", + // size: 7400 * MB, + // }, + // { + // name: "gemma:7b", + // size: 5000 * MB, + // }, + // { + // name: "starcoder2:15b", + // size: 9100 * MB, + // }, + // } + + var chosenModels []model + switch { + case max < 10000*MB: + slog.Info("selecting small models") + chosenModels = smallModels + // case max < 30000*MB: + default: + slog.Info("selecting medium models") + chosenModels = mediumModels + // default: + // slog.Info("selecting large models") + // chosenModels = largModels + } + + req, resp := GenerateRequests() + + for i := range req { + if i > len(chosenModels) { + break + } + req[i].Model = chosenModels[i].name + } + + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Minute) // TODO baseline -- 10m too short + defer cancel() + client, _, cleanup := InitServerConnection(ctx, t) + defer cleanup() + + // Make sure all the models are pulled before we get started + for _, r := range req { + require.NoError(t, PullIfMissing(ctx, client, r.Model)) + } + + var wg sync.WaitGroup + consumed := uint64(256 * MB) // Assume some baseline usage + for i := 0; i < len(req); i++ { + // Always get at least 2 models, but dont' overshoot VRAM too much or we'll take too long + if i > 1 && consumed > max { + slog.Info("achieved target vram exhaustion", "count", i, "vramMB", max/1024/1024, "modelsMB", consumed/1024/1024) + break + } + consumed += chosenModels[i].size + slog.Info("target vram", "count", i, "vramMB", max/1024/1024, "modelsMB", consumed/1024/1024) + + wg.Add(1) + go func(i int) { + defer wg.Done() + for j := 0; j < 3; j++ { + slog.Info("Starting", "req", i, "iter", j, "model", req[i].Model) + DoGenerate(ctx, t, client, req[i], resp[i], 90*time.Second, 5*time.Second) + } + }(i) + } + wg.Wait() +} diff --git a/integration/context_test.go b/integration/context_test.go index 80ea540b..08033125 100644 --- a/integration/context_test.go +++ b/integration/context_test.go @@ -4,7 +4,6 @@ package integration import ( "context" - "net/http" "testing" "time" @@ -25,5 +24,5 @@ func TestContextExhaustion(t *testing.T) { "num_ctx": 128, }, } - GenerateTestHelper(ctx, t, &http.Client{}, req, []string{"once", "upon", "lived"}) + GenerateTestHelper(ctx, t, req, []string{"once", "upon", "lived"}) } diff --git a/integration/llm_image_test.go b/integration/llm_image_test.go index 94082d6e..77319aef 100644 --- a/integration/llm_image_test.go +++ b/integration/llm_image_test.go @@ -5,7 +5,6 @@ package integration import ( "context" "encoding/base64" - "net/http" "testing" "time" @@ -29,10 +28,11 @@ func TestIntegrationMultimodal(t *testing.T) { }, } - resp := "the ollamas" + // Note: sometimes it returns "the ollamas" sometimes "the ollams" + resp := "the ollam" ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute) defer cancel() - GenerateTestHelper(ctx, t, &http.Client{}, req, []string{resp}) + GenerateTestHelper(ctx, t, req, []string{resp}) } const imageEncoding = `iVBORw0KGgoAAAANSUhEUgAAANIAAAB4CAYAAACHHqzKAAAAAXNSR0IArs4c6QAAAIRlWElmTU0AKgAAAAgABQESAAMAAAABAAEAAAEaAAUAAAABAAAASgEb diff --git a/integration/llm_test.go b/integration/llm_test.go index bcc169d6..4952b072 100644 --- a/integration/llm_test.go +++ b/integration/llm_test.go @@ -4,8 +4,6 @@ package integration import ( "context" - "net/http" - "sync" "testing" "time" @@ -45,25 +43,5 @@ var ( func TestIntegrationSimpleOrcaMini(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Second*120) defer cancel() - GenerateTestHelper(ctx, t, &http.Client{}, req[0], resp[0]) + GenerateTestHelper(ctx, t, req[0], resp[0]) } - -// TODO -// The server always loads a new runner and closes the old one, which forces serial execution -// At present this test case fails with concurrency problems. Eventually we should try to -// get true concurrency working with n_parallel support in the backend -func TestIntegrationConcurrentPredictOrcaMini(t *testing.T) { - var wg sync.WaitGroup - wg.Add(len(req)) - ctx, cancel := context.WithTimeout(context.Background(), time.Second*120) - defer cancel() - for i := 0; i < len(req); i++ { - go func(i int) { - defer wg.Done() - GenerateTestHelper(ctx, t, &http.Client{}, req[i], resp[i]) - }(i) - } - wg.Wait() -} - -// TODO - create a parallel test with 2 different models once we support concurrency diff --git a/integration/utils_test.go b/integration/utils_test.go index 0f712271..3e91187a 100644 --- a/integration/utils_test.go +++ b/integration/utils_test.go @@ -5,13 +5,14 @@ package integration import ( "bytes" "context" - "encoding/json" + "errors" "fmt" "io" "log/slog" "math/rand" "net" "net/http" + "net/url" "os" "path/filepath" "runtime" @@ -23,9 +24,13 @@ import ( "github.com/ollama/ollama/api" "github.com/ollama/ollama/app/lifecycle" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) +func Init() { + lifecycle.InitLogging() +} + func FindPort() string { port := 0 if a, err := net.ResolveTCPAddr("tcp", "localhost:0"); err == nil { @@ -41,7 +46,7 @@ func FindPort() string { return strconv.Itoa(port) } -func GetTestEndpoint() (string, string) { +func GetTestEndpoint() (*api.Client, string) { defaultPort := "11434" ollamaHost := os.Getenv("OLLAMA_HOST") @@ -67,16 +72,20 @@ func GetTestEndpoint() (string, string) { port = FindPort() } - url := fmt.Sprintf("%s:%s", host, port) - slog.Info("server connection", "url", url) - return scheme, url + slog.Info("server connection", "host", host, "port", port) + + return api.NewClient( + &url.URL{ + Scheme: scheme, + Host: net.JoinHostPort(host, port), + }, + http.DefaultClient), fmt.Sprintf("%s:%s", host, port) } -// TODO make fanicier, grab logs, etc. var serverMutex sync.Mutex var serverReady bool -func StartServer(ctx context.Context, ollamaHost string) error { +func startServer(ctx context.Context, ollamaHost string) error { // Make sure the server has been built CLIName, err := filepath.Abs("../ollama") if err != nil { @@ -125,67 +134,76 @@ func StartServer(ctx context.Context, ollamaHost string) error { return nil } -func PullIfMissing(ctx context.Context, client *http.Client, scheme, testEndpoint, modelName string) error { +func PullIfMissing(ctx context.Context, client *api.Client, modelName string) error { slog.Info("checking status of model", "model", modelName) showReq := &api.ShowRequest{Name: modelName} - requestJSON, err := json.Marshal(showReq) - if err != nil { - return err - } - req, err := http.NewRequest("POST", scheme+"://"+testEndpoint+"/api/show", bytes.NewReader(requestJSON)) - if err != nil { + showCtx, cancel := context.WithDeadlineCause( + ctx, + time.Now().Add(5*time.Second), + fmt.Errorf("show for existing model %s took too long", modelName), + ) + defer cancel() + _, err := client.Show(showCtx, showReq) + var statusError api.StatusError + switch { + case errors.As(err, &statusError) && statusError.StatusCode == http.StatusNotFound: + break + case err != nil: return err - } - - // Make the request with the HTTP client - response, err := client.Do(req.WithContext(ctx)) - if err != nil { - return err - } - defer response.Body.Close() - if response.StatusCode == 200 { + default: slog.Info("model already present", "model", modelName) return nil } - slog.Info("model missing", "status", response.StatusCode) + slog.Info("model missing", "model", modelName) + stallDuration := 30 * time.Second // This includes checksum verification, which can take a while on larger models + stallTimer := time.NewTimer(stallDuration) + fn := func(resp api.ProgressResponse) error { + // fmt.Print(".") + if !stallTimer.Reset(stallDuration) { + return fmt.Errorf("stall was detected, aborting status reporting") + } + return nil + } + + stream := true pullReq := &api.PullRequest{Name: modelName, Stream: &stream} - requestJSON, err = json.Marshal(pullReq) - if err != nil { - return err - } - req, err = http.NewRequest("POST", scheme+"://"+testEndpoint+"/api/pull", bytes.NewReader(requestJSON)) - if err != nil { - return err - } - slog.Info("pulling", "model", modelName) + var pullError error - response, err = client.Do(req.WithContext(ctx)) - if err != nil { - return err + done := make(chan int) + go func() { + pullError = client.Pull(ctx, pullReq, fn) + done <- 0 + }() + + select { + case <-stallTimer.C: + return fmt.Errorf("download stalled") + case <-done: + return pullError } - defer response.Body.Close() - if response.StatusCode != 200 { - return fmt.Errorf("failed to pull model") // TODO more details perhaps - } - slog.Info("model pulled", "model", modelName) - return nil } var serverProcMutex sync.Mutex -func GenerateTestHelper(ctx context.Context, t *testing.T, client *http.Client, genReq api.GenerateRequest, anyResp []string) { - - // TODO maybe stuff in an init routine? - lifecycle.InitLogging() - - requestJSON, err := json.Marshal(genReq) - if err != nil { - t.Fatalf("Error serializing request: %v", err) +// Returns an Client, the testEndpoint, and a cleanup function, fails the test on errors +// Starts the server if needed +func InitServerConnection(ctx context.Context, t *testing.T) (*api.Client, string, func()) { + client, testEndpoint := GetTestEndpoint() + if os.Getenv("OLLAMA_TEST_EXISTING") == "" { + serverProcMutex.Lock() + fp, err := os.CreateTemp("", "ollama-server-*.log") + if err != nil { + t.Fatalf("failed to generate log file: %s", err) + } + lifecycle.ServerLogFile = fp.Name() + fp.Close() + require.NoError(t, startServer(ctx, testEndpoint)) } - defer func() { + + return client, testEndpoint, func() { if os.Getenv("OLLAMA_TEST_EXISTING") == "" { defer serverProcMutex.Unlock() if t.Failed() { @@ -203,63 +221,118 @@ func GenerateTestHelper(ctx context.Context, t *testing.T, client *http.Client, os.Stderr.Write(data) slog.Warn("END OF SERVER") } - err = os.Remove(lifecycle.ServerLogFile) + err := os.Remove(lifecycle.ServerLogFile) if err != nil && !os.IsNotExist(err) { slog.Warn("failed to cleanup", "logfile", lifecycle.ServerLogFile, "error", err) } } - }() - scheme, testEndpoint := GetTestEndpoint() - - if os.Getenv("OLLAMA_TEST_EXISTING") == "" { - serverProcMutex.Lock() - fp, err := os.CreateTemp("", "ollama-server-*.log") - if err != nil { - t.Fatalf("failed to generate log file: %s", err) - } - lifecycle.ServerLogFile = fp.Name() - fp.Close() - assert.NoError(t, StartServer(ctx, testEndpoint)) } - - err = PullIfMissing(ctx, client, scheme, testEndpoint, genReq.Model) - if err != nil { - t.Fatalf("Error pulling model: %v", err) - } - - // Make the request and get the response - req, err := http.NewRequest("POST", scheme+"://"+testEndpoint+"/api/generate", bytes.NewReader(requestJSON)) - if err != nil { - t.Fatalf("Error creating request: %v", err) - } - - // Set the content type for the request - req.Header.Set("Content-Type", "application/json") - - // Make the request with the HTTP client - response, err := client.Do(req.WithContext(ctx)) - if err != nil { - t.Fatalf("Error making request: %v", err) - } - defer response.Body.Close() - body, err := io.ReadAll(response.Body) - assert.NoError(t, err) - assert.Equal(t, response.StatusCode, 200, string(body)) - - // Verify the response is valid JSON - var payload api.GenerateResponse - err = json.Unmarshal(body, &payload) - if err != nil { - assert.NoError(t, err, body) - } - - // Verify the response contains the expected data - atLeastOne := false - for _, resp := range anyResp { - if strings.Contains(strings.ToLower(payload.Response), resp) { - atLeastOne = true - break - } - } - assert.True(t, atLeastOne, "none of %v found in %s", anyResp, payload.Response) +} + +func GenerateTestHelper(ctx context.Context, t *testing.T, genReq api.GenerateRequest, anyResp []string) { + client, _, cleanup := InitServerConnection(ctx, t) + defer cleanup() + require.NoError(t, PullIfMissing(ctx, client, genReq.Model)) + DoGenerate(ctx, t, client, genReq, anyResp, 30*time.Second, 10*time.Second) +} + +func DoGenerate(ctx context.Context, t *testing.T, client *api.Client, genReq api.GenerateRequest, anyResp []string, initialTimeout, streamTimeout time.Duration) { + stallTimer := time.NewTimer(initialTimeout) + var buf bytes.Buffer + fn := func(response api.GenerateResponse) error { + // fmt.Print(".") + buf.Write([]byte(response.Response)) + if !stallTimer.Reset(streamTimeout) { + return fmt.Errorf("stall was detected while streaming response, aborting") + } + return nil + } + + stream := true + genReq.Stream = &stream + done := make(chan int) + var genErr error + go func() { + genErr = client.Generate(ctx, &genReq, fn) + done <- 0 + }() + + select { + case <-stallTimer.C: + if buf.Len() == 0 { + t.Errorf("generate never started. Timed out after :%s", initialTimeout.String()) + } else { + t.Errorf("generate stalled. Response so far:%s", buf.String()) + } + case <-done: + require.NoError(t, genErr, "failed with %s request prompt %s ", genReq.Model, genReq.Prompt) + // Verify the response contains the expected data + response := buf.String() + atLeastOne := false + for _, resp := range anyResp { + if strings.Contains(strings.ToLower(response), resp) { + atLeastOne = true + break + } + } + require.True(t, atLeastOne, "none of %v found in %s", anyResp, response) + slog.Info("test pass", "model", genReq.Model, "prompt", genReq.Prompt, "contains", anyResp, "response", response) + case <-ctx.Done(): + t.Error("outer test context done while waiting for generate") + } +} + +// Generate a set of requests +// By default each request uses orca-mini as the model +func GenerateRequests() ([]api.GenerateRequest, [][]string) { + return []api.GenerateRequest{ + { + Model: "orca-mini", + Prompt: "why is the ocean blue?", + Stream: &stream, + Options: map[string]interface{}{ + "seed": 42, + "temperature": 0.0, + }, + }, { + Model: "orca-mini", + Prompt: "why is the color of dirt brown?", + Stream: &stream, + Options: map[string]interface{}{ + "seed": 42, + "temperature": 0.0, + }, + }, { + Model: "orca-mini", + Prompt: "what is the origin of the us thanksgiving holiday?", + Stream: &stream, + Options: map[string]interface{}{ + "seed": 42, + "temperature": 0.0, + }, + }, { + Model: "orca-mini", + Prompt: "what is the origin of independence day?", + Stream: &stream, + Options: map[string]interface{}{ + "seed": 42, + "temperature": 0.0, + }, + }, { + Model: "orca-mini", + Prompt: "what is the composition of air?", + Stream: &stream, + Options: map[string]interface{}{ + "seed": 42, + "temperature": 0.0, + }, + }, + }, + [][]string{ + []string{"sunlight"}, + []string{"soil", "organic", "earth", "black", "tan"}, + []string{"england", "english", "massachusetts", "pilgrims"}, + []string{"fourth", "july", "declaration", "independence"}, + []string{"nitrogen", "oxygen", "carbon", "dioxide"}, + } } diff --git a/llm/memory.go b/llm/memory.go new file mode 100644 index 00000000..0dff54d3 --- /dev/null +++ b/llm/memory.go @@ -0,0 +1,162 @@ +package llm + +import ( + "fmt" + "log/slog" + "strings" + + "github.com/ollama/ollama/api" + "github.com/ollama/ollama/format" + "github.com/ollama/ollama/gpu" +) + +// This algorithm looks for a complete fit to determine if we need to unload other models +func PredictServerFit(allGpus gpu.GpuInfoList, ggml *GGML, adapters, projectors []string, opts api.Options) (bool, uint64) { + var estimatedVRAM uint64 + if opts.NumCtx > int(ggml.KV().ContextLength()) { + slog.Warn("requested context length is greater than model max context length", "requested", opts.NumCtx, "model", ggml.KV().ContextLength()) + opts.NumCtx = int(ggml.KV().ContextLength()) + } + + if opts.NumCtx < 4 { + opts.NumCtx = 4 + } + + // Split up the GPUs by type and try them + for _, gpus := range allGpus.ByLibrary() { + var layerCount int + layerCount, estimatedVRAM = EstimateGPULayers(gpus, ggml, projectors, opts) + if opts.NumGPU < 0 { + if layerCount > 0 && layerCount >= int(ggml.KV().BlockCount()+1) { + return true, estimatedVRAM + } + } else { + if layerCount > 0 && layerCount >= opts.NumGPU { + return true, estimatedVRAM + } + } + } + return false, estimatedVRAM +} + +// Given a model and one or more GPU targets, predict how many layers and bytes we can load +// The GPUs provided must all be the same Library +func EstimateGPULayers(gpus []gpu.GpuInfo, ggml *GGML, projectors []string, opts api.Options) (int, uint64) { + if gpus[0].Library == "cpu" { + return 0, 0 + } + var memoryAvailable uint64 + for _, info := range gpus { + memoryAvailable += info.FreeMemory + } + slog.Debug("evaluating", "library", gpus[0].Library, "gpu_count", len(gpus), "available", format.HumanBytes2(memoryAvailable)) + + // TODO - this is probably wrong, first GPU vs secondaries will have different overheads + memoryMinimum := gpus[0].MinimumMemory + + for _, projector := range projectors { + memoryMinimum += projectorMemoryRequirements(projector) + + // multimodal models require at least 2048 context + opts.NumCtx = max(opts.NumCtx, 2048) + } + + // fp16 k,v = (1 (k) + 1 (v)) * sizeof(float16) * n_ctx * n_layer * n_embd / n_head * n_head_kv + var kv uint64 = 2 * 2 * uint64(opts.NumCtx) * ggml.KV().BlockCount() * ggml.KV().EmbeddingLength() / ggml.KV().HeadCount() * ggml.KV().HeadCountKV() + + graphPartialOffload, graphFullOffload := ggml.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch))) + if graphPartialOffload == 0 { + graphPartialOffload = ggml.KV().GQA() * kv / 6 + } + + if graphFullOffload == 0 { + graphFullOffload = graphPartialOffload + } + + graphFullOffload *= uint64(len(gpus)) + graphPartialOffload *= uint64(len(gpus)) + + // memoryRequiredTotal represents the memory required for full GPU offloading (all layers) + memoryRequiredTotal := memoryMinimum + graphFullOffload + + // memoryRequiredPartial represents the memory required for partial GPU offloading (n > 0, n < layers) + memoryRequiredPartial := memoryMinimum + graphPartialOffload + + if memoryRequiredPartial > memoryAvailable { + slog.Debug("insufficient VRAM to load any model layers") + return 0, 0 + } + + var layerCount int + layers := ggml.Tensors().Layers() + for i := 0; i < int(ggml.KV().BlockCount()); i++ { + memoryLayer := layers[fmt.Sprintf("blk.%d", i)].size() + + // KV is proportional to the number of layers + memoryLayer += kv / ggml.KV().BlockCount() + + memoryRequiredTotal += memoryLayer + if memoryAvailable > memoryRequiredPartial+memoryLayer { + memoryRequiredPartial += memoryLayer + layerCount++ + } + } + + var memoryLayerOutput uint64 + for k, v := range layers { + if !strings.HasPrefix(k, "blk.") { + memoryLayerOutput += v.size() + } + } + + memoryRequiredTotal += memoryLayerOutput + + if memoryAvailable > memoryRequiredTotal { + layerCount = int(ggml.KV().BlockCount()) + 1 + memoryRequiredPartial = memoryRequiredTotal + } + + memoryWeights := memoryRequiredTotal - memoryMinimum - graphFullOffload - kv + + slog.Info( + "offload to gpu", + slog.Group( + "layers", + // actual number of layers offloaded + "real", opts.NumGPU, + // estimated number of layers that can be offloaded + "estimate", layerCount, + ), + slog.Group( + "memory", + // memory available for offloading + "available", format.HumanBytes2(memoryAvailable), + slog.Group( + "required", + // memory required for full offloading + "full", format.HumanBytes2(memoryRequiredTotal), + // memory required to offload layers.estimate layers + "partial", format.HumanBytes2(memoryRequiredPartial), + // memory of KV cache + "kv", format.HumanBytes2(kv), + ), + slog.Group( + "weights", + // memory of the weights + "total", format.HumanBytes2(memoryWeights), + // memory of repeating layers + "repeating", format.HumanBytes2(memoryWeights-memoryLayerOutput), + // memory of non-repeating layers + "nonrepeating", format.HumanBytes2(memoryLayerOutput), + ), + slog.Group( + "graph", + // memory of graph when fully offloaded + "full", format.HumanBytes2(graphFullOffload), + // memory of graph when not fully offloaded + "partial", format.HumanBytes2(graphPartialOffload), + ), + ), + ) + return layerCount, uint64(memoryRequiredPartial) +} diff --git a/llm/payload.go b/llm/payload.go index 46713c43..c81c2784 100644 --- a/llm/payload.go +++ b/llm/payload.go @@ -9,6 +9,7 @@ import ( "log/slog" "os" "path/filepath" + "runtime" "strings" "golang.org/x/exp/slices" @@ -138,6 +139,23 @@ func serversForGpu(info gpu.GpuInfo) []string { return servers } +// Return the optimal server for this CPU architecture +func serverForCpu() string { + if runtime.GOOS == "darwin" && runtime.GOARCH == "arm64" { + return "metal" + } + variant := gpu.GetCPUVariant() + availableServers := availableServers() + if variant != "" { + for cmp := range availableServers { + if cmp == "cpu_"+variant { + return cmp + } + } + } + return "cpu" +} + // extract extracts the embedded files to the target directory func extractFiles(targetDir string, glob string) error { files, err := fs.Glob(libEmbed, glob) diff --git a/llm/server.go b/llm/server.go index fcae39f3..01a712c3 100644 --- a/llm/server.go +++ b/llm/server.go @@ -21,21 +21,43 @@ import ( "strings" "time" + "golang.org/x/sync/semaphore" + "github.com/ollama/ollama/api" "github.com/ollama/ollama/format" "github.com/ollama/ollama/gpu" ) -// LlamaServer is an instance of the llama.cpp server -type LlamaServer struct { +type LlamaServer interface { + Ping(ctx context.Context) error + WaitUntilRunning(ctx context.Context) error + Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error + Embedding(ctx context.Context, prompt string) ([]float64, error) + Tokenize(ctx context.Context, content string) ([]int, error) + Detokenize(ctx context.Context, tokens []int) (string, error) + Close() error + EstimatedVRAM() uint64 +} + +// llmServer is an instance of the llama.cpp server +type llmServer struct { port int cmd *exec.Cmd done chan error // Channel to signal when the process exits status *StatusWriter options api.Options + + // TODO - this should be broken down by GPU + estimatedVRAM uint64 // Estimated usage of VRAM by the loaded model + + sem *semaphore.Weighted } -func NewLlamaServer(model string, adapters, projectors []string, opts api.Options) (*LlamaServer, error) { +func LoadModel(model string) (*GGML, error) { + if _, err := os.Stat(model); err != nil { + return nil, err + } + f, err := os.Open(model) if err != nil { return nil, err @@ -43,10 +65,13 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option defer f.Close() ggml, _, err := DecodeGGML(f) - if err != nil { - return nil, err - } + return ggml, err +} +// NewLlamaServer will run a server for the given GPUs +// The gpu list must be a single family. +func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, projectors []string, opts api.Options) (LlamaServer, error) { + var err error if opts.NumCtx > int(ggml.KV().ContextLength()) { slog.Warn("requested context length is greater than model max context length", "requested", opts.NumCtx, "model", ggml.KV().ContextLength()) opts.NumCtx = int(ggml.KV().ContextLength()) @@ -56,130 +81,50 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option opts.NumCtx = 4 } - memoryAvailable, _ := gpu.CheckVRAM() - info := gpu.GetGPUInfo() + cpuRunner := "" + var estimatedVRAM uint64 + var systemMemory uint64 + if (len(gpus) == 1 && gpus[0].Library == "cpu") || opts.NumGPU == 0 { - memoryMinimum := info.MinimumMemory - for _, projector := range projectors { - memoryMinimum += projectorMemoryRequirements(projector) + // TODO evaluate system memory to see if we should block the load, or force an unload of another CPU runner - // multimodal models require at least 2048 context - opts.NumCtx = max(opts.NumCtx, 2048) - } + cpuRunner = serverForCpu() + } else { + if gpus[0].Library == "metal" { + memInfo, err := gpu.GetCPUMem() + if err != nil { + slog.Error("failed to lookup system memory", "error", err) + } else { + systemMemory = memInfo.TotalMemory + slog.Debug("system memory", "total", format.HumanBytes2(systemMemory)) + } + } + var layers int + layers, estimatedVRAM = EstimateGPULayers(gpus, ggml, projectors, opts) - // fp16 k,v = (1 (k) + 1 (v)) * sizeof(float16) * n_ctx * n_layer * n_embd / n_head * n_head_kv - var kv uint64 = 2 * 2 * uint64(opts.NumCtx) * ggml.KV().BlockCount() * ggml.KV().EmbeddingLength() / ggml.KV().HeadCount() * ggml.KV().HeadCountKV() - - graphPartialOffload, graphFullOffload := ggml.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch))) - if graphPartialOffload == 0 { - graphPartialOffload = ggml.KV().GQA() * kv / 6 - } - - if graphFullOffload == 0 { - graphFullOffload = graphPartialOffload - } - - graphFullOffload *= uint64(info.DeviceCount) - graphPartialOffload *= uint64(info.DeviceCount) - - // memoryRequiredTotal represents the memory required for full GPU offloading (all layers) - memoryRequiredTotal := memoryMinimum + graphFullOffload - - // memoryRequiredPartial represents the memory required for partial GPU offloading (n > 0, n < layers) - memoryRequiredPartial := memoryMinimum + graphPartialOffload - - if info.Library != "metal" { - if memoryRequiredPartial > memoryAvailable { - info.Library = "cpu" + if gpus[0].Library == "metal" && estimatedVRAM > systemMemory { + // disable partial offloading when model is greater than total system memory as this + // can lead to locking up the system + opts.NumGPU = 0 + } else if opts.NumGPU < 0 && layers > 0 && gpus[0].Library != "cpu" { + opts.NumGPU = layers } } - var layerCount int - layers := ggml.Tensors().Layers() - for i := 0; i < int(ggml.KV().BlockCount()); i++ { - memoryLayer := layers[fmt.Sprintf("blk.%d", i)].size() - - // KV is proportional to the number of layers - memoryLayer += kv / ggml.KV().BlockCount() - - memoryRequiredTotal += memoryLayer - if memoryAvailable > memoryRequiredPartial+memoryLayer { - memoryRequiredPartial += memoryLayer - layerCount++ - } - } - - var memoryLayerOutput uint64 - for k, v := range layers { - if !strings.HasPrefix(k, "blk.") { - memoryLayerOutput += v.size() - } - } - - memoryRequiredTotal += memoryLayerOutput - - if info.Library == "metal" && memoryRequiredTotal > info.TotalMemory { - // disable partial offloading when model is greater than total system memory - opts.NumGPU = 0 - } else if memoryAvailable > memoryRequiredTotal { - layerCount = int(ggml.KV().BlockCount()) + 1 - memoryRequiredPartial = memoryRequiredTotal - } - - if opts.NumGPU < 0 { - opts.NumGPU = layerCount - } - - memoryWeights := memoryRequiredTotal - memoryMinimum - graphFullOffload - kv - - slog.Info( - "offload to gpu", - slog.Group( - "layers", - // actual number of layers offloaded - "real", opts.NumGPU, - // estimated number of layers that can be offloaded - "estimate", layerCount, - ), - slog.Group( - "memory", - // memory available for offloading - "available", format.HumanBytes2(memoryAvailable), - slog.Group( - "required", - // memory required for full offloading - "full", format.HumanBytes2(memoryRequiredTotal), - // memory required to offload layers.estimate layers - "partial", format.HumanBytes2(memoryRequiredPartial), - // memory of KV cache - "kv", format.HumanBytes2(kv), - ), - slog.Group( - "weights", - // memory of the weights - "total", format.HumanBytes2(memoryWeights), - // memory of repeating layers - "repeating", format.HumanBytes2(memoryWeights-memoryLayerOutput), - // memory of non-repeating layers - "nonrepeating", format.HumanBytes2(memoryLayerOutput), - ), - slog.Group( - "graph", - // memory of graph when fully offloaded - "full", format.HumanBytes2(graphFullOffload), - // memory of graph when not fully offloaded - "partial", format.HumanBytes2(graphPartialOffload), - ), - ), - ) + // Loop through potential servers + finalErr := fmt.Errorf("no suitable llama servers found") if len(adapters) > 1 { return nil, errors.New("ollama supports only one lora adapter, but multiple were provided") } availableServers := availableServers() - servers := serversForGpu(info) - + var servers []string + if cpuRunner != "" { + servers = []string{cpuRunner} + } else { + servers = serversForGpu(gpus[0]) // All GPUs in the list are matching Library and Variant + } demandLib := strings.Trim(os.Getenv("OLLAMA_LLM_LIBRARY"), "\"' ") if demandLib != "" { serverPath := availableServers[demandLib] @@ -192,7 +137,7 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option } if len(servers) == 0 { - return nil, fmt.Errorf("no servers found for %v", info) + return nil, fmt.Errorf("no servers found for %v", gpus) } params := []string{ @@ -249,8 +194,18 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option params = append(params, "--numa") } - // Loop through potential servers - var finalErr error + // "--cont-batching", // TODO - doesn't seem to have any noticeable perf change for multiple requests + numParallel := 1 + if onp := os.Getenv("OLLAMA_NUM_PARALLEL"); onp != "" { + numParallel, err = strconv.Atoi(onp) + if err != nil || numParallel <= 0 { + err = fmt.Errorf("invalid OLLAMA_NUM_PARALLEL=%s must be greater than zero - %w", onp, err) + slog.Error("misconfiguration", "error", err) + return nil, err + } + } + params = append(params, "--parallel", fmt.Sprintf("%d", numParallel)) + for i := 0; i < len(servers); i++ { dir := availableServers[servers[i]] @@ -275,30 +230,49 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option } // append the server directory to LD_LIBRARY_PATH/PATH libraryPaths := []string{dir} + if libraryPath, ok := os.LookupEnv(pathEnv); ok { // Append our runner directory to the path // This will favor system libraries over our bundled library dependencies libraryPaths = append(filepath.SplitList(libraryPath), libraryPaths...) } + // Note: we always put the dependency path first + // since this was the exact version we verified for AMD GPUs + // and we favor what the user had in their path + if gpus[0].DependencyPath != "" { + // TODO refine for multi-gpu support + libraryPaths = append([]string{gpus[0].DependencyPath}, libraryPaths...) + } + server := filepath.Join(dir, "ollama_llama_server") if runtime.GOOS == "windows" { server = server + ".exe" } - s := &LlamaServer{ - port: port, - cmd: exec.Command(server, finalParams...), - status: NewStatusWriter(os.Stderr), - options: opts, + s := &llmServer{ + port: port, + cmd: exec.Command(server, finalParams...), + status: NewStatusWriter(os.Stderr), + options: opts, + estimatedVRAM: estimatedVRAM, + sem: semaphore.NewWeighted(int64(numParallel)), } + libEnv := fmt.Sprintf("%s=%s", pathEnv, strings.Join(libraryPaths, string(filepath.ListSeparator))) - slog.Debug(libEnv) s.cmd.Env = append(os.Environ(), libEnv) s.cmd.Stdout = os.Stdout s.cmd.Stderr = s.status + // TODO - multiple GPU selection logic... + key, val := gpu.GpuInfoList(gpus).GetVisibleDevicesEnv() + if key != "" { + s.cmd.Env = append(s.cmd.Env, key+"="+val) + } + slog.Info("starting llama server", "cmd", s.cmd.String()) + // Log at debug as the environment is inherited and might contain sensitive information + slog.Debug("subprocess", "environment", s.cmd.Env) if err = s.cmd.Start(); err != nil { msg := "" @@ -316,6 +290,13 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option _ = s.cmd.Wait() }() + // TODO - make sure this is all wired up correctly + // if err = s.WaitUntilRunning(); err != nil { + // slog.Error("error starting llama server", "server", servers[i], "error", err) + // s.Close() + // finalErr = err + // continue + // } return s, nil } @@ -353,6 +334,21 @@ const ( // iota is reset to 0 ServerStatusError ) +func (s ServerStatus) ToString() string { + switch s { + case ServerStatusReady: + return "llm server ready" + case ServerStatusNoSlotsAvaialble: + return "llm busy - no slots available" + case ServerStatusLoadingModel: + return "llm server loading model" + case ServerStatusNotResponding: + return "llm server not responding" + default: + return "llm server error" + } +} + type ServerStatusResp struct { Status string `json:"status"` SlotsIdle int `json:"slots_idle"` @@ -360,7 +356,7 @@ type ServerStatusResp struct { Error string `json:"error"` } -func (s *LlamaServer) getServerStatus(ctx context.Context) (ServerStatus, error) { +func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) { // Fail fast if its exited if s.cmd.ProcessState != nil { msg := "" @@ -407,7 +403,7 @@ func (s *LlamaServer) getServerStatus(ctx context.Context) (ServerStatus, error) } } -func (s *LlamaServer) Ping(ctx context.Context) error { +func (s *llmServer) Ping(ctx context.Context) error { _, err := s.getServerStatus(ctx) if err != nil { slog.Debug("server unhealthy", "error", err) @@ -416,7 +412,7 @@ func (s *LlamaServer) Ping(ctx context.Context) error { return nil } -func (s *LlamaServer) WaitUntilRunning() error { +func (s *llmServer) WaitUntilRunning(ctx context.Context) error { start := time.Now() // TODO we need to wire up a better way to detect hangs during model load and startup of the server expiresAt := time.Now().Add(10 * time.Minute) // be generous with timeout, large models can take a while to load @@ -427,6 +423,9 @@ func (s *LlamaServer) WaitUntilRunning() error { var lastStatus ServerStatus = -1 for { select { + case <-ctx.Done(): + slog.Info("context expired before server started") + return fmt.Errorf("timed out waiting for llama runner to start") case err := <-s.done: msg := "" if s.status != nil && s.status.LastErrMsg != "" { @@ -450,9 +449,9 @@ func (s *LlamaServer) WaitUntilRunning() error { return fmt.Errorf("llama runner process no longer running: %d %s", s.cmd.ProcessState.ExitCode(), msg) } - ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) + c, cancel := context.WithTimeout(ctx, 200*time.Millisecond) defer cancel() - status, err := s.getServerStatus(ctx) + status, err := s.getServerStatus(c) if err != nil && lastStatus != status { slog.Debug("server not yet available", "error", err) lastStatus = status @@ -538,7 +537,12 @@ type CompletionResponse struct { EvalDuration time.Duration } -func (s *LlamaServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error { +func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error { + if err := s.sem.Acquire(ctx, 1); err != nil { + slog.Error("Failed to acquire semaphore", "error", err) + return err + } + defer s.sem.Release(1) request := map[string]any{ "prompt": req.Prompt, "stream": true, @@ -569,7 +573,7 @@ func (s *LlamaServer) Completion(ctx context.Context, req CompletionRequest, fn if err != nil { return err } else if status != ServerStatusReady { - return fmt.Errorf("unexpected server status: %d", status) + return fmt.Errorf("unexpected server status: %s", status.ToString()) } if req.Format == "json" { @@ -716,13 +720,18 @@ type EmbeddingResponse struct { Embedding []float64 `json:"embedding"` } -func (s *LlamaServer) Embedding(ctx context.Context, prompt string) ([]float64, error) { +func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, error) { + if err := s.sem.Acquire(ctx, 1); err != nil { + slog.Error("Failed to acquire semaphore", "error", err) + return nil, err + } + defer s.sem.Release(1) // Make sure the server is ready status, err := s.getServerStatus(ctx) if err != nil { return nil, err } else if status != ServerStatusReady { - return nil, fmt.Errorf("unexpected server status: %d", status) + return nil, fmt.Errorf("unexpected server status: %s", status.ToString()) } data, err := json.Marshal(TokenizeRequest{Content: prompt}) @@ -768,13 +777,13 @@ type TokenizeResponse struct { Tokens []int `json:"tokens"` } -func (s *LlamaServer) Tokenize(ctx context.Context, content string) ([]int, error) { +func (s *llmServer) Tokenize(ctx context.Context, content string) ([]int, error) { // Make sure the server is ready status, err := s.getServerStatus(ctx) if err != nil { return nil, err - } else if status != ServerStatusReady { - return nil, fmt.Errorf("unexpected server status: %d", status) + } else if status != ServerStatusReady && status != ServerStatusNoSlotsAvaialble { + return nil, fmt.Errorf("unexpected server status: %s", status.ToString()) } data, err := json.Marshal(TokenizeRequest{Content: content}) @@ -820,13 +829,13 @@ type DetokenizeResponse struct { Content string `json:"content"` } -func (s *LlamaServer) Detokenize(ctx context.Context, tokens []int) (string, error) { +func (s *llmServer) Detokenize(ctx context.Context, tokens []int) (string, error) { // Make sure the server is ready status, err := s.getServerStatus(ctx) if err != nil { return "", err - } else if status != ServerStatusReady { - return "", fmt.Errorf("unexpected server status: %d", status) + } else if status != ServerStatusReady && status != ServerStatusNoSlotsAvaialble { + return "", fmt.Errorf("unexpected server status: %s", status.ToString()) } data, err := json.Marshal(DetokenizeRequest{Tokens: tokens}) @@ -864,7 +873,7 @@ func (s *LlamaServer) Detokenize(ctx context.Context, tokens []int) (string, err return decoded.Content, nil } -func (s *LlamaServer) Close() error { +func (s *llmServer) Close() error { if s.cmd != nil { slog.Debug("stopping llama server") return s.cmd.Process.Kill() @@ -873,6 +882,10 @@ func (s *LlamaServer) Close() error { return nil } +func (s *llmServer) EstimatedVRAM() uint64 { + return s.estimatedVRAM +} + func parseDurationMs(ms float64) time.Duration { dur, err := time.ParseDuration(fmt.Sprintf("%fms", ms)) if err != nil { diff --git a/server/routes.go b/server/routes.go index b0d36b14..016deb34 100644 --- a/server/routes.go +++ b/server/routes.go @@ -15,11 +15,8 @@ import ( "os" "os/signal" "path/filepath" - "reflect" - "runtime" "strconv" "strings" - "sync" "syscall" "time" @@ -38,7 +35,8 @@ import ( var mode string = gin.DebugMode type Server struct { - addr net.Addr + addr net.Addr + sched *Scheduler } func init() { @@ -53,88 +51,8 @@ func init() { gin.SetMode(mode) } -var loaded struct { - mu sync.Mutex - - llama *llm.LlamaServer - - expireTimer *time.Timer - - model string - adapters []string - projectors []string - *api.Options -} - var defaultSessionDuration = 5 * time.Minute -func unload() { - if loaded.llama != nil { - loaded.llama.Close() - } - - loaded.llama = nil - loaded.model = "" - loaded.adapters = nil - loaded.projectors = nil - loaded.Options = nil -} - -// load a model into memory if it is not already loaded, it is up to the caller to lock loaded.mu before calling this function -func load(c *gin.Context, model *Model, opts api.Options, sessionDuration time.Duration) error { - ctx, cancel := context.WithTimeout(c, 10*time.Second) - defer cancel() - - needLoad := loaded.llama == nil || // is there a model loaded? - loaded.model != model.ModelPath || // has the base model changed? - !reflect.DeepEqual(loaded.adapters, model.AdapterPaths) || // have the adapters changed? - !reflect.DeepEqual(loaded.projectors, model.ProjectorPaths) || // have the adapters changed? - !reflect.DeepEqual(loaded.Options.Runner, opts.Runner) || // have the runner options changed? - loaded.llama.Ping(ctx) != nil - - if needLoad { - if loaded.llama != nil { - slog.Info("changing loaded model") - unload() - } - - llama, err := llm.NewLlamaServer(model.ModelPath, model.AdapterPaths, model.ProjectorPaths, opts) - if err != nil { - // some older models are not compatible with newer versions of llama.cpp - // show a generalized compatibility error until there is a better way to - // check for model compatibility - if errors.Is(llm.ErrUnsupportedFormat, err) || strings.Contains(err.Error(), "failed to load model") { - err = fmt.Errorf("%v: this model may be incompatible with your version of Ollama. If you previously pulled this model, try updating it by running `ollama pull %s`", err, model.ShortName) - } - - return err - } - - loaded.model = model.ModelPath - loaded.adapters = model.AdapterPaths - loaded.projectors = model.ProjectorPaths - loaded.llama = llama - loaded.Options = &opts - - if err = llama.WaitUntilRunning(); err != nil { - slog.Error("error loading llama server", "error", err) - unload() - return err - } - } - - if loaded.expireTimer == nil { - loaded.expireTimer = time.AfterFunc(sessionDuration, func() { - loaded.mu.Lock() - defer loaded.mu.Unlock() - unload() - }) - } - - loaded.expireTimer.Reset(sessionDuration) - return nil -} - func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) { opts := api.DefaultOptions() if err := opts.FromMap(model.Options); err != nil { @@ -154,9 +72,7 @@ func isSupportedImageType(image []byte) bool { return slices.Contains(allowedTypes, contentType) } -func GenerateHandler(c *gin.Context) { - loaded.mu.Lock() - defer loaded.mu.Unlock() +func (s *Server) GenerateHandler(c *gin.Context) { checkpointStart := time.Now() var req api.GenerateRequest @@ -224,7 +140,11 @@ func GenerateHandler(c *gin.Context) { sessionDuration = req.KeepAlive.Duration } - if err := load(c, model, opts, sessionDuration); err != nil { + rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration) + var runner *runnerRef + select { + case runner = <-rCh: + case err = <-eCh: c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } @@ -275,7 +195,7 @@ func GenerateHandler(c *gin.Context) { sb.Reset() if req.Context != nil { - prev, err := loaded.llama.Detokenize(c.Request.Context(), req.Context) + prev, err := runner.llama.Detokenize(c.Request.Context(), req.Context) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return @@ -297,9 +217,6 @@ func GenerateHandler(c *gin.Context) { defer close(ch) fn := func(r llm.CompletionResponse) { - // Update model expiration - loaded.expireTimer.Reset(sessionDuration) - // Build up the full response if _, err := generated.WriteString(r.Content); err != nil { ch <- gin.H{"error": err.Error()} @@ -331,7 +248,7 @@ func GenerateHandler(c *gin.Context) { } // TODO (jmorganca): encode() should not strip special tokens - tokens, err := loaded.llama.Tokenize(c.Request.Context(), p) + tokens, err := runner.llama.Tokenize(c.Request.Context(), p) if err != nil { ch <- gin.H{"error": err.Error()} return @@ -359,7 +276,7 @@ func GenerateHandler(c *gin.Context) { Images: images, Options: opts, } - if err := loaded.llama.Completion(c.Request.Context(), req, fn); err != nil { + if err := runner.llama.Completion(c.Request.Context(), req, fn); err != nil { ch <- gin.H{"error": err.Error()} } }() @@ -421,10 +338,7 @@ func getDefaultSessionDuration() time.Duration { return defaultSessionDuration } -func EmbeddingsHandler(c *gin.Context) { - loaded.mu.Lock() - defer loaded.mu.Unlock() - +func (s *Server) EmbeddingsHandler(c *gin.Context) { var req api.EmbeddingRequest err := c.ShouldBindJSON(&req) switch { @@ -469,7 +383,11 @@ func EmbeddingsHandler(c *gin.Context) { sessionDuration = req.KeepAlive.Duration } - if err := load(c, model, opts, sessionDuration); err != nil { + rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration) + var runner *runnerRef + select { + case runner = <-rCh: + case err = <-eCh: c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } @@ -480,7 +398,7 @@ func EmbeddingsHandler(c *gin.Context) { return } - embedding, err := loaded.llama.Embedding(c.Request.Context(), req.Prompt) + embedding, err := runner.llama.Embedding(c.Request.Context(), req.Prompt) if err != nil { slog.Info(fmt.Sprintf("embedding generation failed: %v", err)) c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"}) @@ -493,7 +411,7 @@ func EmbeddingsHandler(c *gin.Context) { c.JSON(http.StatusOK, resp) } -func PullModelHandler(c *gin.Context) { +func (s *Server) PullModelHandler(c *gin.Context) { var req api.PullRequest err := c.ShouldBindJSON(&req) switch { @@ -542,7 +460,7 @@ func PullModelHandler(c *gin.Context) { streamResponse(c, ch) } -func PushModelHandler(c *gin.Context) { +func (s *Server) PushModelHandler(c *gin.Context) { var req api.PushRequest err := c.ShouldBindJSON(&req) switch { @@ -591,7 +509,7 @@ func PushModelHandler(c *gin.Context) { streamResponse(c, ch) } -func CreateModelHandler(c *gin.Context) { +func (s *Server) CreateModelHandler(c *gin.Context) { var req api.CreateRequest err := c.ShouldBindJSON(&req) switch { @@ -664,7 +582,7 @@ func CreateModelHandler(c *gin.Context) { streamResponse(c, ch) } -func DeleteModelHandler(c *gin.Context) { +func (s *Server) DeleteModelHandler(c *gin.Context) { var req api.DeleteRequest err := c.ShouldBindJSON(&req) switch { @@ -709,7 +627,7 @@ func DeleteModelHandler(c *gin.Context) { c.JSON(http.StatusOK, nil) } -func ShowModelHandler(c *gin.Context) { +func (s *Server) ShowModelHandler(c *gin.Context) { var req api.ShowRequest err := c.ShouldBindJSON(&req) switch { @@ -809,7 +727,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) { return resp, nil } -func ListModelsHandler(c *gin.Context) { +func (s *Server) ListModelsHandler(c *gin.Context) { models := make([]api.ModelResponse, 0) manifestsPath, err := GetManifestPath() if err != nil { @@ -869,7 +787,7 @@ func ListModelsHandler(c *gin.Context) { c.JSON(http.StatusOK, api.ListResponse{Models: models}) } -func CopyModelHandler(c *gin.Context) { +func (s *Server) CopyModelHandler(c *gin.Context) { var req api.CopyRequest err := c.ShouldBindJSON(&req) switch { @@ -901,7 +819,7 @@ func CopyModelHandler(c *gin.Context) { } } -func HeadBlobHandler(c *gin.Context) { +func (s *Server) HeadBlobHandler(c *gin.Context) { path, err := GetBlobsPath(c.Param("digest")) if err != nil { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) @@ -916,7 +834,7 @@ func HeadBlobHandler(c *gin.Context) { c.Status(http.StatusOK) } -func CreateBlobHandler(c *gin.Context) { +func (s *Server) CreateBlobHandler(c *gin.Context) { path, err := GetBlobsPath(c.Param("digest")) if err != nil { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) @@ -1063,27 +981,27 @@ func (s *Server) GenerateRoutes() http.Handler { allowedHostsMiddleware(s.addr), ) - r.POST("/api/pull", PullModelHandler) - r.POST("/api/generate", GenerateHandler) - r.POST("/api/chat", ChatHandler) - r.POST("/api/embeddings", EmbeddingsHandler) - r.POST("/api/create", CreateModelHandler) - r.POST("/api/push", PushModelHandler) - r.POST("/api/copy", CopyModelHandler) - r.DELETE("/api/delete", DeleteModelHandler) - r.POST("/api/show", ShowModelHandler) - r.POST("/api/blobs/:digest", CreateBlobHandler) - r.HEAD("/api/blobs/:digest", HeadBlobHandler) + r.POST("/api/pull", s.PullModelHandler) + r.POST("/api/generate", s.GenerateHandler) + r.POST("/api/chat", s.ChatHandler) + r.POST("/api/embeddings", s.EmbeddingsHandler) + r.POST("/api/create", s.CreateModelHandler) + r.POST("/api/push", s.PushModelHandler) + r.POST("/api/copy", s.CopyModelHandler) + r.DELETE("/api/delete", s.DeleteModelHandler) + r.POST("/api/show", s.ShowModelHandler) + r.POST("/api/blobs/:digest", s.CreateBlobHandler) + r.HEAD("/api/blobs/:digest", s.HeadBlobHandler) // Compatibility endpoints - r.POST("/v1/chat/completions", openai.Middleware(), ChatHandler) + r.POST("/v1/chat/completions", openai.Middleware(), s.ChatHandler) for _, method := range []string{http.MethodGet, http.MethodHead} { r.Handle(method, "/", func(c *gin.Context) { c.String(http.StatusOK, "Ollama is running") }) - r.Handle(method, "/api/tags", ListModelsHandler) + r.Handle(method, "/api/tags", s.ListModelsHandler) r.Handle(method, "/api/version", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"version": version.Version}) }) @@ -1137,7 +1055,9 @@ func Serve(ln net.Listener) error { } } - s := &Server{addr: ln.Addr()} + ctx, done := context.WithCancel(context.Background()) + sched := InitScheduler(ctx) + s := &Server{addr: ln.Addr(), sched: sched} r := s.GenerateRoutes() slog.Info(fmt.Sprintf("Listening on %s (version %s)", ln.Addr(), version.Version)) @@ -1150,7 +1070,8 @@ func Serve(ln net.Listener) error { signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM) go func() { <-signals - unload() + done() + sched.unloadAllRunners() gpu.Cleanup() os.Exit(0) }() @@ -1158,12 +1079,12 @@ func Serve(ln net.Listener) error { if err := llm.Init(); err != nil { return fmt.Errorf("unable to initialize llm library %w", err) } - if runtime.GOOS == "linux" { // TODO - windows too - // check compatibility to log warnings - if _, err := gpu.CheckVRAM(); err != nil { - slog.Info(err.Error()) - } - } + + s.sched.Run(ctx) + + // At startup we retrieve GPU information so we can get log messages before loading a model + // This will log warnings to the log in case we have problems with detected GPUs + _ = gpu.GetGPUInfo() return srvr.Serve(ln) } @@ -1219,9 +1140,9 @@ func streamResponse(c *gin.Context, ch chan any) { } // ChatPrompt builds up a prompt from a series of messages for the currently `loaded` model -func chatPrompt(ctx context.Context, template string, messages []api.Message, numCtx int) (string, error) { +func chatPrompt(ctx context.Context, runner *runnerRef, template string, messages []api.Message, numCtx int) (string, error) { encode := func(s string) ([]int, error) { - return loaded.llama.Tokenize(ctx, s) + return runner.llama.Tokenize(ctx, s) } prompt, err := ChatPrompt(template, messages, numCtx, encode) @@ -1232,10 +1153,7 @@ func chatPrompt(ctx context.Context, template string, messages []api.Message, nu return prompt, nil } -func ChatHandler(c *gin.Context) { - loaded.mu.Lock() - defer loaded.mu.Unlock() - +func (s *Server) ChatHandler(c *gin.Context) { checkpointStart := time.Now() var req api.ChatRequest @@ -1292,7 +1210,11 @@ func ChatHandler(c *gin.Context) { sessionDuration = req.KeepAlive.Duration } - if err := load(c, model, opts, sessionDuration); err != nil { + rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration) + var runner *runnerRef + select { + case runner = <-rCh: + case err = <-eCh: c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } @@ -1309,7 +1231,7 @@ func ChatHandler(c *gin.Context) { }, req.Messages...) } - prompt, err := chatPrompt(c.Request.Context(), model.Template, req.Messages, opts.NumCtx) + prompt, err := chatPrompt(c.Request.Context(), runner, model.Template, req.Messages, opts.NumCtx) if err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return @@ -1352,8 +1274,6 @@ func ChatHandler(c *gin.Context) { defer close(ch) fn := func(r llm.CompletionResponse) { - // Update model expiration - loaded.expireTimer.Reset(sessionDuration) resp := api.ChatResponse{ Model: req.Model, @@ -1376,7 +1296,7 @@ func ChatHandler(c *gin.Context) { ch <- resp } - if err := loaded.llama.Completion(c.Request.Context(), llm.CompletionRequest{ + if err := runner.llama.Completion(c.Request.Context(), llm.CompletionRequest{ Prompt: prompt, Format: req.Format, Images: images, diff --git a/server/sched.go b/server/sched.go new file mode 100644 index 00000000..37e83694 --- /dev/null +++ b/server/sched.go @@ -0,0 +1,525 @@ +package server + +import ( + "context" + "errors" + "fmt" + "log/slog" + "os" + "reflect" + "sort" + "strconv" + "strings" + "sync" + "time" + + "github.com/ollama/ollama/api" + "github.com/ollama/ollama/format" + "github.com/ollama/ollama/gpu" + "github.com/ollama/ollama/llm" + "golang.org/x/exp/slices" +) + +type LlmRequest struct { + ctx context.Context //nolint:containedctx + model *Model + ggml *llm.GGML // TODO - how large is this, and do we need to free it after we've finished loading? + opts api.Options + sessionDuration time.Duration + successCh chan *runnerRef + errCh chan error +} + +type Scheduler struct { + pendingReqCh chan *LlmRequest + finishedReqCh chan *LlmRequest + expiredCh chan *runnerRef + unloadedCh chan interface{} + + loaded map[string]*runnerRef + loadedMu sync.Mutex + + loadFn func(req *LlmRequest, gpus gpu.GpuInfoList) + newServerFn func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options) (llm.LlamaServer, error) + getGpuFn func() gpu.GpuInfoList +} + +// TODO set this to zero after a release or two, to enable multiple models by default +var loadedMax = 1 // Maximum runners; < 1 maps to as many as will fit in VRAM (unlimited for CPU runners) +var maxQueuedRequests = 10 // TODO configurable + +func InitScheduler(ctx context.Context) *Scheduler { + maxRunners := os.Getenv("OLLAMA_MAX_LOADED_MODELS") + if maxRunners != "" { + m, err := strconv.Atoi(maxRunners) + if err != nil { + slog.Error("invalid setting", "OLLAMA_MAX_LOADED_MODELS", maxRunners, "error", err) + } else { + loadedMax = m + } + } + + sched := &Scheduler{ + pendingReqCh: make(chan *LlmRequest, maxQueuedRequests), + finishedReqCh: make(chan *LlmRequest, maxQueuedRequests), + expiredCh: make(chan *runnerRef, maxQueuedRequests), + unloadedCh: make(chan interface{}, maxQueuedRequests), + loaded: make(map[string]*runnerRef), + newServerFn: llm.NewLlamaServer, + getGpuFn: gpu.GetGPUInfo, + } + sched.loadFn = sched.load + return sched +} + +// context must be canceled to decrement ref count and release the runner +func (s *Scheduler) GetRunner(c context.Context, model *Model, opts api.Options, sessionDuration time.Duration) (chan *runnerRef, chan error) { + ggml, err := llm.LoadModel(model.ModelPath) + req := &LlmRequest{ + ctx: c, + model: model, + ggml: ggml, + opts: opts, + sessionDuration: sessionDuration, + successCh: make(chan *runnerRef), + errCh: make(chan error, 1), + } + if err != nil { + req.errCh <- err + return req.successCh, req.errCh + } + select { + case s.pendingReqCh <- req: + default: + req.errCh <- fmt.Errorf("server busy, please try again. maximum pending requests exceeded") + } + return req.successCh, req.errCh +} + +// Returns immediately, spawns go routines for the scheduler which will shutdown when ctx is done +func (s *Scheduler) Run(ctx context.Context) { + slog.Debug("starting llm scheduler") + go func() { + s.processPending(ctx) + }() + + go func() { + s.processCompleted(ctx) + }() +} + +func (s *Scheduler) processPending(ctx context.Context) { + for { + select { + case <-ctx.Done(): + slog.Debug("shutting down scheduler pending loop") + return + case pending := <-s.pendingReqCh: + // Block other requests until we get this pending request running + for { + var runnerToExpire *runnerRef + s.loadedMu.Lock() + runner := s.loaded[pending.model.ModelPath] + loadedCount := len(s.loaded) + s.loadedMu.Unlock() + if runner != nil { + if runner.needsReload(ctx, pending) { + runnerToExpire = runner + } else { + // Runner is usable, return it + pending.useLoadedRunner(runner, s.finishedReqCh) + break + } + } else if loadedCount == 0 { + slog.Debug("loading first model", "model", pending.model.ModelPath) + gpus := s.getGpuFn() + g := pickBestFitGPUs(pending, gpus) + if g != nil { + gpus = g + } + s.loadFn(pending, gpus) + break + } else if loadedMax > 0 && loadedCount >= loadedMax { + slog.Debug("max runners achieved, unloading one to make room", "runner_count", loadedCount) + runnerToExpire = s.findRunnerToUnload(pending) + } else { + // More than one loaded model, so we have to see if the new one fits + // Get a refreshed GPU list + gpus := s.getGpuFn() + // Update free memory from currently loaded models + s.updateFreeSpace(gpus) + gpus = pickBestFitGPUs(pending, gpus) + if gpus != nil { + slog.Debug("new model fits with existing models, loading") + s.loadFn(pending, gpus) + break + } + runnerToExpire = s.findRunnerToUnload(pending) + } + + if runnerToExpire == nil { + // Shouildn't happen + slog.Error("runner to expire was nil!") + continue + } + // Trigger an expiration to unload once it's done + runnerToExpire.refMu.Lock() + slog.Debug("resetting model to expire immediately to make room", "model", runnerToExpire.model, "refCount", runnerToExpire.refCount) + if runnerToExpire.expireTimer != nil { + runnerToExpire.expireTimer.Stop() + runnerToExpire.expireTimer = nil + } + runnerToExpire.sessionDuration = 0 + if runnerToExpire.refCount <= 0 { + s.expiredCh <- runnerToExpire + } + runnerToExpire.refMu.Unlock() + // Wait for the unload to happen + // Note: at this point we're queueing up all incoming requests, even if they were for + // a different model that's loaded and not scheduled to be removed. + slog.Debug("waiting for pending requests to complete and unload to occur", "model", runnerToExpire.model) + select { + case <-ctx.Done(): + slog.Debug("shutting down scheduler pending loop") + return + case <-s.unloadedCh: + slog.Debug("unload completed", "model", runnerToExpire.model) + continue + } + } + case <-s.unloadedCh: + // An unload request when there are no pending request can be ignored + slog.Debug("ignoring unload event with no pending requests") + } + } +} + +func (s *Scheduler) processCompleted(ctx context.Context) { + // Process completed requests, expired timers, and unloading models + for { + select { + case <-ctx.Done(): + slog.Debug("shutting down scheduler completed loop") + return + case finished := <-s.finishedReqCh: + s.loadedMu.Lock() + runner := s.loaded[finished.model.ModelPath] + s.loadedMu.Unlock() + if runner == nil { + slog.Error("finished requeset signal received after model unloaded", "model", finished.model.ModelPath) + continue + } + runner.refMu.Lock() + runner.refCount-- + if runner.refCount <= 0 { + if runner.sessionDuration <= 0 { + slog.Debug("runner with zero duration has gone idle, expiring to unload", "model", runner.model) + if runner.expireTimer != nil { + runner.expireTimer.Stop() + runner.expireTimer = nil + } + s.expiredCh <- runner + } else if runner.expireTimer == nil { + slog.Debug("runner with non-zero duration has gone idle, adding timer", "model", runner.model, "duration", runner.sessionDuration) + runner.expireTimer = time.AfterFunc(runner.sessionDuration, func() { + slog.Debug("timer expired, expiring to unload", "model", runner.model) + runner.refMu.Lock() + defer runner.refMu.Unlock() + if runner.expireTimer != nil { + runner.expireTimer.Stop() + } + s.expiredCh <- runner + }) + } else { + slog.Debug("runner with non-zero duration has gone idle, resetting timer", "model", runner.model, "duration", runner.sessionDuration) + runner.expireTimer.Reset(runner.sessionDuration) + } + } + slog.Debug("after processing request finished event", "model", runner.model, "refCount", runner.refCount) + runner.refMu.Unlock() + case runner := <-s.expiredCh: + slog.Debug("runner expired event received", "model", runner.model) + runner.refMu.Lock() + if runner.refCount > 0 { + // Shouldn't happen, but safeguard to ensure no leaked runners + slog.Debug("expired event with positive ref count, retrying", "model", runner.model, "refCount", runner.refCount) + go func(runner *runnerRef) { + // We can't unload yet, but want to as soon as the current request completes + // So queue up another expired event + time.Sleep(10 * time.Millisecond) + s.expiredCh <- runner + }(runner) + runner.refMu.Unlock() + continue + } + + slog.Debug("got lock to unload", "model", runner.model) + runner.unload() + s.loadedMu.Lock() + delete(s.loaded, runner.model) + s.loadedMu.Unlock() + slog.Debug("runner released", "model", runner.model) + runner.refMu.Unlock() + slog.Debug("sending an unloaded event", "model", runner.model) + s.unloadedCh <- struct{}{} + } + } +} + +// Complete the pending request and send the runner back to the requester +// Wires up a finished event after the request context is completed +// Updates session duration, and resets expiration timer +func (pending *LlmRequest) useLoadedRunner(runner *runnerRef, finished chan *LlmRequest) { + runner.refMu.Lock() + defer runner.refMu.Unlock() + runner.refCount++ + runner.sessionDuration = pending.sessionDuration + pending.successCh <- runner + go func() { + <-pending.ctx.Done() + slog.Debug("context for request finished") + finished <- pending + }() +} + +func (s *Scheduler) load(req *LlmRequest, gpus gpu.GpuInfoList) { + llama, err := s.newServerFn(gpus, req.model.ModelPath, req.ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts) + if err != nil { + // some older models are not compatible with newer versions of llama.cpp + // show a generalized compatibility error until there is a better way to + // check for model compatibility + if errors.Is(llm.ErrUnsupportedFormat, err) || strings.Contains(err.Error(), "failed to load model") { + err = fmt.Errorf("%v: this model may be incompatible with your version of Ollama. If you previously pulled this model, try updating it by running `ollama pull %s`", err, req.model.ShortName) + } + slog.Info("NewLlamaServer failed", "model", req.model.ModelPath, "error", err) + req.errCh <- err + return + } + runner := &runnerRef{} + runner.model = req.model.ModelPath + runner.adapters = req.model.AdapterPaths + runner.projectors = req.model.ProjectorPaths + runner.llama = llama + runner.Options = &req.opts + runner.sessionDuration = req.sessionDuration + runner.gpus = gpus + runner.estimatedVRAM = llama.EstimatedVRAM() + runner.loading = true + runner.refCount = 1 + runner.refMu.Lock() + s.loadedMu.Lock() + s.loaded[req.model.ModelPath] = runner + slog.Info("loaded runners", "count", len(s.loaded)) + s.loadedMu.Unlock() + + go func() { + defer runner.refMu.Unlock() + if err = llama.WaitUntilRunning(req.ctx); err != nil { + slog.Error("error loading llama server", "error", err) + runner.refCount-- + req.errCh <- err + slog.Debug("triggering expiration for failed load", "model", runner.model) + s.expiredCh <- runner + return + } + slog.Debug("finished setting up runner", "model", req.model.ModelPath) + runner.loading = false + go func() { + <-req.ctx.Done() + slog.Debug("context for request finished") + s.finishedReqCh <- req + }() + req.successCh <- runner + }() +} + +func (s *Scheduler) updateFreeSpace(allGpus gpu.GpuInfoList) { + type predKey struct { + Library string + ID string + } + predMap := map[predKey]uint64{} // Sum up the total predicted usage per GPU for all runners + s.loadedMu.Lock() + for _, r := range s.loaded { + r.refMu.Lock() + gpuIDs := make([]string, 0, len(r.gpus)) + if r.llama != nil { + + // TODO this should be broken down by GPU instead of assuming uniform spread + estimatedVRAMPerGPU := r.llama.EstimatedVRAM() / uint64(len(r.gpus)) + for _, gpu := range r.gpus { + gpuIDs = append(gpuIDs, gpu.ID) + } + for _, gpu := range allGpus { + if slices.Contains(gpuIDs, gpu.ID) { + predMap[predKey{gpu.Library, gpu.ID}] += estimatedVRAMPerGPU + } + } + } else { + slog.Warn("unexpected nil runner reference, memory prediction may be incorrect") + } + r.refMu.Unlock() + } + s.loadedMu.Unlock() + + // Now that we've summed up all the GPU usage predictions across all the loaded runners, update the gpu list + for i := range allGpus { + if p, ok := predMap[predKey{allGpus[i].Library, allGpus[i].ID}]; ok { + slog.Debug("gpu reported", "gpu", allGpus[i].ID, "library", allGpus[i].Library, "available", format.HumanBytes2(allGpus[i].FreeMemory)) + if p > allGpus[i].TotalMemory { + // Shouldn't happen + slog.Warn("predicted usage exceeds VRAM", "gpu", allGpus[i].ID, "totalMemory", allGpus[i].TotalMemory, "predicted", p) + allGpus[i].FreeMemory = 0 + } else if (allGpus[i].TotalMemory - p) < allGpus[i].FreeMemory { // predicted free is smaller than reported free, use it + // TODO maybe we should just always trust our numbers, since cuda's free memory reporting is laggy + // and we might unload models we didn't actually need to. The risk is if some other GPU intensive app is loaded + // after we start our first runner, then we'll never acount for that, so picking the smallest free value seems prudent. + allGpus[i].FreeMemory = allGpus[i].TotalMemory - p + } + slog.Info("updated VRAM", "gpu", allGpus[i].ID, "library", allGpus[i].Library, "total", format.HumanBytes2(allGpus[i].TotalMemory), "available", format.HumanBytes2(allGpus[i].FreeMemory)) + } + } +} + +type runnerRef struct { + refMu sync.Mutex + // refCond sync.Cond // Signaled on transition from 1 -> 0 refCount + refCount uint // prevent unloading if > 0 + // unloading bool // set to true when we are trying to unload the runner + + llama llm.LlamaServer + loading bool // True only during initial load, then false forever + gpus gpu.GpuInfoList // Recorded at time of provisioning + estimatedVRAM uint64 + + sessionDuration time.Duration + expireTimer *time.Timer + + model string + adapters []string + projectors []string + *api.Options +} + +// The refMu must already be held when calling unload +func (runner *runnerRef) unload() { + if runner.llama != nil { + runner.llama.Close() + } + runner.llama = nil + runner.adapters = nil + runner.projectors = nil + runner.Options = nil + runner.gpus = nil +} + +func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool { + slog.Debug("evaluating already loaded", "model", req.model.ModelPath) + runner.refMu.Lock() + defer runner.refMu.Unlock() + // Ignore the NumGPU settings for comparison + optsExisting := runner.Options.Runner + optsExisting.NumGPU = -1 + optsNew := req.opts.Runner + optsNew.NumGPU = -1 + timeout := 10 * time.Second + if runner.loading { + timeout = 2 * time.Minute // Initial load can take a long time for big models on slow systems... + } + ctx, cancel := context.WithTimeout(ctx, timeout) // BUG - + defer cancel() + if !reflect.DeepEqual(runner.adapters, req.model.AdapterPaths) || // have the adapters changed? + !reflect.DeepEqual(runner.projectors, req.model.ProjectorPaths) || // have the projectors changed? + !reflect.DeepEqual(optsExisting, optsNew) || // have the runner options changed? + runner.llama.Ping(ctx) != nil { + return true + } + return false +} + +type ByDuration []*runnerRef + +func (a ByDuration) Len() int { return len(a) } +func (a ByDuration) Swap(i, j int) { a[i], a[j] = a[j], a[i] } +func (a ByDuration) Less(i, j int) bool { + // uint64 to turn negative time (never unload) to largest + return uint64(a[i].sessionDuration) < uint64(a[j].sessionDuration) +} + +// TODO - future consideration to pick runners based on size +// type BySize []*runnerRef +// func (a BySize) Len() int { return len(a) } +// func (a BySize) Swap(i, j int) { a[i], a[j] = a[j], a[i] } +// func (a BySize) Less(i, j int) bool { return a[i].estimatedVRAM < a[j].estimatedVRAM } + +// pickBestFitGPUs will try to find the optimal placement of the model in the available GPUs where the model fully fits +// If the model can not be fit fully within the available GPU(s) nil is returned +func pickBestFitGPUs(req *LlmRequest, gpus gpu.GpuInfoList) gpu.GpuInfoList { + var estimatedVRAM uint64 + for _, gl := range gpus.ByLibrary() { + var ok bool + sgl := append(make(gpu.GpuInfoList, 0, len(gl)), gl...) + + // TODO - potentially sort by performance capability, existing models loaded, etc. + // Note: at present, this will favor more VRAM over faster GPU speed in mixed setups + sort.Sort(sort.Reverse(gpu.ByFreeMemory(sgl))) + + // First attempt to fit the model into a single GPU + for _, g := range sgl { + if ok, estimatedVRAM = llm.PredictServerFit([]gpu.GpuInfo{g}, req.ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts); ok { + slog.Debug("new model will fit in available VRAM in single GPU, loading", "model", req.model.ModelPath, "gpu", g.ID, "available", g.FreeMemory, "required", format.HumanBytes2(estimatedVRAM)) + return []gpu.GpuInfo{g} + } + } + + // TODO future refinements + // - if multiple Libraries, see if any single GPU in any Library will fit + // - try subsets of GPUs instead of just falling back to 1 or all in a family + + // Now try all the GPUs + if ok, estimatedVRAM = llm.PredictServerFit(gl, req.ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts); ok { + slog.Debug("new model will fit in available VRAM, loading", "model", req.model.ModelPath, "library", gl[0].Library, "required", format.HumanBytes2(estimatedVRAM)) + return gl + } + } + return nil +} + +// findRunnerToUnload finds a runner to unload to make room for a new model +func (s *Scheduler) findRunnerToUnload(req *LlmRequest) *runnerRef { + s.loadedMu.Lock() + runnerList := make([]*runnerRef, 0, len(s.loaded)) + for _, r := range s.loaded { + runnerList = append(runnerList, r) + } + s.loadedMu.Unlock() + + // In the future we can enhance the algorithm to be smarter about picking the optimal runner to unload + // e.g., if we have multiple options, will one make room for the request? + sort.Sort(ByDuration(runnerList)) + + // First try to find a runner that's already idle + for _, runner := range runnerList { + runner.refMu.Lock() + rc := runner.refCount + runner.refMu.Unlock() + if rc == 0 { + slog.Debug("found an idle runner to unload") + return runner + } + } + // None appear idle, just wait for the one with the shortest duration + slog.Debug("no idle runners, picking the shortest duration", "count", len(runnerList)) + return runnerList[0] +} + +func (s *Scheduler) unloadAllRunners() { + s.loadedMu.Lock() + defer s.loadedMu.Unlock() + for model, runner := range s.loaded { + if runner.llama != nil { + slog.Debug("shutting down runner", "model", model) + runner.llama.Close() + } + } +} diff --git a/server/sched_test.go b/server/sched_test.go new file mode 100644 index 00000000..b5117631 --- /dev/null +++ b/server/sched_test.go @@ -0,0 +1,553 @@ +package server + +import ( + "bytes" + "context" + "encoding/binary" + "fmt" + "log/slog" + "os" + "testing" + "time" + + "github.com/ollama/ollama/api" + "github.com/ollama/ollama/app/lifecycle" + "github.com/ollama/ollama/format" + "github.com/ollama/ollama/gpu" + "github.com/ollama/ollama/llm" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func init() { + os.Setenv("OLLAMA_DEBUG", "1") + lifecycle.InitLogging() +} + +func TestInitScheduler(t *testing.T) { + ctx, done := context.WithCancel(context.Background()) + defer done() + initialMax := loadedMax + s := InitScheduler(ctx) + require.Equal(t, initialMax, loadedMax) + require.NotNil(t, s.loaded) + + os.Setenv("OLLAMA_MAX_LOADED_MODELS", "blue") + s = InitScheduler(ctx) + require.Equal(t, initialMax, loadedMax) + require.NotNil(t, s.loaded) + + os.Setenv("OLLAMA_MAX_LOADED_MODELS", "0") + s = InitScheduler(ctx) + require.Equal(t, 0, loadedMax) + require.NotNil(t, s.loaded) +} + +func TestLoad(t *testing.T) { + ctx, done := context.WithTimeout(context.Background(), 5*time.Millisecond) + defer done() + s := InitScheduler(ctx) + req := &LlmRequest{ + ctx: ctx, + model: &Model{ModelPath: "foo"}, + successCh: make(chan *runnerRef, 1), + errCh: make(chan error, 1), + sessionDuration: 2, + } + // Fail to load model first + s.newServerFn = func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options) (llm.LlamaServer, error) { + return nil, fmt.Errorf("something failed to load model blah") + } + gpus := gpu.GpuInfoList{} + s.load(req, gpus) + require.Len(t, req.successCh, 0) + require.Len(t, req.errCh, 1) + require.Len(t, s.loaded, 0) + err := <-req.errCh + require.Contains(t, err.Error(), "this model may be incompatible") + + server := &mockLlm{estimatedVRAM: 10} + s.newServerFn = func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options) (llm.LlamaServer, error) { + return server, nil + } + s.load(req, gpus) + select { + case err := <-req.errCh: + require.NoError(t, err) + case resp := <-req.successCh: + require.Equal(t, uint64(10), resp.estimatedVRAM) + require.Equal(t, uint(1), resp.refCount) + require.Len(t, s.loaded, 1) + } + + req.model.ModelPath = "dummy_model_path" + server.waitResp = fmt.Errorf("wait failure") + s.load(req, gpus) + select { + case err := <-req.errCh: + require.Contains(t, err.Error(), "wait failure") + case resp := <-req.successCh: + t.Errorf("unexpected success %v", resp) + } + runner := s.loaded["dummy_model_path"] + require.NotNil(t, runner) + require.Equal(t, uint(0), runner.refCount) + require.Len(t, s.expiredCh, 1) +} + +type bundle struct { + ctx context.Context //nolint:containedctx + ctxDone func() + srv *mockLlm + req *LlmRequest +} + +func (scenario *bundle) newServer(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options) (llm.LlamaServer, error) { + return scenario.srv, nil +} + +func newScenario(t *testing.T, ctx context.Context, modelName string, estimatedVRAM uint64) *bundle { + scenario := &bundle{} + scenario.ctx, scenario.ctxDone = context.WithCancel(ctx) + t.Helper() + + f, err := os.CreateTemp(t.TempDir(), modelName) + assert.Nil(t, err) + defer f.Close() + + gguf := llm.NewGGUFV3(binary.LittleEndian) + err = gguf.Encode(f, llm.KV{ + "general.architecture": "llama", + "general.name": "name", + "llama.context_length": uint32(32), + "llama.embedding_length": uint32(4096), + "llama.block_count": uint32(1), + "llama.attention.head_count": uint32(32), + "llama.attention.head_count_kv": uint32(32), + "tokenizer.ggml.tokens": []string{" "}, + "tokenizer.ggml.scores": []float32{0}, + "tokenizer.ggml.token_type": []int32{0}, + }, []llm.Tensor{ + {Name: "blk.0.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: &bytes.Reader{}}, + }) + assert.Nil(t, err) + fname := f.Name() + model := &Model{Name: modelName, ModelPath: fname} + ggml, err := llm.LoadModel(model.ModelPath) + require.NoError(t, err) + scenario.req = &LlmRequest{ + ctx: scenario.ctx, + model: model, + ggml: ggml, + sessionDuration: 5 * time.Millisecond, + successCh: make(chan *runnerRef, 1), + errCh: make(chan error, 1), + } + scenario.srv = &mockLlm{estimatedVRAM: estimatedVRAM} + return scenario +} + +func TestRequests(t *testing.T) { + ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer done() + + // Same model, same request + scenario1a := newScenario(t, ctx, "ollama-model-1", 10) + scenario1a.req.sessionDuration = 0 + scenario1b := newScenario(t, ctx, "ollama-model-1", 11) + scenario1b.req.model = scenario1a.req.model + scenario1b.req.ggml = scenario1a.req.ggml + scenario1b.req.sessionDuration = 0 + + // simple reload of same model + scenario2a := newScenario(t, ctx, "ollama-model-1", 20) + scenario2a.req.model = scenario1a.req.model + scenario2a.req.ggml = scenario1a.req.ggml + + // Multiple loaded models + scenario3a := newScenario(t, ctx, "ollama-model-3a", 1*format.GigaByte) + scenario3b := newScenario(t, ctx, "ollama-model-3b", 24*format.GigaByte) + scenario3c := newScenario(t, ctx, "ollama-model-3c", 30) // Needs prior unloaded + + s := InitScheduler(ctx) + s.getGpuFn = func() gpu.GpuInfoList { + g := gpu.GpuInfo{Library: "metal"} + g.TotalMemory = 24 * format.GigaByte + g.FreeMemory = 12 * format.GigaByte + return []gpu.GpuInfo{g} + } + s.newServerFn = scenario1a.newServer + slog.Info("scenario1a") + s.pendingReqCh <- scenario1a.req + require.Len(t, s.pendingReqCh, 1) + s.Run(ctx) + select { + case resp := <-scenario1a.req.successCh: + require.Equal(t, resp.llama, scenario1a.srv) + require.Len(t, s.pendingReqCh, 0) + require.Len(t, scenario1a.req.errCh, 0) + case <-ctx.Done(): + t.Errorf("timeout") + } + + // Same runner as first request due to not needing a reload + s.newServerFn = scenario1b.newServer + slog.Info("scenario1b") + s.pendingReqCh <- scenario1b.req + select { + case resp := <-scenario1b.req.successCh: + require.Equal(t, resp.llama, scenario1a.srv) + require.Len(t, s.pendingReqCh, 0) + require.Len(t, scenario1b.req.errCh, 0) + case <-ctx.Done(): + t.Errorf("timeout") + } + + // Trigger a reload + s.newServerFn = scenario2a.newServer + scenario2a.req.model.AdapterPaths = []string{"new"} + slog.Info("scenario2a") + s.pendingReqCh <- scenario2a.req + // finish first two requests, so model can reload + time.Sleep(1 * time.Millisecond) + scenario1a.ctxDone() + scenario1b.ctxDone() + select { + case resp := <-scenario2a.req.successCh: + require.Equal(t, resp.llama, scenario2a.srv) + require.Len(t, s.pendingReqCh, 0) + require.Len(t, scenario2a.req.errCh, 0) + case <-ctx.Done(): + t.Errorf("timeout") + } + + loadedMax = 1 + s.newServerFn = scenario3a.newServer + slog.Info("scenario3a") + s.pendingReqCh <- scenario3a.req + // finish prior request, so new model can load + time.Sleep(1 * time.Millisecond) + scenario2a.ctxDone() + select { + case resp := <-scenario3a.req.successCh: + require.Equal(t, resp.llama, scenario3a.srv) + require.Len(t, s.pendingReqCh, 0) + require.Len(t, scenario3a.req.errCh, 0) + case <-ctx.Done(): + t.Errorf("timeout") + } + require.Len(t, s.loaded, 1) + + loadedMax = 0 + s.newServerFn = scenario3b.newServer + slog.Info("scenario3b") + s.pendingReqCh <- scenario3b.req + select { + case resp := <-scenario3b.req.successCh: + require.Equal(t, resp.llama, scenario3b.srv) + require.Len(t, s.pendingReqCh, 0) + require.Len(t, scenario3b.req.errCh, 0) + case <-ctx.Done(): + t.Errorf("timeout") + } + require.Len(t, s.loaded, 2) + + // Try to load a model that wont fit + s.newServerFn = scenario3c.newServer + slog.Info("scenario3c") + require.Len(t, s.loaded, 2) + scenario3a.ctxDone() // Won't help since this one isn't big enough to make room + time.Sleep(2 * time.Millisecond) + s.pendingReqCh <- scenario3c.req + // finish prior request, so new model can load + time.Sleep(6 * time.Millisecond) + require.Len(t, s.loaded, 1) + scenario3b.ctxDone() + select { + case resp := <-scenario3c.req.successCh: + require.Equal(t, resp.llama, scenario3c.srv) + require.Len(t, s.pendingReqCh, 0) + require.Len(t, scenario3c.req.errCh, 0) + case <-ctx.Done(): + t.Errorf("timeout") + } + require.Len(t, s.loaded, 1) +} + +func TestGetRunner(t *testing.T) { + ctx, done := context.WithTimeout(context.Background(), 20*time.Millisecond) + defer done() + + // Same model, same request + scenario1a := newScenario(t, ctx, "ollama-model-1a", 10) + scenario1a.req.sessionDuration = 0 + scenario1b := newScenario(t, ctx, "ollama-model-1b", 10) + scenario1b.req.sessionDuration = 0 + scenario1c := newScenario(t, ctx, "ollama-model-1c", 10) + scenario1c.req.sessionDuration = 0 + maxQueuedRequests = 1 + s := InitScheduler(ctx) + s.getGpuFn = func() gpu.GpuInfoList { + g := gpu.GpuInfo{Library: "metal"} + g.TotalMemory = 24 * format.GigaByte + g.FreeMemory = 12 * format.GigaByte + return []gpu.GpuInfo{g} + } + s.newServerFn = scenario1a.newServer + slog.Info("scenario1a") + successCh1a, errCh1a := s.GetRunner(scenario1a.ctx, scenario1a.req.model, scenario1a.req.opts, scenario1a.req.sessionDuration) + require.Len(t, s.pendingReqCh, 1) + slog.Info("scenario1b") + successCh1b, errCh1b := s.GetRunner(scenario1b.ctx, scenario1b.req.model, scenario1b.req.opts, scenario1b.req.sessionDuration) + require.Len(t, s.pendingReqCh, 1) + require.Len(t, successCh1b, 0) + require.Len(t, errCh1b, 1) + err := <-errCh1b + require.Contains(t, err.Error(), "server busy") + s.Run(ctx) + select { + case resp := <-successCh1a: + require.Equal(t, resp.llama, scenario1a.srv) + require.Len(t, s.pendingReqCh, 0) + require.Len(t, errCh1a, 0) + case <-ctx.Done(): + t.Errorf("timeout") + } + scenario1a.ctxDone() + require.Len(t, s.loaded, 1) + + scenario1c.req.model.ModelPath = "bad path" + slog.Info("scenario1c") + successCh1c, errCh1c := s.GetRunner(scenario1c.ctx, scenario1c.req.model, scenario1c.req.opts, scenario1c.req.sessionDuration) + require.Len(t, s.pendingReqCh, 0) + require.Len(t, successCh1c, 0) + require.Len(t, errCh1c, 1) + err = <-errCh1c + require.Contains(t, err.Error(), "bad path") + scenario1b.ctxDone() + + time.Sleep(5 * time.Millisecond) + require.Len(t, s.loaded, 0) +} + +// TODO - add one scenario that triggers the bogus finished event with positive ref count +func TestPrematureExpired(t *testing.T) { + ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer done() + + // Same model, same request + scenario1a := newScenario(t, ctx, "ollama-model-1a", 10) + s := InitScheduler(ctx) + s.getGpuFn = func() gpu.GpuInfoList { + g := gpu.GpuInfo{Library: "metal"} + g.TotalMemory = 24 * format.GigaByte + g.FreeMemory = 12 * format.GigaByte + return []gpu.GpuInfo{g} + } + s.newServerFn = scenario1a.newServer + successCh1a, errCh1a := s.GetRunner(scenario1a.ctx, scenario1a.req.model, scenario1a.req.opts, scenario1a.req.sessionDuration) + require.Len(t, s.pendingReqCh, 1) + s.Run(ctx) + select { + case resp := <-successCh1a: + require.Equal(t, resp.llama, scenario1a.srv) + require.Len(t, s.pendingReqCh, 0) + require.Len(t, errCh1a, 0) + require.Len(t, s.loaded, 1) + slog.Info("sending premature expired event now") + s.expiredCh <- resp // Shouldn't happen in real life, but make sure its safe + case <-ctx.Done(): + t.Errorf("timeout") + } + time.Sleep(scenario1a.req.sessionDuration) + scenario1a.ctxDone() + time.Sleep(20 * time.Millisecond) + require.LessOrEqual(t, len(s.finishedReqCh), 1) + time.Sleep(10 * time.Millisecond) + require.Len(t, s.finishedReqCh, 0) + require.Len(t, s.loaded, 0) + + // also shouldn't happen in real life + s.finishedReqCh <- scenario1a.req + time.Sleep(5 * time.Millisecond) +} + +func TestUseLoadedRunner(t *testing.T) { + ctx, done := context.WithTimeout(context.Background(), 5*time.Millisecond) + req := &LlmRequest{ + ctx: ctx, + successCh: make(chan *runnerRef, 1), + sessionDuration: 2, + } + finished := make(chan *LlmRequest) + llm1 := &mockLlm{} + r1 := &runnerRef{llama: llm1, sessionDuration: 1} + req.useLoadedRunner(r1, finished) + require.Equal(t, uint(1), r1.refCount) + require.Equal(t, time.Duration(2), r1.sessionDuration) + select { + case success := <-req.successCh: + require.Equal(t, r1, success) + case <-ctx.Done(): + t.Errorf("timeout") + } + done() + fin := <-finished + require.Equal(t, req, fin) +} + +func TestUpdateFreeSpace(t *testing.T) { + ctx, done := context.WithTimeout(context.Background(), 5*time.Millisecond) + defer done() + gpus := gpu.GpuInfoList{ + { + Library: "a", + ID: "1", + }, + { + Library: "a", + ID: "2", + }, + } + gpus[0].TotalMemory = 1000 + gpus[0].FreeMemory = 900 + gpus[1].TotalMemory = 2000 + gpus[1].FreeMemory = 1900 + llm1 := &mockLlm{estimatedVRAM: 100} + llm2 := &mockLlm{estimatedVRAM: 200} + r1 := &runnerRef{llama: llm1, gpus: gpus} + r2 := &runnerRef{llama: llm2, gpus: gpus} + + s := InitScheduler(ctx) + s.loaded["a"] = r1 + s.loaded["b"] = r2 + + s.updateFreeSpace(gpus) + require.Equal(t, uint64(850), gpus[0].FreeMemory) + require.Equal(t, uint64(1850), gpus[1].FreeMemory) + +} + +func TestFindRunnerToUnload(t *testing.T) { + ctx, done := context.WithTimeout(context.Background(), 5*time.Millisecond) + defer done() + req := &LlmRequest{ctx: ctx} + r1 := &runnerRef{refCount: 1, sessionDuration: 1} + r2 := &runnerRef{sessionDuration: 2} + + s := InitScheduler(ctx) + s.loaded["a"] = r1 + s.loaded["b"] = r2 + + resp := s.findRunnerToUnload(req) + require.Equal(t, r2, resp) + r2.refCount = 1 + resp = s.findRunnerToUnload(req) + require.Equal(t, r1, resp) + +} + +func TestNeedsReload(t *testing.T) { + ctx, done := context.WithTimeout(context.Background(), 5*time.Millisecond) + defer done() + + llm := &mockLlm{} + runner := &runnerRef{ + adapters: []string{"adapter1"}, + projectors: []string{"projector1"}, + Options: &api.Options{}, + llama: llm, + } + req := &LlmRequest{ + model: &Model{ + AdapterPaths: []string{"adapter2"}, + ProjectorPaths: []string{"projector2"}, + }, + opts: api.Options{}, + } + resp := runner.needsReload(ctx, req) + require.True(t, resp) + req.model.AdapterPaths = runner.adapters + resp = runner.needsReload(ctx, req) + require.True(t, resp) + req.model.ProjectorPaths = runner.projectors + runner.loading = true + req.opts.NumBatch = 1234 + resp = runner.needsReload(ctx, req) + require.True(t, resp) + req.opts.NumBatch = runner.Options.NumBatch + llm.pingResp = fmt.Errorf("foo") + resp = runner.needsReload(ctx, req) + require.True(t, resp) + llm.pingResp = nil + resp = runner.needsReload(ctx, req) + require.False(t, resp) + req.opts.NumGPU = 99 + resp = runner.needsReload(ctx, req) + require.False(t, resp) +} + +func TestUnloadAllRunners(t *testing.T) { + ctx, done := context.WithTimeout(context.Background(), 5*time.Millisecond) + defer done() + + llm1 := &mockLlm{} + llm2 := &mockLlm{} + s := InitScheduler(ctx) + s.unloadAllRunners() + + r1 := &runnerRef{llama: llm1} + r2 := &runnerRef{llama: llm2} + + s.loaded["a"] = r1 + s.loaded["b"] = r2 + s.unloadAllRunners() + + require.True(t, llm1.closeCalled) + require.True(t, llm2.closeCalled) +} + +func TestUnload(t *testing.T) { + llm1 := &mockLlm{} + r1 := &runnerRef{llama: llm1} + r2 := &runnerRef{adapters: []string{"A"}} + r1.unload() + require.True(t, llm1.closeCalled) + r2.unload() + require.Nil(t, r2.adapters) +} + +type mockLlm struct { + pingResp error + waitResp error + completionResp error + embeddingResp []float64 + embeddingRespErr error + tokenizeResp []int + tokenizeRespErr error + detokenizeResp string + detonekizeRespErr error + closeResp error + closeCalled bool + estimatedVRAM uint64 +} + +func (s *mockLlm) Ping(ctx context.Context) error { return s.pingResp } +func (s *mockLlm) WaitUntilRunning(ctx context.Context) error { return s.waitResp } +func (s *mockLlm) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error { + return s.completionResp +} +func (s *mockLlm) Embedding(ctx context.Context, prompt string) ([]float64, error) { + return s.embeddingResp, s.embeddingRespErr +} +func (s *mockLlm) Tokenize(ctx context.Context, content string) ([]int, error) { + return s.tokenizeResp, s.tokenizeRespErr +} +func (s *mockLlm) Detokenize(ctx context.Context, tokens []int) (string, error) { + return s.detokenizeResp, s.detonekizeRespErr +} +func (s *mockLlm) Close() error { + s.closeCalled = true + return s.closeResp +} +func (s *mockLlm) EstimatedVRAM() uint64 { return s.estimatedVRAM }