Refactor intel gpu discovery

This commit is contained in:
Daniel Hiltgen 2024-05-29 16:37:34 -07:00
parent 48702dd149
commit 4e2b7e181d
4 changed files with 304 additions and 169 deletions

View file

@ -16,7 +16,6 @@ import (
"os"
"path/filepath"
"runtime"
"strconv"
"strings"
"sync"
"unsafe"
@ -25,16 +24,21 @@ import (
"github.com/ollama/ollama/format"
)
type handles struct {
type cudaHandles struct {
deviceCount int
cudart *C.cudart_handle_t
nvcuda *C.nvcuda_handle_t
}
type oneapiHandles struct {
oneapi *C.oneapi_handle_t
deviceCount int
}
const (
cudaMinimumMemory = 457 * format.MebiByte
rocmMinimumMemory = 457 * format.MebiByte
// TODO OneAPI minimum memory
)
var (
@ -107,19 +111,19 @@ var OneapiLinuxGlobs = []string{
var CudaTegra string = os.Getenv("JETSON_JETPACK")
// Note: gpuMutex must already be held
func initCudaHandles() *handles {
func initCudaHandles() *cudaHandles {
// TODO - if the ollama build is CPU only, don't do these checks as they're irrelevant and confusing
gpuHandles := &handles{}
cHandles := &cudaHandles{}
// Short Circuit if we already know which library to use
if nvcudaLibPath != "" {
gpuHandles.deviceCount, gpuHandles.nvcuda, _ = LoadNVCUDAMgmt([]string{nvcudaLibPath})
return gpuHandles
cHandles.deviceCount, cHandles.nvcuda, _ = LoadNVCUDAMgmt([]string{nvcudaLibPath})
return cHandles
}
if cudartLibPath != "" {
gpuHandles.deviceCount, gpuHandles.cudart, _ = LoadCUDARTMgmt([]string{cudartLibPath})
return gpuHandles
cHandles.deviceCount, cHandles.cudart, _ = LoadCUDARTMgmt([]string{cudartLibPath})
return cHandles
}
slog.Debug("searching for GPU discovery libraries for NVIDIA")
@ -127,8 +131,6 @@ func initCudaHandles() *handles {
var cudartMgmtPatterns []string
var nvcudaMgmtName string
var nvcudaMgmtPatterns []string
var oneapiMgmtName string
var oneapiMgmtPatterns []string
tmpDir, _ := PayloadsDir()
switch runtime.GOOS {
@ -140,8 +142,6 @@ func initCudaHandles() *handles {
// Aligned with driver, we can't carry as payloads
nvcudaMgmtName = "nvcuda.dll"
nvcudaMgmtPatterns = NvcudaWindowsGlobs
oneapiMgmtName = "ze_intel_gpu64.dll"
oneapiMgmtPatterns = OneapiWindowsGlobs
case "linux":
cudartMgmtName = "libcudart.so*"
if tmpDir != "" {
@ -152,10 +152,8 @@ func initCudaHandles() *handles {
// Aligned with driver, we can't carry as payloads
nvcudaMgmtName = "libcuda.so*"
nvcudaMgmtPatterns = NvcudaLinuxGlobs
oneapiMgmtName = "libze_intel_gpu.so"
oneapiMgmtPatterns = OneapiLinuxGlobs
default:
return gpuHandles
return cHandles
}
nvcudaLibPaths := FindGPULibs(nvcudaMgmtName, nvcudaMgmtPatterns)
@ -163,10 +161,10 @@ func initCudaHandles() *handles {
deviceCount, nvcuda, libPath := LoadNVCUDAMgmt(nvcudaLibPaths)
if nvcuda != nil {
slog.Debug("detected GPUs", "count", deviceCount, "library", libPath)
gpuHandles.nvcuda = nvcuda
gpuHandles.deviceCount = deviceCount
cHandles.nvcuda = nvcuda
cHandles.deviceCount = deviceCount
nvcudaLibPath = libPath
return gpuHandles
return cHandles
}
}
@ -175,26 +173,45 @@ func initCudaHandles() *handles {
deviceCount, cudart, libPath := LoadCUDARTMgmt(cudartLibPaths)
if cudart != nil {
slog.Debug("detected GPUs", "library", libPath, "count", deviceCount)
gpuHandles.cudart = cudart
gpuHandles.deviceCount = deviceCount
cHandles.cudart = cudart
cHandles.deviceCount = deviceCount
cudartLibPath = libPath
return gpuHandles
return cHandles
}
}
return cHandles
}
// Note: gpuMutex must already be held
func initOneAPIHandles() *oneapiHandles {
oHandles := &oneapiHandles{}
var oneapiMgmtName string
var oneapiMgmtPatterns []string
// Short Circuit if we already know which library to use
if oneapiLibPath != "" {
oHandles.deviceCount, oHandles.oneapi, _ = LoadOneapiMgmt([]string{oneapiLibPath})
return oHandles
}
switch runtime.GOOS {
case "windows":
oneapiMgmtName = "ze_intel_gpu64.dll"
oneapiMgmtPatterns = OneapiWindowsGlobs
case "linux":
oneapiMgmtName = "libze_intel_gpu.so"
oneapiMgmtPatterns = OneapiLinuxGlobs
default:
return oHandles
}
oneapiLibPaths := FindGPULibs(oneapiMgmtName, oneapiMgmtPatterns)
if len(oneapiLibPaths) > 0 {
deviceCount, oneapi, libPath := LoadOneapiMgmt(oneapiLibPaths)
if oneapi != nil {
slog.Debug("detected Intel GPUs", "library", libPath, "count", deviceCount)
gpuHandles.oneapi = oneapi
gpuHandles.deviceCount = deviceCount
oneapiLibPath = libPath
return gpuHandles
}
oHandles.deviceCount, oHandles.oneapi, oneapiLibPath = LoadOneapiMgmt(oneapiLibPaths)
}
return gpuHandles
return oHandles
}
func GetGPUInfo() GpuInfoList {
@ -203,16 +220,22 @@ func GetGPUInfo() GpuInfoList {
gpuMutex.Lock()
defer gpuMutex.Unlock()
needRefresh := true
var gpuHandles *handles
var cHandles *cudaHandles
var oHandles *oneapiHandles
defer func() {
if gpuHandles == nil {
return
if cHandles != nil {
if cHandles.cudart != nil {
C.cudart_release(*cHandles.cudart)
}
if cHandles.nvcuda != nil {
C.nvcuda_release(*cHandles.nvcuda)
}
}
if gpuHandles.cudart != nil {
C.cudart_release(*gpuHandles.cudart)
}
if gpuHandles.nvcuda != nil {
C.nvcuda_release(*gpuHandles.nvcuda)
if oHandles != nil {
if oHandles.oneapi != nil {
// TODO - is this needed?
C.oneapi_release(*oHandles.oneapi)
}
}
}()
@ -253,13 +276,11 @@ func GetGPUInfo() GpuInfoList {
}
// Load ALL libraries
gpuHandles = initCudaHandles()
// TODO needs a refactoring pass to init oneapi handles
cHandles = initCudaHandles()
// NVIDIA
for i := range gpuHandles.deviceCount {
if gpuHandles.cudart != nil || gpuHandles.nvcuda != nil {
for i := range cHandles.deviceCount {
if cHandles.cudart != nil || cHandles.nvcuda != nil {
gpuInfo := CudaGPUInfo{
GpuInfo: GpuInfo{
Library: "cuda",
@ -268,12 +289,12 @@ func GetGPUInfo() GpuInfoList {
}
var driverMajor int
var driverMinor int
if gpuHandles.cudart != nil {
C.cudart_bootstrap(*gpuHandles.cudart, C.int(i), &memInfo)
if cHandles.cudart != nil {
C.cudart_bootstrap(*cHandles.cudart, C.int(i), &memInfo)
} else {
C.nvcuda_bootstrap(*gpuHandles.nvcuda, C.int(i), &memInfo)
driverMajor = int(gpuHandles.nvcuda.driver_major)
driverMinor = int(gpuHandles.nvcuda.driver_minor)
C.nvcuda_bootstrap(*cHandles.nvcuda, C.int(i), &memInfo)
driverMajor = int(cHandles.nvcuda.driver_major)
driverMinor = int(cHandles.nvcuda.driver_minor)
}
if memInfo.err != nil {
slog.Info("error looking up nvidia GPU memory", "error", C.GoString(memInfo.err))
@ -297,20 +318,35 @@ func GetGPUInfo() GpuInfoList {
// TODO potentially sort on our own algorithm instead of what the underlying GPU library does...
cudaGPUs = append(cudaGPUs, gpuInfo)
}
if gpuHandles.oneapi != nil {
}
// Intel
oHandles = initOneAPIHandles()
for d := 0; oHandles.oneapi != nil && d < int(oHandles.oneapi.num_drivers); d++ {
if oHandles.oneapi == nil {
// shouldn't happen
slog.Warn("nil oneapi handle with driver count", "count", int(oHandles.oneapi.num_drivers))
continue
}
devCount := C.oneapi_get_device_count(*oHandles.oneapi, C.int(d))
for i := 0; i < int(devCount); i++ {
gpuInfo := OneapiGPUInfo{
GpuInfo: GpuInfo{
Library: "oneapi",
},
index: i,
driverIndex: d,
gpuIndex: i,
}
// TODO - split bootstrapping from updating free memory
C.oneapi_check_vram(*gpuHandles.oneapi, &memInfo)
C.oneapi_check_vram(*oHandles.oneapi, C.int(d), C.int(i), &memInfo)
// TODO - convert this to MinimumMemory based on testing...
var totalFreeMem float64 = float64(memInfo.free) * 0.95 // work-around: leave some reserve vram for mkl lib used in ggml-sycl backend.
memInfo.free = C.uint64_t(totalFreeMem)
gpuInfo.TotalMemory = uint64(memInfo.total)
gpuInfo.FreeMemory = uint64(memInfo.free)
gpuInfo.ID = strconv.Itoa(i)
gpuInfo.ID = C.GoString(&memInfo.gpu_id[0])
gpuInfo.Name = C.GoString(&memInfo.gpu_name[0])
// TODO dependency path?
oneapiGPUs = append(oneapiGPUs, gpuInfo)
}
}
@ -325,14 +361,14 @@ func GetGPUInfo() GpuInfoList {
if needRefresh {
// TODO - CPU system memory tracking/refresh
var memInfo C.mem_info_t
if gpuHandles == nil && len(cudaGPUs) > 0 {
gpuHandles = initCudaHandles()
if cHandles == nil && len(cudaGPUs) > 0 {
cHandles = initCudaHandles()
}
for i, gpu := range cudaGPUs {
if gpuHandles.cudart != nil {
C.cudart_bootstrap(*gpuHandles.cudart, C.int(gpu.index), &memInfo)
if cHandles.cudart != nil {
C.cudart_bootstrap(*cHandles.cudart, C.int(gpu.index), &memInfo)
} else {
C.nvcuda_get_free(*gpuHandles.nvcuda, C.int(gpu.index), &memInfo.free)
C.nvcuda_get_free(*cHandles.nvcuda, C.int(gpu.index), &memInfo.free)
}
if memInfo.err != nil {
slog.Warn("error looking up nvidia GPU memory", "error", C.GoString(memInfo.err))
@ -346,6 +382,23 @@ func GetGPUInfo() GpuInfoList {
slog.Debug("updating cuda free memory", "gpu", gpu.ID, "name", gpu.Name, "before", format.HumanBytes2(gpu.FreeMemory), "now", format.HumanBytes2(uint64(memInfo.free)))
cudaGPUs[i].FreeMemory = uint64(memInfo.free)
}
if oHandles == nil && len(oneapiGPUs) > 0 {
oHandles = initOneAPIHandles()
}
for i, gpu := range oneapiGPUs {
if oHandles.oneapi == nil {
// shouldn't happen
slog.Warn("nil oneapi handle with device count", "count", oHandles.deviceCount)
continue
}
C.oneapi_check_vram(*oHandles.oneapi, C.int(gpu.driverIndex), C.int(gpu.gpuIndex), &memInfo)
// TODO - convert this to MinimumMemory based on testing...
var totalFreeMem float64 = float64(memInfo.free) * 0.95 // work-around: leave some reserve vram for mkl lib used in ggml-sycl backend.
memInfo.free = C.uint64_t(totalFreeMem)
oneapiGPUs[i].FreeMemory = uint64(memInfo.free)
}
err := RocmGPUInfoList(rocmGPUs).RefreshFreeMemory()
if err != nil {
slog.Debug("problem refreshing ROCm free memory", "error", err)
@ -359,6 +412,9 @@ func GetGPUInfo() GpuInfoList {
for _, gpu := range rocmGPUs {
resp = append(resp, gpu.GpuInfo)
}
for _, gpu := range oneapiGPUs {
resp = append(resp, gpu.GpuInfo)
}
if len(resp) == 0 {
resp = append(resp, cpus[0].GpuInfo)
}
@ -476,6 +532,7 @@ func LoadNVCUDAMgmt(nvcudaLibPaths []string) (int, *C.nvcuda_handle_t, string) {
func LoadOneapiMgmt(oneapiLibPaths []string) (int, *C.oneapi_handle_t, string) {
var resp C.oneapi_init_resp_t
num_devices := 0
resp.oh.verbose = getVerboseState()
for _, libPath := range oneapiLibPaths {
lib := C.CString(libPath)
@ -485,7 +542,10 @@ func LoadOneapiMgmt(oneapiLibPaths []string) (int, *C.oneapi_handle_t, string) {
slog.Debug("Unable to load oneAPI management library", "library", libPath, "error", C.GoString(resp.err))
C.free(unsafe.Pointer(resp.err))
} else {
return int(resp.num_devices), &resp.oh, libPath
for i := 0; i < int(resp.oh.num_drivers); i++ {
num_devices += int(C.oneapi_get_device_count(resp.oh, C.int(i)))
}
return num_devices, &resp.oh, libPath
}
}
return 0, nil, ""

View file

@ -8,9 +8,13 @@ void oneapi_init(char *oneapi_lib_path, oneapi_init_resp_t *resp)
{
ze_result_t ret;
resp->err = NULL;
resp->oh.devices = NULL;
resp->oh.num_devices = NULL;
resp->oh.drivers = NULL;
resp->oh.num_drivers = 0;
const int buflen = 256;
char buf[buflen + 1];
int i;
int i, d, count;
struct lookup
{
char *s;
@ -66,19 +70,65 @@ void oneapi_init(char *oneapi_lib_path, oneapi_init_resp_t *resp)
ret = (*resp->oh.zesInit)(0);
if (ret != ZE_RESULT_SUCCESS)
{
LOG(resp->oh.verbose, "zesInit err: %d\n", ret);
UNLOAD_LIBRARY(resp->oh.handle);
resp->oh.handle = NULL;
snprintf(buf, buflen, "oneapi vram init failure: %d", ret);
LOG(resp->oh.verbose, "zesInit err: %x\n", ret);
snprintf(buf, buflen, "oneapi vram init failure: %x", ret);
resp->err = strdup(buf);
oneapi_release(resp->oh);
return;
}
(*resp->oh.zesDriverGet)(&resp->num_devices, NULL);
count = 0;
ret = (*resp->oh.zesDriverGet)(&resp->oh.num_drivers, NULL);
if (ret != ZE_RESULT_SUCCESS)
{
LOG(resp->oh.verbose, "zesDriverGet err: %x\n", ret);
snprintf(buf, buflen, "unable to get driver count: %x", ret);
resp->err = strdup(buf);
oneapi_release(resp->oh);
return;
}
LOG(resp->oh.verbose, "oneapi driver count: %d\n", resp->oh.num_drivers);
resp->oh.drivers = malloc(resp->oh.num_drivers * sizeof(zes_driver_handle_t));
resp->oh.num_devices = malloc(resp->oh.num_drivers * sizeof(uint32_t));
memset(&resp->oh.num_devices[0], 0, resp->oh.num_drivers * sizeof(uint32_t));
resp->oh.devices = malloc(resp->oh.num_drivers * sizeof(zes_device_handle_t*));
ret = (*resp->oh.zesDriverGet)(&resp->oh.num_drivers, &resp->oh.drivers[0]);
if (ret != ZE_RESULT_SUCCESS)
{
LOG(resp->oh.verbose, "zesDriverGet err: %x\n", ret);
snprintf(buf, buflen, "unable to get driver count: %x", ret);
resp->err = strdup(buf);
oneapi_release(resp->oh);
return;
}
for (d = 0; d < resp->oh.num_drivers; d++) {
ret = (*resp->oh.zesDeviceGet)(resp->oh.drivers[d], &resp->oh.num_devices[d], NULL);
if (ret != ZE_RESULT_SUCCESS)
{
LOG(resp->oh.verbose, "zesDeviceGet err: %x\n", ret);
snprintf(buf, buflen, "unable to get device count: %x", ret);
resp->err = strdup(buf);
oneapi_release(resp->oh);
return;
}
resp->oh.devices[d] = malloc(resp->oh.num_devices[d] * sizeof(zes_device_handle_t));
ret = (*resp->oh.zesDeviceGet)(resp->oh.drivers[d], &resp->oh.num_devices[d], resp->oh.devices[d]);
if (ret != ZE_RESULT_SUCCESS)
{
LOG(resp->oh.verbose, "zesDeviceGet err: %x\n", ret);
snprintf(buf, buflen, "unable to get device count: %x", ret);
resp->err = strdup(buf);
oneapi_release(resp->oh);
return;
}
count += resp->oh.num_devices[d];
}
return;
}
void oneapi_check_vram(oneapi_handle_t h, mem_info_t *resp)
void oneapi_check_vram(oneapi_handle_t h, int driver, int device, mem_info_t *resp)
{
ze_result_t ret;
resp->err = NULL;
@ -93,122 +143,135 @@ void oneapi_check_vram(oneapi_handle_t h, mem_info_t *resp)
resp->err = strdup("Level-Zero handle not initialized");
return;
}
uint32_t driversCount = 0;
ret = (*h.zesDriverGet)(&driversCount, NULL);
if (ret != ZE_RESULT_SUCCESS)
{
snprintf(buf, buflen, "unable to get driver count: %d", ret);
resp->err = strdup(buf);
if (driver > h.num_drivers || device > h.num_devices[driver]) {
resp->err = strdup("driver of device index out of bounds");
return;
}
LOG(h.verbose, "discovered %d Level-Zero drivers\n", driversCount);
zes_driver_handle_t *allDrivers =
malloc(driversCount * sizeof(zes_driver_handle_t));
(*h.zesDriverGet)(&driversCount, allDrivers);
resp->total = 0;
resp->free = 0;
for (d = 0; d < driversCount; d++)
zes_device_ext_properties_t ext_props;
ext_props.stype = ZES_STRUCTURE_TYPE_DEVICE_EXT_PROPERTIES;
ext_props.pNext = NULL;
zes_device_properties_t props;
props.stype = ZES_STRUCTURE_TYPE_DEVICE_PROPERTIES;
props.pNext = &ext_props;
ret = (*h.zesDeviceGetProperties)(h.devices[driver][device], &props);
if (ret != ZE_RESULT_SUCCESS)
{
uint32_t deviceCount = 0;
ret = (*h.zesDeviceGet)(allDrivers[d], &deviceCount, NULL);
snprintf(buf, buflen, "unable to get device properties: %d", ret);
resp->err = strdup(buf);
return;
}
snprintf(&resp->gpu_name[0], GPU_NAME_LEN, props.modelName);
// TODO this needs to map to ONEAPI_DEVICE_SELECTOR syntax
// (this is probably wrong...)
// TODO - the driver isn't included - what if there are multiple drivers?
snprintf(&resp->gpu_id[0], GPU_ID_LEN, "%d", device);
if (h.verbose)
{
// When in verbose mode, report more information about
// the card we discover.
LOG(h.verbose, "[%d:%d] oneAPI device name: %s\n", driver, device,
props.modelName);
LOG(h.verbose, "[%d:%d] oneAPI brand: %s\n", driver, device,
props.brandName);
LOG(h.verbose, "[%d:%d] oneAPI vendor: %s\n", driver, device,
props.vendorName);
LOG(h.verbose, "[%d:%d] oneAPI S/N: %s\n", driver, device,
props.serialNumber);
LOG(h.verbose, "[%d:%d] oneAPI board number: %s\n", driver, device,
props.boardNumber);
}
// TODO
// Compute Capability equivalent in resp->major, resp->minor, resp->patch
uint32_t memCount = 0;
ret = (*h.zesDeviceEnumMemoryModules)(h.devices[driver][device], &memCount, NULL);
if (ret != ZE_RESULT_SUCCESS)
{
snprintf(buf, buflen,
"unable to enumerate Level-Zero memory modules: %x", ret);
resp->err = strdup(buf);
return;
}
LOG(h.verbose, "discovered %d Level-Zero memory modules\n", memCount);
zes_mem_handle_t *mems = malloc(memCount * sizeof(zes_mem_handle_t));
(*h.zesDeviceEnumMemoryModules)(h.devices[driver][device], &memCount, mems);
for (m = 0; m < memCount; m++)
{
zes_mem_state_t state;
state.stype = ZES_STRUCTURE_TYPE_MEM_STATE;
state.pNext = NULL;
ret = (*h.zesMemoryGetState)(mems[m], &state);
if (ret != ZE_RESULT_SUCCESS)
{
snprintf(buf, buflen, "unable to get device count: %d", ret);
snprintf(buf, buflen, "unable to get memory state: %x", ret);
resp->err = strdup(buf);
free(allDrivers);
free(mems);
return;
}
LOG(h.verbose, "discovered %d Level-Zero devices\n", deviceCount);
zes_device_handle_t *devices =
malloc(deviceCount * sizeof(zes_device_handle_t));
(*h.zesDeviceGet)(allDrivers[d], &deviceCount, devices);
for (i = 0; i < deviceCount; i++)
{
zes_device_ext_properties_t ext_props;
ext_props.stype = ZES_STRUCTURE_TYPE_DEVICE_EXT_PROPERTIES;
ext_props.pNext = NULL;
zes_device_properties_t props;
props.stype = ZES_STRUCTURE_TYPE_DEVICE_PROPERTIES;
props.pNext = &ext_props;
ret = (*h.zesDeviceGetProperties)(devices[i], &props);
if (ret != ZE_RESULT_SUCCESS)
{
snprintf(buf, buflen, "unable to get device properties: %d", ret);
resp->err = strdup(buf);
free(allDrivers);
free(devices);
return;
}
if (h.verbose)
{
// When in verbose mode, report more information about
// the card we discover.
LOG(h.verbose, "[%d] oneAPI device name: %s\n", i,
props.modelName);
LOG(h.verbose, "[%d] oneAPI brand: %s\n", i,
props.brandName);
LOG(h.verbose, "[%d] oneAPI vendor: %s\n", i,
props.vendorName);
LOG(h.verbose, "[%d] oneAPI S/N: %s\n", i,
props.serialNumber);
LOG(h.verbose, "[%d] oneAPI board number: %s\n", i,
props.boardNumber);
}
uint32_t memCount = 0;
ret = (*h.zesDeviceEnumMemoryModules)(devices[i], &memCount, NULL);
if (ret != ZE_RESULT_SUCCESS)
{
snprintf(buf, buflen,
"unable to enumerate Level-Zero memory modules: %d", ret);
resp->err = strdup(buf);
free(allDrivers);
free(devices);
return;
}
LOG(h.verbose, "discovered %d Level-Zero memory modules\n", memCount);
zes_mem_handle_t *mems = malloc(memCount * sizeof(zes_mem_handle_t));
(*h.zesDeviceEnumMemoryModules)(devices[i], &memCount, mems);
for (m = 0; m < memCount; m++)
{
zes_mem_state_t state;
state.stype = ZES_STRUCTURE_TYPE_MEM_STATE;
state.pNext = NULL;
ret = (*h.zesMemoryGetState)(mems[m], &state);
if (ret != ZE_RESULT_SUCCESS)
{
snprintf(buf, buflen, "unable to get memory state: %d", ret);
resp->err = strdup(buf);
free(allDrivers);
free(devices);
free(mems);
return;
}
resp->total += state.size;
resp->free += state.free;
}
free(mems);
}
free(devices);
resp->total += state.size;
resp->free += state.free;
}
free(allDrivers);
free(mems);
}
void oneapi_release(oneapi_handle_t h)
{
int d;
LOG(h.verbose, "releasing oneapi library\n");
for (d = 0; d < h.num_drivers; d++)
{
if (h.devices != NULL && h.devices[d] != NULL)
{
free(h.devices[d]);
}
}
if (h.devices != NULL)
{
free(h.devices);
h.devices = NULL;
}
if (h.num_devices != NULL)
{
free(h.num_devices);
h.num_devices = NULL;
}
if (h.drivers != NULL)
{
free(h.drivers);
h.drivers = NULL;
}
h.num_drivers = 0;
UNLOAD_LIBRARY(h.handle);
h.handle = NULL;
}
int oneapi_get_device_count(oneapi_handle_t h, int driver)
{
if (h.handle == NULL || h.num_devices == NULL)
{
return 0;
}
if (driver > h.num_drivers)
{
return 0;
}
return (int)h.num_devices[driver];
}
#endif // __APPLE__

View file

@ -175,6 +175,16 @@ typedef struct oneapi_handle
{
void *handle;
uint16_t verbose;
uint32_t num_drivers;
zes_driver_handle_t *drivers;
uint32_t *num_devices;
zes_device_handle_t **devices;
// TODO Driver major, minor information
// int driver_major;
// int driver_minor;
ze_result_t (*zesInit)(int);
ze_result_t (*zesDriverGet)(uint32_t *pCount, zes_driver_handle_t *phDrivers);
ze_result_t (*zesDeviceGet)(zes_driver_handle_t hDriver, uint32_t *pCount,
@ -194,7 +204,6 @@ typedef struct oneapi_handle
typedef struct oneapi_init_resp
{
char *err; // If err is non-null handle is invalid
int num_devices;
oneapi_handle_t oh;
} oneapi_init_resp_t;
@ -205,7 +214,9 @@ typedef struct oneapi_version_resp
} oneapi_version_resp_t;
void oneapi_init(char *oneapi_lib_path, oneapi_init_resp_t *resp);
void oneapi_check_vram(oneapi_handle_t rh, mem_info_t *resp);
void oneapi_check_vram(oneapi_handle_t h, int driver, int device, mem_info_t *resp);
void oneapi_release(oneapi_handle_t h);
int oneapi_get_device_count(oneapi_handle_t h, int driver);
#endif // __GPU_INFO_INTEL_H__
#endif // __APPLE__

View file

@ -57,7 +57,8 @@ type RocmGPUInfoList []RocmGPUInfo
type OneapiGPUInfo struct {
GpuInfo
index int // device index
driverIndex int // nolint: unused
gpuIndex int // nolint: unused
}
type OneapiGPUInfoList []OneapiGPUInfo