diff --git a/llama_cpp/llama_cpp.py b/llama_cpp/llama_cpp.py index 37d4637..e5ae4f7 100644 --- a/llama_cpp/llama_cpp.py +++ b/llama_cpp/llama_cpp.py @@ -109,12 +109,13 @@ if TYPE_CHECKING: CtypesFuncPointer: TypeAlias = ctypes._FuncPointer # type: ignore +F = TypeVar("F", bound=Callable[..., Any]) def ctypes_function_for_shared_library(lib: ctypes.CDLL): def ctypes_function( name: str, argtypes: List[Any], restype: Any, enabled: bool = True ): - def decorator(f: Callable[..., Any]): + def decorator(f: F) -> F: if enabled: func = getattr(lib, name) func.argtypes = argtypes