diff --git a/llama_cpp/llama_cpp.py b/llama_cpp/llama_cpp.py index 8a5869c..0f2b4d5 100644 --- a/llama_cpp/llama_cpp.py +++ b/llama_cpp/llama_cpp.py @@ -1,9 +1,21 @@ import sys import os import ctypes -from ctypes import c_int, c_float, c_char_p, c_void_p, c_bool, POINTER, Structure, Array, c_uint8, c_size_t +from ctypes import ( + c_int, + c_float, + c_char_p, + c_void_p, + c_bool, + POINTER, + Structure, + Array, + c_uint8, + c_size_t, +) import pathlib + # Load the library def _load_shared_library(lib_base_name): # Determine the file extension based on the platform @@ -22,10 +34,10 @@ def _load_shared_library(lib_base_name): # for llamacpp) and "llama" (default name for this repo) _lib_paths = [ _base_path / f"lib{lib_base_name}{lib_ext}", - _base_path / f"{lib_base_name}{lib_ext}" + _base_path / f"{lib_base_name}{lib_ext}", ] - if ("LLAMA_CPP_LIB" in os.environ): + if "LLAMA_CPP_LIB" in os.environ: lib_base_name = os.environ["LLAMA_CPP_LIB"] _lib = pathlib.Path(lib_base_name) _base_path = _lib.parent.resolve() @@ -43,7 +55,10 @@ def _load_shared_library(lib_base_name): except Exception as e: raise RuntimeError(f"Failed to load shared library '{_lib_path}': {e}") - raise FileNotFoundError(f"Shared library with base name '{lib_base_name}' not found") + raise FileNotFoundError( + f"Shared library with base name '{lib_base_name}' not found" + ) + # Specify the base name of the shared library to load _lib_base_name = "llama" @@ -95,6 +110,10 @@ class llama_context_params(Structure): llama_context_params_p = POINTER(llama_context_params) +LLAMA_FTYPE_ALL_F32 = ctypes.c_int(0) +LLAMA_FTYPE_MOSTLY_F16 = ctypes.c_int(1) # except 1d tensors +LLAMA_FTYPE_MOSTLY_Q4_0 = ctypes.c_int(2) # except 1d tensors +LLAMA_FTYPE_MOSTLY_Q4_1 = ctypes.c_int(3) # except 1d tensors # Functions @@ -106,18 +125,23 @@ def llama_context_default_params() -> llama_context_params: _lib.llama_context_default_params.argtypes = [] _lib.llama_context_default_params.restype = llama_context_params + def llama_mmap_supported() -> c_bool: return _lib.llama_mmap_supported() + _lib.llama_mmap_supported.argtypes = [] _lib.llama_mmap_supported.restype = c_bool + def llama_mlock_supported() -> c_bool: return _lib.llama_mlock_supported() + _lib.llama_mlock_supported.argtypes = [] _lib.llama_mlock_supported.restype = c_bool + # Various functions for loading a ggml llama model. # Allocate (almost) all memory needed for the model. # Return NULL on failure @@ -142,42 +166,49 @@ _lib.llama_free.restype = None # TODO: not great API - very likely to change # Returns 0 on success -def llama_model_quantize( - fname_inp: bytes, fname_out: bytes, itype: c_int -) -> c_int: +def llama_model_quantize(fname_inp: bytes, fname_out: bytes, itype: c_int) -> c_int: return _lib.llama_model_quantize(fname_inp, fname_out, itype) _lib.llama_model_quantize.argtypes = [c_char_p, c_char_p, c_int] _lib.llama_model_quantize.restype = c_int + # Returns the KV cache that will contain the context for the # ongoing prediction with the model. def llama_get_kv_cache(ctx: llama_context_p): return _lib.llama_get_kv_cache(ctx) + _lib.llama_get_kv_cache.argtypes = [llama_context_p] _lib.llama_get_kv_cache.restype = POINTER(c_uint8) + # Returns the size of the KV cache def llama_get_kv_cache_size(ctx: llama_context_p) -> c_size_t: return _lib.llama_get_kv_cache_size(ctx) + _lib.llama_get_kv_cache_size.argtypes = [llama_context_p] _lib.llama_get_kv_cache_size.restype = c_size_t + # Returns the number of tokens in the KV cache def llama_get_kv_cache_token_count(ctx: llama_context_p) -> c_int: return _lib.llama_get_kv_cache_token_count(ctx) + _lib.llama_get_kv_cache_token_count.argtypes = [llama_context_p] _lib.llama_get_kv_cache_token_count.restype = c_int # Sets the KV cache containing the current context for the model -def llama_set_kv_cache(ctx: llama_context_p, kv_cache, n_size: c_size_t, n_token_count: c_int): +def llama_set_kv_cache( + ctx: llama_context_p, kv_cache, n_size: c_size_t, n_token_count: c_int +): return _lib.llama_set_kv_cache(ctx, kv_cache, n_size, n_token_count) + _lib.llama_set_kv_cache.argtypes = [llama_context_p, POINTER(c_uint8), c_size_t, c_int] _lib.llama_set_kv_cache.restype = None diff --git a/vendor/llama.cpp b/vendor/llama.cpp index 684da25..3e6e70d 160000 --- a/vendor/llama.cpp +++ b/vendor/llama.cpp @@ -1 +1 @@ -Subproject commit 684da25926e5c505f725b4f10b5485b218fa1fc7 +Subproject commit 3e6e70d8e8917b5bd14c7c9f9b89a585f1ff0b31