Refactor intel gpu discovery

This commit is contained in:
Daniel Hiltgen 2024-05-29 16:37:34 -07:00
parent 48702dd149
commit 4e2b7e181d
4 changed files with 304 additions and 169 deletions

View file

@ -16,7 +16,6 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"runtime" "runtime"
"strconv"
"strings" "strings"
"sync" "sync"
"unsafe" "unsafe"
@ -25,16 +24,21 @@ import (
"github.com/ollama/ollama/format" "github.com/ollama/ollama/format"
) )
type handles struct { type cudaHandles struct {
deviceCount int deviceCount int
cudart *C.cudart_handle_t cudart *C.cudart_handle_t
nvcuda *C.nvcuda_handle_t nvcuda *C.nvcuda_handle_t
}
type oneapiHandles struct {
oneapi *C.oneapi_handle_t oneapi *C.oneapi_handle_t
deviceCount int
} }
const ( const (
cudaMinimumMemory = 457 * format.MebiByte cudaMinimumMemory = 457 * format.MebiByte
rocmMinimumMemory = 457 * format.MebiByte rocmMinimumMemory = 457 * format.MebiByte
// TODO OneAPI minimum memory
) )
var ( var (
@ -107,19 +111,19 @@ var OneapiLinuxGlobs = []string{
var CudaTegra string = os.Getenv("JETSON_JETPACK") var CudaTegra string = os.Getenv("JETSON_JETPACK")
// Note: gpuMutex must already be held // Note: gpuMutex must already be held
func initCudaHandles() *handles { func initCudaHandles() *cudaHandles {
// TODO - if the ollama build is CPU only, don't do these checks as they're irrelevant and confusing // TODO - if the ollama build is CPU only, don't do these checks as they're irrelevant and confusing
gpuHandles := &handles{} cHandles := &cudaHandles{}
// Short Circuit if we already know which library to use // Short Circuit if we already know which library to use
if nvcudaLibPath != "" { if nvcudaLibPath != "" {
gpuHandles.deviceCount, gpuHandles.nvcuda, _ = LoadNVCUDAMgmt([]string{nvcudaLibPath}) cHandles.deviceCount, cHandles.nvcuda, _ = LoadNVCUDAMgmt([]string{nvcudaLibPath})
return gpuHandles return cHandles
} }
if cudartLibPath != "" { if cudartLibPath != "" {
gpuHandles.deviceCount, gpuHandles.cudart, _ = LoadCUDARTMgmt([]string{cudartLibPath}) cHandles.deviceCount, cHandles.cudart, _ = LoadCUDARTMgmt([]string{cudartLibPath})
return gpuHandles return cHandles
} }
slog.Debug("searching for GPU discovery libraries for NVIDIA") slog.Debug("searching for GPU discovery libraries for NVIDIA")
@ -127,8 +131,6 @@ func initCudaHandles() *handles {
var cudartMgmtPatterns []string var cudartMgmtPatterns []string
var nvcudaMgmtName string var nvcudaMgmtName string
var nvcudaMgmtPatterns []string var nvcudaMgmtPatterns []string
var oneapiMgmtName string
var oneapiMgmtPatterns []string
tmpDir, _ := PayloadsDir() tmpDir, _ := PayloadsDir()
switch runtime.GOOS { switch runtime.GOOS {
@ -140,8 +142,6 @@ func initCudaHandles() *handles {
// Aligned with driver, we can't carry as payloads // Aligned with driver, we can't carry as payloads
nvcudaMgmtName = "nvcuda.dll" nvcudaMgmtName = "nvcuda.dll"
nvcudaMgmtPatterns = NvcudaWindowsGlobs nvcudaMgmtPatterns = NvcudaWindowsGlobs
oneapiMgmtName = "ze_intel_gpu64.dll"
oneapiMgmtPatterns = OneapiWindowsGlobs
case "linux": case "linux":
cudartMgmtName = "libcudart.so*" cudartMgmtName = "libcudart.so*"
if tmpDir != "" { if tmpDir != "" {
@ -152,10 +152,8 @@ func initCudaHandles() *handles {
// Aligned with driver, we can't carry as payloads // Aligned with driver, we can't carry as payloads
nvcudaMgmtName = "libcuda.so*" nvcudaMgmtName = "libcuda.so*"
nvcudaMgmtPatterns = NvcudaLinuxGlobs nvcudaMgmtPatterns = NvcudaLinuxGlobs
oneapiMgmtName = "libze_intel_gpu.so"
oneapiMgmtPatterns = OneapiLinuxGlobs
default: default:
return gpuHandles return cHandles
} }
nvcudaLibPaths := FindGPULibs(nvcudaMgmtName, nvcudaMgmtPatterns) nvcudaLibPaths := FindGPULibs(nvcudaMgmtName, nvcudaMgmtPatterns)
@ -163,10 +161,10 @@ func initCudaHandles() *handles {
deviceCount, nvcuda, libPath := LoadNVCUDAMgmt(nvcudaLibPaths) deviceCount, nvcuda, libPath := LoadNVCUDAMgmt(nvcudaLibPaths)
if nvcuda != nil { if nvcuda != nil {
slog.Debug("detected GPUs", "count", deviceCount, "library", libPath) slog.Debug("detected GPUs", "count", deviceCount, "library", libPath)
gpuHandles.nvcuda = nvcuda cHandles.nvcuda = nvcuda
gpuHandles.deviceCount = deviceCount cHandles.deviceCount = deviceCount
nvcudaLibPath = libPath nvcudaLibPath = libPath
return gpuHandles return cHandles
} }
} }
@ -175,26 +173,45 @@ func initCudaHandles() *handles {
deviceCount, cudart, libPath := LoadCUDARTMgmt(cudartLibPaths) deviceCount, cudart, libPath := LoadCUDARTMgmt(cudartLibPaths)
if cudart != nil { if cudart != nil {
slog.Debug("detected GPUs", "library", libPath, "count", deviceCount) slog.Debug("detected GPUs", "library", libPath, "count", deviceCount)
gpuHandles.cudart = cudart cHandles.cudart = cudart
gpuHandles.deviceCount = deviceCount cHandles.deviceCount = deviceCount
cudartLibPath = libPath cudartLibPath = libPath
return gpuHandles return cHandles
} }
} }
return cHandles
}
// Note: gpuMutex must already be held
func initOneAPIHandles() *oneapiHandles {
oHandles := &oneapiHandles{}
var oneapiMgmtName string
var oneapiMgmtPatterns []string
// Short Circuit if we already know which library to use
if oneapiLibPath != "" {
oHandles.deviceCount, oHandles.oneapi, _ = LoadOneapiMgmt([]string{oneapiLibPath})
return oHandles
}
switch runtime.GOOS {
case "windows":
oneapiMgmtName = "ze_intel_gpu64.dll"
oneapiMgmtPatterns = OneapiWindowsGlobs
case "linux":
oneapiMgmtName = "libze_intel_gpu.so"
oneapiMgmtPatterns = OneapiLinuxGlobs
default:
return oHandles
}
oneapiLibPaths := FindGPULibs(oneapiMgmtName, oneapiMgmtPatterns) oneapiLibPaths := FindGPULibs(oneapiMgmtName, oneapiMgmtPatterns)
if len(oneapiLibPaths) > 0 { if len(oneapiLibPaths) > 0 {
deviceCount, oneapi, libPath := LoadOneapiMgmt(oneapiLibPaths) oHandles.deviceCount, oHandles.oneapi, oneapiLibPath = LoadOneapiMgmt(oneapiLibPaths)
if oneapi != nil {
slog.Debug("detected Intel GPUs", "library", libPath, "count", deviceCount)
gpuHandles.oneapi = oneapi
gpuHandles.deviceCount = deviceCount
oneapiLibPath = libPath
return gpuHandles
}
} }
return gpuHandles return oHandles
} }
func GetGPUInfo() GpuInfoList { func GetGPUInfo() GpuInfoList {
@ -203,16 +220,22 @@ func GetGPUInfo() GpuInfoList {
gpuMutex.Lock() gpuMutex.Lock()
defer gpuMutex.Unlock() defer gpuMutex.Unlock()
needRefresh := true needRefresh := true
var gpuHandles *handles var cHandles *cudaHandles
var oHandles *oneapiHandles
defer func() { defer func() {
if gpuHandles == nil { if cHandles != nil {
return if cHandles.cudart != nil {
C.cudart_release(*cHandles.cudart)
} }
if gpuHandles.cudart != nil { if cHandles.nvcuda != nil {
C.cudart_release(*gpuHandles.cudart) C.nvcuda_release(*cHandles.nvcuda)
}
}
if oHandles != nil {
if oHandles.oneapi != nil {
// TODO - is this needed?
C.oneapi_release(*oHandles.oneapi)
} }
if gpuHandles.nvcuda != nil {
C.nvcuda_release(*gpuHandles.nvcuda)
} }
}() }()
@ -253,13 +276,11 @@ func GetGPUInfo() GpuInfoList {
} }
// Load ALL libraries // Load ALL libraries
gpuHandles = initCudaHandles() cHandles = initCudaHandles()
// TODO needs a refactoring pass to init oneapi handles
// NVIDIA // NVIDIA
for i := range gpuHandles.deviceCount { for i := range cHandles.deviceCount {
if gpuHandles.cudart != nil || gpuHandles.nvcuda != nil { if cHandles.cudart != nil || cHandles.nvcuda != nil {
gpuInfo := CudaGPUInfo{ gpuInfo := CudaGPUInfo{
GpuInfo: GpuInfo{ GpuInfo: GpuInfo{
Library: "cuda", Library: "cuda",
@ -268,12 +289,12 @@ func GetGPUInfo() GpuInfoList {
} }
var driverMajor int var driverMajor int
var driverMinor int var driverMinor int
if gpuHandles.cudart != nil { if cHandles.cudart != nil {
C.cudart_bootstrap(*gpuHandles.cudart, C.int(i), &memInfo) C.cudart_bootstrap(*cHandles.cudart, C.int(i), &memInfo)
} else { } else {
C.nvcuda_bootstrap(*gpuHandles.nvcuda, C.int(i), &memInfo) C.nvcuda_bootstrap(*cHandles.nvcuda, C.int(i), &memInfo)
driverMajor = int(gpuHandles.nvcuda.driver_major) driverMajor = int(cHandles.nvcuda.driver_major)
driverMinor = int(gpuHandles.nvcuda.driver_minor) driverMinor = int(cHandles.nvcuda.driver_minor)
} }
if memInfo.err != nil { if memInfo.err != nil {
slog.Info("error looking up nvidia GPU memory", "error", C.GoString(memInfo.err)) slog.Info("error looking up nvidia GPU memory", "error", C.GoString(memInfo.err))
@ -297,20 +318,35 @@ func GetGPUInfo() GpuInfoList {
// TODO potentially sort on our own algorithm instead of what the underlying GPU library does... // TODO potentially sort on our own algorithm instead of what the underlying GPU library does...
cudaGPUs = append(cudaGPUs, gpuInfo) cudaGPUs = append(cudaGPUs, gpuInfo)
} }
if gpuHandles.oneapi != nil { }
// Intel
oHandles = initOneAPIHandles()
for d := 0; oHandles.oneapi != nil && d < int(oHandles.oneapi.num_drivers); d++ {
if oHandles.oneapi == nil {
// shouldn't happen
slog.Warn("nil oneapi handle with driver count", "count", int(oHandles.oneapi.num_drivers))
continue
}
devCount := C.oneapi_get_device_count(*oHandles.oneapi, C.int(d))
for i := 0; i < int(devCount); i++ {
gpuInfo := OneapiGPUInfo{ gpuInfo := OneapiGPUInfo{
GpuInfo: GpuInfo{ GpuInfo: GpuInfo{
Library: "oneapi", Library: "oneapi",
}, },
index: i, driverIndex: d,
gpuIndex: i,
} }
// TODO - split bootstrapping from updating free memory // TODO - split bootstrapping from updating free memory
C.oneapi_check_vram(*gpuHandles.oneapi, &memInfo) C.oneapi_check_vram(*oHandles.oneapi, C.int(d), C.int(i), &memInfo)
// TODO - convert this to MinimumMemory based on testing...
var totalFreeMem float64 = float64(memInfo.free) * 0.95 // work-around: leave some reserve vram for mkl lib used in ggml-sycl backend. var totalFreeMem float64 = float64(memInfo.free) * 0.95 // work-around: leave some reserve vram for mkl lib used in ggml-sycl backend.
memInfo.free = C.uint64_t(totalFreeMem) memInfo.free = C.uint64_t(totalFreeMem)
gpuInfo.TotalMemory = uint64(memInfo.total) gpuInfo.TotalMemory = uint64(memInfo.total)
gpuInfo.FreeMemory = uint64(memInfo.free) gpuInfo.FreeMemory = uint64(memInfo.free)
gpuInfo.ID = strconv.Itoa(i) gpuInfo.ID = C.GoString(&memInfo.gpu_id[0])
gpuInfo.Name = C.GoString(&memInfo.gpu_name[0])
// TODO dependency path?
oneapiGPUs = append(oneapiGPUs, gpuInfo) oneapiGPUs = append(oneapiGPUs, gpuInfo)
} }
} }
@ -325,14 +361,14 @@ func GetGPUInfo() GpuInfoList {
if needRefresh { if needRefresh {
// TODO - CPU system memory tracking/refresh // TODO - CPU system memory tracking/refresh
var memInfo C.mem_info_t var memInfo C.mem_info_t
if gpuHandles == nil && len(cudaGPUs) > 0 { if cHandles == nil && len(cudaGPUs) > 0 {
gpuHandles = initCudaHandles() cHandles = initCudaHandles()
} }
for i, gpu := range cudaGPUs { for i, gpu := range cudaGPUs {
if gpuHandles.cudart != nil { if cHandles.cudart != nil {
C.cudart_bootstrap(*gpuHandles.cudart, C.int(gpu.index), &memInfo) C.cudart_bootstrap(*cHandles.cudart, C.int(gpu.index), &memInfo)
} else { } else {
C.nvcuda_get_free(*gpuHandles.nvcuda, C.int(gpu.index), &memInfo.free) C.nvcuda_get_free(*cHandles.nvcuda, C.int(gpu.index), &memInfo.free)
} }
if memInfo.err != nil { if memInfo.err != nil {
slog.Warn("error looking up nvidia GPU memory", "error", C.GoString(memInfo.err)) slog.Warn("error looking up nvidia GPU memory", "error", C.GoString(memInfo.err))
@ -346,6 +382,23 @@ func GetGPUInfo() GpuInfoList {
slog.Debug("updating cuda free memory", "gpu", gpu.ID, "name", gpu.Name, "before", format.HumanBytes2(gpu.FreeMemory), "now", format.HumanBytes2(uint64(memInfo.free))) slog.Debug("updating cuda free memory", "gpu", gpu.ID, "name", gpu.Name, "before", format.HumanBytes2(gpu.FreeMemory), "now", format.HumanBytes2(uint64(memInfo.free)))
cudaGPUs[i].FreeMemory = uint64(memInfo.free) cudaGPUs[i].FreeMemory = uint64(memInfo.free)
} }
if oHandles == nil && len(oneapiGPUs) > 0 {
oHandles = initOneAPIHandles()
}
for i, gpu := range oneapiGPUs {
if oHandles.oneapi == nil {
// shouldn't happen
slog.Warn("nil oneapi handle with device count", "count", oHandles.deviceCount)
continue
}
C.oneapi_check_vram(*oHandles.oneapi, C.int(gpu.driverIndex), C.int(gpu.gpuIndex), &memInfo)
// TODO - convert this to MinimumMemory based on testing...
var totalFreeMem float64 = float64(memInfo.free) * 0.95 // work-around: leave some reserve vram for mkl lib used in ggml-sycl backend.
memInfo.free = C.uint64_t(totalFreeMem)
oneapiGPUs[i].FreeMemory = uint64(memInfo.free)
}
err := RocmGPUInfoList(rocmGPUs).RefreshFreeMemory() err := RocmGPUInfoList(rocmGPUs).RefreshFreeMemory()
if err != nil { if err != nil {
slog.Debug("problem refreshing ROCm free memory", "error", err) slog.Debug("problem refreshing ROCm free memory", "error", err)
@ -359,6 +412,9 @@ func GetGPUInfo() GpuInfoList {
for _, gpu := range rocmGPUs { for _, gpu := range rocmGPUs {
resp = append(resp, gpu.GpuInfo) resp = append(resp, gpu.GpuInfo)
} }
for _, gpu := range oneapiGPUs {
resp = append(resp, gpu.GpuInfo)
}
if len(resp) == 0 { if len(resp) == 0 {
resp = append(resp, cpus[0].GpuInfo) resp = append(resp, cpus[0].GpuInfo)
} }
@ -476,6 +532,7 @@ func LoadNVCUDAMgmt(nvcudaLibPaths []string) (int, *C.nvcuda_handle_t, string) {
func LoadOneapiMgmt(oneapiLibPaths []string) (int, *C.oneapi_handle_t, string) { func LoadOneapiMgmt(oneapiLibPaths []string) (int, *C.oneapi_handle_t, string) {
var resp C.oneapi_init_resp_t var resp C.oneapi_init_resp_t
num_devices := 0
resp.oh.verbose = getVerboseState() resp.oh.verbose = getVerboseState()
for _, libPath := range oneapiLibPaths { for _, libPath := range oneapiLibPaths {
lib := C.CString(libPath) lib := C.CString(libPath)
@ -485,7 +542,10 @@ func LoadOneapiMgmt(oneapiLibPaths []string) (int, *C.oneapi_handle_t, string) {
slog.Debug("Unable to load oneAPI management library", "library", libPath, "error", C.GoString(resp.err)) slog.Debug("Unable to load oneAPI management library", "library", libPath, "error", C.GoString(resp.err))
C.free(unsafe.Pointer(resp.err)) C.free(unsafe.Pointer(resp.err))
} else { } else {
return int(resp.num_devices), &resp.oh, libPath for i := 0; i < int(resp.oh.num_drivers); i++ {
num_devices += int(C.oneapi_get_device_count(resp.oh, C.int(i)))
}
return num_devices, &resp.oh, libPath
} }
} }
return 0, nil, "" return 0, nil, ""

View file

@ -8,9 +8,13 @@ void oneapi_init(char *oneapi_lib_path, oneapi_init_resp_t *resp)
{ {
ze_result_t ret; ze_result_t ret;
resp->err = NULL; resp->err = NULL;
resp->oh.devices = NULL;
resp->oh.num_devices = NULL;
resp->oh.drivers = NULL;
resp->oh.num_drivers = 0;
const int buflen = 256; const int buflen = 256;
char buf[buflen + 1]; char buf[buflen + 1];
int i; int i, d, count;
struct lookup struct lookup
{ {
char *s; char *s;
@ -66,19 +70,65 @@ void oneapi_init(char *oneapi_lib_path, oneapi_init_resp_t *resp)
ret = (*resp->oh.zesInit)(0); ret = (*resp->oh.zesInit)(0);
if (ret != ZE_RESULT_SUCCESS) if (ret != ZE_RESULT_SUCCESS)
{ {
LOG(resp->oh.verbose, "zesInit err: %d\n", ret); LOG(resp->oh.verbose, "zesInit err: %x\n", ret);
UNLOAD_LIBRARY(resp->oh.handle); snprintf(buf, buflen, "oneapi vram init failure: %x", ret);
resp->oh.handle = NULL;
snprintf(buf, buflen, "oneapi vram init failure: %d", ret);
resp->err = strdup(buf); resp->err = strdup(buf);
oneapi_release(resp->oh);
return;
} }
(*resp->oh.zesDriverGet)(&resp->num_devices, NULL); count = 0;
ret = (*resp->oh.zesDriverGet)(&resp->oh.num_drivers, NULL);
if (ret != ZE_RESULT_SUCCESS)
{
LOG(resp->oh.verbose, "zesDriverGet err: %x\n", ret);
snprintf(buf, buflen, "unable to get driver count: %x", ret);
resp->err = strdup(buf);
oneapi_release(resp->oh);
return;
}
LOG(resp->oh.verbose, "oneapi driver count: %d\n", resp->oh.num_drivers);
resp->oh.drivers = malloc(resp->oh.num_drivers * sizeof(zes_driver_handle_t));
resp->oh.num_devices = malloc(resp->oh.num_drivers * sizeof(uint32_t));
memset(&resp->oh.num_devices[0], 0, resp->oh.num_drivers * sizeof(uint32_t));
resp->oh.devices = malloc(resp->oh.num_drivers * sizeof(zes_device_handle_t*));
ret = (*resp->oh.zesDriverGet)(&resp->oh.num_drivers, &resp->oh.drivers[0]);
if (ret != ZE_RESULT_SUCCESS)
{
LOG(resp->oh.verbose, "zesDriverGet err: %x\n", ret);
snprintf(buf, buflen, "unable to get driver count: %x", ret);
resp->err = strdup(buf);
oneapi_release(resp->oh);
return;
}
for (d = 0; d < resp->oh.num_drivers; d++) {
ret = (*resp->oh.zesDeviceGet)(resp->oh.drivers[d], &resp->oh.num_devices[d], NULL);
if (ret != ZE_RESULT_SUCCESS)
{
LOG(resp->oh.verbose, "zesDeviceGet err: %x\n", ret);
snprintf(buf, buflen, "unable to get device count: %x", ret);
resp->err = strdup(buf);
oneapi_release(resp->oh);
return;
}
resp->oh.devices[d] = malloc(resp->oh.num_devices[d] * sizeof(zes_device_handle_t));
ret = (*resp->oh.zesDeviceGet)(resp->oh.drivers[d], &resp->oh.num_devices[d], resp->oh.devices[d]);
if (ret != ZE_RESULT_SUCCESS)
{
LOG(resp->oh.verbose, "zesDeviceGet err: %x\n", ret);
snprintf(buf, buflen, "unable to get device count: %x", ret);
resp->err = strdup(buf);
oneapi_release(resp->oh);
return;
}
count += resp->oh.num_devices[d];
}
return; return;
} }
void oneapi_check_vram(oneapi_handle_t h, mem_info_t *resp) void oneapi_check_vram(oneapi_handle_t h, int driver, int device, mem_info_t *resp)
{ {
ze_result_t ret; ze_result_t ret;
resp->err = NULL; resp->err = NULL;
@ -94,43 +144,14 @@ void oneapi_check_vram(oneapi_handle_t h, mem_info_t *resp)
return; return;
} }
uint32_t driversCount = 0; if (driver > h.num_drivers || device > h.num_devices[driver]) {
ret = (*h.zesDriverGet)(&driversCount, NULL); resp->err = strdup("driver of device index out of bounds");
if (ret != ZE_RESULT_SUCCESS)
{
snprintf(buf, buflen, "unable to get driver count: %d", ret);
resp->err = strdup(buf);
return; return;
} }
LOG(h.verbose, "discovered %d Level-Zero drivers\n", driversCount);
zes_driver_handle_t *allDrivers =
malloc(driversCount * sizeof(zes_driver_handle_t));
(*h.zesDriverGet)(&driversCount, allDrivers);
resp->total = 0; resp->total = 0;
resp->free = 0; resp->free = 0;
for (d = 0; d < driversCount; d++)
{
uint32_t deviceCount = 0;
ret = (*h.zesDeviceGet)(allDrivers[d], &deviceCount, NULL);
if (ret != ZE_RESULT_SUCCESS)
{
snprintf(buf, buflen, "unable to get device count: %d", ret);
resp->err = strdup(buf);
free(allDrivers);
return;
}
LOG(h.verbose, "discovered %d Level-Zero devices\n", deviceCount);
zes_device_handle_t *devices =
malloc(deviceCount * sizeof(zes_device_handle_t));
(*h.zesDeviceGet)(allDrivers[d], &deviceCount, devices);
for (i = 0; i < deviceCount; i++)
{
zes_device_ext_properties_t ext_props; zes_device_ext_properties_t ext_props;
ext_props.stype = ZES_STRUCTURE_TYPE_DEVICE_EXT_PROPERTIES; ext_props.stype = ZES_STRUCTURE_TYPE_DEVICE_EXT_PROPERTIES;
ext_props.pNext = NULL; ext_props.pNext = NULL;
@ -139,48 +160,54 @@ void oneapi_check_vram(oneapi_handle_t h, mem_info_t *resp)
props.stype = ZES_STRUCTURE_TYPE_DEVICE_PROPERTIES; props.stype = ZES_STRUCTURE_TYPE_DEVICE_PROPERTIES;
props.pNext = &ext_props; props.pNext = &ext_props;
ret = (*h.zesDeviceGetProperties)(devices[i], &props); ret = (*h.zesDeviceGetProperties)(h.devices[driver][device], &props);
if (ret != ZE_RESULT_SUCCESS) if (ret != ZE_RESULT_SUCCESS)
{ {
snprintf(buf, buflen, "unable to get device properties: %d", ret); snprintf(buf, buflen, "unable to get device properties: %d", ret);
resp->err = strdup(buf); resp->err = strdup(buf);
free(allDrivers);
free(devices);
return; return;
} }
snprintf(&resp->gpu_name[0], GPU_NAME_LEN, props.modelName);
// TODO this needs to map to ONEAPI_DEVICE_SELECTOR syntax
// (this is probably wrong...)
// TODO - the driver isn't included - what if there are multiple drivers?
snprintf(&resp->gpu_id[0], GPU_ID_LEN, "%d", device);
if (h.verbose) if (h.verbose)
{ {
// When in verbose mode, report more information about // When in verbose mode, report more information about
// the card we discover. // the card we discover.
LOG(h.verbose, "[%d] oneAPI device name: %s\n", i, LOG(h.verbose, "[%d:%d] oneAPI device name: %s\n", driver, device,
props.modelName); props.modelName);
LOG(h.verbose, "[%d] oneAPI brand: %s\n", i, LOG(h.verbose, "[%d:%d] oneAPI brand: %s\n", driver, device,
props.brandName); props.brandName);
LOG(h.verbose, "[%d] oneAPI vendor: %s\n", i, LOG(h.verbose, "[%d:%d] oneAPI vendor: %s\n", driver, device,
props.vendorName); props.vendorName);
LOG(h.verbose, "[%d] oneAPI S/N: %s\n", i, LOG(h.verbose, "[%d:%d] oneAPI S/N: %s\n", driver, device,
props.serialNumber); props.serialNumber);
LOG(h.verbose, "[%d] oneAPI board number: %s\n", i, LOG(h.verbose, "[%d:%d] oneAPI board number: %s\n", driver, device,
props.boardNumber); props.boardNumber);
} }
// TODO
// Compute Capability equivalent in resp->major, resp->minor, resp->patch
uint32_t memCount = 0; uint32_t memCount = 0;
ret = (*h.zesDeviceEnumMemoryModules)(devices[i], &memCount, NULL); ret = (*h.zesDeviceEnumMemoryModules)(h.devices[driver][device], &memCount, NULL);
if (ret != ZE_RESULT_SUCCESS) if (ret != ZE_RESULT_SUCCESS)
{ {
snprintf(buf, buflen, snprintf(buf, buflen,
"unable to enumerate Level-Zero memory modules: %d", ret); "unable to enumerate Level-Zero memory modules: %x", ret);
resp->err = strdup(buf); resp->err = strdup(buf);
free(allDrivers);
free(devices);
return; return;
} }
LOG(h.verbose, "discovered %d Level-Zero memory modules\n", memCount); LOG(h.verbose, "discovered %d Level-Zero memory modules\n", memCount);
zes_mem_handle_t *mems = malloc(memCount * sizeof(zes_mem_handle_t)); zes_mem_handle_t *mems = malloc(memCount * sizeof(zes_mem_handle_t));
(*h.zesDeviceEnumMemoryModules)(devices[i], &memCount, mems); (*h.zesDeviceEnumMemoryModules)(h.devices[driver][device], &memCount, mems);
for (m = 0; m < memCount; m++) for (m = 0; m < memCount; m++)
{ {
@ -190,10 +217,8 @@ void oneapi_check_vram(oneapi_handle_t h, mem_info_t *resp)
ret = (*h.zesMemoryGetState)(mems[m], &state); ret = (*h.zesMemoryGetState)(mems[m], &state);
if (ret != ZE_RESULT_SUCCESS) if (ret != ZE_RESULT_SUCCESS)
{ {
snprintf(buf, buflen, "unable to get memory state: %d", ret); snprintf(buf, buflen, "unable to get memory state: %x", ret);
resp->err = strdup(buf); resp->err = strdup(buf);
free(allDrivers);
free(devices);
free(mems); free(mems);
return; return;
} }
@ -205,10 +230,48 @@ void oneapi_check_vram(oneapi_handle_t h, mem_info_t *resp)
free(mems); free(mems);
} }
free(devices); void oneapi_release(oneapi_handle_t h)
{
int d;
LOG(h.verbose, "releasing oneapi library\n");
for (d = 0; d < h.num_drivers; d++)
{
if (h.devices != NULL && h.devices[d] != NULL)
{
free(h.devices[d]);
}
}
if (h.devices != NULL)
{
free(h.devices);
h.devices = NULL;
}
if (h.num_devices != NULL)
{
free(h.num_devices);
h.num_devices = NULL;
}
if (h.drivers != NULL)
{
free(h.drivers);
h.drivers = NULL;
}
h.num_drivers = 0;
UNLOAD_LIBRARY(h.handle);
h.handle = NULL;
} }
free(allDrivers); int oneapi_get_device_count(oneapi_handle_t h, int driver)
{
if (h.handle == NULL || h.num_devices == NULL)
{
return 0;
}
if (driver > h.num_drivers)
{
return 0;
}
return (int)h.num_devices[driver];
} }
#endif // __APPLE__ #endif // __APPLE__

View file

@ -175,6 +175,16 @@ typedef struct oneapi_handle
{ {
void *handle; void *handle;
uint16_t verbose; uint16_t verbose;
uint32_t num_drivers;
zes_driver_handle_t *drivers;
uint32_t *num_devices;
zes_device_handle_t **devices;
// TODO Driver major, minor information
// int driver_major;
// int driver_minor;
ze_result_t (*zesInit)(int); ze_result_t (*zesInit)(int);
ze_result_t (*zesDriverGet)(uint32_t *pCount, zes_driver_handle_t *phDrivers); ze_result_t (*zesDriverGet)(uint32_t *pCount, zes_driver_handle_t *phDrivers);
ze_result_t (*zesDeviceGet)(zes_driver_handle_t hDriver, uint32_t *pCount, ze_result_t (*zesDeviceGet)(zes_driver_handle_t hDriver, uint32_t *pCount,
@ -194,7 +204,6 @@ typedef struct oneapi_handle
typedef struct oneapi_init_resp typedef struct oneapi_init_resp
{ {
char *err; // If err is non-null handle is invalid char *err; // If err is non-null handle is invalid
int num_devices;
oneapi_handle_t oh; oneapi_handle_t oh;
} oneapi_init_resp_t; } oneapi_init_resp_t;
@ -205,7 +214,9 @@ typedef struct oneapi_version_resp
} oneapi_version_resp_t; } oneapi_version_resp_t;
void oneapi_init(char *oneapi_lib_path, oneapi_init_resp_t *resp); void oneapi_init(char *oneapi_lib_path, oneapi_init_resp_t *resp);
void oneapi_check_vram(oneapi_handle_t rh, mem_info_t *resp); void oneapi_check_vram(oneapi_handle_t h, int driver, int device, mem_info_t *resp);
void oneapi_release(oneapi_handle_t h);
int oneapi_get_device_count(oneapi_handle_t h, int driver);
#endif // __GPU_INFO_INTEL_H__ #endif // __GPU_INFO_INTEL_H__
#endif // __APPLE__ #endif // __APPLE__

View file

@ -57,7 +57,8 @@ type RocmGPUInfoList []RocmGPUInfo
type OneapiGPUInfo struct { type OneapiGPUInfo struct {
GpuInfo GpuInfo
index int // device index driverIndex int // nolint: unused
gpuIndex int // nolint: unused
} }
type OneapiGPUInfoList []OneapiGPUInfo type OneapiGPUInfoList []OneapiGPUInfo