2024-02-15 17:15:09 -08:00
package gpu
import (
"bufio"
"errors"
"fmt"
"io"
"log/slog"
"os"
"path/filepath"
2024-05-07 14:54:26 -07:00
"regexp"
2024-02-15 17:15:09 -08:00
"slices"
2024-07-24 13:43:26 -07:00
"sort"
2024-02-15 17:15:09 -08:00
"strconv"
"strings"
2024-03-30 09:50:05 -07:00
2024-05-08 11:11:50 -07:00
"github.com/ollama/ollama/envconfig"
2024-03-30 09:50:05 -07:00
"github.com/ollama/ollama/format"
2024-02-15 17:15:09 -08:00
)
// Discovery logic for AMD/ROCm GPUs
const (
DriverVersionFile = "/sys/module/amdgpu/version"
AMDNodesSysfsDir = "/sys/class/kfd/kfd/topology/nodes/"
GPUPropertiesFileGlob = AMDNodesSysfsDir + "*/properties"
// Prefix with the node dir
GPUTotalMemoryFileGlob = "mem_banks/*/properties" // size_in_bytes line
2024-05-14 16:18:42 -07:00
// Direct Rendering Manager sysfs location
2024-05-18 12:34:31 -07:00
DRMDeviceDirGlob = "/sys/class/drm/card*/device"
2024-05-14 16:18:42 -07:00
DRMTotalMemoryFile = "mem_info_vram_total"
DRMUsedMemoryFile = "mem_info_vram_used"
// In hex; properties file is in decimal
DRMUniqueIDFile = "unique_id"
DRMVendorFile = "vendor"
DRMDeviceFile = "device"
2024-02-15 17:15:09 -08:00
)
var (
// Used to validate if the given ROCm lib is usable
2024-05-01 15:47:12 -07:00
ROCmLibGlobs = [ ] string { "libhipblas.so.2*" , "rocblas" } // TODO - probably include more coverage of files here...
RocmStandardLocations = [ ] string { "/opt/rocm/lib" , "/usr/lib64" }
2024-02-15 17:15:09 -08:00
)
// Gather GPU information from the amdgpu driver if any supported GPUs are detected
2024-05-15 15:13:16 -07:00
func AMDGetGPUInfo ( ) [ ] RocmGPUInfo {
resp := [ ] RocmGPUInfo { }
2024-02-15 17:15:09 -08:00
if ! AMDDetected ( ) {
2024-03-30 09:50:05 -07:00
return resp
2024-02-15 17:15:09 -08:00
}
// Opportunistic logging of driver version to aid in troubleshooting
2024-05-07 14:54:26 -07:00
driverMajor , driverMinor , err := AMDDriverVersion ( )
if err != nil {
2024-02-15 17:15:09 -08:00
// TODO - if we see users crash and burn with the upstreamed kernel this can be adjusted to hard-fail rocm support and fallback to CPU
2024-03-30 09:50:05 -07:00
slog . Warn ( "ollama recommends running the https://www.amd.com/en/support/linux-drivers" , "error" , err )
2024-02-15 17:15:09 -08:00
}
2024-03-30 09:50:05 -07:00
// Determine if the user has already pre-selected which GPUs to look at, then ignore the others
var visibleDevices [ ] string
2024-07-03 19:30:19 -07:00
hipVD := envconfig . HipVisibleDevices ( ) // zero based index only
rocrVD := envconfig . RocrVisibleDevices ( ) // zero based index or UUID, but consumer cards seem to not support UUID
gpuDO := envconfig . GpuDeviceOrdinal ( ) // zero based index
2024-03-30 09:50:05 -07:00
switch {
// TODO is this priorty order right?
case hipVD != "" :
visibleDevices = strings . Split ( hipVD , "," )
case rocrVD != "" :
visibleDevices = strings . Split ( rocrVD , "," )
// TODO - since we don't yet support UUIDs, consider detecting and reporting here
// all our test systems show GPU-XX indicating UUID is not supported
case gpuDO != "" :
visibleDevices = strings . Split ( gpuDO , "," )
2024-02-15 17:15:09 -08:00
}
2024-07-03 19:30:19 -07:00
gfxOverride := envconfig . HsaOverrideGfxVersion ( )
2024-03-30 09:50:05 -07:00
var supported [ ] string
libDir := ""
// The amdgpu driver always exposes the host CPU(s) first, but we have to skip them and subtract
// from the other IDs to get alignment with the HIP libraries expectations (zero is the first GPU, not the CPU)
matches , _ := filepath . Glob ( GPUPropertiesFileGlob )
2024-07-24 13:43:26 -07:00
sort . Slice ( matches , func ( i , j int ) bool {
// /sys/class/kfd/kfd/topology/nodes/<number>/properties
a , err := strconv . ParseInt ( filepath . Base ( filepath . Dir ( matches [ i ] ) ) , 10 , 64 )
if err != nil {
slog . Debug ( "parse err" , "error" , err , "match" , matches [ i ] )
return false
}
b , err := strconv . ParseInt ( filepath . Base ( filepath . Dir ( matches [ j ] ) ) , 10 , 64 )
if err != nil {
slog . Debug ( "parse err" , "error" , err , "match" , matches [ i ] )
return false
}
return a < b
} )
2024-03-30 09:50:05 -07:00
cpuCount := 0
for _ , match := range matches {
slog . Debug ( "evaluating amdgpu node " + match )
fp , err := os . Open ( match )
if err != nil {
slog . Debug ( "failed to open sysfs node" , "file" , match , "error" , err )
2024-02-15 17:15:09 -08:00
continue
}
2024-03-30 09:50:05 -07:00
defer fp . Close ( )
nodeID , err := strconv . Atoi ( filepath . Base ( filepath . Dir ( match ) ) )
if err != nil {
slog . Debug ( "failed to parse node ID" , "error" , err )
continue
2024-02-15 17:15:09 -08:00
}
2024-03-30 09:50:05 -07:00
scanner := bufio . NewScanner ( fp )
isCPU := false
var major , minor , patch uint64
2024-05-14 16:18:42 -07:00
var vendor , device , uniqueID uint64
2024-03-30 09:50:05 -07:00
for scanner . Scan ( ) {
line := strings . TrimSpace ( scanner . Text ( ) )
// Note: we could also use "cpu_cores_count X" where X is greater than zero to detect CPUs
if strings . HasPrefix ( line , "gfx_target_version" ) {
ver := strings . Fields ( line )
2024-02-15 17:15:09 -08:00
2024-03-30 09:50:05 -07:00
// Detect CPUs
if len ( ver ) == 2 && ver [ 1 ] == "0" {
slog . Debug ( "detected CPU " + match )
isCPU = true
break
}
2024-02-15 17:15:09 -08:00
2024-03-30 09:50:05 -07:00
if len ( ver ) != 2 || len ( ver [ 1 ] ) < 5 {
slog . Warn ( "malformed " + match , "gfx_target_version" , line )
// If this winds up being a CPU, our offsets may be wrong
continue
}
l := len ( ver [ 1 ] )
var err1 , err2 , err3 error
patch , err1 = strconv . ParseUint ( ver [ 1 ] [ l - 2 : l ] , 10 , 32 )
minor , err2 = strconv . ParseUint ( ver [ 1 ] [ l - 4 : l - 2 ] , 10 , 32 )
major , err3 = strconv . ParseUint ( ver [ 1 ] [ : l - 4 ] , 10 , 32 )
if err1 != nil || err2 != nil || err3 != nil {
slog . Debug ( "malformed int " + line )
continue
}
2024-05-07 14:54:26 -07:00
} else if strings . HasPrefix ( line , "vendor_id" ) {
ver := strings . Fields ( line )
if len ( ver ) != 2 {
2024-05-14 16:18:42 -07:00
slog . Debug ( "malformed" , "vendor_id" , line )
2024-05-07 14:54:26 -07:00
continue
}
2024-05-14 16:18:42 -07:00
vendor , err = strconv . ParseUint ( ver [ 1 ] , 10 , 64 )
2024-05-07 14:54:26 -07:00
if err != nil {
2024-05-14 16:18:42 -07:00
slog . Debug ( "malformed" , "vendor_id" , line , "error" , err )
2024-05-07 14:54:26 -07:00
}
} else if strings . HasPrefix ( line , "device_id" ) {
ver := strings . Fields ( line )
if len ( ver ) != 2 {
2024-05-14 16:18:42 -07:00
slog . Debug ( "malformed" , "device_id" , line )
2024-05-07 14:54:26 -07:00
continue
}
2024-05-14 16:18:42 -07:00
device , err = strconv . ParseUint ( ver [ 1 ] , 10 , 64 )
2024-05-07 14:54:26 -07:00
if err != nil {
2024-05-14 16:18:42 -07:00
slog . Debug ( "malformed" , "device_id" , line , "error" , err )
}
} else if strings . HasPrefix ( line , "unique_id" ) {
ver := strings . Fields ( line )
if len ( ver ) != 2 {
slog . Debug ( "malformed" , "unique_id" , line )
continue
}
uniqueID , err = strconv . ParseUint ( ver [ 1 ] , 10 , 64 )
if err != nil {
slog . Debug ( "malformed" , "unique_id" , line , "error" , err )
2024-05-07 14:54:26 -07:00
}
2024-03-30 09:50:05 -07:00
}
// TODO - any other properties we want to extract and record?
// vendor_id + device_id -> pci lookup for "Name"
// Other metrics that may help us understand relative performance between multiple GPUs
2024-02-15 17:15:09 -08:00
}
2024-05-14 16:18:42 -07:00
// Note: while ./mem_banks/*/used_memory exists, it doesn't appear to take other VRAM consumers
// into consideration, so we instead map the device over to the DRM driver sysfs nodes which
// do reliably report VRAM usage.
2024-03-30 09:50:05 -07:00
if isCPU {
cpuCount ++
continue
2024-02-15 17:15:09 -08:00
}
2024-03-30 09:50:05 -07:00
// CPUs are always first in the list
gpuID := nodeID - cpuCount
2024-02-15 17:15:09 -08:00
2024-03-30 09:50:05 -07:00
// Shouldn't happen, but just in case...
if gpuID < 0 {
slog . Error ( "unexpected amdgpu sysfs data resulted in negative GPU ID, please set OLLAMA_DEBUG=1 and report an issue" )
2024-06-05 12:07:20 -07:00
return nil
2024-02-15 17:15:09 -08:00
}
2024-03-30 09:50:05 -07:00
if int ( major ) < RocmComputeMin {
2024-05-07 14:54:26 -07:00
slog . Warn ( fmt . Sprintf ( "amdgpu too old gfx%d%x%x" , major , minor , patch ) , "gpu" , gpuID )
2024-02-15 17:15:09 -08:00
continue
}
2024-03-30 09:50:05 -07:00
// Look up the memory for the current node
2024-02-15 17:15:09 -08:00
totalMemory := uint64 ( 0 )
usedMemory := uint64 ( 0 )
2024-05-15 15:13:16 -07:00
var usedFile string
2024-05-14 16:18:42 -07:00
mapping := [ ] struct {
id uint64
filename string
} {
{ vendor , DRMVendorFile } ,
{ device , DRMDeviceFile } ,
{ uniqueID , DRMUniqueIDFile } , // Not all devices will report this
2024-02-15 17:15:09 -08:00
}
2024-05-14 16:18:42 -07:00
slog . Debug ( "mapping amdgpu to drm sysfs nodes" , "amdgpu" , match , "vendor" , vendor , "device" , device , "unique_id" , uniqueID )
// Map over to DRM location to find the total/free memory
drmMatches , _ := filepath . Glob ( DRMDeviceDirGlob )
for _ , devDir := range drmMatches {
matched := true
for _ , m := range mapping {
if m . id == 0 {
2024-06-05 12:07:20 -07:00
// Null ID means it didn't populate, so we can't use it to match
2024-05-14 16:18:42 -07:00
continue
}
filename := filepath . Join ( devDir , m . filename )
2024-06-05 12:07:20 -07:00
buf , err := os . ReadFile ( filename )
2024-05-14 16:18:42 -07:00
if err != nil {
slog . Debug ( "failed to read sysfs node" , "file" , filename , "error" , err )
matched = false
break
}
2024-06-05 12:07:20 -07:00
// values here are in hex, strip off the lead 0x and parse so we can compare the numeric (decimal) values in amdgpu
2024-05-14 16:18:42 -07:00
cmp , err := strconv . ParseUint ( strings . TrimPrefix ( strings . TrimSpace ( string ( buf ) ) , "0x" ) , 16 , 64 )
if err != nil {
slog . Debug ( "failed to parse sysfs node" , "file" , filename , "error" , err )
matched = false
break
}
if cmp != m . id {
matched = false
break
}
}
if ! matched {
2024-02-15 17:15:09 -08:00
continue
}
2024-05-14 16:18:42 -07:00
// Found the matching DRM directory
slog . Debug ( "matched" , "amdgpu" , match , "drm" , devDir )
totalFile := filepath . Join ( devDir , DRMTotalMemoryFile )
2024-06-05 12:07:20 -07:00
buf , err := os . ReadFile ( totalFile )
2024-02-15 17:15:09 -08:00
if err != nil {
2024-05-14 16:18:42 -07:00
slog . Debug ( "failed to read sysfs node" , "file" , totalFile , "error" , err )
break
2024-02-15 17:15:09 -08:00
}
2024-05-14 16:18:42 -07:00
totalMemory , err = strconv . ParseUint ( strings . TrimSpace ( string ( buf ) ) , 10 , 64 )
2024-02-15 17:15:09 -08:00
if err != nil {
2024-05-14 16:18:42 -07:00
slog . Debug ( "failed to parse sysfs node" , "file" , totalFile , "error" , err )
break
}
2024-05-15 15:13:16 -07:00
usedFile = filepath . Join ( devDir , DRMUsedMemoryFile )
usedMemory , err = getFreeMemory ( usedFile )
2024-05-14 16:18:42 -07:00
if err != nil {
2024-05-15 15:13:16 -07:00
slog . Debug ( "failed to update used memory" , "error" , err )
2024-02-15 17:15:09 -08:00
}
2024-05-14 16:18:42 -07:00
break
2024-02-15 17:15:09 -08:00
}
2024-03-30 09:50:05 -07:00
// iGPU detection, remove this check once we can support an iGPU variant of the rocm library
if totalMemory < IGPUMemLimit {
2024-05-07 14:54:26 -07:00
slog . Info ( "unsupported Radeon iGPU detected skipping" , "id" , gpuID , "total" , format . HumanBytes2 ( totalMemory ) )
2024-03-30 09:50:05 -07:00
continue
}
2024-05-07 14:54:26 -07:00
var name string
// TODO - PCI ID lookup
if vendor > 0 && device > 0 {
name = fmt . Sprintf ( "%04x:%04x" , vendor , device )
}
2024-03-30 09:50:05 -07:00
2024-05-07 14:54:26 -07:00
slog . Debug ( "amdgpu memory" , "gpu" , gpuID , "total" , format . HumanBytes2 ( totalMemory ) )
slog . Debug ( "amdgpu memory" , "gpu" , gpuID , "available" , format . HumanBytes2 ( totalMemory - usedMemory ) )
2024-05-15 15:13:16 -07:00
gpuInfo := RocmGPUInfo {
GpuInfo : GpuInfo {
Library : "rocm" ,
memInfo : memInfo {
TotalMemory : totalMemory ,
FreeMemory : ( totalMemory - usedMemory ) ,
} ,
2024-06-05 12:07:20 -07:00
ID : strconv . Itoa ( gpuID ) ,
2024-05-15 15:13:16 -07:00
Name : name ,
Compute : fmt . Sprintf ( "gfx%d%x%x" , major , minor , patch ) ,
MinimumMemory : rocmMinimumMemory ,
DriverMajor : driverMajor ,
DriverMinor : driverMinor ,
2024-03-30 09:50:05 -07:00
} ,
2024-05-15 15:13:16 -07:00
usedFilepath : usedFile ,
2024-03-30 09:50:05 -07:00
}
// If the user wants to filter to a subset of devices, filter out if we aren't a match
if len ( visibleDevices ) > 0 {
include := false
for _ , visible := range visibleDevices {
if visible == gpuInfo . ID {
include = true
break
}
}
if ! include {
slog . Info ( "filtering out device per user request" , "id" , gpuInfo . ID , "visible_devices" , visibleDevices )
continue
}
}
// Final validation is gfx compatibility - load the library if we haven't already loaded it
// even if the user overrides, we still need to validate the library
if libDir == "" {
libDir , err = AMDValidateLibDir ( )
if err != nil {
slog . Warn ( "unable to verify rocm library, will use cpu" , "error" , err )
2024-06-05 12:07:20 -07:00
return nil
2024-03-30 09:50:05 -07:00
}
}
gpuInfo . DependencyPath = libDir
if gfxOverride == "" {
// Only load supported list once
if len ( supported ) == 0 {
supported , err = GetSupportedGFX ( libDir )
if err != nil {
slog . Warn ( "failed to lookup supported GFX types, falling back to CPU mode" , "error" , err )
2024-06-05 12:07:20 -07:00
return nil
2024-03-30 09:50:05 -07:00
}
slog . Debug ( "rocm supported GPUs" , "types" , supported )
}
2024-05-07 14:54:26 -07:00
gfx := gpuInfo . Compute
2024-03-30 09:50:05 -07:00
if ! slices . Contains [ [ ] string , string ] ( supported , gfx ) {
slog . Warn ( "amdgpu is not supported" , "gpu" , gpuInfo . ID , "gpu_type" , gfx , "library" , libDir , "supported_types" , supported )
// TODO - consider discrete markdown just for ROCM troubleshooting?
slog . Warn ( "See https://github.com/ollama/ollama/blob/main/docs/gpu.md#overrides for HSA_OVERRIDE_GFX_VERSION usage" )
continue
} else {
slog . Info ( "amdgpu is supported" , "gpu" , gpuInfo . ID , "gpu_type" , gfx )
}
} else {
2024-05-07 14:54:26 -07:00
slog . Info ( "skipping rocm gfx compatibility check" , "HSA_OVERRIDE_GFX_VERSION" , gfxOverride )
2024-03-30 09:50:05 -07:00
}
2024-05-31 16:15:21 -07:00
// Check for env var workarounds
if name == "1002:687f" { // Vega RX 56
gpuInfo . EnvWorkarounds = append ( gpuInfo . EnvWorkarounds , [ 2 ] string { "HSA_ENABLE_SDMA" , "0" } )
}
2024-03-30 09:50:05 -07:00
// The GPU has passed all the verification steps and is supported
resp = append ( resp , gpuInfo )
2024-02-15 17:15:09 -08:00
}
2024-03-30 09:50:05 -07:00
if len ( resp ) == 0 {
slog . Info ( "no compatible amdgpu devices detected" )
2024-02-15 17:15:09 -08:00
}
2024-03-30 09:50:05 -07:00
return resp
2024-02-15 17:15:09 -08:00
}
// Quick check for AMD driver so we can skip amdgpu discovery if not present
func AMDDetected ( ) bool {
// Some driver versions (older?) don't have a version file, so just lookup the parent dir
sysfsDir := filepath . Dir ( DriverVersionFile )
_ , err := os . Stat ( sysfsDir )
if errors . Is ( err , os . ErrNotExist ) {
slog . Debug ( "amdgpu driver not detected " + sysfsDir )
return false
} else if err != nil {
2024-03-30 09:50:05 -07:00
slog . Debug ( "error looking up amd driver" , "path" , sysfsDir , "error" , err )
2024-02-15 17:15:09 -08:00
return false
}
return true
}
// Prefer to use host installed ROCm, as long as it meets our minimum requirements
// failing that, tell the user how to download it on their own
func AMDValidateLibDir ( ) ( string , error ) {
2024-03-30 09:50:05 -07:00
libDir , err := commonAMDValidateLibDir ( )
2024-03-10 12:13:46 -07:00
if err == nil {
2024-03-30 09:50:05 -07:00
return libDir , nil
2024-03-10 12:13:46 -07:00
}
2024-03-08 09:45:55 -08:00
// Well known ollama installer path
installedRocmDir := "/usr/share/ollama/lib/rocm"
if rocmLibUsable ( installedRocmDir ) {
2024-03-30 09:50:05 -07:00
return installedRocmDir , nil
2024-02-15 17:15:09 -08:00
}
2024-03-08 09:45:55 -08:00
// If we still haven't found a usable rocm, the user will have to install it on their own
slog . Warn ( "amdgpu detected, but no compatible rocm library found. Either install rocm v6, or follow manual install instructions at https://github.com/ollama/ollama/blob/main/docs/linux.md#manual-install" )
2024-08-01 14:52:15 -07:00
return "" , errors . New ( "no suitable rocm found, falling back to CPU" )
2024-02-15 17:15:09 -08:00
}
2024-05-07 14:54:26 -07:00
func AMDDriverVersion ( ) ( driverMajor , driverMinor int , err error ) {
_ , err = os . Stat ( DriverVersionFile )
2024-02-15 17:15:09 -08:00
if err != nil {
2024-05-07 14:54:26 -07:00
return 0 , 0 , fmt . Errorf ( "amdgpu version file missing: %s %w" , DriverVersionFile , err )
2024-02-15 17:15:09 -08:00
}
fp , err := os . Open ( DriverVersionFile )
if err != nil {
2024-05-07 14:54:26 -07:00
return 0 , 0 , err
2024-02-15 17:15:09 -08:00
}
defer fp . Close ( )
verString , err := io . ReadAll ( fp )
if err != nil {
2024-05-07 14:54:26 -07:00
return 0 , 0 , err
}
pattern := ` \A(\d+)\.(\d+).* `
regex := regexp . MustCompile ( pattern )
match := regex . FindStringSubmatch ( string ( verString ) )
if len ( match ) < 2 {
return 0 , 0 , fmt . Errorf ( "malformed version string %s" , string ( verString ) )
}
driverMajor , err = strconv . Atoi ( match [ 1 ] )
if err != nil {
return 0 , 0 , err
}
driverMinor , err = strconv . Atoi ( match [ 2 ] )
if err != nil {
return 0 , 0 , err
2024-02-15 17:15:09 -08:00
}
2024-05-07 14:54:26 -07:00
return driverMajor , driverMinor , nil
2024-02-15 17:15:09 -08:00
}
2024-05-15 15:13:16 -07:00
func ( gpus RocmGPUInfoList ) RefreshFreeMemory ( ) error {
if len ( gpus ) == 0 {
return nil
}
for i := range gpus {
usedMemory , err := getFreeMemory ( gpus [ i ] . usedFilepath )
if err != nil {
return err
}
slog . Debug ( "updating rocm free memory" , "gpu" , gpus [ i ] . ID , "name" , gpus [ i ] . Name , "before" , format . HumanBytes2 ( gpus [ i ] . FreeMemory ) , "now" , format . HumanBytes2 ( gpus [ i ] . TotalMemory - usedMemory ) )
gpus [ i ] . FreeMemory = gpus [ i ] . TotalMemory - usedMemory
}
return nil
}
func getFreeMemory ( usedFile string ) ( uint64 , error ) {
2024-06-05 12:07:20 -07:00
buf , err := os . ReadFile ( usedFile )
2024-05-15 15:13:16 -07:00
if err != nil {
return 0 , fmt . Errorf ( "failed to read sysfs node %s %w" , usedFile , err )
}
usedMemory , err := strconv . ParseUint ( strings . TrimSpace ( string ( buf ) ) , 10 , 64 )
if err != nil {
slog . Debug ( "failed to parse sysfs node" , "file" , usedFile , "error" , err )
return 0 , fmt . Errorf ( "failed to parse sysfs node %s %w" , usedFile , err )
}
return usedMemory , nil
}