diff --git a/gpu/gpu.go b/gpu/gpu.go index b7d1c1ad..6937de7a 100644 --- a/gpu/gpu.go +++ b/gpu/gpu.go @@ -110,6 +110,8 @@ func GetGPUInfo() GpuInfo { C.free(unsafe.Pointer(memInfo.err)) return resp } + + resp.DeviceCount = uint32(memInfo.count) resp.FreeMemory = uint64(memInfo.free) resp.TotalMemory = uint64(memInfo.total) return resp @@ -132,7 +134,7 @@ func CheckVRAM() (int64, error) { gpuInfo := GetGPUInfo() if gpuInfo.FreeMemory > 0 && (gpuInfo.Library == "cuda" || gpuInfo.Library == "rocm") { // leave 10% or 384Mi of VRAM free for unaccounted for overhead - overhead := gpuInfo.FreeMemory / 10 + overhead := gpuInfo.FreeMemory * uint64(gpuInfo.DeviceCount) / 10 if overhead < 384*1024*1024 { overhead = 384 * 1024 * 1024 } diff --git a/gpu/gpu_darwin.go b/gpu/gpu_darwin.go index b3556f90..23c95e36 100644 --- a/gpu/gpu_darwin.go +++ b/gpu/gpu_darwin.go @@ -42,6 +42,7 @@ func getCPUMem() (memInfo, error) { return memInfo{ TotalMemory: 0, FreeMemory: 0, + DeviceCount: 0, }, nil } diff --git a/gpu/gpu_info.h b/gpu/gpu_info.h index 3b2edc70..5ba19271 100644 --- a/gpu/gpu_info.h +++ b/gpu/gpu_info.h @@ -34,6 +34,7 @@ extern "C" { typedef struct mem_info { uint64_t total; uint64_t free; + unsigned int count; char *err; // If non-nill, caller responsible for freeing } mem_info_t; diff --git a/gpu/gpu_info_cpu.c b/gpu/gpu_info_cpu.c index 38e2a563..0c4d62c5 100644 --- a/gpu/gpu_info_cpu.c +++ b/gpu/gpu_info_cpu.c @@ -8,6 +8,7 @@ void cpu_check_ram(mem_info_t *resp) { MEMORYSTATUSEX info; info.dwLength = sizeof(info); if (GlobalMemoryStatusEx(&info) != 0) { + resp->count = 1; resp->total = info.ullTotalPhys; resp->free = info.ullAvailPhys; } else { @@ -26,6 +27,7 @@ void cpu_check_ram(mem_info_t *resp) { if (sysinfo(&info) != 0) { resp->err = strdup(strerror(errno)); } else { + resp->count = 1; resp->total = info.totalram * info.mem_unit; resp->free = info.freeram * info.mem_unit; } diff --git a/gpu/gpu_info_cuda.c b/gpu/gpu_info_cuda.c index 9dc97bd9..9e76b791 100644 --- a/gpu/gpu_info_cuda.c +++ b/gpu/gpu_info_cuda.c @@ -94,8 +94,7 @@ void cuda_check_vram(cuda_handle_t h, mem_info_t *resp) { return; } - unsigned int devices; - ret = (*h.getCount)(&devices); + ret = (*h.getCount)(&resp->count); if (ret != NVML_SUCCESS) { snprintf(buf, buflen, "unable to get device count: %d", ret); resp->err = strdup(buf); @@ -104,8 +103,7 @@ void cuda_check_vram(cuda_handle_t h, mem_info_t *resp) { resp->total = 0; resp->free = 0; - - for (i = 0; i < devices; i++) { + for (i = 0; i < resp->count; i++) { ret = (*h.getHandle)(i, &device); if (ret != NVML_SUCCESS) { snprintf(buf, buflen, "unable to get device handle %d: %d", i, ret); diff --git a/gpu/gpu_info_rocm.c b/gpu/gpu_info_rocm.c index 367d11fd..9901172b 100644 --- a/gpu/gpu_info_rocm.c +++ b/gpu/gpu_info_rocm.c @@ -110,6 +110,8 @@ void rocm_check_vram(rocm_handle_t h, mem_info_t *resp) { return; } + // TODO: set this to the actual number of devices + resp->count = 1; resp->total = totalMem; resp->free = totalMem - usedMem; return; diff --git a/gpu/gpu_test.go b/gpu/gpu_test.go index d5585d3c..c260211e 100644 --- a/gpu/gpu_test.go +++ b/gpu/gpu_test.go @@ -18,6 +18,7 @@ func TestBasicGetGPUInfo(t *testing.T) { case "linux", "windows": assert.Greater(t, info.TotalMemory, uint64(0)) assert.Greater(t, info.FreeMemory, uint64(0)) + assert.Greater(t, info.DeviceCount, uint64(0)) default: return } @@ -35,7 +36,6 @@ func TestCPUMemInfo(t *testing.T) { default: return } - } // TODO - add some logic to figure out card type through other means and actually verify we got back what we expected diff --git a/gpu/types.go b/gpu/types.go index c3c39210..abc16dbc 100644 --- a/gpu/types.go +++ b/gpu/types.go @@ -3,6 +3,7 @@ package gpu type memInfo struct { TotalMemory uint64 `json:"total_memory,omitempty"` FreeMemory uint64 `json:"free_memory,omitempty"` + DeviceCount uint32 `json:"device_count,omitempty"` } // Beginning of an `ollama info` command