Merge pull request #10 from MillionthOdin16/patch-1
Improve Shared Library Loading Mechanism
This commit is contained in:
commit
1d9a988644
1 changed files with 41 additions and 9 deletions
|
@ -1,17 +1,49 @@
|
|||
import sys
|
||||
import os
|
||||
import ctypes
|
||||
|
||||
from ctypes import c_int, c_float, c_char_p, c_void_p, c_bool, POINTER, Structure, Array, c_uint8, c_size_t
|
||||
|
||||
import pathlib
|
||||
from itertools import chain
|
||||
|
||||
# Load the library
|
||||
# TODO: fragile, should fix
|
||||
_base_path = pathlib.Path(__file__).parent
|
||||
(_lib_path,) = chain(
|
||||
_base_path.glob("*.so"), _base_path.glob("*.dylib"), _base_path.glob("*.dll")
|
||||
)
|
||||
_lib = ctypes.CDLL(str(_lib_path))
|
||||
def _load_shared_library(lib_base_name):
|
||||
# Determine the file extension based on the platform
|
||||
if sys.platform.startswith("linux"):
|
||||
lib_ext = ".so"
|
||||
elif sys.platform == "darwin":
|
||||
lib_ext = ".dylib"
|
||||
elif sys.platform == "win32":
|
||||
lib_ext = ".dll"
|
||||
else:
|
||||
raise RuntimeError("Unsupported platform")
|
||||
|
||||
# Construct the paths to the possible shared library names
|
||||
_base_path = pathlib.Path(__file__).parent.resolve()
|
||||
# Searching for the library in the current directory under the name "libllama" (default name
|
||||
# for llamacpp) and "llama" (default name for this repo)
|
||||
_lib_paths = [
|
||||
_base_path / f"lib{lib_base_name}{lib_ext}",
|
||||
_base_path / f"{lib_base_name}{lib_ext}"
|
||||
]
|
||||
|
||||
# Add the library directory to the DLL search path on Windows (if needed)
|
||||
if sys.platform == "win32" and sys.version_info >= (3, 8):
|
||||
os.add_dll_directory(str(_base_path))
|
||||
|
||||
# Try to load the shared library, handling potential errors
|
||||
for _lib_path in _lib_paths:
|
||||
if _lib_path.exists():
|
||||
try:
|
||||
return ctypes.CDLL(str(_lib_path))
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load shared library '{_lib_path}': {e}")
|
||||
|
||||
raise FileNotFoundError(f"Shared library with base name '{lib_base_name}' not found")
|
||||
|
||||
# Specify the base name of the shared library to load
|
||||
_lib_base_name = "llama"
|
||||
|
||||
# Load the library
|
||||
_lib = _load_shared_library(_lib_base_name)
|
||||
|
||||
# C types
|
||||
llama_context_p = c_void_p
|
||||
|
|
Loading…
Add table
Reference in a new issue