diff --git a/gpu/gpu_darwin.go b/gpu/gpu_darwin.go index ecca3b7b..ed63f718 100644 --- a/gpu/gpu_darwin.go +++ b/gpu/gpu_darwin.go @@ -33,11 +33,15 @@ func getCPUMem() (memInfo, error) { } func NumGPU(numLayer, fileSizeBytes int64, opts api.Options) int { + if opts.NumGPU != -1 { + return opts.NumGPU + } + + // metal only supported on arm64 if runtime.GOARCH == "arm64" { return 1 } - // metal only supported on arm64 return 0 }