diff --git a/llama_cpp/llama_cpp.py b/llama_cpp/llama_cpp.py index 50eae2d..75800c0 100644 --- a/llama_cpp/llama_cpp.py +++ b/llama_cpp/llama_cpp.py @@ -62,6 +62,9 @@ def _load_shared_library(lib_base_name: str): if "CUDA_PATH" in os.environ: os.add_dll_directory(os.path.join(os.environ["CUDA_PATH"], "bin")) os.add_dll_directory(os.path.join(os.environ["CUDA_PATH"], "lib")) + if "HIP_PATH" in os.environ: + os.add_dll_directory(os.path.join(os.environ["HIP_PATH"], "bin")) + os.add_dll_directory(os.path.join(os.environ["HIP_PATH"], "lib")) cdll_args["winmode"] = ctypes.RTLD_GLOBAL # Try to load the shared library, handling potential errors