2024-03-30 16:50:05 +00:00
|
|
|
//go:build linux || windows
|
|
|
|
|
|
|
|
package gpu
|
|
|
|
|
|
|
|
import (
|
|
|
|
"log/slog"
|
2024-06-14 03:46:14 +00:00
|
|
|
"os"
|
|
|
|
"regexp"
|
|
|
|
"runtime"
|
|
|
|
"strconv"
|
2024-03-30 16:50:05 +00:00
|
|
|
"strings"
|
|
|
|
)
|
|
|
|
|
2024-06-14 03:46:14 +00:00
|
|
|
// Jetson devices have JETSON_JETPACK="x.y.z" factory set to the Jetpack version installed.
|
|
|
|
// Included to drive logic for reducing Ollama-allocated overhead on L4T/Jetson devices.
|
|
|
|
var CudaTegra string = os.Getenv("JETSON_JETPACK")
|
|
|
|
|
2024-03-30 16:50:05 +00:00
|
|
|
func cudaGetVisibleDevicesEnv(gpuInfo []GpuInfo) (string, string) {
|
|
|
|
ids := []string{}
|
|
|
|
for _, info := range gpuInfo {
|
|
|
|
if info.Library != "cuda" {
|
|
|
|
// TODO shouldn't happen if things are wired correctly...
|
|
|
|
slog.Debug("cudaGetVisibleDevicesEnv skipping over non-cuda device", "library", info.Library)
|
|
|
|
continue
|
|
|
|
}
|
|
|
|
ids = append(ids, info.ID)
|
|
|
|
}
|
|
|
|
return "CUDA_VISIBLE_DEVICES", strings.Join(ids, ",")
|
|
|
|
}
|
2024-06-14 03:46:14 +00:00
|
|
|
|
|
|
|
func cudaGetVariant(gpuInfo CudaGPUInfo) string {
|
|
|
|
if runtime.GOARCH == "arm64" && runtime.GOOS == "linux" {
|
|
|
|
if CudaTegra != "" {
|
|
|
|
ver := strings.Split(CudaTegra, ".")
|
|
|
|
if len(ver) > 0 {
|
|
|
|
return "jetpack" + ver[0]
|
|
|
|
}
|
|
|
|
} else if data, err := os.ReadFile("/etc/nv_tegra_release"); err == nil {
|
|
|
|
r := regexp.MustCompile(` R(\d+) `)
|
|
|
|
m := r.FindSubmatch(data)
|
|
|
|
if len(m) != 2 {
|
|
|
|
slog.Info("Unexpected format for /etc/nv_tegra_release. Set JETSON_JETPACK to select version")
|
|
|
|
} else {
|
|
|
|
if l4t, err := strconv.Atoi(string(m[1])); err == nil {
|
|
|
|
// Note: mapping from L4t -> JP is inconsistent (can't just subtract 30)
|
|
|
|
// https://developer.nvidia.com/embedded/jetpack-archive
|
|
|
|
switch l4t {
|
|
|
|
case 35:
|
|
|
|
return "jetpack5"
|
|
|
|
case 36:
|
|
|
|
return "jetpack6"
|
|
|
|
default:
|
|
|
|
slog.Info("unsupported L4T version", "nv_tegra_release", string(data))
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
if gpuInfo.computeMajor < 6 || gpuInfo.DriverMajor < 12 {
|
|
|
|
return "v11"
|
|
|
|
}
|
|
|
|
return "v12"
|
|
|
|
}
|