From 4e2b7e181d069166134be5391974b7a49ca08890 Mon Sep 17 00:00:00 2001 From: Daniel Hiltgen Date: Wed, 29 May 2024 16:37:34 -0700 Subject: [PATCH] Refactor intel gpu discovery --- gpu/gpu.go | 176 +++++++++++++++++--------- gpu/gpu_info_oneapi.c | 279 ++++++++++++++++++++++++++---------------- gpu/gpu_info_oneapi.h | 15 ++- gpu/types.go | 3 +- 4 files changed, 304 insertions(+), 169 deletions(-) diff --git a/gpu/gpu.go b/gpu/gpu.go index 1832667b..d7a3ba44 100644 --- a/gpu/gpu.go +++ b/gpu/gpu.go @@ -16,7 +16,6 @@ import ( "os" "path/filepath" "runtime" - "strconv" "strings" "sync" "unsafe" @@ -25,16 +24,21 @@ import ( "github.com/ollama/ollama/format" ) -type handles struct { +type cudaHandles struct { deviceCount int cudart *C.cudart_handle_t nvcuda *C.nvcuda_handle_t +} + +type oneapiHandles struct { oneapi *C.oneapi_handle_t + deviceCount int } const ( cudaMinimumMemory = 457 * format.MebiByte rocmMinimumMemory = 457 * format.MebiByte + // TODO OneAPI minimum memory ) var ( @@ -107,19 +111,19 @@ var OneapiLinuxGlobs = []string{ var CudaTegra string = os.Getenv("JETSON_JETPACK") // Note: gpuMutex must already be held -func initCudaHandles() *handles { +func initCudaHandles() *cudaHandles { // TODO - if the ollama build is CPU only, don't do these checks as they're irrelevant and confusing - gpuHandles := &handles{} + cHandles := &cudaHandles{} // Short Circuit if we already know which library to use if nvcudaLibPath != "" { - gpuHandles.deviceCount, gpuHandles.nvcuda, _ = LoadNVCUDAMgmt([]string{nvcudaLibPath}) - return gpuHandles + cHandles.deviceCount, cHandles.nvcuda, _ = LoadNVCUDAMgmt([]string{nvcudaLibPath}) + return cHandles } if cudartLibPath != "" { - gpuHandles.deviceCount, gpuHandles.cudart, _ = LoadCUDARTMgmt([]string{cudartLibPath}) - return gpuHandles + cHandles.deviceCount, cHandles.cudart, _ = LoadCUDARTMgmt([]string{cudartLibPath}) + return cHandles } slog.Debug("searching for GPU discovery libraries for NVIDIA") @@ -127,8 +131,6 @@ func initCudaHandles() *handles { var cudartMgmtPatterns []string var nvcudaMgmtName string var nvcudaMgmtPatterns []string - var oneapiMgmtName string - var oneapiMgmtPatterns []string tmpDir, _ := PayloadsDir() switch runtime.GOOS { @@ -140,8 +142,6 @@ func initCudaHandles() *handles { // Aligned with driver, we can't carry as payloads nvcudaMgmtName = "nvcuda.dll" nvcudaMgmtPatterns = NvcudaWindowsGlobs - oneapiMgmtName = "ze_intel_gpu64.dll" - oneapiMgmtPatterns = OneapiWindowsGlobs case "linux": cudartMgmtName = "libcudart.so*" if tmpDir != "" { @@ -152,10 +152,8 @@ func initCudaHandles() *handles { // Aligned with driver, we can't carry as payloads nvcudaMgmtName = "libcuda.so*" nvcudaMgmtPatterns = NvcudaLinuxGlobs - oneapiMgmtName = "libze_intel_gpu.so" - oneapiMgmtPatterns = OneapiLinuxGlobs default: - return gpuHandles + return cHandles } nvcudaLibPaths := FindGPULibs(nvcudaMgmtName, nvcudaMgmtPatterns) @@ -163,10 +161,10 @@ func initCudaHandles() *handles { deviceCount, nvcuda, libPath := LoadNVCUDAMgmt(nvcudaLibPaths) if nvcuda != nil { slog.Debug("detected GPUs", "count", deviceCount, "library", libPath) - gpuHandles.nvcuda = nvcuda - gpuHandles.deviceCount = deviceCount + cHandles.nvcuda = nvcuda + cHandles.deviceCount = deviceCount nvcudaLibPath = libPath - return gpuHandles + return cHandles } } @@ -175,26 +173,45 @@ func initCudaHandles() *handles { deviceCount, cudart, libPath := LoadCUDARTMgmt(cudartLibPaths) if cudart != nil { slog.Debug("detected GPUs", "library", libPath, "count", deviceCount) - gpuHandles.cudart = cudart - gpuHandles.deviceCount = deviceCount + cHandles.cudart = cudart + cHandles.deviceCount = deviceCount cudartLibPath = libPath - return gpuHandles + return cHandles } } + return cHandles +} + +// Note: gpuMutex must already be held +func initOneAPIHandles() *oneapiHandles { + oHandles := &oneapiHandles{} + var oneapiMgmtName string + var oneapiMgmtPatterns []string + + // Short Circuit if we already know which library to use + if oneapiLibPath != "" { + oHandles.deviceCount, oHandles.oneapi, _ = LoadOneapiMgmt([]string{oneapiLibPath}) + return oHandles + } + + switch runtime.GOOS { + case "windows": + oneapiMgmtName = "ze_intel_gpu64.dll" + oneapiMgmtPatterns = OneapiWindowsGlobs + case "linux": + oneapiMgmtName = "libze_intel_gpu.so" + oneapiMgmtPatterns = OneapiLinuxGlobs + default: + return oHandles + } + oneapiLibPaths := FindGPULibs(oneapiMgmtName, oneapiMgmtPatterns) if len(oneapiLibPaths) > 0 { - deviceCount, oneapi, libPath := LoadOneapiMgmt(oneapiLibPaths) - if oneapi != nil { - slog.Debug("detected Intel GPUs", "library", libPath, "count", deviceCount) - gpuHandles.oneapi = oneapi - gpuHandles.deviceCount = deviceCount - oneapiLibPath = libPath - return gpuHandles - } + oHandles.deviceCount, oHandles.oneapi, oneapiLibPath = LoadOneapiMgmt(oneapiLibPaths) } - return gpuHandles + return oHandles } func GetGPUInfo() GpuInfoList { @@ -203,16 +220,22 @@ func GetGPUInfo() GpuInfoList { gpuMutex.Lock() defer gpuMutex.Unlock() needRefresh := true - var gpuHandles *handles + var cHandles *cudaHandles + var oHandles *oneapiHandles defer func() { - if gpuHandles == nil { - return + if cHandles != nil { + if cHandles.cudart != nil { + C.cudart_release(*cHandles.cudart) + } + if cHandles.nvcuda != nil { + C.nvcuda_release(*cHandles.nvcuda) + } } - if gpuHandles.cudart != nil { - C.cudart_release(*gpuHandles.cudart) - } - if gpuHandles.nvcuda != nil { - C.nvcuda_release(*gpuHandles.nvcuda) + if oHandles != nil { + if oHandles.oneapi != nil { + // TODO - is this needed? + C.oneapi_release(*oHandles.oneapi) + } } }() @@ -253,13 +276,11 @@ func GetGPUInfo() GpuInfoList { } // Load ALL libraries - gpuHandles = initCudaHandles() - - // TODO needs a refactoring pass to init oneapi handles + cHandles = initCudaHandles() // NVIDIA - for i := range gpuHandles.deviceCount { - if gpuHandles.cudart != nil || gpuHandles.nvcuda != nil { + for i := range cHandles.deviceCount { + if cHandles.cudart != nil || cHandles.nvcuda != nil { gpuInfo := CudaGPUInfo{ GpuInfo: GpuInfo{ Library: "cuda", @@ -268,12 +289,12 @@ func GetGPUInfo() GpuInfoList { } var driverMajor int var driverMinor int - if gpuHandles.cudart != nil { - C.cudart_bootstrap(*gpuHandles.cudart, C.int(i), &memInfo) + if cHandles.cudart != nil { + C.cudart_bootstrap(*cHandles.cudart, C.int(i), &memInfo) } else { - C.nvcuda_bootstrap(*gpuHandles.nvcuda, C.int(i), &memInfo) - driverMajor = int(gpuHandles.nvcuda.driver_major) - driverMinor = int(gpuHandles.nvcuda.driver_minor) + C.nvcuda_bootstrap(*cHandles.nvcuda, C.int(i), &memInfo) + driverMajor = int(cHandles.nvcuda.driver_major) + driverMinor = int(cHandles.nvcuda.driver_minor) } if memInfo.err != nil { slog.Info("error looking up nvidia GPU memory", "error", C.GoString(memInfo.err)) @@ -297,20 +318,35 @@ func GetGPUInfo() GpuInfoList { // TODO potentially sort on our own algorithm instead of what the underlying GPU library does... cudaGPUs = append(cudaGPUs, gpuInfo) } - if gpuHandles.oneapi != nil { + } + + // Intel + oHandles = initOneAPIHandles() + for d := 0; oHandles.oneapi != nil && d < int(oHandles.oneapi.num_drivers); d++ { + if oHandles.oneapi == nil { + // shouldn't happen + slog.Warn("nil oneapi handle with driver count", "count", int(oHandles.oneapi.num_drivers)) + continue + } + devCount := C.oneapi_get_device_count(*oHandles.oneapi, C.int(d)) + for i := 0; i < int(devCount); i++ { gpuInfo := OneapiGPUInfo{ GpuInfo: GpuInfo{ Library: "oneapi", }, - index: i, + driverIndex: d, + gpuIndex: i, } // TODO - split bootstrapping from updating free memory - C.oneapi_check_vram(*gpuHandles.oneapi, &memInfo) + C.oneapi_check_vram(*oHandles.oneapi, C.int(d), C.int(i), &memInfo) + // TODO - convert this to MinimumMemory based on testing... var totalFreeMem float64 = float64(memInfo.free) * 0.95 // work-around: leave some reserve vram for mkl lib used in ggml-sycl backend. memInfo.free = C.uint64_t(totalFreeMem) gpuInfo.TotalMemory = uint64(memInfo.total) gpuInfo.FreeMemory = uint64(memInfo.free) - gpuInfo.ID = strconv.Itoa(i) + gpuInfo.ID = C.GoString(&memInfo.gpu_id[0]) + gpuInfo.Name = C.GoString(&memInfo.gpu_name[0]) + // TODO dependency path? oneapiGPUs = append(oneapiGPUs, gpuInfo) } } @@ -325,14 +361,14 @@ func GetGPUInfo() GpuInfoList { if needRefresh { // TODO - CPU system memory tracking/refresh var memInfo C.mem_info_t - if gpuHandles == nil && len(cudaGPUs) > 0 { - gpuHandles = initCudaHandles() + if cHandles == nil && len(cudaGPUs) > 0 { + cHandles = initCudaHandles() } for i, gpu := range cudaGPUs { - if gpuHandles.cudart != nil { - C.cudart_bootstrap(*gpuHandles.cudart, C.int(gpu.index), &memInfo) + if cHandles.cudart != nil { + C.cudart_bootstrap(*cHandles.cudart, C.int(gpu.index), &memInfo) } else { - C.nvcuda_get_free(*gpuHandles.nvcuda, C.int(gpu.index), &memInfo.free) + C.nvcuda_get_free(*cHandles.nvcuda, C.int(gpu.index), &memInfo.free) } if memInfo.err != nil { slog.Warn("error looking up nvidia GPU memory", "error", C.GoString(memInfo.err)) @@ -346,6 +382,23 @@ func GetGPUInfo() GpuInfoList { slog.Debug("updating cuda free memory", "gpu", gpu.ID, "name", gpu.Name, "before", format.HumanBytes2(gpu.FreeMemory), "now", format.HumanBytes2(uint64(memInfo.free))) cudaGPUs[i].FreeMemory = uint64(memInfo.free) } + + if oHandles == nil && len(oneapiGPUs) > 0 { + oHandles = initOneAPIHandles() + } + for i, gpu := range oneapiGPUs { + if oHandles.oneapi == nil { + // shouldn't happen + slog.Warn("nil oneapi handle with device count", "count", oHandles.deviceCount) + continue + } + C.oneapi_check_vram(*oHandles.oneapi, C.int(gpu.driverIndex), C.int(gpu.gpuIndex), &memInfo) + // TODO - convert this to MinimumMemory based on testing... + var totalFreeMem float64 = float64(memInfo.free) * 0.95 // work-around: leave some reserve vram for mkl lib used in ggml-sycl backend. + memInfo.free = C.uint64_t(totalFreeMem) + oneapiGPUs[i].FreeMemory = uint64(memInfo.free) + } + err := RocmGPUInfoList(rocmGPUs).RefreshFreeMemory() if err != nil { slog.Debug("problem refreshing ROCm free memory", "error", err) @@ -359,6 +412,9 @@ func GetGPUInfo() GpuInfoList { for _, gpu := range rocmGPUs { resp = append(resp, gpu.GpuInfo) } + for _, gpu := range oneapiGPUs { + resp = append(resp, gpu.GpuInfo) + } if len(resp) == 0 { resp = append(resp, cpus[0].GpuInfo) } @@ -476,6 +532,7 @@ func LoadNVCUDAMgmt(nvcudaLibPaths []string) (int, *C.nvcuda_handle_t, string) { func LoadOneapiMgmt(oneapiLibPaths []string) (int, *C.oneapi_handle_t, string) { var resp C.oneapi_init_resp_t + num_devices := 0 resp.oh.verbose = getVerboseState() for _, libPath := range oneapiLibPaths { lib := C.CString(libPath) @@ -485,7 +542,10 @@ func LoadOneapiMgmt(oneapiLibPaths []string) (int, *C.oneapi_handle_t, string) { slog.Debug("Unable to load oneAPI management library", "library", libPath, "error", C.GoString(resp.err)) C.free(unsafe.Pointer(resp.err)) } else { - return int(resp.num_devices), &resp.oh, libPath + for i := 0; i < int(resp.oh.num_drivers); i++ { + num_devices += int(C.oneapi_get_device_count(resp.oh, C.int(i))) + } + return num_devices, &resp.oh, libPath } } return 0, nil, "" diff --git a/gpu/gpu_info_oneapi.c b/gpu/gpu_info_oneapi.c index 4be90e80..cc58f7a2 100644 --- a/gpu/gpu_info_oneapi.c +++ b/gpu/gpu_info_oneapi.c @@ -8,9 +8,13 @@ void oneapi_init(char *oneapi_lib_path, oneapi_init_resp_t *resp) { ze_result_t ret; resp->err = NULL; + resp->oh.devices = NULL; + resp->oh.num_devices = NULL; + resp->oh.drivers = NULL; + resp->oh.num_drivers = 0; const int buflen = 256; char buf[buflen + 1]; - int i; + int i, d, count; struct lookup { char *s; @@ -66,19 +70,65 @@ void oneapi_init(char *oneapi_lib_path, oneapi_init_resp_t *resp) ret = (*resp->oh.zesInit)(0); if (ret != ZE_RESULT_SUCCESS) { - LOG(resp->oh.verbose, "zesInit err: %d\n", ret); - UNLOAD_LIBRARY(resp->oh.handle); - resp->oh.handle = NULL; - snprintf(buf, buflen, "oneapi vram init failure: %d", ret); + LOG(resp->oh.verbose, "zesInit err: %x\n", ret); + snprintf(buf, buflen, "oneapi vram init failure: %x", ret); resp->err = strdup(buf); + oneapi_release(resp->oh); + return; } - (*resp->oh.zesDriverGet)(&resp->num_devices, NULL); + count = 0; + ret = (*resp->oh.zesDriverGet)(&resp->oh.num_drivers, NULL); + if (ret != ZE_RESULT_SUCCESS) + { + LOG(resp->oh.verbose, "zesDriverGet err: %x\n", ret); + snprintf(buf, buflen, "unable to get driver count: %x", ret); + resp->err = strdup(buf); + oneapi_release(resp->oh); + return; + } + LOG(resp->oh.verbose, "oneapi driver count: %d\n", resp->oh.num_drivers); + resp->oh.drivers = malloc(resp->oh.num_drivers * sizeof(zes_driver_handle_t)); + resp->oh.num_devices = malloc(resp->oh.num_drivers * sizeof(uint32_t)); + memset(&resp->oh.num_devices[0], 0, resp->oh.num_drivers * sizeof(uint32_t)); + resp->oh.devices = malloc(resp->oh.num_drivers * sizeof(zes_device_handle_t*)); + ret = (*resp->oh.zesDriverGet)(&resp->oh.num_drivers, &resp->oh.drivers[0]); + if (ret != ZE_RESULT_SUCCESS) + { + LOG(resp->oh.verbose, "zesDriverGet err: %x\n", ret); + snprintf(buf, buflen, "unable to get driver count: %x", ret); + resp->err = strdup(buf); + oneapi_release(resp->oh); + return; + } + + for (d = 0; d < resp->oh.num_drivers; d++) { + ret = (*resp->oh.zesDeviceGet)(resp->oh.drivers[d], &resp->oh.num_devices[d], NULL); + if (ret != ZE_RESULT_SUCCESS) + { + LOG(resp->oh.verbose, "zesDeviceGet err: %x\n", ret); + snprintf(buf, buflen, "unable to get device count: %x", ret); + resp->err = strdup(buf); + oneapi_release(resp->oh); + return; + } + resp->oh.devices[d] = malloc(resp->oh.num_devices[d] * sizeof(zes_device_handle_t)); + ret = (*resp->oh.zesDeviceGet)(resp->oh.drivers[d], &resp->oh.num_devices[d], resp->oh.devices[d]); + if (ret != ZE_RESULT_SUCCESS) + { + LOG(resp->oh.verbose, "zesDeviceGet err: %x\n", ret); + snprintf(buf, buflen, "unable to get device count: %x", ret); + resp->err = strdup(buf); + oneapi_release(resp->oh); + return; + } + count += resp->oh.num_devices[d]; + } return; } -void oneapi_check_vram(oneapi_handle_t h, mem_info_t *resp) +void oneapi_check_vram(oneapi_handle_t h, int driver, int device, mem_info_t *resp) { ze_result_t ret; resp->err = NULL; @@ -93,122 +143,135 @@ void oneapi_check_vram(oneapi_handle_t h, mem_info_t *resp) resp->err = strdup("Level-Zero handle not initialized"); return; } - - uint32_t driversCount = 0; - ret = (*h.zesDriverGet)(&driversCount, NULL); - if (ret != ZE_RESULT_SUCCESS) - { - snprintf(buf, buflen, "unable to get driver count: %d", ret); - resp->err = strdup(buf); + + if (driver > h.num_drivers || device > h.num_devices[driver]) { + resp->err = strdup("driver of device index out of bounds"); return; } - LOG(h.verbose, "discovered %d Level-Zero drivers\n", driversCount); - - zes_driver_handle_t *allDrivers = - malloc(driversCount * sizeof(zes_driver_handle_t)); - (*h.zesDriverGet)(&driversCount, allDrivers); resp->total = 0; resp->free = 0; - for (d = 0; d < driversCount; d++) + zes_device_ext_properties_t ext_props; + ext_props.stype = ZES_STRUCTURE_TYPE_DEVICE_EXT_PROPERTIES; + ext_props.pNext = NULL; + + zes_device_properties_t props; + props.stype = ZES_STRUCTURE_TYPE_DEVICE_PROPERTIES; + props.pNext = &ext_props; + + ret = (*h.zesDeviceGetProperties)(h.devices[driver][device], &props); + if (ret != ZE_RESULT_SUCCESS) { - uint32_t deviceCount = 0; - ret = (*h.zesDeviceGet)(allDrivers[d], &deviceCount, NULL); + snprintf(buf, buflen, "unable to get device properties: %d", ret); + resp->err = strdup(buf); + return; + } + + snprintf(&resp->gpu_name[0], GPU_NAME_LEN, props.modelName); + + // TODO this needs to map to ONEAPI_DEVICE_SELECTOR syntax + // (this is probably wrong...) + // TODO - the driver isn't included - what if there are multiple drivers? + snprintf(&resp->gpu_id[0], GPU_ID_LEN, "%d", device); + + if (h.verbose) + { + // When in verbose mode, report more information about + // the card we discover. + LOG(h.verbose, "[%d:%d] oneAPI device name: %s\n", driver, device, + props.modelName); + LOG(h.verbose, "[%d:%d] oneAPI brand: %s\n", driver, device, + props.brandName); + LOG(h.verbose, "[%d:%d] oneAPI vendor: %s\n", driver, device, + props.vendorName); + LOG(h.verbose, "[%d:%d] oneAPI S/N: %s\n", driver, device, + props.serialNumber); + LOG(h.verbose, "[%d:%d] oneAPI board number: %s\n", driver, device, + props.boardNumber); + } + + // TODO + // Compute Capability equivalent in resp->major, resp->minor, resp->patch + + uint32_t memCount = 0; + ret = (*h.zesDeviceEnumMemoryModules)(h.devices[driver][device], &memCount, NULL); + if (ret != ZE_RESULT_SUCCESS) + { + snprintf(buf, buflen, + "unable to enumerate Level-Zero memory modules: %x", ret); + resp->err = strdup(buf); + return; + } + + LOG(h.verbose, "discovered %d Level-Zero memory modules\n", memCount); + + zes_mem_handle_t *mems = malloc(memCount * sizeof(zes_mem_handle_t)); + (*h.zesDeviceEnumMemoryModules)(h.devices[driver][device], &memCount, mems); + + for (m = 0; m < memCount; m++) + { + zes_mem_state_t state; + state.stype = ZES_STRUCTURE_TYPE_MEM_STATE; + state.pNext = NULL; + ret = (*h.zesMemoryGetState)(mems[m], &state); if (ret != ZE_RESULT_SUCCESS) { - snprintf(buf, buflen, "unable to get device count: %d", ret); + snprintf(buf, buflen, "unable to get memory state: %x", ret); resp->err = strdup(buf); - free(allDrivers); + free(mems); return; } - LOG(h.verbose, "discovered %d Level-Zero devices\n", deviceCount); - - zes_device_handle_t *devices = - malloc(deviceCount * sizeof(zes_device_handle_t)); - (*h.zesDeviceGet)(allDrivers[d], &deviceCount, devices); - - for (i = 0; i < deviceCount; i++) - { - zes_device_ext_properties_t ext_props; - ext_props.stype = ZES_STRUCTURE_TYPE_DEVICE_EXT_PROPERTIES; - ext_props.pNext = NULL; - - zes_device_properties_t props; - props.stype = ZES_STRUCTURE_TYPE_DEVICE_PROPERTIES; - props.pNext = &ext_props; - - ret = (*h.zesDeviceGetProperties)(devices[i], &props); - if (ret != ZE_RESULT_SUCCESS) - { - snprintf(buf, buflen, "unable to get device properties: %d", ret); - resp->err = strdup(buf); - free(allDrivers); - free(devices); - return; - } - - if (h.verbose) - { - // When in verbose mode, report more information about - // the card we discover. - LOG(h.verbose, "[%d] oneAPI device name: %s\n", i, - props.modelName); - LOG(h.verbose, "[%d] oneAPI brand: %s\n", i, - props.brandName); - LOG(h.verbose, "[%d] oneAPI vendor: %s\n", i, - props.vendorName); - LOG(h.verbose, "[%d] oneAPI S/N: %s\n", i, - props.serialNumber); - LOG(h.verbose, "[%d] oneAPI board number: %s\n", i, - props.boardNumber); - } - - uint32_t memCount = 0; - ret = (*h.zesDeviceEnumMemoryModules)(devices[i], &memCount, NULL); - if (ret != ZE_RESULT_SUCCESS) - { - snprintf(buf, buflen, - "unable to enumerate Level-Zero memory modules: %d", ret); - resp->err = strdup(buf); - free(allDrivers); - free(devices); - return; - } - - LOG(h.verbose, "discovered %d Level-Zero memory modules\n", memCount); - - zes_mem_handle_t *mems = malloc(memCount * sizeof(zes_mem_handle_t)); - (*h.zesDeviceEnumMemoryModules)(devices[i], &memCount, mems); - - for (m = 0; m < memCount; m++) - { - zes_mem_state_t state; - state.stype = ZES_STRUCTURE_TYPE_MEM_STATE; - state.pNext = NULL; - ret = (*h.zesMemoryGetState)(mems[m], &state); - if (ret != ZE_RESULT_SUCCESS) - { - snprintf(buf, buflen, "unable to get memory state: %d", ret); - resp->err = strdup(buf); - free(allDrivers); - free(devices); - free(mems); - return; - } - - resp->total += state.size; - resp->free += state.free; - } - - free(mems); - } - - free(devices); + resp->total += state.size; + resp->free += state.free; } - free(allDrivers); + free(mems); +} + +void oneapi_release(oneapi_handle_t h) +{ + int d; + LOG(h.verbose, "releasing oneapi library\n"); + for (d = 0; d < h.num_drivers; d++) + { + if (h.devices != NULL && h.devices[d] != NULL) + { + free(h.devices[d]); + } + } + if (h.devices != NULL) + { + free(h.devices); + h.devices = NULL; + } + if (h.num_devices != NULL) + { + free(h.num_devices); + h.num_devices = NULL; + } + if (h.drivers != NULL) + { + free(h.drivers); + h.drivers = NULL; + } + h.num_drivers = 0; + UNLOAD_LIBRARY(h.handle); + h.handle = NULL; +} + +int oneapi_get_device_count(oneapi_handle_t h, int driver) +{ + if (h.handle == NULL || h.num_devices == NULL) + { + return 0; + } + if (driver > h.num_drivers) + { + return 0; + } + return (int)h.num_devices[driver]; } #endif // __APPLE__ diff --git a/gpu/gpu_info_oneapi.h b/gpu/gpu_info_oneapi.h index 9db9fae0..7607935c 100644 --- a/gpu/gpu_info_oneapi.h +++ b/gpu/gpu_info_oneapi.h @@ -175,6 +175,16 @@ typedef struct oneapi_handle { void *handle; uint16_t verbose; + + uint32_t num_drivers; + zes_driver_handle_t *drivers; + uint32_t *num_devices; + zes_device_handle_t **devices; + + // TODO Driver major, minor information + // int driver_major; + // int driver_minor; + ze_result_t (*zesInit)(int); ze_result_t (*zesDriverGet)(uint32_t *pCount, zes_driver_handle_t *phDrivers); ze_result_t (*zesDeviceGet)(zes_driver_handle_t hDriver, uint32_t *pCount, @@ -194,7 +204,6 @@ typedef struct oneapi_handle typedef struct oneapi_init_resp { char *err; // If err is non-null handle is invalid - int num_devices; oneapi_handle_t oh; } oneapi_init_resp_t; @@ -205,7 +214,9 @@ typedef struct oneapi_version_resp } oneapi_version_resp_t; void oneapi_init(char *oneapi_lib_path, oneapi_init_resp_t *resp); -void oneapi_check_vram(oneapi_handle_t rh, mem_info_t *resp); +void oneapi_check_vram(oneapi_handle_t h, int driver, int device, mem_info_t *resp); +void oneapi_release(oneapi_handle_t h); +int oneapi_get_device_count(oneapi_handle_t h, int driver); #endif // __GPU_INFO_INTEL_H__ #endif // __APPLE__ diff --git a/gpu/types.go b/gpu/types.go index a633e6c7..2b1ea429 100644 --- a/gpu/types.go +++ b/gpu/types.go @@ -57,7 +57,8 @@ type RocmGPUInfoList []RocmGPUInfo type OneapiGPUInfo struct { GpuInfo - index int // device index + driverIndex int // nolint: unused + gpuIndex int // nolint: unused } type OneapiGPUInfoList []OneapiGPUInfo