2024-02-16 01:15:09 +00:00
package gpu
import (
"bufio"
"errors"
"fmt"
"io"
"log/slog"
"os"
"path/filepath"
"slices"
"strconv"
"strings"
)
// Discovery logic for AMD/ROCm GPUs
const (
DriverVersionFile = "/sys/module/amdgpu/version"
AMDNodesSysfsDir = "/sys/class/kfd/kfd/topology/nodes/"
GPUPropertiesFileGlob = AMDNodesSysfsDir + "*/properties"
// Prefix with the node dir
GPUTotalMemoryFileGlob = "mem_banks/*/properties" // size_in_bytes line
GPUUsedMemoryFileGlob = "mem_banks/*/used_memory"
RocmStandardLocation = "/opt/rocm/lib"
)
var (
// Used to validate if the given ROCm lib is usable
ROCmLibGlobs = [ ] string { "libhipblas.so.2*" , "rocblas" } // TODO - probably include more coverage of files here...
)
// Gather GPU information from the amdgpu driver if any supported GPUs are detected
// HIP_VISIBLE_DEVICES will be set if we detect a mix of unsupported and supported devices
// and the user hasn't already set this variable
func AMDGetGPUInfo ( resp * GpuInfo ) {
// TODO - DRY this out with windows
if ! AMDDetected ( ) {
return
}
skip := map [ int ] interface { } { }
// Opportunistic logging of driver version to aid in troubleshooting
ver , err := AMDDriverVersion ( )
if err == nil {
slog . Info ( "AMD Driver: " + ver )
} else {
// TODO - if we see users crash and burn with the upstreamed kernel this can be adjusted to hard-fail rocm support and fallback to CPU
slog . Warn ( fmt . Sprintf ( "ollama recommends running the https://www.amd.com/en/support/linux-drivers: %s" , err ) )
}
// If the user has specified exactly which GPUs to use, look up their memory
visibleDevices := os . Getenv ( "HIP_VISIBLE_DEVICES" )
if visibleDevices != "" {
ids := [ ] int { }
for _ , idStr := range strings . Split ( visibleDevices , "," ) {
id , err := strconv . Atoi ( idStr )
if err != nil {
slog . Warn ( fmt . Sprintf ( "malformed HIP_VISIBLE_DEVICES=%s %s" , visibleDevices , err ) )
} else {
ids = append ( ids , id )
}
}
amdProcMemLookup ( resp , nil , ids )
return
}
// Gather GFX version information from all detected cards
gfx := AMDGFXVersions ( )
verStrings := [ ] string { }
for i , v := range gfx {
verStrings = append ( verStrings , v . ToGFXString ( ) )
if v . Major == 0 {
// Silently skip CPUs
skip [ i ] = struct { } { }
continue
}
if v . Major < 9 {
// TODO consider this a build-time setting if we can support 8xx family GPUs
slog . Warn ( fmt . Sprintf ( "amdgpu [%d] too old %s" , i , v . ToGFXString ( ) ) )
skip [ i ] = struct { } { }
}
}
slog . Info ( fmt . Sprintf ( "detected amdgpu versions %v" , verStrings ) )
// Abort if all GPUs are skipped
if len ( skip ) >= len ( gfx ) {
slog . Info ( "all detected amdgpus are skipped, falling back to CPU" )
return
}
// If we got this far, then we have at least 1 GPU that's a ROCm candidate, so make sure we have a lib
libDir , err := AMDValidateLibDir ( )
if err != nil {
slog . Warn ( fmt . Sprintf ( "unable to verify rocm library, will use cpu: %s" , err ) )
return
}
gfxOverride := os . Getenv ( "HSA_OVERRIDE_GFX_VERSION" )
if gfxOverride == "" {
supported , err := GetSupportedGFX ( libDir )
if err != nil {
slog . Warn ( fmt . Sprintf ( "failed to lookup supported GFX types, falling back to CPU mode: %s" , err ) )
return
}
slog . Debug ( fmt . Sprintf ( "rocm supported GPU types %v" , supported ) )
for i , v := range gfx {
if ! slices . Contains [ [ ] string , string ] ( supported , v . ToGFXString ( ) ) {
slog . Warn ( fmt . Sprintf ( "amdgpu [%d] %s is not supported by %s %v" , i , v . ToGFXString ( ) , libDir , supported ) )
// TODO - consider discrete markdown just for ROCM troubleshooting?
slog . Warn ( "See https://github.com/ollama/ollama/blob/main/docs/troubleshooting.md for HSA_OVERRIDE_GFX_VERSION usage" )
skip [ i ] = struct { } { }
} else {
slog . Info ( fmt . Sprintf ( "amdgpu [%d] %s is supported" , i , v . ToGFXString ( ) ) )
}
}
} else {
slog . Debug ( "skipping rocm gfx compatibility check with HSA_OVERRIDE_GFX_VERSION=" + gfxOverride )
}
if len ( skip ) >= len ( gfx ) {
slog . Info ( "all detected amdgpus are skipped, falling back to CPU" )
return
}
ids := make ( [ ] int , len ( gfx ) )
i := 0
for k := range gfx {
ids [ i ] = k
i ++
}
amdProcMemLookup ( resp , skip , ids )
if resp . memInfo . DeviceCount == 0 {
return
}
if len ( skip ) > 0 {
amdSetVisibleDevices ( ids , skip )
}
}
// Walk the sysfs nodes for the available GPUs and gather information from them
// skipping over any devices in the skip map
func amdProcMemLookup ( resp * GpuInfo , skip map [ int ] interface { } , ids [ ] int ) {
resp . memInfo . DeviceCount = 0
resp . memInfo . TotalMemory = 0
resp . memInfo . FreeMemory = 0
if len ( ids ) == 0 {
slog . Debug ( "discovering all amdgpu devices" )
entries , err := os . ReadDir ( AMDNodesSysfsDir )
if err != nil {
slog . Warn ( fmt . Sprintf ( "failed to read amdgpu sysfs %s - %s" , AMDNodesSysfsDir , err ) )
return
}
for _ , node := range entries {
if ! node . IsDir ( ) {
continue
}
id , err := strconv . Atoi ( node . Name ( ) )
if err != nil {
slog . Warn ( "malformed amdgpu sysfs node id " + node . Name ( ) )
continue
}
ids = append ( ids , id )
}
}
slog . Debug ( fmt . Sprintf ( "discovering amdgpu devices %v" , ids ) )
for _ , id := range ids {
if _ , skipped := skip [ id ] ; skipped {
continue
}
totalMemory := uint64 ( 0 )
usedMemory := uint64 ( 0 )
propGlob := filepath . Join ( AMDNodesSysfsDir , strconv . Itoa ( id ) , GPUTotalMemoryFileGlob )
propFiles , err := filepath . Glob ( propGlob )
if err != nil {
slog . Warn ( fmt . Sprintf ( "error looking up total GPU memory: %s %s" , propGlob , err ) )
}
// 1 or more memory banks - sum the values of all of them
for _ , propFile := range propFiles {
fp , err := os . Open ( propFile )
if err != nil {
slog . Warn ( fmt . Sprintf ( "failed to open sysfs node file %s: %s" , propFile , err ) )
continue
}
defer fp . Close ( )
scanner := bufio . NewScanner ( fp )
for scanner . Scan ( ) {
line := strings . TrimSpace ( scanner . Text ( ) )
if strings . HasPrefix ( line , "size_in_bytes" ) {
ver := strings . Fields ( line )
if len ( ver ) != 2 {
slog . Warn ( "malformed " + line )
continue
}
bankSizeInBytes , err := strconv . ParseUint ( ver [ 1 ] , 10 , 64 )
if err != nil {
slog . Warn ( "malformed int " + line )
continue
}
totalMemory += bankSizeInBytes
}
}
}
if totalMemory == 0 {
continue
}
usedGlob := filepath . Join ( AMDNodesSysfsDir , strconv . Itoa ( id ) , GPUUsedMemoryFileGlob )
usedFiles , err := filepath . Glob ( usedGlob )
if err != nil {
slog . Warn ( fmt . Sprintf ( "error looking up used GPU memory: %s %s" , usedGlob , err ) )
continue
}
for _ , usedFile := range usedFiles {
fp , err := os . Open ( usedFile )
if err != nil {
slog . Warn ( fmt . Sprintf ( "failed to open sysfs node file %s: %s" , usedFile , err ) )
continue
}
defer fp . Close ( )
data , err := io . ReadAll ( fp )
if err != nil {
slog . Warn ( fmt . Sprintf ( "failed to read sysfs node file %s: %s" , usedFile , err ) )
continue
}
used , err := strconv . ParseUint ( strings . TrimSpace ( string ( data ) ) , 10 , 64 )
if err != nil {
slog . Warn ( fmt . Sprintf ( "malformed used memory %s: %s" , string ( data ) , err ) )
continue
}
usedMemory += used
}
slog . Info ( fmt . Sprintf ( "[%d] amdgpu totalMemory %d" , id , totalMemory ) )
slog . Info ( fmt . Sprintf ( "[%d] amdgpu freeMemory %d" , id , ( totalMemory - usedMemory ) ) )
resp . memInfo . DeviceCount ++
resp . memInfo . TotalMemory += totalMemory
resp . memInfo . FreeMemory += ( totalMemory - usedMemory )
}
if resp . memInfo . DeviceCount > 0 {
resp . Library = "rocm"
}
}
// Quick check for AMD driver so we can skip amdgpu discovery if not present
func AMDDetected ( ) bool {
// Some driver versions (older?) don't have a version file, so just lookup the parent dir
sysfsDir := filepath . Dir ( DriverVersionFile )
_ , err := os . Stat ( sysfsDir )
if errors . Is ( err , os . ErrNotExist ) {
slog . Debug ( "amdgpu driver not detected " + sysfsDir )
return false
} else if err != nil {
slog . Debug ( fmt . Sprintf ( "error looking up amd driver %s %s" , sysfsDir , err ) )
return false
}
return true
}
func setupLink ( source , target string ) error {
if err := os . RemoveAll ( target ) ; err != nil {
return fmt . Errorf ( "failed to remove old rocm directory %s %w" , target , err )
}
if err := os . Symlink ( source , target ) ; err != nil {
return fmt . Errorf ( "failed to create link %s => %s %w" , source , target , err )
}
slog . Debug ( fmt . Sprintf ( "host rocm linked %s => %s" , source , target ) )
return nil
}
// Ensure the AMD rocm lib dir is wired up
// Prefer to use host installed ROCm, as long as it meets our minimum requirements
// failing that, tell the user how to download it on their own
func AMDValidateLibDir ( ) ( string , error ) {
// We rely on the rpath compiled into our library to find rocm
// so we establish a symlink to wherever we find it on the system
2024-03-08 17:45:55 +00:00
// to <payloads>/rocm
payloadsDir , err := PayloadsDir ( )
2024-02-16 01:15:09 +00:00
if err != nil {
2024-03-08 17:45:55 +00:00
return "" , err
2024-02-16 01:15:09 +00:00
}
2024-03-08 17:45:55 +00:00
// If we already have a rocm dependency wired, nothing more to do
rocmTargetDir := filepath . Join ( payloadsDir , "rocm" )
2024-02-16 01:15:09 +00:00
if rocmLibUsable ( rocmTargetDir ) {
return rocmTargetDir , nil
}
2024-03-08 17:45:55 +00:00
// Well known ollama installer path
installedRocmDir := "/usr/share/ollama/lib/rocm"
if rocmLibUsable ( installedRocmDir ) {
return rocmTargetDir , setupLink ( installedRocmDir , rocmTargetDir )
2024-02-16 01:15:09 +00:00
}
// Prefer explicit HIP env var
hipPath := os . Getenv ( "HIP_PATH" )
if hipPath != "" {
hipLibDir := filepath . Join ( hipPath , "lib" )
if rocmLibUsable ( hipLibDir ) {
slog . Debug ( "detected ROCM via HIP_PATH=" + hipPath )
return rocmTargetDir , setupLink ( hipLibDir , rocmTargetDir )
}
}
// Scan the library path for potential matches
ldPaths := strings . Split ( os . Getenv ( "LD_LIBRARY_PATH" ) , ":" )
for _ , ldPath := range ldPaths {
d , err := filepath . Abs ( ldPath )
if err != nil {
continue
}
if rocmLibUsable ( d ) {
return rocmTargetDir , setupLink ( d , rocmTargetDir )
}
}
// Well known location(s)
if rocmLibUsable ( "/opt/rocm/lib" ) {
return rocmTargetDir , setupLink ( "/opt/rocm/lib" , rocmTargetDir )
}
2024-03-08 17:45:55 +00:00
// If we still haven't found a usable rocm, the user will have to install it on their own
slog . Warn ( "amdgpu detected, but no compatible rocm library found. Either install rocm v6, or follow manual install instructions at https://github.com/ollama/ollama/blob/main/docs/linux.md#manual-install" )
2024-02-16 01:15:09 +00:00
return "" , fmt . Errorf ( "no suitable rocm found, falling back to CPU" )
}
func AMDDriverVersion ( ) ( string , error ) {
_ , err := os . Stat ( DriverVersionFile )
if err != nil {
return "" , fmt . Errorf ( "amdgpu version file missing: %s %w" , DriverVersionFile , err )
}
fp , err := os . Open ( DriverVersionFile )
if err != nil {
return "" , err
}
defer fp . Close ( )
verString , err := io . ReadAll ( fp )
if err != nil {
return "" , err
}
return strings . TrimSpace ( string ( verString ) ) , nil
}
func AMDGFXVersions ( ) map [ int ] Version {
res := map [ int ] Version { }
matches , _ := filepath . Glob ( GPUPropertiesFileGlob )
for _ , match := range matches {
fp , err := os . Open ( match )
if err != nil {
slog . Debug ( fmt . Sprintf ( "failed to open sysfs node file %s: %s" , match , err ) )
continue
}
defer fp . Close ( )
i , err := strconv . Atoi ( filepath . Base ( filepath . Dir ( match ) ) )
if err != nil {
slog . Debug ( fmt . Sprintf ( "failed to parse node ID %s" , err ) )
continue
}
scanner := bufio . NewScanner ( fp )
for scanner . Scan ( ) {
line := strings . TrimSpace ( scanner . Text ( ) )
if strings . HasPrefix ( line , "gfx_target_version" ) {
ver := strings . Fields ( line )
if len ( ver ) != 2 || len ( ver [ 1 ] ) < 5 {
if ver [ 1 ] == "0" {
// Silently skip the CPU
continue
} else {
slog . Debug ( "malformed " + line )
}
res [ i ] = Version {
Major : 0 ,
Minor : 0 ,
Patch : 0 ,
}
continue
}
l := len ( ver [ 1 ] )
patch , err1 := strconv . ParseUint ( ver [ 1 ] [ l - 2 : l ] , 10 , 32 )
minor , err2 := strconv . ParseUint ( ver [ 1 ] [ l - 4 : l - 2 ] , 10 , 32 )
major , err3 := strconv . ParseUint ( ver [ 1 ] [ : l - 4 ] , 10 , 32 )
if err1 != nil || err2 != nil || err3 != nil {
slog . Debug ( "malformed int " + line )
continue
}
res [ i ] = Version {
Major : uint ( major ) ,
Minor : uint ( minor ) ,
Patch : uint ( patch ) ,
}
}
}
}
return res
}
func ( v Version ) ToGFXString ( ) string {
return fmt . Sprintf ( "gfx%d%d%d" , v . Major , v . Minor , v . Patch )
}