Don't disable GPUs on arm without AVX

AVX is an x86 feature, so ARM should be excluded from
the check.
This commit is contained in:
Daniel Hiltgen 2024-01-28 15:22:38 -08:00
parent e02ecfb6c8
commit 15562e887d

View file

@ -122,15 +122,15 @@ func GetGPUInfo() GpuInfo {
initGPUHandles() initGPUHandles()
} }
// All our GPU builds have AVX enabled, so fallback to CPU if we don't detect at least AVX // All our GPU builds on x86 have AVX enabled, so fallback to CPU if we don't detect at least AVX
cpuVariant := GetCPUVariant() cpuVariant := GetCPUVariant()
if cpuVariant == "" { if cpuVariant == "" && runtime.GOARCH == "amd64" {
slog.Warn("CPU does not have AVX or AVX2, disabling GPU support.") slog.Warn("CPU does not have AVX or AVX2, disabling GPU support.")
} }
var memInfo C.mem_info_t var memInfo C.mem_info_t
resp := GpuInfo{} resp := GpuInfo{}
if gpuHandles.cuda != nil && cpuVariant != "" { if gpuHandles.cuda != nil && (cpuVariant != "" || runtime.GOARCH != "amd64") {
C.cuda_check_vram(*gpuHandles.cuda, &memInfo) C.cuda_check_vram(*gpuHandles.cuda, &memInfo)
if memInfo.err != nil { if memInfo.err != nil {
slog.Info(fmt.Sprintf("error looking up CUDA GPU memory: %s", C.GoString(memInfo.err))) slog.Info(fmt.Sprintf("error looking up CUDA GPU memory: %s", C.GoString(memInfo.err)))
@ -149,7 +149,7 @@ func GetGPUInfo() GpuInfo {
slog.Info(fmt.Sprintf("CUDA GPU is too old. Falling back to CPU mode. Compute Capability detected: %d.%d", cc.major, cc.minor)) slog.Info(fmt.Sprintf("CUDA GPU is too old. Falling back to CPU mode. Compute Capability detected: %d.%d", cc.major, cc.minor))
} }
} }
} else if gpuHandles.rocm != nil && cpuVariant != "" { } else if gpuHandles.rocm != nil && (cpuVariant != "" || runtime.GOARCH != "amd64") {
C.rocm_check_vram(*gpuHandles.rocm, &memInfo) C.rocm_check_vram(*gpuHandles.rocm, &memInfo)
if memInfo.err != nil { if memInfo.err != nil {
slog.Info(fmt.Sprintf("error looking up ROCm GPU memory: %s", C.GoString(memInfo.err))) slog.Info(fmt.Sprintf("error looking up ROCm GPU memory: %s", C.GoString(memInfo.err)))