Refine handling of shim presence

This allows the CPU only builds to work on systems with Radeon cards
This commit is contained in:
Daniel Hiltgen 2023-12-15 14:27:27 -08:00
parent 1b991d0ba9
commit 3269535a4c
2 changed files with 8 additions and 7 deletions

View file

@ -22,6 +22,9 @@ type LLM interface {
Close() Close()
} }
// Set to false on linux/windows if we are able to load the shim
var ShimPresent = false
func New(workDir, model string, adapters, projectors []string, opts api.Options) (LLM, error) { func New(workDir, model string, adapters, projectors []string, opts api.Options) (LLM, error) {
if _, err := os.Stat(model); err != nil { if _, err := os.Stat(model); err != nil {
return nil, err return nil, err
@ -79,11 +82,10 @@ func New(workDir, model string, adapters, projectors []string, opts api.Options)
opts.RopeFrequencyBase = 0.0 opts.RopeFrequencyBase = 0.0
opts.RopeFrequencyScale = 0.0 opts.RopeFrequencyScale = 0.0
gpuInfo := gpu.GetGPUInfo() gpuInfo := gpu.GetGPUInfo()
switch gpuInfo.Driver { if gpuInfo.Driver == "ROCM" && ShimPresent {
case "ROCM":
return newRocmShimExtServer(model, adapters, projectors, ggml.NumLayers(), opts) return newRocmShimExtServer(model, adapters, projectors, ggml.NumLayers(), opts)
default: } else {
// Rely on the built-in CUDA based server which will fall back to CPU // Rely on the built-in CUDA/Metal based server which will fall back to CPU
return newLlamaExtServer(model, adapters, projectors, ggml.NumLayers(), opts) return newLlamaExtServer(model, adapters, projectors, ggml.NumLayers(), opts)
} }
} }

View file

@ -30,7 +30,6 @@ import (
var libEmbed embed.FS var libEmbed embed.FS
var RocmShimMissing = fmt.Errorf("ROCm shim library not included in this build of ollama. Radeon GPUs are not supported") var RocmShimMissing = fmt.Errorf("ROCm shim library not included in this build of ollama. Radeon GPUs are not supported")
var NoShim = true
type shimExtServer struct { type shimExtServer struct {
s C.struct_rocm_llama_server s C.struct_rocm_llama_server
@ -78,7 +77,7 @@ func (llm *shimExtServer) llama_server_release_json_resp(json_resp **C.char) {
} }
func newRocmShimExtServer(model string, adapters, projectors []string, numLayers int64, opts api.Options) (extServer, error) { func newRocmShimExtServer(model string, adapters, projectors []string, numLayers int64, opts api.Options) (extServer, error) {
if NoShim { if !ShimPresent {
return nil, RocmShimMissing return nil, RocmShimMissing
} }
log.Printf("Loading ROCM llm server") log.Printf("Loading ROCM llm server")
@ -207,6 +206,6 @@ func extractLib(workDir string) error {
case err != nil: case err != nil:
return fmt.Errorf("stat ROCm shim %s: %v", files[0], err) return fmt.Errorf("stat ROCm shim %s: %v", files[0], err)
} }
NoShim = false ShimPresent = true
return nil return nil
} }