412 lines
12 KiB
Go
412 lines
12 KiB
Go
|
package gpu
|
||
|
|
||
|
import (
|
||
|
"bufio"
|
||
|
"errors"
|
||
|
"fmt"
|
||
|
"io"
|
||
|
"log/slog"
|
||
|
"os"
|
||
|
"path/filepath"
|
||
|
"slices"
|
||
|
"strconv"
|
||
|
"strings"
|
||
|
|
||
|
"github.com/jmorganca/ollama/version"
|
||
|
)
|
||
|
|
||
|
// Discovery logic for AMD/ROCm GPUs
|
||
|
|
||
|
const (
|
||
|
curlMsg = "curl -fsSL https://github.com/ollama/ollama/releases/download/v%s/rocm-amd64-deps.tgz | tar -zxf - -C %s"
|
||
|
DriverVersionFile = "/sys/module/amdgpu/version"
|
||
|
AMDNodesSysfsDir = "/sys/class/kfd/kfd/topology/nodes/"
|
||
|
GPUPropertiesFileGlob = AMDNodesSysfsDir + "*/properties"
|
||
|
|
||
|
// Prefix with the node dir
|
||
|
GPUTotalMemoryFileGlob = "mem_banks/*/properties" // size_in_bytes line
|
||
|
GPUUsedMemoryFileGlob = "mem_banks/*/used_memory"
|
||
|
RocmStandardLocation = "/opt/rocm/lib"
|
||
|
)
|
||
|
|
||
|
var (
|
||
|
// Used to validate if the given ROCm lib is usable
|
||
|
ROCmLibGlobs = []string{"libhipblas.so.2*", "rocblas"} // TODO - probably include more coverage of files here...
|
||
|
)
|
||
|
|
||
|
// Gather GPU information from the amdgpu driver if any supported GPUs are detected
|
||
|
// HIP_VISIBLE_DEVICES will be set if we detect a mix of unsupported and supported devices
|
||
|
// and the user hasn't already set this variable
|
||
|
func AMDGetGPUInfo(resp *GpuInfo) {
|
||
|
// TODO - DRY this out with windows
|
||
|
if !AMDDetected() {
|
||
|
return
|
||
|
}
|
||
|
skip := map[int]interface{}{}
|
||
|
|
||
|
// Opportunistic logging of driver version to aid in troubleshooting
|
||
|
ver, err := AMDDriverVersion()
|
||
|
if err == nil {
|
||
|
slog.Info("AMD Driver: " + ver)
|
||
|
} else {
|
||
|
// TODO - if we see users crash and burn with the upstreamed kernel this can be adjusted to hard-fail rocm support and fallback to CPU
|
||
|
slog.Warn(fmt.Sprintf("ollama recommends running the https://www.amd.com/en/support/linux-drivers: %s", err))
|
||
|
}
|
||
|
|
||
|
// If the user has specified exactly which GPUs to use, look up their memory
|
||
|
visibleDevices := os.Getenv("HIP_VISIBLE_DEVICES")
|
||
|
if visibleDevices != "" {
|
||
|
ids := []int{}
|
||
|
for _, idStr := range strings.Split(visibleDevices, ",") {
|
||
|
id, err := strconv.Atoi(idStr)
|
||
|
if err != nil {
|
||
|
slog.Warn(fmt.Sprintf("malformed HIP_VISIBLE_DEVICES=%s %s", visibleDevices, err))
|
||
|
} else {
|
||
|
ids = append(ids, id)
|
||
|
}
|
||
|
}
|
||
|
amdProcMemLookup(resp, nil, ids)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
// Gather GFX version information from all detected cards
|
||
|
gfx := AMDGFXVersions()
|
||
|
verStrings := []string{}
|
||
|
for i, v := range gfx {
|
||
|
verStrings = append(verStrings, v.ToGFXString())
|
||
|
if v.Major == 0 {
|
||
|
// Silently skip CPUs
|
||
|
skip[i] = struct{}{}
|
||
|
continue
|
||
|
}
|
||
|
if v.Major < 9 {
|
||
|
// TODO consider this a build-time setting if we can support 8xx family GPUs
|
||
|
slog.Warn(fmt.Sprintf("amdgpu [%d] too old %s", i, v.ToGFXString()))
|
||
|
skip[i] = struct{}{}
|
||
|
}
|
||
|
}
|
||
|
slog.Info(fmt.Sprintf("detected amdgpu versions %v", verStrings))
|
||
|
|
||
|
// Abort if all GPUs are skipped
|
||
|
if len(skip) >= len(gfx) {
|
||
|
slog.Info("all detected amdgpus are skipped, falling back to CPU")
|
||
|
return
|
||
|
}
|
||
|
|
||
|
// If we got this far, then we have at least 1 GPU that's a ROCm candidate, so make sure we have a lib
|
||
|
libDir, err := AMDValidateLibDir()
|
||
|
if err != nil {
|
||
|
slog.Warn(fmt.Sprintf("unable to verify rocm library, will use cpu: %s", err))
|
||
|
return
|
||
|
}
|
||
|
|
||
|
gfxOverride := os.Getenv("HSA_OVERRIDE_GFX_VERSION")
|
||
|
if gfxOverride == "" {
|
||
|
supported, err := GetSupportedGFX(libDir)
|
||
|
if err != nil {
|
||
|
slog.Warn(fmt.Sprintf("failed to lookup supported GFX types, falling back to CPU mode: %s", err))
|
||
|
return
|
||
|
}
|
||
|
slog.Debug(fmt.Sprintf("rocm supported GPU types %v", supported))
|
||
|
|
||
|
for i, v := range gfx {
|
||
|
if !slices.Contains[[]string, string](supported, v.ToGFXString()) {
|
||
|
slog.Warn(fmt.Sprintf("amdgpu [%d] %s is not supported by %s %v", i, v.ToGFXString(), libDir, supported))
|
||
|
// TODO - consider discrete markdown just for ROCM troubleshooting?
|
||
|
slog.Warn("See https://github.com/ollama/ollama/blob/main/docs/troubleshooting.md for HSA_OVERRIDE_GFX_VERSION usage")
|
||
|
skip[i] = struct{}{}
|
||
|
} else {
|
||
|
slog.Info(fmt.Sprintf("amdgpu [%d] %s is supported", i, v.ToGFXString()))
|
||
|
}
|
||
|
}
|
||
|
} else {
|
||
|
slog.Debug("skipping rocm gfx compatibility check with HSA_OVERRIDE_GFX_VERSION=" + gfxOverride)
|
||
|
}
|
||
|
|
||
|
if len(skip) >= len(gfx) {
|
||
|
slog.Info("all detected amdgpus are skipped, falling back to CPU")
|
||
|
return
|
||
|
}
|
||
|
|
||
|
ids := make([]int, len(gfx))
|
||
|
i := 0
|
||
|
for k := range gfx {
|
||
|
ids[i] = k
|
||
|
i++
|
||
|
}
|
||
|
amdProcMemLookup(resp, skip, ids)
|
||
|
if resp.memInfo.DeviceCount == 0 {
|
||
|
return
|
||
|
}
|
||
|
if len(skip) > 0 {
|
||
|
amdSetVisibleDevices(ids, skip)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Walk the sysfs nodes for the available GPUs and gather information from them
|
||
|
// skipping over any devices in the skip map
|
||
|
func amdProcMemLookup(resp *GpuInfo, skip map[int]interface{}, ids []int) {
|
||
|
resp.memInfo.DeviceCount = 0
|
||
|
resp.memInfo.TotalMemory = 0
|
||
|
resp.memInfo.FreeMemory = 0
|
||
|
if len(ids) == 0 {
|
||
|
slog.Debug("discovering all amdgpu devices")
|
||
|
entries, err := os.ReadDir(AMDNodesSysfsDir)
|
||
|
if err != nil {
|
||
|
slog.Warn(fmt.Sprintf("failed to read amdgpu sysfs %s - %s", AMDNodesSysfsDir, err))
|
||
|
return
|
||
|
}
|
||
|
for _, node := range entries {
|
||
|
if !node.IsDir() {
|
||
|
continue
|
||
|
}
|
||
|
id, err := strconv.Atoi(node.Name())
|
||
|
if err != nil {
|
||
|
slog.Warn("malformed amdgpu sysfs node id " + node.Name())
|
||
|
continue
|
||
|
}
|
||
|
ids = append(ids, id)
|
||
|
}
|
||
|
}
|
||
|
slog.Debug(fmt.Sprintf("discovering amdgpu devices %v", ids))
|
||
|
|
||
|
for _, id := range ids {
|
||
|
if _, skipped := skip[id]; skipped {
|
||
|
continue
|
||
|
}
|
||
|
totalMemory := uint64(0)
|
||
|
usedMemory := uint64(0)
|
||
|
propGlob := filepath.Join(AMDNodesSysfsDir, strconv.Itoa(id), GPUTotalMemoryFileGlob)
|
||
|
propFiles, err := filepath.Glob(propGlob)
|
||
|
if err != nil {
|
||
|
slog.Warn(fmt.Sprintf("error looking up total GPU memory: %s %s", propGlob, err))
|
||
|
}
|
||
|
// 1 or more memory banks - sum the values of all of them
|
||
|
for _, propFile := range propFiles {
|
||
|
fp, err := os.Open(propFile)
|
||
|
if err != nil {
|
||
|
slog.Warn(fmt.Sprintf("failed to open sysfs node file %s: %s", propFile, err))
|
||
|
continue
|
||
|
}
|
||
|
defer fp.Close()
|
||
|
scanner := bufio.NewScanner(fp)
|
||
|
for scanner.Scan() {
|
||
|
line := strings.TrimSpace(scanner.Text())
|
||
|
if strings.HasPrefix(line, "size_in_bytes") {
|
||
|
ver := strings.Fields(line)
|
||
|
if len(ver) != 2 {
|
||
|
slog.Warn("malformed " + line)
|
||
|
continue
|
||
|
}
|
||
|
bankSizeInBytes, err := strconv.ParseUint(ver[1], 10, 64)
|
||
|
if err != nil {
|
||
|
slog.Warn("malformed int " + line)
|
||
|
continue
|
||
|
}
|
||
|
totalMemory += bankSizeInBytes
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
if totalMemory == 0 {
|
||
|
continue
|
||
|
}
|
||
|
usedGlob := filepath.Join(AMDNodesSysfsDir, strconv.Itoa(id), GPUUsedMemoryFileGlob)
|
||
|
usedFiles, err := filepath.Glob(usedGlob)
|
||
|
if err != nil {
|
||
|
slog.Warn(fmt.Sprintf("error looking up used GPU memory: %s %s", usedGlob, err))
|
||
|
continue
|
||
|
}
|
||
|
for _, usedFile := range usedFiles {
|
||
|
fp, err := os.Open(usedFile)
|
||
|
if err != nil {
|
||
|
slog.Warn(fmt.Sprintf("failed to open sysfs node file %s: %s", usedFile, err))
|
||
|
continue
|
||
|
}
|
||
|
defer fp.Close()
|
||
|
data, err := io.ReadAll(fp)
|
||
|
if err != nil {
|
||
|
slog.Warn(fmt.Sprintf("failed to read sysfs node file %s: %s", usedFile, err))
|
||
|
continue
|
||
|
}
|
||
|
used, err := strconv.ParseUint(strings.TrimSpace(string(data)), 10, 64)
|
||
|
if err != nil {
|
||
|
slog.Warn(fmt.Sprintf("malformed used memory %s: %s", string(data), err))
|
||
|
continue
|
||
|
}
|
||
|
usedMemory += used
|
||
|
}
|
||
|
slog.Info(fmt.Sprintf("[%d] amdgpu totalMemory %d", id, totalMemory))
|
||
|
slog.Info(fmt.Sprintf("[%d] amdgpu freeMemory %d", id, (totalMemory - usedMemory)))
|
||
|
resp.memInfo.DeviceCount++
|
||
|
resp.memInfo.TotalMemory += totalMemory
|
||
|
resp.memInfo.FreeMemory += (totalMemory - usedMemory)
|
||
|
}
|
||
|
if resp.memInfo.DeviceCount > 0 {
|
||
|
resp.Library = "rocm"
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Quick check for AMD driver so we can skip amdgpu discovery if not present
|
||
|
func AMDDetected() bool {
|
||
|
// Some driver versions (older?) don't have a version file, so just lookup the parent dir
|
||
|
sysfsDir := filepath.Dir(DriverVersionFile)
|
||
|
_, err := os.Stat(sysfsDir)
|
||
|
if errors.Is(err, os.ErrNotExist) {
|
||
|
slog.Debug("amdgpu driver not detected " + sysfsDir)
|
||
|
return false
|
||
|
} else if err != nil {
|
||
|
slog.Debug(fmt.Sprintf("error looking up amd driver %s %s", sysfsDir, err))
|
||
|
return false
|
||
|
}
|
||
|
return true
|
||
|
}
|
||
|
|
||
|
func setupLink(source, target string) error {
|
||
|
if err := os.RemoveAll(target); err != nil {
|
||
|
return fmt.Errorf("failed to remove old rocm directory %s %w", target, err)
|
||
|
}
|
||
|
if err := os.Symlink(source, target); err != nil {
|
||
|
return fmt.Errorf("failed to create link %s => %s %w", source, target, err)
|
||
|
}
|
||
|
slog.Debug(fmt.Sprintf("host rocm linked %s => %s", source, target))
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// Ensure the AMD rocm lib dir is wired up
|
||
|
// Prefer to use host installed ROCm, as long as it meets our minimum requirements
|
||
|
// failing that, tell the user how to download it on their own
|
||
|
func AMDValidateLibDir() (string, error) {
|
||
|
// We rely on the rpath compiled into our library to find rocm
|
||
|
// so we establish a symlink to wherever we find it on the system
|
||
|
// to $AssetsDir/rocm
|
||
|
|
||
|
// If we already have a rocm dependency wired, nothing more to do
|
||
|
assetsDir, err := AssetsDir()
|
||
|
if err != nil {
|
||
|
return "", fmt.Errorf("unable to lookup lib dir: %w", err)
|
||
|
}
|
||
|
// Versioned directory
|
||
|
rocmTargetDir := filepath.Join(assetsDir, "rocm")
|
||
|
if rocmLibUsable(rocmTargetDir) {
|
||
|
return rocmTargetDir, nil
|
||
|
}
|
||
|
// Parent dir (unversioned)
|
||
|
commonRocmDir := filepath.Join(filepath.Dir(assetsDir), "rocm")
|
||
|
if rocmLibUsable(commonRocmDir) {
|
||
|
return rocmTargetDir, setupLink(commonRocmDir, rocmTargetDir)
|
||
|
}
|
||
|
|
||
|
// Prefer explicit HIP env var
|
||
|
hipPath := os.Getenv("HIP_PATH")
|
||
|
if hipPath != "" {
|
||
|
hipLibDir := filepath.Join(hipPath, "lib")
|
||
|
if rocmLibUsable(hipLibDir) {
|
||
|
slog.Debug("detected ROCM via HIP_PATH=" + hipPath)
|
||
|
return rocmTargetDir, setupLink(hipLibDir, rocmTargetDir)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Scan the library path for potential matches
|
||
|
ldPaths := strings.Split(os.Getenv("LD_LIBRARY_PATH"), ":")
|
||
|
for _, ldPath := range ldPaths {
|
||
|
d, err := filepath.Abs(ldPath)
|
||
|
if err != nil {
|
||
|
continue
|
||
|
}
|
||
|
if rocmLibUsable(d) {
|
||
|
return rocmTargetDir, setupLink(d, rocmTargetDir)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Well known location(s)
|
||
|
if rocmLibUsable("/opt/rocm/lib") {
|
||
|
return rocmTargetDir, setupLink("/opt/rocm/lib", rocmTargetDir)
|
||
|
}
|
||
|
err = os.MkdirAll(rocmTargetDir, 0755)
|
||
|
if err != nil {
|
||
|
return "", fmt.Errorf("failed to create empty rocm dir %s %w", rocmTargetDir, err)
|
||
|
}
|
||
|
|
||
|
// If we still haven't found a usable rocm, the user will have to download it on their own
|
||
|
slog.Warn("amdgpu detected, but no compatible rocm library found. Either install rocm v6, or run the following")
|
||
|
slog.Warn(fmt.Sprintf(curlMsg, version.Version, rocmTargetDir))
|
||
|
return "", fmt.Errorf("no suitable rocm found, falling back to CPU")
|
||
|
}
|
||
|
|
||
|
func AMDDriverVersion() (string, error) {
|
||
|
_, err := os.Stat(DriverVersionFile)
|
||
|
if err != nil {
|
||
|
return "", fmt.Errorf("amdgpu version file missing: %s %w", DriverVersionFile, err)
|
||
|
}
|
||
|
fp, err := os.Open(DriverVersionFile)
|
||
|
if err != nil {
|
||
|
return "", err
|
||
|
}
|
||
|
defer fp.Close()
|
||
|
verString, err := io.ReadAll(fp)
|
||
|
if err != nil {
|
||
|
return "", err
|
||
|
}
|
||
|
return strings.TrimSpace(string(verString)), nil
|
||
|
}
|
||
|
|
||
|
func AMDGFXVersions() map[int]Version {
|
||
|
res := map[int]Version{}
|
||
|
matches, _ := filepath.Glob(GPUPropertiesFileGlob)
|
||
|
for _, match := range matches {
|
||
|
fp, err := os.Open(match)
|
||
|
if err != nil {
|
||
|
slog.Debug(fmt.Sprintf("failed to open sysfs node file %s: %s", match, err))
|
||
|
continue
|
||
|
}
|
||
|
defer fp.Close()
|
||
|
i, err := strconv.Atoi(filepath.Base(filepath.Dir(match)))
|
||
|
if err != nil {
|
||
|
slog.Debug(fmt.Sprintf("failed to parse node ID %s", err))
|
||
|
continue
|
||
|
}
|
||
|
|
||
|
scanner := bufio.NewScanner(fp)
|
||
|
for scanner.Scan() {
|
||
|
line := strings.TrimSpace(scanner.Text())
|
||
|
if strings.HasPrefix(line, "gfx_target_version") {
|
||
|
ver := strings.Fields(line)
|
||
|
if len(ver) != 2 || len(ver[1]) < 5 {
|
||
|
|
||
|
if ver[1] == "0" {
|
||
|
// Silently skip the CPU
|
||
|
continue
|
||
|
} else {
|
||
|
slog.Debug("malformed " + line)
|
||
|
}
|
||
|
res[i] = Version{
|
||
|
Major: 0,
|
||
|
Minor: 0,
|
||
|
Patch: 0,
|
||
|
}
|
||
|
continue
|
||
|
}
|
||
|
l := len(ver[1])
|
||
|
patch, err1 := strconv.ParseUint(ver[1][l-2:l], 10, 32)
|
||
|
minor, err2 := strconv.ParseUint(ver[1][l-4:l-2], 10, 32)
|
||
|
major, err3 := strconv.ParseUint(ver[1][:l-4], 10, 32)
|
||
|
if err1 != nil || err2 != nil || err3 != nil {
|
||
|
slog.Debug("malformed int " + line)
|
||
|
continue
|
||
|
}
|
||
|
|
||
|
res[i] = Version{
|
||
|
Major: uint(major),
|
||
|
Minor: uint(minor),
|
||
|
Patch: uint(patch),
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
return res
|
||
|
}
|
||
|
|
||
|
func (v Version) ToGFXString() string {
|
||
|
return fmt.Sprintf("gfx%d%d%d", v.Major, v.Minor, v.Patch)
|
||
|
}
|