Merge pull request #1935 from dhiltgen/cpu_fallback

Fix up the CPU fallback selection
This commit is contained in:
Daniel Hiltgen 2024-01-11 15:52:32 -08:00 committed by GitHub
commit 3773fb6465
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 29 additions and 16 deletions

View file

@ -34,7 +34,7 @@ func GetGPUInfo() GpuInfo {
mem, _ := getCPUMem() mem, _ := getCPUMem()
if runtime.GOARCH == "amd64" { if runtime.GOARCH == "amd64" {
return GpuInfo{ return GpuInfo{
Library: "default", Library: "cpu",
Variant: GetCPUVariant(), Variant: GetCPUVariant(),
memInfo: mem, memInfo: mem,
} }

View file

@ -51,7 +51,6 @@ func New(workDir, model string, adapters, projectors []string, opts api.Options)
graph := int64(ggml.NumGQA()) * kv / 6 graph := int64(ggml.NumGQA()) * kv / 6
info := gpu.GetGPUInfo() info := gpu.GetGPUInfo()
library := info.Library
switch runtime.GOOS { switch runtime.GOOS {
case "darwin": case "darwin":
if opts.NumGPU == 0 { if opts.NumGPU == 0 {
@ -60,13 +59,15 @@ func New(workDir, model string, adapters, projectors []string, opts api.Options)
if size+kv+graph > vram { if size+kv+graph > vram {
log.Println("not enough vram available, falling back to CPU only") log.Println("not enough vram available, falling back to CPU only")
info.Library = "cpu"
info.Variant = gpu.GetCPUVariant()
opts.NumGPU = 0 opts.NumGPU = 0
break break
} }
opts.NumGPU = 1 opts.NumGPU = 1
default: default:
if library == "cpu" || library == "default" { if info.Library == "cpu" {
log.Println("GPU not available, falling back to CPU") log.Println("GPU not available, falling back to CPU")
opts.NumGPU = 0 opts.NumGPU = 0
break break
@ -74,7 +75,8 @@ func New(workDir, model string, adapters, projectors []string, opts api.Options)
// don't use GPU at all if no layers are loaded // don't use GPU at all if no layers are loaded
if opts.NumGPU == 0 { if opts.NumGPU == 0 {
library = "cpu" info.Library = "cpu"
info.Variant = gpu.GetCPUVariant()
break break
} }
@ -101,7 +103,8 @@ func New(workDir, model string, adapters, projectors []string, opts api.Options)
min := graph + kv*layers/maxlayers min := graph + kv*layers/maxlayers
if layers <= 0 || min > avg { if layers <= 0 || min > avg {
log.Printf("not enough vram available, falling back to CPU only") log.Printf("not enough vram available, falling back to CPU only")
library = "cpu" info.Library = "cpu"
info.Variant = gpu.GetCPUVariant()
opts.NumGPU = 0 opts.NumGPU = 0
break break
} }
@ -111,8 +114,7 @@ func New(workDir, model string, adapters, projectors []string, opts api.Options)
opts.RopeFrequencyBase = 0.0 opts.RopeFrequencyBase = 0.0
opts.RopeFrequencyScale = 0.0 opts.RopeFrequencyScale = 0.0
gpuInfo := gpu.GetGPUInfo() return newLlmServer(info, model, adapters, projectors, opts)
return newLlmServer(gpuInfo, model, adapters, projectors, opts)
} }
// Give any native cgo implementations an opportunity to initialize // Give any native cgo implementations an opportunity to initialize

View file

@ -28,6 +28,13 @@ func getDynLibs(gpuInfo gpu.GpuInfo) []string {
if gpuInfo.Library == "default" { if gpuInfo.Library == "default" {
return []string{"default"} return []string{"default"}
} }
// TODO - temporary until we have multiple CPU variations for Darwin
// Short circuit on darwin with metal only
if len(availableDynLibs) == 1 {
if _, onlyMetal := availableDynLibs["metal"]; onlyMetal {
return []string{availableDynLibs["metal"]}
}
}
exactMatch := "" exactMatch := ""
dynLibs := []string{} dynLibs := []string{}

View file

@ -16,39 +16,43 @@ func TestGetDynLibs(t *testing.T) {
assert.Len(t, res, 1) assert.Len(t, res, 1)
assert.Equal(t, availableDynLibs["cpu"], res[0]) assert.Equal(t, availableDynLibs["cpu"], res[0])
variant := gpu.GetCPUVariant()
if variant != "" {
variant = "_" + variant
}
availableDynLibs = map[string]string{ availableDynLibs = map[string]string{
"rocm_v5": "X_rocm_v5", "rocm_v5": "X_rocm_v5",
"rocm_v6": "X_rocm_v6", "rocm_v6": "X_rocm_v6",
"cpu": "X_cpu", "cpu" + variant: "X_cpu",
} }
assert.Equal(t, true, rocmDynLibPresent()) assert.Equal(t, true, rocmDynLibPresent())
res = getDynLibs(gpu.GpuInfo{Library: "rocm"}) res = getDynLibs(gpu.GpuInfo{Library: "rocm"})
assert.Len(t, res, 3) assert.Len(t, res, 3)
assert.Equal(t, availableDynLibs["rocm_v5"], res[0]) assert.Equal(t, availableDynLibs["rocm_v5"], res[0])
assert.Equal(t, availableDynLibs["rocm_v6"], res[1]) assert.Equal(t, availableDynLibs["rocm_v6"], res[1])
assert.Equal(t, availableDynLibs["cpu"], res[2]) assert.Equal(t, availableDynLibs["cpu"+variant], res[2])
res = getDynLibs(gpu.GpuInfo{Library: "rocm", Variant: "v6"}) res = getDynLibs(gpu.GpuInfo{Library: "rocm", Variant: "v6"})
assert.Len(t, res, 3) assert.Len(t, res, 3)
assert.Equal(t, availableDynLibs["rocm_v6"], res[0]) assert.Equal(t, availableDynLibs["rocm_v6"], res[0])
assert.Equal(t, availableDynLibs["rocm_v5"], res[1]) assert.Equal(t, availableDynLibs["rocm_v5"], res[1])
assert.Equal(t, availableDynLibs["cpu"], res[2]) assert.Equal(t, availableDynLibs["cpu"+variant], res[2])
res = getDynLibs(gpu.GpuInfo{Library: "cuda"}) res = getDynLibs(gpu.GpuInfo{Library: "cuda"})
assert.Len(t, res, 1) assert.Len(t, res, 1)
assert.Equal(t, availableDynLibs["cpu"], res[0]) assert.Equal(t, availableDynLibs["cpu"+variant], res[0])
res = getDynLibs(gpu.GpuInfo{Library: "default"}) res = getDynLibs(gpu.GpuInfo{Library: "default"})
assert.Len(t, res, 1) assert.Len(t, res, 1)
assert.Equal(t, "default", res[0]) assert.Equal(t, "default", res[0])
availableDynLibs = map[string]string{ availableDynLibs = map[string]string{
"rocm": "X_rocm_v5", "rocm": "X_rocm_v5",
"cpu": "X_cpu", "cpu" + variant: "X_cpu",
} }
assert.Equal(t, true, rocmDynLibPresent()) assert.Equal(t, true, rocmDynLibPresent())
res = getDynLibs(gpu.GpuInfo{Library: "rocm", Variant: "v6"}) res = getDynLibs(gpu.GpuInfo{Library: "rocm", Variant: "v6"})
assert.Len(t, res, 2) assert.Len(t, res, 2)
assert.Equal(t, availableDynLibs["rocm"], res[0]) assert.Equal(t, availableDynLibs["rocm"], res[0])
assert.Equal(t, availableDynLibs["cpu"], res[1]) assert.Equal(t, availableDynLibs["cpu"+variant], res[1])
} }