From 11d82d7b9b94b971ee2965d9a683a80e4929097a Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Fri, 13 Oct 2023 14:45:50 -0700 Subject: [PATCH] update checkvram --- llm/llama.go | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/llm/llama.go b/llm/llama.go index db51429c..80463eeb 100644 --- a/llm/llama.go +++ b/llm/llama.go @@ -24,6 +24,7 @@ import ( "time" "github.com/jmorganca/ollama/api" + "github.com/jmorganca/ollama/format" ) //go:embed llama.cpp/*/build/*/bin/* @@ -197,7 +198,7 @@ type llama struct { var errNoGPU = errors.New("nvidia-smi command failed") -// CheckVRAM returns the available VRAM in MiB on Linux machines with NVIDIA GPUs +// CheckVRAM returns the free VRAM in bytes on Linux machines with NVIDIA GPUs func CheckVRAM() (int64, error) { cmd := exec.Command("nvidia-smi", "--query-gpu=memory.free", "--format=csv,noheader,nounits") var stdout bytes.Buffer @@ -207,7 +208,7 @@ func CheckVRAM() (int64, error) { return 0, errNoGPU } - var free int64 + var freeMiB int64 scanner := bufio.NewScanner(&stdout) for scanner.Scan() { line := scanner.Text() @@ -216,15 +217,16 @@ func CheckVRAM() (int64, error) { return 0, fmt.Errorf("failed to parse available VRAM: %v", err) } - free += vram + freeMiB += vram } - if free*1024*1024 < 2*1000*1000*1000 { + freeBytes := freeMiB * 1024 * 1024 + if freeBytes < 2*format.GigaByte { log.Printf("less than 2 GB VRAM available, falling back to CPU only") - free = 0 + freeMiB = 0 } - return free, nil + return freeBytes, nil } func NumGPU(numLayer, fileSizeBytes int64, opts api.Options) int { @@ -232,7 +234,7 @@ func NumGPU(numLayer, fileSizeBytes int64, opts api.Options) int { return opts.NumGPU } if runtime.GOOS == "linux" { - vramMib, err := CheckVRAM() + freeBytes, err := CheckVRAM() if err != nil { if err.Error() != "nvidia-smi command failed" { log.Print(err.Error()) @@ -241,15 +243,13 @@ func NumGPU(numLayer, fileSizeBytes int64, opts api.Options) int { return 0 } - freeVramBytes := int64(vramMib) * 1024 * 1024 // 1 MiB = 1024^2 bytes - // Calculate bytes per layer // TODO: this is a rough heuristic, better would be to calculate this based on number of layers and context size bytesPerLayer := fileSizeBytes / numLayer // max number of layers we can fit in VRAM, subtract 8% to prevent consuming all available VRAM and running out of memory - layers := int(freeVramBytes/bytesPerLayer) * 92 / 100 - log.Printf("%d MiB VRAM available, loading up to %d GPU layers", vramMib, layers) + layers := int(freeBytes/bytesPerLayer) * 92 / 100 + log.Printf("%d MiB VRAM available, loading up to %d GPU layers", freeBytes, layers) return layers }