gpu: Group GPU Library sets by variant (#6483)

The recent cuda variant changes uncovered a bug in ByLibrary
which failed to group by common variant for GPU types.
This commit is contained in:
Daniel Hiltgen 2024-08-23 15:11:56 -07:00 committed by GitHub
parent 9638c24c58
commit 69be940bf6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 26 additions and 1 deletions

View file

@ -32,4 +32,29 @@ func TestCPUMemInfo(t *testing.T) {
} }
} }
func TestByLibrary(t *testing.T) {
type testCase struct {
input []GpuInfo
expect int
}
testCases := map[string]*testCase{
"empty": {input: []GpuInfo{}, expect: 0},
"cpu": {input: []GpuInfo{{Library: "cpu"}}, expect: 1},
"cpu + GPU": {input: []GpuInfo{{Library: "cpu"}, {Library: "cuda"}}, expect: 2},
"cpu + 2 GPU no variant": {input: []GpuInfo{{Library: "cpu"}, {Library: "cuda"}, {Library: "cuda"}}, expect: 2},
"cpu + 2 GPU same variant": {input: []GpuInfo{{Library: "cpu"}, {Library: "cuda", Variant: "v11"}, {Library: "cuda", Variant: "v11"}}, expect: 2},
"cpu + 2 GPU diff variant": {input: []GpuInfo{{Library: "cpu"}, {Library: "cuda", Variant: "v11"}, {Library: "cuda", Variant: "v12"}}, expect: 3},
}
for k, v := range testCases {
t.Run(k, func(t *testing.T) {
resp := (GpuInfoList)(v.input).ByLibrary()
if len(resp) != v.expect {
t.Fatalf("expected length %d, got %d => %+v", v.expect, len(resp), resp)
}
})
}
}
// TODO - add some logic to figure out card type through other means and actually verify we got back what we expected // TODO - add some logic to figure out card type through other means and actually verify we got back what we expected

View file

@ -94,7 +94,7 @@ func (l GpuInfoList) ByLibrary() []GpuInfoList {
} }
} }
if !found { if !found {
libs = append(libs, info.Library) libs = append(libs, requested)
resp = append(resp, []GpuInfo{info}) resp = append(resp, []GpuInfo{info})
} }
} }