diff --git a/llm/llama.go b/llm/llama.go index f731acf4..903c5f74 100644 --- a/llm/llama.go +++ b/llm/llama.go @@ -196,7 +196,10 @@ type llama struct { Running } -var errNoGPU = errors.New("nvidia-smi command failed") +var ( + errNvidiaSMI = errors.New("nvidia-smi command failed") + errAvailableVRAM = errors.New("not enough VRAM available, falling back to CPU only") +) // CheckVRAM returns the free VRAM in bytes on Linux machines with NVIDIA GPUs func CheckVRAM() (int64, error) { @@ -205,7 +208,7 @@ func CheckVRAM() (int64, error) { cmd.Stdout = &stdout err := cmd.Run() if err != nil { - return 0, errNoGPU + return 0, errNvidiaSMI } var freeMiB int64 @@ -226,8 +229,8 @@ func CheckVRAM() (int64, error) { freeBytes := freeMiB * 1024 * 1024 if freeBytes < 2*format.GigaByte { - log.Printf("less than 2 GB VRAM available, falling back to CPU only") - freeMiB = 0 + log.Printf("less than 2 GB VRAM available") + return 0, errAvailableVRAM } return freeBytes, nil @@ -240,7 +243,7 @@ func NumGPU(numLayer, fileSizeBytes int64, opts api.Options) int { if runtime.GOOS == "linux" { freeBytes, err := CheckVRAM() if err != nil { - if err.Error() != "nvidia-smi command failed" { + if !errors.Is(err, errNvidiaSMI) { log.Print(err.Error()) } // nvidia driver not installed or no nvidia GPU found