Update llama_cpp.py
Make shared library code more robust with some platform specific functionality and more descriptive errors when failures occur
This commit is contained in:
parent
b9a4513363
commit
a40476e299
1 changed files with 41 additions and 9 deletions
|
@ -1,17 +1,49 @@
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
import ctypes
|
import ctypes
|
||||||
|
|
||||||
from ctypes import c_int, c_float, c_char_p, c_void_p, c_bool, POINTER, Structure, Array, c_uint8, c_size_t
|
from ctypes import c_int, c_float, c_char_p, c_void_p, c_bool, POINTER, Structure, Array, c_uint8, c_size_t
|
||||||
|
|
||||||
import pathlib
|
import pathlib
|
||||||
from itertools import chain
|
|
||||||
|
|
||||||
# Load the library
|
# Load the library
|
||||||
# TODO: fragile, should fix
|
def load_shared_library(lib_base_name):
|
||||||
_base_path = pathlib.Path(__file__).parent
|
# Determine the file extension based on the platform
|
||||||
(_lib_path,) = chain(
|
if sys.platform.startswith("linux"):
|
||||||
_base_path.glob("*.so"), _base_path.glob("*.dylib"), _base_path.glob("*.dll")
|
lib_ext = ".so"
|
||||||
)
|
elif sys.platform == "darwin":
|
||||||
_lib = ctypes.CDLL(str(_lib_path))
|
lib_ext = ".dylib"
|
||||||
|
elif sys.platform == "win32":
|
||||||
|
lib_ext = ".dll"
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Unsupported platform")
|
||||||
|
|
||||||
|
# Construct the paths to the possible shared library names
|
||||||
|
_base_path = pathlib.Path(__file__).parent.resolve()
|
||||||
|
# Searching for the library in the current directory under the name "libllama" (default name
|
||||||
|
# for llamacpp) and "llama" (default name for this repo)
|
||||||
|
_lib_paths = [
|
||||||
|
_base_path / f"lib{lib_base_name}{lib_ext}",
|
||||||
|
_base_path / f"{lib_base_name}{lib_ext}"
|
||||||
|
]
|
||||||
|
|
||||||
|
# Add the library directory to the DLL search path on Windows (if needed)
|
||||||
|
if sys.platform == "win32" and sys.version_info >= (3, 8):
|
||||||
|
os.add_dll_directory(str(_base_path))
|
||||||
|
|
||||||
|
# Try to load the shared library, handling potential errors
|
||||||
|
for _lib_path in _lib_paths:
|
||||||
|
if _lib_path.exists():
|
||||||
|
try:
|
||||||
|
return ctypes.CDLL(str(_lib_path))
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Failed to load shared library '{_lib_path}': {e}")
|
||||||
|
|
||||||
|
raise FileNotFoundError(f"Shared library with base name '{lib_base_name}' not found")
|
||||||
|
|
||||||
|
# Specify the base name of the shared library to load
|
||||||
|
lib_base_name = "llama"
|
||||||
|
|
||||||
|
# Load the library
|
||||||
|
_lib = load_shared_library(lib_base_name)
|
||||||
|
|
||||||
# C types
|
# C types
|
||||||
llama_context_p = c_void_p
|
llama_context_p = c_void_p
|
||||||
|
|
Loading…
Reference in a new issue