diff --git a/llama_cpp/_internals.py b/llama_cpp/_internals.py index c60fdff..10c1a6f 100644 --- a/llama_cpp/_internals.py +++ b/llama_cpp/_internals.py @@ -108,7 +108,7 @@ class _LlamaModel: scale, path_base_model.encode("utf-8") if path_base_model is not None - else llama_cpp.c_char_p(0), + else ctypes.c_char_p(0), n_threads, ) @@ -303,8 +303,8 @@ class _LlamaContext: assert self.ctx is not None assert batch.batch is not None return_code = llama_cpp.llama_decode( - ctx=self.ctx, - batch=batch.batch, + self.ctx, + batch.batch, ) if return_code != 0: raise RuntimeError(f"llama_decode returned {return_code}") @@ -493,7 +493,7 @@ class _LlamaBatch: def __init__( self, *, n_tokens: int, embd: int, n_seq_max: int, verbose: bool = True ): - self.n_tokens = n_tokens + self._n_tokens = n_tokens self.embd = embd self.n_seq_max = n_seq_max self.verbose = verbose @@ -502,7 +502,7 @@ class _LlamaBatch: self.batch = None self.batch = llama_cpp.llama_batch_init( - self.n_tokens, self.embd, self.n_seq_max + self._n_tokens, self.embd, self.n_seq_max ) def __del__(self): @@ -570,12 +570,13 @@ class _LlamaTokenDataArray: self.candidates.data = self.candidates_data.ctypes.data_as( llama_cpp.llama_token_data_p ) - self.candidates.sorted = llama_cpp.c_bool(False) - self.candidates.size = llama_cpp.c_size_t(self.n_vocab) + self.candidates.sorted = ctypes.c_bool(False) + self.candidates.size = ctypes.c_size_t(self.n_vocab) # Python wrappers over common/common def _tokenize(model: _LlamaModel, text: str, add_bos: bool, special: bool) -> list[int]: + assert model.model is not None n_tokens = len(text) + 1 if add_bos else len(text) result = (llama_cpp.llama_token * n_tokens)() n_tokens = llama_cpp.llama_tokenize( diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 65902c8..9fc4ec2 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -1818,7 +1818,7 @@ class Llama: self.input_ids = state.input_ids.copy() self.n_tokens = state.n_tokens state_size = state.llama_state_size - LLamaStateArrayType = llama_cpp.c_uint8 * state_size + LLamaStateArrayType = ctypes.c_uint8 * state_size llama_state = LLamaStateArrayType.from_buffer_copy(state.llama_state) if llama_cpp.llama_set_state_data(self._ctx.ctx, llama_state) != state_size: diff --git a/llama_cpp/llama_cpp.py b/llama_cpp/llama_cpp.py index 3ebe82b..69dbe09 100644 --- a/llama_cpp/llama_cpp.py +++ b/llama_cpp/llama_cpp.py @@ -2,26 +2,11 @@ import sys import os import ctypes from ctypes import ( - c_bool, - c_char_p, - c_int, - c_int8, - c_int32, - c_uint8, - c_uint32, - c_int64, - c_size_t, - c_float, - c_double, - c_void_p, - POINTER, _Pointer, # type: ignore - Structure, - Union as CtypesUnion, Array, ) import pathlib -from typing import List, Union +from typing import List, Union, NewType, Optional # Load the library @@ -71,7 +56,7 @@ def _load_shared_library(lib_base_name: str): for _lib_path in _lib_paths: if _lib_path.exists(): try: - return ctypes.CDLL(str(_lib_path), **cdll_args) + return ctypes.CDLL(str(_lib_path), **cdll_args) # type: ignore except Exception as e: raise RuntimeError(f"Failed to load shared library '{_lib_path}': {e}") @@ -87,13 +72,13 @@ _lib_base_name = "llama" _lib = _load_shared_library(_lib_base_name) # Misc -c_float_p = POINTER(c_float) -c_uint8_p = POINTER(c_uint8) -c_size_t_p = POINTER(c_size_t) +c_float_p = ctypes.POINTER(ctypes.c_float) +c_uint8_p = ctypes.POINTER(ctypes.c_uint8) +c_size_t_p = ctypes.POINTER(ctypes.c_size_t) # from ggml-backend.h # typedef bool (*ggml_backend_sched_eval_callback)(struct ggml_tensor * t, bool ask, void * user_data); -ggml_backend_sched_eval_callback = ctypes.CFUNCTYPE(c_bool, c_void_p, c_bool, c_void_p) +ggml_backend_sched_eval_callback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_void_p, ctypes.c_bool, ctypes.c_void_p) # llama.h bindings @@ -121,19 +106,21 @@ LLAMA_SESSION_VERSION = 4 # struct llama_model; -llama_model_p = c_void_p +llama_model_p = NewType("llama_model_p", int) +llama_model_p_ctypes = ctypes.c_void_p # struct llama_context; -llama_context_p = c_void_p +llama_context_p = NewType("llama_context_p", int) +llama_context_p_ctypes = ctypes.c_void_p # typedef int32_t llama_pos; -llama_pos = c_int32 +llama_pos = ctypes.c_int32 # typedef int32_t llama_token; -llama_token = c_int32 -llama_token_p = POINTER(llama_token) +llama_token = ctypes.c_int32 +llama_token_p = ctypes.POINTER(llama_token) # typedef int32_t llama_seq_id; -llama_seq_id = c_int32 +llama_seq_id = ctypes.c_int32 # enum llama_vocab_type { @@ -258,7 +245,7 @@ LLAMA_SPLIT_ROW = 2 # float logit; // log-odds of the token # float p; // probability of the token # } llama_token_data; -class llama_token_data(Structure): +class llama_token_data(ctypes.Structure): """Used to store token data Attributes: @@ -268,12 +255,12 @@ class llama_token_data(Structure): _fields_ = [ ("id", llama_token), - ("logit", c_float), - ("p", c_float), + ("logit", ctypes.c_float), + ("p", ctypes.c_float), ] -llama_token_data_p = POINTER(llama_token_data) +llama_token_data_p = ctypes.POINTER(llama_token_data) # typedef struct llama_token_data_array { @@ -281,7 +268,7 @@ llama_token_data_p = POINTER(llama_token_data) # size_t size; # bool sorted; # } llama_token_data_array; -class llama_token_data_array(Structure): +class llama_token_data_array(ctypes.Structure): """Used to sample tokens given logits Attributes: @@ -291,15 +278,15 @@ class llama_token_data_array(Structure): _fields_ = [ ("data", llama_token_data_p), - ("size", c_size_t), - ("sorted", c_bool), + ("size", ctypes.c_size_t), + ("sorted", ctypes.c_bool), ] -llama_token_data_array_p = POINTER(llama_token_data_array) +llama_token_data_array_p = ctypes.POINTER(llama_token_data_array) # typedef bool (*llama_progress_callback)(float progress, void *ctx); -llama_progress_callback = ctypes.CFUNCTYPE(c_bool, c_float, c_void_p) +llama_progress_callback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_float, ctypes.c_void_p) # // Input data for llama_decode @@ -332,7 +319,7 @@ llama_progress_callback = ctypes.CFUNCTYPE(c_bool, c_float, c_void_p) # llama_pos all_pos_1; // used if pos == NULL # llama_seq_id all_seq_id; // used if seq_id == NULL # } llama_batch; -class llama_batch(Structure): +class llama_batch(ctypes.Structure): """Input data for llama_decode A llama_batch object can contain input about one or many sequences @@ -341,19 +328,19 @@ class llama_batch(Structure): Attributes: token (ctypes.Array[llama_token]): the token ids of the input (used when embd is NULL) - embd (ctypes.Array[ctypes.c_float]): token embeddings (i.e. float vector of size n_embd) (used when token is NULL) + embd (ctypes.Array[ctypes.ctypes.c_float]): token embeddings (i.e. float vector of size n_embd) (used when token is NULL) pos (ctypes.Array[ctypes.Array[llama_pos]]): the positions of the respective token in the sequence seq_id (ctypes.Array[ctypes.Array[llama_seq_id]]): the sequence to which the respective token belongs """ _fields_ = [ - ("n_tokens", c_int32), - ("token", POINTER(llama_token)), + ("n_tokens", ctypes.c_int32), + ("token", ctypes.POINTER(llama_token)), ("embd", c_float_p), - ("pos", POINTER(llama_pos)), - ("n_seq_id", POINTER(c_int32)), - ("seq_id", POINTER(POINTER(llama_seq_id))), - ("logits", POINTER(c_int8)), + ("pos", ctypes.POINTER(llama_pos)), + ("n_seq_id", ctypes.POINTER(ctypes.c_int32)), + ("seq_id", ctypes.POINTER(ctypes.POINTER(llama_seq_id))), + ("logits", ctypes.POINTER(ctypes.c_int8)), ("all_pos_0", llama_pos), ("all_pos_1", llama_pos), ("all_seq_id", llama_seq_id), @@ -379,18 +366,18 @@ LLAMA_KV_OVERRIDE_BOOL = 2 # bool bool_value; # }; # }; -class llama_model_kv_override_value(CtypesUnion): +class llama_model_kv_override_value(ctypes.Union): _fields_ = [ - ("int_value", c_int64), - ("float_value", c_double), - ("bool_value", c_bool), + ("int_value", ctypes.c_int64), + ("float_value", ctypes.c_double), + ("bool_value", ctypes.c_bool), ] -class llama_model_kv_override(Structure): +class llama_model_kv_override(ctypes.Structure): _fields_ = [ ("key", ctypes.c_char * 128), - ("tag", c_int), + ("tag", ctypes.c_int), ("value", llama_model_kv_override_value), ] @@ -425,32 +412,32 @@ class llama_model_kv_override(Structure): # bool use_mmap; // use mmap if possible # bool use_mlock; // force system to keep model in RAM # }; -class llama_model_params(Structure): +class llama_model_params(ctypes.Structure): """Parameters for llama_model Attributes: n_gpu_layers (int): number of layers to store in VRAM split_mode (int): how to split the model across multiple GPUs main_gpu (int): the GPU that is used for the entire model. main_gpu interpretation depends on split_mode: LLAMA_SPLIT_NONE: the GPU that is used for the entire model LLAMA_SPLIT_ROW: the GPU that is used for small tensors and intermediate results LLAMA_SPLIT_LAYER: ignored - tensor_split (ctypes.Array[ctypes.c_float]): proportion of the model (layers or rows) to offload to each GPU, size: llama_max_devices() + tensor_split (ctypes.Array[ctypes.ctypes.c_float]): proportion of the model (layers or rows) to offload to each GPU, size: llama_max_devices() progress_callback (llama_progress_callback): called with a progress value between 0.0 and 1.0. Pass NULL to disable. If the provided progress_callback returns true, model loading continues. If it returns false, model loading is immediately aborted. - progress_callback_user_data (ctypes.c_void_p): context pointer passed to the progress callback + progress_callback_user_data (ctypes.ctypes.c_void_p): context pointer passed to the progress callback kv_overrides (ctypes.Array[llama_model_kv_override]): override key-value pairs of the model meta data vocab_only (bool): only load the vocabulary, no weights use_mmap (bool): use mmap if possible use_mlock (bool): force system to keep model in RAM""" _fields_ = [ - ("n_gpu_layers", c_int32), - ("split_mode", c_int), - ("main_gpu", c_int32), + ("n_gpu_layers", ctypes.c_int32), + ("split_mode", ctypes.c_int), + ("main_gpu", ctypes.c_int32), ("tensor_split", c_float_p), ("progress_callback", llama_progress_callback), - ("progress_callback_user_data", c_void_p), - ("kv_overrides", POINTER(llama_model_kv_override)), - ("vocab_only", c_bool), - ("use_mmap", c_bool), - ("use_mlock", c_bool), + ("progress_callback_user_data", ctypes.c_void_p), + ("kv_overrides", ctypes.POINTER(llama_model_kv_override)), + ("vocab_only", ctypes.c_bool), + ("use_mmap", ctypes.c_bool), + ("use_mlock", ctypes.c_bool), ] @@ -485,7 +472,7 @@ class llama_model_params(Structure): # bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU # bool do_pooling; // whether to pool (sum) embedding results by sequence id (ignored if no pooling layer) # }; -class llama_context_params(Structure): +class llama_context_params(ctypes.Structure): """Parameters for llama_context Attributes: @@ -503,7 +490,7 @@ class llama_context_params(Structure): yarn_beta_slow (float): YaRN high correction dim yarn_orig_ctx (int): YaRN original context size cb_eval (ggml_backend_sched_eval_callback): callback for scheduling eval - cb_eval_user_data (ctypes.c_void_p): user data for cb_eval + cb_eval_user_data (ctypes.ctypes.c_void_p): user data for cb_eval type_k (int): data type for K cache type_v (int): data type for V cache mul_mat_q (bool): if true, use experimental mul_mat_q kernels (DEPRECATED - always true) @@ -514,28 +501,28 @@ class llama_context_params(Structure): """ _fields_ = [ - ("seed", c_uint32), - ("n_ctx", c_uint32), - ("n_batch", c_uint32), - ("n_threads", c_uint32), - ("n_threads_batch", c_uint32), - ("rope_scaling_type", c_int32), - ("rope_freq_base", c_float), - ("rope_freq_scale", c_float), - ("yarn_ext_factor", c_float), - ("yarn_attn_factor", c_float), - ("yarn_beta_fast", c_float), - ("yarn_beta_slow", c_float), - ("yarn_orig_ctx", c_uint32), + ("seed", ctypes.c_uint32), + ("n_ctx", ctypes.c_uint32), + ("n_batch", ctypes.c_uint32), + ("n_threads", ctypes.c_uint32), + ("n_threads_batch", ctypes.c_uint32), + ("rope_scaling_type", ctypes.c_int32), + ("rope_freq_base", ctypes.c_float), + ("rope_freq_scale", ctypes.c_float), + ("yarn_ext_factor", ctypes.c_float), + ("yarn_attn_factor", ctypes.c_float), + ("yarn_beta_fast", ctypes.c_float), + ("yarn_beta_slow", ctypes.c_float), + ("yarn_orig_ctx", ctypes.c_uint32), ("cb_eval", ggml_backend_sched_eval_callback), - ("cb_eval_user_data", c_void_p), - ("type_k", c_int), - ("type_v", c_int), - ("mul_mat_q", c_bool), - ("logits_all", c_bool), - ("embedding", c_bool), - ("offload_kqv", c_bool), - ("do_pooling", c_bool), + ("cb_eval_user_data", ctypes.c_void_p), + ("type_k", ctypes.c_int), + ("type_v", ctypes.c_int), + ("mul_mat_q", ctypes.c_bool), + ("logits_all", ctypes.c_bool), + ("embedding", ctypes.c_bool), + ("offload_kqv", ctypes.c_bool), + ("do_pooling", ctypes.c_bool), ] @@ -545,7 +532,7 @@ class llama_context_params(Structure): # // if it exists. # // It might not exist for progress report where '.' is output repeatedly. # typedef void (*llama_log_callback)(enum llama_log_level level, const char * text, void * user_data); -llama_log_callback = ctypes.CFUNCTYPE(None, c_int, c_char_p, c_void_p) +llama_log_callback = ctypes.CFUNCTYPE(None, ctypes.c_int, ctypes.c_char_p, ctypes.c_void_p) """Signature for logging events Note that text includes the new line character at the end for most events. If your logging mechanism cannot handle that, check if the last character is '\n' and strip it @@ -563,7 +550,7 @@ It might not exist for progress report where '.' is output repeatedly.""" # bool pure; // disable k-quant mixtures and quantize all tensors to the same type # void * imatrix; // pointer to importance matrix data # } llama_model_quantize_params; -class llama_model_quantize_params(Structure): +class llama_model_quantize_params(ctypes.Structure): """Parameters for llama_model_quantize Attributes: @@ -573,23 +560,23 @@ class llama_model_quantize_params(Structure): quantize_output_tensor (bool): quantize output.weight only_copy (bool): only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored pure (bool): disable k-quant mixtures and quantize all tensors to the same type - imatrix (ctypes.c_void_p): pointer to importance matrix data + imatrix (ctypes.ctypes.c_void_p): pointer to importance matrix data """ _fields_ = [ - ("nthread", c_int32), - ("ftype", c_int), - ("allow_requantize", c_bool), - ("quantize_output_tensor", c_bool), - ("only_copy", c_bool), - ("pure", c_bool), - ("imatrix", c_void_p), + ("nthread", ctypes.c_int32), + ("ftype", ctypes.c_int), + ("allow_requantize", ctypes.c_bool), + ("quantize_output_tensor", ctypes.c_bool), + ("only_copy", ctypes.c_bool), + ("pure", ctypes.c_bool), + ("imatrix", ctypes.c_void_p), ] # // grammar types # struct llama_grammar; -llama_grammar_p = c_void_p +llama_grammar_p = ctypes.c_void_p # // grammar element type # enum llama_gretype { @@ -629,14 +616,14 @@ LLAMA_GRETYPE_CHAR_ALT = 6 # enum llama_gretype type; # uint32_t value; // Unicode code point or rule ID # } llama_grammar_element; -class llama_grammar_element(Structure): +class llama_grammar_element(ctypes.Structure): _fields_ = [ - ("type", c_int), - ("value", c_uint32), + ("type", ctypes.c_int), + ("value", ctypes.c_uint32), ] -llama_grammar_element_p = POINTER(llama_grammar_element) +llama_grammar_element_p = ctypes.POINTER(llama_grammar_element) # // performance timing information # struct llama_timings { @@ -652,17 +639,17 @@ llama_grammar_element_p = POINTER(llama_grammar_element) # int32_t n_p_eval; # int32_t n_eval; # }; -class llama_timings(Structure): +class llama_timings(ctypes.Structure): _fields_ = [ - ("t_start_ms", c_double), - ("t_end_ms", c_double), - ("t_load_ms", c_double), - ("t_sample_ms", c_double), - ("t_p_eval_ms", c_double), - ("t_eval_ms", c_double), - ("n_sample", c_int32), - ("n_p_eval", c_int32), - ("n_eval", c_int32), + ("t_start_ms", ctypes.c_double), + ("t_end_ms", ctypes.c_double), + ("t_load_ms", ctypes.c_double), + ("t_sample_ms", ctypes.c_double), + ("t_p_eval_ms", ctypes.c_double), + ("t_eval_ms", ctypes.c_double), + ("n_sample", ctypes.c_int32), + ("n_p_eval", ctypes.c_int32), + ("n_eval", ctypes.c_int32), ] @@ -671,10 +658,10 @@ class llama_timings(Structure): # const char * role; # const char * content; # } llama_chat_message; -class llama_chat_message(Structure): +class llama_chat_message(ctypes.Structure): _fields_ = [ - ("role", c_char_p), - ("content", c_char_p), + ("role", ctypes.c_char_p), + ("content", ctypes.c_char_p), ] @@ -682,31 +669,34 @@ class llama_chat_message(Structure): # LLAMA_API struct llama_model_params llama_model_default_params(void); def llama_model_default_params() -> llama_model_params: """Get default parameters for llama_model""" - return _lib.llama_model_default_params() + ... -_lib.llama_model_default_params.argtypes = [] -_lib.llama_model_default_params.restype = llama_model_params +llama_model_default_params = _lib.llama_model_default_params +llama_model_default_params.argtypes = [] +llama_model_default_params.restype = llama_model_params # LLAMA_API struct llama_context_params llama_context_default_params(void); def llama_context_default_params() -> llama_context_params: """Get default parameters for llama_context""" - return _lib.llama_context_default_params() + ... -_lib.llama_context_default_params.argtypes = [] -_lib.llama_context_default_params.restype = llama_context_params +llama_context_default_params = _lib.llama_context_default_params +llama_context_default_params.argtypes = [] +llama_context_default_params.restype = llama_context_params # LLAMA_API struct llama_model_quantize_params llama_model_quantize_default_params(void); def llama_model_quantize_default_params() -> llama_model_quantize_params: """Get default parameters for llama_model_quantize""" - return _lib.llama_model_quantize_default_params() + ... -_lib.llama_model_quantize_default_params.argtypes = [] -_lib.llama_model_quantize_default_params.restype = llama_model_quantize_params +llama_model_quantize_default_params = _lib.llama_model_quantize_default_params +llama_model_quantize_default_params.argtypes = [] +llama_model_quantize_default_params.restype = llama_model_quantize_params # // Initialize the llama + ggml backend @@ -718,11 +708,12 @@ def llama_backend_init(): """Initialize the llama + ggml backend If numa is true, use NUMA optimizations Call once at the start of the program""" - return _lib.llama_backend_init() + ... -_lib.llama_backend_init.argtypes = [] -_lib.llama_backend_init.restype = None +llama_backend_init = _lib.llama_backend_init +llama_backend_init.argtypes = [] +llama_backend_init.restype = None # // numa strategies @@ -744,206 +735,227 @@ GGML_NUMA_STRATEGY_COUNT = 5 # //optional: # LLAMA_API void llama_numa_init(enum ggml_numa_strategy numa); -def llama_numa_init(numa: int): - return _lib.llama_numa_init(numa) +def llama_numa_init(numa: int, /): + ... -_lib.llama_numa_init.argtypes = [c_int] -_lib.llama_numa_init.restype = None +llama_numa_init = _lib.llama_numa_init +llama_numa_init.argtypes = [ctypes.c_int] +llama_numa_init.restype = None # // Call once at the end of the program - currently only used for MPI # LLAMA_API void llama_backend_free(void); def llama_backend_free(): """Call once at the end of the program - currently only used for MPI""" - return _lib.llama_backend_free() + ... -_lib.llama_backend_free.argtypes = [] -_lib.llama_backend_free.restype = None +llama_backend_free = _lib.llama_backend_free +llama_backend_free.argtypes = [] +llama_backend_free.restype = None # LLAMA_API struct llama_model * llama_load_model_from_file( # const char * path_model, # struct llama_model_params params); def llama_load_model_from_file( - path_model: bytes, params: llama_model_params -) -> llama_model_p: - return _lib.llama_load_model_from_file(path_model, params) + path_model: bytes, params: llama_model_params, / +) -> Optional[llama_model_p]: + ... -_lib.llama_load_model_from_file.argtypes = [c_char_p, llama_model_params] -_lib.llama_load_model_from_file.restype = llama_model_p +llama_load_model_from_file = _lib.llama_load_model_from_file +llama_load_model_from_file.argtypes = [ctypes.c_char_p, llama_model_params] +llama_load_model_from_file.restype = llama_model_p_ctypes # LLAMA_API void llama_free_model(struct llama_model * model); -def llama_free_model(model: llama_model_p): - return _lib.llama_free_model(model) +def llama_free_model(model: llama_model_p, /): + ... -_lib.llama_free_model.argtypes = [llama_model_p] -_lib.llama_free_model.restype = None +llama_free_model = _lib.llama_free_model +llama_free_model.argtypes = [llama_model_p_ctypes] +llama_free_model.restype = None # LLAMA_API struct llama_context * llama_new_context_with_model( # struct llama_model * model, # struct llama_context_params params); def llama_new_context_with_model( - model: llama_model_p, params: llama_context_params -) -> llama_context_p: - return _lib.llama_new_context_with_model(model, params) + model: llama_model_p, params: llama_context_params, / +) -> Optional[llama_context_p]: + ... -_lib.llama_new_context_with_model.argtypes = [llama_model_p, llama_context_params] -_lib.llama_new_context_with_model.restype = llama_context_p +llama_new_context_with_model = _lib.llama_new_context_with_model +llama_new_context_with_model.argtypes = [llama_model_p_ctypes, llama_context_params] +llama_new_context_with_model.restype = llama_context_p_ctypes # // Frees all allocated memory # LLAMA_API void llama_free(struct llama_context * ctx); -def llama_free(ctx: llama_context_p): +def llama_free(ctx: llama_context_p, /): """Frees all allocated memory""" - return _lib.llama_free(ctx) + ... -_lib.llama_free.argtypes = [llama_context_p] -_lib.llama_free.restype = None +llama_free = _lib.llama_free +llama_free.argtypes = [llama_context_p_ctypes] +llama_free.restype = None # LLAMA_API int64_t llama_time_us(void); def llama_time_us() -> int: - return _lib.llama_time_us() + ... -_lib.llama_time_us.argtypes = [] -_lib.llama_time_us.restype = ctypes.c_int64 +llama_time_us = _lib.llama_time_us +llama_time_us.argtypes = [] +llama_time_us.restype = ctypes.c_int64 # LLAMA_API size_t llama_max_devices(void); def llama_max_devices() -> int: - return _lib.llama_max_devices() + ... -_lib.llama_max_devices.argtypes = [] -_lib.llama_max_devices.restype = ctypes.c_size_t +llama_max_devices = _lib.llama_max_devices +llama_max_devices.argtypes = [] +llama_max_devices.restype = ctypes.c_size_t # LLAMA_API bool llama_supports_mmap (void); def llama_supports_mmap() -> bool: - return _lib.llama_supports_mmap() + ... -_lib.llama_supports_mmap.argtypes = [] -_lib.llama_supports_mmap.restype = c_bool +llama_supports_mmap = _lib.llama_supports_mmap +llama_supports_mmap.argtypes = [] +llama_supports_mmap.restype = ctypes.c_bool # LLAMA_API bool llama_supports_mlock (void); def llama_supports_mlock() -> bool: - return _lib.llama_supports_mlock() + ... -_lib.llama_supports_mlock.argtypes = [] -_lib.llama_supports_mlock.restype = c_bool +llama_supports_mlock = _lib.llama_supports_mlock +llama_supports_mlock.argtypes = [] +llama_supports_mlock.restype = ctypes.c_bool # LLAMA_API bool llama_supports_gpu_offload(void); def llama_supports_gpu_offload() -> bool: - return _lib.llama_supports_gpu_offload() + ... -_lib.llama_supports_gpu_offload.argtypes = [] -_lib.llama_supports_gpu_offload.restype = c_bool +llama_supports_gpu_offload = _lib.llama_supports_gpu_offload +llama_supports_gpu_offload.argtypes = [] +llama_supports_gpu_offload.restype = ctypes.c_bool # LLAMA_API DEPRECATED(bool llama_mmap_supported (void), "use llama_supports_mmap() instead"); def llama_mmap_supported() -> bool: - return _lib.llama_mmap_supported() + ... -_lib.llama_mmap_supported.argtypes = [] -_lib.llama_mmap_supported.restype = c_bool +llama_mmap_supported = _lib.llama_mmap_supported +llama_mmap_supported.argtypes = [] +llama_mmap_supported.restype = ctypes.c_bool # LLAMA_API DEPRECATED(bool llama_mlock_supported(void), "use llama_supports_mlock() instead"); def llama_mlock_supported() -> bool: - return _lib.llama_mlock_supported() + ... -_lib.llama_mlock_supported.argtypes = [] -_lib.llama_mlock_supported.restype = c_bool +llama_mlock_supported = _lib.llama_mlock_supported +llama_mlock_supported.argtypes = [] +llama_mlock_supported.restype = ctypes.c_bool # LLAMA_API const struct llama_model * llama_get_model(const struct llama_context * ctx); -def llama_get_model(ctx: llama_context_p) -> llama_model_p: - return _lib.llama_get_model(ctx) +def llama_get_model(ctx: llama_context_p, /) -> Optional[llama_model_p]: + ... -_lib.llama_get_model.argtypes = [llama_context_p] -_lib.llama_get_model.restype = llama_model_p +llama_get_model = _lib.llama_get_model +llama_get_model.argtypes = [llama_context_p_ctypes] +llama_get_model.restype = llama_model_p_ctypes # LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx); -def llama_n_ctx(ctx: llama_context_p) -> int: - return _lib.llama_n_ctx(ctx) +def llama_n_ctx(ctx: llama_context_p, /) -> int: + ... -_lib.llama_n_ctx.argtypes = [llama_context_p] -_lib.llama_n_ctx.restype = c_uint32 +llama_n_ctx = _lib.llama_n_ctx +llama_n_ctx.argtypes = [llama_context_p_ctypes] +llama_n_ctx.restype = ctypes.c_uint32 # LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx); -def llama_n_batch(ctx: llama_context_p) -> int: - return _lib.llama_n_batch(ctx) +def llama_n_batch(ctx: llama_context_p, /) -> int: + ... -_lib.llama_n_batch.argtypes = [llama_context_p] -_lib.llama_n_batch.restype = c_uint32 +llama_n_batch = _lib.llama_n_batch +llama_n_batch.argtypes = [llama_context_p_ctypes] +llama_n_batch.restype = ctypes.c_uint32 # LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_model * model); -def llama_vocab_type(model: llama_model_p) -> int: - return _lib.llama_vocab_type(model) +def llama_vocab_type(model: llama_model_p, /) -> int: + ... -_lib.llama_vocab_type.argtypes = [llama_model_p] -_lib.llama_vocab_type.restype = c_int +llama_vocab_type = _lib.llama_vocab_type +llama_vocab_type.argtypes = [llama_model_p_ctypes] +llama_vocab_type.restype = ctypes.c_int # LLAMA_API int32_t llama_n_vocab (const struct llama_model * model); -def llama_n_vocab(model: llama_model_p) -> int: - return _lib.llama_n_vocab(model) +def llama_n_vocab(model: llama_model_p, /) -> int: + ... -_lib.llama_n_vocab.argtypes = [llama_model_p] -_lib.llama_n_vocab.restype = c_int32 +llama_n_vocab = _lib.llama_n_vocab +llama_n_vocab.argtypes = [llama_model_p_ctypes] +llama_n_vocab.restype = ctypes.c_int32 # LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model); -def llama_n_ctx_train(model: llama_model_p) -> int: - return _lib.llama_n_ctx_train(model) +def llama_n_ctx_train(model: llama_model_p, /) -> int: + ... -_lib.llama_n_ctx_train.argtypes = [llama_model_p] -_lib.llama_n_ctx_train.restype = c_int32 +llama_n_ctx_train = _lib.llama_n_ctx_train +llama_n_ctx_train.argtypes = [llama_model_p_ctypes] +llama_n_ctx_train.restype = ctypes.c_int32 # LLAMA_API int32_t llama_n_embd (const struct llama_model * model); -def llama_n_embd(model: llama_model_p) -> int: - return _lib.llama_n_embd(model) +def llama_n_embd(model: llama_model_p, /) -> int: + ... -_lib.llama_n_embd.argtypes = [llama_model_p] -_lib.llama_n_embd.restype = c_int32 +llama_n_embd = _lib.llama_n_embd +llama_n_embd.argtypes = [llama_model_p_ctypes] +llama_n_embd.restype = ctypes.c_int32 # // Get the model's RoPE frequency scaling factor # LLAMA_API float llama_rope_freq_scale_train(const struct llama_model * model); -def llama_rope_freq_scale_train(model: llama_model_p) -> float: +def llama_rope_freq_scale_train(model: llama_model_p, /) -> float: """Get the model's RoPE frequency scaling factor""" - return _lib.llama_rope_freq_scale_train(model) + ... -_lib.llama_rope_freq_scale_train.argtypes = [llama_model_p] -_lib.llama_rope_freq_scale_train.restype = c_float +llama_rope_freq_scale_train = _lib.llama_rope_freq_scale_train +llama_rope_freq_scale_train.argtypes = [llama_model_p_ctypes] +llama_rope_freq_scale_train.restype = ctypes.c_float # // Functions to access the model's GGUF metadata scalar values # // - The functions return the length of the string on success, or -1 on failure @@ -954,109 +966,117 @@ _lib.llama_rope_freq_scale_train.restype = c_float # // Get metadata value as a string by key name # LLAMA_API int32_t llama_model_meta_val_str(const struct llama_model * model, const char * key, char * buf, size_t buf_size); def llama_model_meta_val_str( - model: llama_model_p, key: Union[c_char_p, bytes], buf: bytes, buf_size: int + model: llama_model_p, key: Union[ctypes.c_char_p, bytes], buf: bytes, buf_size: int, / ) -> int: """Get metadata value as a string by key name""" - return _lib.llama_model_meta_val_str(model, key, buf, buf_size) + ... -_lib.llama_model_meta_val_str.argtypes = [llama_model_p, c_char_p, c_char_p, c_size_t] -_lib.llama_model_meta_val_str.restype = c_int32 +llama_model_meta_val_str = _lib.llama_model_meta_val_str +llama_model_meta_val_str.argtypes = [llama_model_p_ctypes, ctypes.c_char_p, ctypes.c_char_p, ctypes.c_size_t] +llama_model_meta_val_str.restype = ctypes.c_int32 # // Get the number of metadata key/value pairs # LLAMA_API int32_t llama_model_meta_count(const struct llama_model * model); -def llama_model_meta_count(model: llama_model_p) -> int: +def llama_model_meta_count(model: llama_model_p, /) -> int: """Get the number of metadata key/value pairs""" - return _lib.llama_model_meta_count(model) + ... -_lib.llama_model_meta_count.argtypes = [llama_model_p] -_lib.llama_model_meta_count.restype = c_int32 +llama_model_meta_count = _lib.llama_model_meta_count +llama_model_meta_count.argtypes = [llama_model_p_ctypes] +llama_model_meta_count.restype = ctypes.c_int32 # // Get metadata key name by index # LLAMA_API int32_t llama_model_meta_key_by_index(const struct llama_model * model, int32_t i, char * buf, size_t buf_size); def llama_model_meta_key_by_index( - model: llama_model_p, i: Union[c_int, int], buf: bytes, buf_size: int + model: llama_model_p, i: Union[ctypes.c_int, int], buf: bytes, buf_size: int, / ) -> int: """Get metadata key name by index""" - return _lib.llama_model_meta_key_by_index(model, i, buf, buf_size) + ... -_lib.llama_model_meta_key_by_index.argtypes = [ - llama_model_p, - c_int32, - c_char_p, - c_size_t, +llama_model_meta_key_by_index = _lib.llama_model_meta_key_by_index +llama_model_meta_key_by_index.argtypes = [ + llama_model_p_ctypes, + ctypes.c_int32, + ctypes.c_char_p, + ctypes.c_size_t, ] -_lib.llama_model_meta_key_by_index.restype = c_int32 +llama_model_meta_key_by_index.restype = ctypes.c_int32 # // Get metadata value as a string by index # LLAMA_API int32_t llama_model_meta_val_str_by_index(const struct llama_model * model, int32_t i, char * buf, size_t buf_size); def llama_model_meta_val_str_by_index( - model: llama_model_p, i: Union[c_int, int], buf: bytes, buf_size: int + model: llama_model_p, i: Union[ctypes.c_int, int], buf: bytes, buf_size: int, / ) -> int: """Get metadata value as a string by index""" - return _lib.llama_model_meta_val_str_by_index(model, i, buf, buf_size) + ... -_lib.llama_model_meta_val_str_by_index.argtypes = [ - llama_model_p, - c_int32, - c_char_p, - c_size_t, +llama_model_meta_val_str_by_index = _lib.llama_model_meta_val_str_by_index +llama_model_meta_val_str_by_index.argtypes = [ + llama_model_p_ctypes, + ctypes.c_int32, + ctypes.c_char_p, + ctypes.c_size_t, ] -_lib.llama_model_meta_val_str_by_index.restype = c_int32 +llama_model_meta_val_str_by_index.restype = ctypes.c_int32 # // Get a string describing the model type # LLAMA_API int32_t llama_model_desc(const struct llama_model * model, char * buf, size_t buf_size); def llama_model_desc( - model: llama_model_p, buf: bytes, buf_size: Union[c_size_t, int] + model: llama_model_p, buf: bytes, buf_size: Union[ctypes.c_size_t, int], / ) -> int: """Get a string describing the model type""" - return _lib.llama_model_desc(model, buf, buf_size) + ... -_lib.llama_model_desc.argtypes = [llama_model_p, c_char_p, c_size_t] -_lib.llama_model_desc.restype = c_int32 +llama_model_desc = _lib.llama_model_desc +llama_model_desc.argtypes = [llama_model_p_ctypes, ctypes.c_char_p, ctypes.c_size_t] +llama_model_desc.restype = ctypes.c_int32 # // Returns the total size of all the tensors in the model in bytes # LLAMA_API uint64_t llama_model_size(const struct llama_model * model); -def llama_model_size(model: llama_model_p) -> int: +def llama_model_size(model: llama_model_p, /) -> int: """Returns the total size of all the tensors in the model in bytes""" - return _lib.llama_model_size(model) + ... -_lib.llama_model_size.argtypes = [llama_model_p] -_lib.llama_model_size.restype = ctypes.c_uint64 +llama_model_size = _lib.llama_model_size +llama_model_size.argtypes = [llama_model_p_ctypes] +llama_model_size.restype = ctypes.c_uint64 # // Returns the total number of parameters in the model # LLAMA_API uint64_t llama_model_n_params(const struct llama_model * model); -def llama_model_n_params(model: llama_model_p) -> int: +def llama_model_n_params(model: llama_model_p, /) -> int: """Returns the total number of parameters in the model""" - return _lib.llama_model_n_params(model) + ... -_lib.llama_model_n_params.argtypes = [llama_model_p] -_lib.llama_model_n_params.restype = ctypes.c_uint64 +llama_model_n_params = _lib.llama_model_n_params +llama_model_n_params.argtypes = [llama_model_p_ctypes] +llama_model_n_params.restype = ctypes.c_uint64 # // Get a llama model tensor # LLAMA_API struct ggml_tensor * llama_get_model_tensor(struct llama_model * model, const char * name); def llama_get_model_tensor( - model: llama_model_p, name: Union[c_char_p, bytes] -) -> c_void_p: + model: llama_model_p, name: Union[ctypes.c_char_p, bytes], / +) -> ctypes.c_void_p: """Get a llama model tensor""" - return _lib.llama_get_model_tensor(model, name) + ... -_lib.llama_get_model_tensor.argtypes = [llama_model_p, c_char_p] -_lib.llama_get_model_tensor.restype = c_void_p +llama_get_model_tensor = _lib.llama_get_model_tensor +llama_get_model_tensor.argtypes = [llama_model_p_ctypes, ctypes.c_char_p] +llama_get_model_tensor.restype = ctypes.c_void_p # // Returns 0 on success @@ -1067,18 +1087,20 @@ _lib.llama_get_model_tensor.restype = c_void_p def llama_model_quantize( fname_inp: bytes, fname_out: bytes, - params, # type: POINTER(llama_model_quantize_params) # type: ignore + params, # type: ctypes.POINTER(llama_model_quantize_params) # type: ignore + / ) -> int: """Returns 0 on success""" - return _lib.llama_model_quantize(fname_inp, fname_out, params) + ... -_lib.llama_model_quantize.argtypes = [ - c_char_p, - c_char_p, - POINTER(llama_model_quantize_params), +llama_model_quantize = _lib.llama_model_quantize +llama_model_quantize.argtypes = [ + ctypes.c_char_p, + ctypes.c_char_p, + ctypes.POINTER(llama_model_quantize_params), ] -_lib.llama_model_quantize.restype = c_uint32 +llama_model_quantize.restype = ctypes.c_uint32 # // Apply a LoRA adapter to a loaded model @@ -1096,10 +1118,11 @@ _lib.llama_model_quantize.restype = c_uint32 # "use llama_model_apply_lora_from_file instead"); def llama_apply_lora_from_file( ctx: llama_context_p, - path_lora: Union[c_char_p, bytes], - scale: Union[c_float, float], - path_base_model: Union[c_char_p, bytes], - n_threads: Union[c_int, int], + path_lora: Union[ctypes.c_char_p, bytes], + scale: Union[ctypes.c_float, float], + path_base_model: Union[ctypes.c_char_p, bytes], + n_threads: Union[ctypes.c_int, int], + / ) -> int: """Apply a LoRA adapter to a loaded model path_base_model is the path to a higher quality model to use as a base for @@ -1107,19 +1130,18 @@ def llama_apply_lora_from_file( The model needs to be reloaded before applying a new adapter, otherwise the adapter will be applied on top of the previous one Returns 0 on success""" - return _lib.llama_apply_lora_from_file( - ctx, path_lora, scale, path_base_model, n_threads - ) + ... -_lib.llama_apply_lora_from_file.argtypes = [ - llama_context_p, - c_char_p, - c_float, - c_char_p, - c_int32, +llama_apply_lora_from_file = _lib.llama_apply_lora_from_file +llama_apply_lora_from_file.argtypes = [ + llama_context_p_ctypes, + ctypes.c_char_p, + ctypes.c_float, + ctypes.c_char_p, + ctypes.c_int32, ] -_lib.llama_apply_lora_from_file.restype = c_int32 +llama_apply_lora_from_file.restype = ctypes.c_int32 # LLAMA_API int32_t llama_model_apply_lora_from_file( @@ -1130,24 +1152,24 @@ _lib.llama_apply_lora_from_file.restype = c_int32 # int32_t n_threads); def llama_model_apply_lora_from_file( model: llama_model_p, - path_lora: Union[c_char_p, bytes], - scale: Union[c_float, float], - path_base_model: Union[c_char_p, bytes], - n_threads: Union[c_int, int], + path_lora: Union[ctypes.c_char_p, bytes], + scale: Union[ctypes.c_float, float], + path_base_model: Union[ctypes.c_char_p, bytes], + n_threads: Union[ctypes.c_int, int], + / ) -> int: - return _lib.llama_model_apply_lora_from_file( - model, path_lora, scale, path_base_model, n_threads - ) + ... -_lib.llama_model_apply_lora_from_file.argtypes = [ - llama_model_p, - c_char_p, - c_float, - c_char_p, - c_int32, +llama_model_apply_lora_from_file = _lib.llama_model_apply_lora_from_file +llama_model_apply_lora_from_file.argtypes = [ + llama_model_p_ctypes, + ctypes.c_char_p, + ctypes.c_float, + ctypes.c_char_p, + ctypes.c_int32, ] -_lib.llama_model_apply_lora_from_file.restype = c_int32 +llama_model_apply_lora_from_file.restype = ctypes.c_int32 # // # // KV cache @@ -1160,7 +1182,7 @@ _lib.llama_model_apply_lora_from_file.restype = c_int32 # // May be negative if the cell is not populated. # llama_pos pos; # }; -class llama_kv_cache_view_cell(Structure): +class llama_kv_cache_view_cell(ctypes.Structure): _fields_ = [("pos", llama_pos)] @@ -1196,92 +1218,98 @@ class llama_kv_cache_view_cell(Structure): # // The sequences for each cell. There will be n_max_seq items per cell. # llama_seq_id * cells_sequences; # }; -class llama_kv_cache_view(Structure): +class llama_kv_cache_view(ctypes.Structure): _fields_ = [ - ("n_cells", c_int32), - ("n_max_seq", c_int32), - ("token_count", c_int32), - ("used_cells", c_int32), - ("max_contiguous", c_int32), - ("max_contiguous_idx", c_int32), - ("cells", POINTER(llama_kv_cache_view_cell)), - ("cells_sequences", POINTER(llama_seq_id)), + ("n_cells", ctypes.c_int32), + ("n_max_seq", ctypes.c_int32), + ("token_count", ctypes.c_int32), + ("used_cells", ctypes.c_int32), + ("max_contiguous", ctypes.c_int32), + ("max_contiguous_idx", ctypes.c_int32), + ("cells", ctypes.POINTER(llama_kv_cache_view_cell)), + ("cells_sequences", ctypes.POINTER(llama_seq_id)), ] -llama_kv_cache_view_p = POINTER(llama_kv_cache_view) +llama_kv_cache_view_p = ctypes.POINTER(llama_kv_cache_view) # // Create an empty KV cache view. (use only for debugging purposes) # LLAMA_API struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_context * ctx, int32_t n_max_seq); def llama_kv_cache_view_init( - ctx: llama_context_p, n_max_seq: Union[c_int32, int] + ctx: llama_context_p, n_max_seq: Union[ctypes.c_int32, int], / ) -> llama_kv_cache_view: """Create an empty KV cache view. (use only for debugging purposes)""" - return _lib.llama_kv_cache_view_init(ctx, n_max_seq) + ... -_lib.llama_kv_cache_view_init.argtypes = [llama_context_p, c_int32] -_lib.llama_kv_cache_view_init.restype = llama_kv_cache_view +llama_kv_cache_view_init = _lib.llama_kv_cache_view_init +llama_kv_cache_view_init.argtypes = [llama_context_p_ctypes, ctypes.c_int32] +llama_kv_cache_view_init.restype = llama_kv_cache_view # // Free a KV cache view. (use only for debugging purposes) # LLAMA_API void llama_kv_cache_view_free(struct llama_kv_cache_view * view); -def llama_kv_cache_view_free(view: "ctypes.pointer[llama_kv_cache_view]"): # type: ignore +def llama_kv_cache_view_free(view: "ctypes.pointer[llama_kv_cache_view]", /): # type: ignore """Free a KV cache view. (use only for debugging purposes)""" - return _lib.llama_kv_cache_view_free(view) + ... -_lib.llama_kv_cache_view_free.argtypes = [llama_kv_cache_view_p] -_lib.llama_kv_cache_view_free.restype = None +llama_kv_cache_view_free = _lib.llama_kv_cache_view_free +llama_kv_cache_view_free.argtypes = [llama_kv_cache_view_p] +llama_kv_cache_view_free.restype = None # // Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes) # LLAMA_API void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_kv_cache_view * view); -def llama_kv_cache_view_update(ctx: llama_context_p, view: "ctypes.pointer[llama_kv_cache_view]"): # type: ignore +def llama_kv_cache_view_update(ctx: llama_context_p, view: "ctypes.pointer[llama_kv_cache_view]", /): # type: ignore """Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes)""" - return _lib.llama_kv_cache_view_update(ctx, view) + ... -_lib.llama_kv_cache_view_update.argtypes = [llama_context_p, llama_kv_cache_view_p] -_lib.llama_kv_cache_view_update.restype = None +llama_kv_cache_view_update = _lib.llama_kv_cache_view_update +llama_kv_cache_view_update.argtypes = [llama_context_p_ctypes, llama_kv_cache_view_p] +llama_kv_cache_view_update.restype = None # // Returns the number of tokens in the KV cache (slow, use only for debug) # // If a KV cell has multiple sequences assigned to it, it will be counted multiple times # LLAMA_API int32_t llama_get_kv_cache_token_count(const struct llama_context * ctx); -def llama_get_kv_cache_token_count(ctx: llama_context_p) -> int: +def llama_get_kv_cache_token_count(ctx: llama_context_p, /) -> int: """Returns the number of tokens in the KV cache (slow, use only for debug) If a KV cell has multiple sequences assigned to it, it will be counted multiple times """ - return _lib.llama_get_kv_cache_token_count(ctx) + ... -_lib.llama_get_kv_cache_token_count.argtypes = [llama_context_p] -_lib.llama_get_kv_cache_token_count.restype = c_int32 +llama_get_kv_cache_token_count = _lib.llama_get_kv_cache_token_count +llama_get_kv_cache_token_count.argtypes = [llama_context_p_ctypes] +llama_get_kv_cache_token_count.restype = ctypes.c_int32 # // Returns the number of used KV cells (i.e. have at least one sequence assigned to them) # LLAMA_API int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx); -def llama_get_kv_cache_used_cells(ctx: llama_context_p) -> int: +def llama_get_kv_cache_used_cells(ctx: llama_context_p, /) -> int: """Returns the number of used KV cells (i.e. have at least one sequence assigned to them)""" - return _lib.llama_get_kv_cache_used_cells(ctx) + ... -_lib.llama_get_kv_cache_used_cells.argtypes = [llama_context_p] -_lib.llama_get_kv_cache_used_cells.restype = c_int32 +llama_get_kv_cache_used_cells = _lib.llama_get_kv_cache_used_cells +llama_get_kv_cache_used_cells.argtypes = [llama_context_p_ctypes] +llama_get_kv_cache_used_cells.restype = ctypes.c_int32 # // Clear the KV cache # LLAMA_API void llama_kv_cache_clear( # struct llama_context * ctx); -def llama_kv_cache_clear(ctx: llama_context_p): +def llama_kv_cache_clear(ctx: llama_context_p, /): """Clear the KV cache""" - return _lib.llama_kv_cache_clear(ctx) + ... -_lib.llama_kv_cache_clear.argtypes = [llama_context_p] -_lib.llama_kv_cache_clear.restype = None +llama_kv_cache_clear = _lib.llama_kv_cache_clear +llama_kv_cache_clear.argtypes = [llama_context_p_ctypes] +llama_kv_cache_clear.restype = None # // Removes all tokens that belong to the specified sequence and have positions in [p0, p1) @@ -1298,21 +1326,23 @@ def llama_kv_cache_seq_rm( seq_id: Union[llama_seq_id, int], p0: Union[llama_pos, int], p1: Union[llama_pos, int], + / ): """Removes all tokens that belong to the specified sequence and have positions in [p0, p1) seq_id < 0 : match any sequence p0 < 0 : [0, p1] p1 < 0 : [p0, inf)""" - return _lib.llama_kv_cache_seq_rm(ctx, seq_id, p0, p1) + ... -_lib.llama_kv_cache_seq_rm.argtypes = [ - llama_context_p, +llama_kv_cache_seq_rm = _lib.llama_kv_cache_seq_rm +llama_kv_cache_seq_rm.argtypes = [ + llama_context_p_ctypes, llama_seq_id, llama_pos, llama_pos, ] -_lib.llama_kv_cache_seq_rm.restype = None +llama_kv_cache_seq_rm.restype = None # // Copy all tokens that belong to the specified sequence to another sequence @@ -1331,22 +1361,24 @@ def llama_kv_cache_seq_cp( seq_id_dst: Union[llama_seq_id, int], p0: Union[llama_pos, int], p1: Union[llama_pos, int], + / ): """Copy all tokens that belong to the specified sequence to another sequence Note that this does not allocate extra KV cache memory - it simply assigns the tokens to the new sequence p0 < 0 : [0, p1] p1 < 0 : [p0, inf)""" - return _lib.llama_kv_cache_seq_cp(ctx, seq_id_src, seq_id_dst, p0, p1) + ... -_lib.llama_kv_cache_seq_cp.argtypes = [ - llama_context_p, +llama_kv_cache_seq_cp = _lib.llama_kv_cache_seq_cp +llama_kv_cache_seq_cp.argtypes = [ + llama_context_p_ctypes, llama_seq_id, llama_seq_id, llama_pos, llama_pos, ] -_lib.llama_kv_cache_seq_cp.restype = None +llama_kv_cache_seq_cp.restype = None # // Removes all tokens that do not belong to the specified sequence @@ -1356,13 +1388,15 @@ _lib.llama_kv_cache_seq_cp.restype = None def llama_kv_cache_seq_keep( ctx: llama_context_p, seq_id: Union[llama_seq_id, int], + / ): """Removes all tokens that do not belong to the specified sequence""" - return _lib.llama_kv_cache_seq_keep(ctx, seq_id) + ... -_lib.llama_kv_cache_seq_keep.argtypes = [llama_context_p, llama_seq_id] -_lib.llama_kv_cache_seq_keep.restype = None +llama_kv_cache_seq_keep = _lib.llama_kv_cache_seq_keep +llama_kv_cache_seq_keep.argtypes = [llama_context_p_ctypes, llama_seq_id] +llama_kv_cache_seq_keep.restype = None # // Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1) @@ -1381,22 +1415,24 @@ def llama_kv_cache_seq_shift( p0: Union[llama_pos, int], p1: Union[llama_pos, int], delta: Union[llama_pos, int], + / ): """Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1) If the KV cache is RoPEd, the KV data is updated accordingly p0 < 0 : [0, p1] p1 < 0 : [p0, inf)""" - return _lib.llama_kv_cache_seq_shift(ctx, seq_id, p0, p1, delta) + ... -_lib.llama_kv_cache_seq_shift.argtypes = [ - llama_context_p, +llama_kv_cache_seq_shift = _lib.llama_kv_cache_seq_shift +llama_kv_cache_seq_shift.argtypes = [ + llama_context_p_ctypes, llama_seq_id, llama_pos, llama_pos, llama_pos, ] -_lib.llama_kv_cache_seq_shift.restype = None +llama_kv_cache_seq_shift.restype = None # // Integer division of the positions by factor of `d > 1` @@ -1414,23 +1450,25 @@ def llama_kv_cache_seq_div( seq_id: Union[llama_seq_id, int], p0: Union[llama_pos, int], p1: Union[llama_pos, int], - d: Union[c_int, int], + d: Union[ctypes.c_int, int], + / ): """Integer division of the positions by factor of `d > 1` If the KV cache is RoPEd, the KV data is updated accordingly p0 < 0 : [0, p1] p1 < 0 : [p0, inf)""" - return _lib.llama_kv_cache_seq_div(ctx, seq_id, p0, p1, d) + ... -_lib.llama_kv_cache_seq_div.argtypes = [ - llama_context_p, +llama_kv_cache_seq_div = _lib.llama_kv_cache_seq_div +llama_kv_cache_seq_div.argtypes = [ + llama_context_p_ctypes, llama_seq_id, llama_pos, llama_pos, - c_int, + ctypes.c_int, ] -_lib.llama_kv_cache_seq_div.restype = None +llama_kv_cache_seq_div.restype = None # // # // State / sessions @@ -1440,14 +1478,15 @@ _lib.llama_kv_cache_seq_div.restype = None # Returns the maximum size in bytes of the state (rng, logits, embedding # and kv_cache) - will often be smaller after compacting tokens # LLAMA_API size_t llama_get_state_size(const struct llama_context * ctx); -def llama_get_state_size(ctx: llama_context_p) -> int: +def llama_get_state_size(ctx: llama_context_p, /) -> int: """Returns the maximum size in bytes of the state (rng, logits, embedding and kv_cache) - will often be smaller after compacting tokens""" - return _lib.llama_get_state_size(ctx) + ... -_lib.llama_get_state_size.argtypes = [llama_context_p] -_lib.llama_get_state_size.restype = c_size_t +llama_get_state_size = _lib.llama_get_state_size +llama_get_state_size.argtypes = [llama_context_p_ctypes] +llama_get_state_size.restype = ctypes.c_size_t # Copies the state to the specified destination address. @@ -1457,16 +1496,18 @@ _lib.llama_get_state_size.restype = c_size_t # struct llama_context * ctx, # uint8_t * dst); def llama_copy_state_data( - ctx: llama_context_p, dst # type: Array[c_uint8] + ctx: llama_context_p, dst, # type: Array[ctypes.c_uint8] + / ) -> int: """Copies the state to the specified destination address. Destination needs to have allocated enough memory. Returns the number of bytes copied""" - return _lib.llama_copy_state_data(ctx, dst) + ... -_lib.llama_copy_state_data.argtypes = [llama_context_p, c_uint8_p] -_lib.llama_copy_state_data.restype = c_size_t +llama_copy_state_data = _lib.llama_copy_state_data +llama_copy_state_data.argtypes = [llama_context_p_ctypes, c_uint8_p] +llama_copy_state_data.restype = ctypes.c_size_t # Set the state reading from the specified address @@ -1475,14 +1516,16 @@ _lib.llama_copy_state_data.restype = c_size_t # struct llama_context * ctx, # uint8_t * src); def llama_set_state_data( - ctx: llama_context_p, src # type: Array[c_uint8] + ctx: llama_context_p, src, # type: Array[ctypes.c_uint8] + / ) -> int: """Set the state reading from the specified address""" - return _lib.llama_set_state_data(ctx, src) + ... -_lib.llama_set_state_data.argtypes = [llama_context_p, c_uint8_p] -_lib.llama_set_state_data.restype = c_size_t +llama_set_state_data = _lib.llama_set_state_data +llama_set_state_data.argtypes = [llama_context_p_ctypes, c_uint8_p] +llama_set_state_data.restype = ctypes.c_size_t # Save/load session file @@ -1496,22 +1539,22 @@ def llama_load_session_file( ctx: llama_context_p, path_session: bytes, tokens_out, # type: Array[llama_token] - n_token_capacity: Union[c_size_t, int], - n_token_count_out, # type: _Pointer[c_size_t] + n_token_capacity: Union[ctypes.c_size_t, int], + n_token_count_out, # type: _Pointer[ctypes.c_size_t] + / ) -> int: - return _lib.llama_load_session_file( - ctx, path_session, tokens_out, n_token_capacity, n_token_count_out - ) + ... -_lib.llama_load_session_file.argtypes = [ - llama_context_p, - c_char_p, +llama_load_session_file = _lib.llama_load_session_file +llama_load_session_file.argtypes = [ + llama_context_p_ctypes, + ctypes.c_char_p, llama_token_p, - c_size_t, + ctypes.c_size_t, c_size_t_p, ] -_lib.llama_load_session_file.restype = c_size_t +llama_load_session_file.restype = ctypes.c_size_t # LLAMA_API bool llama_save_session_file( @@ -1523,18 +1566,20 @@ def llama_save_session_file( ctx: llama_context_p, path_session: bytes, tokens, # type: Array[llama_token] - n_token_count: Union[c_size_t, int], + n_token_count: Union[ctypes.c_size_t, int], + / ) -> int: - return _lib.llama_save_session_file(ctx, path_session, tokens, n_token_count) + ... -_lib.llama_save_session_file.argtypes = [ - llama_context_p, - c_char_p, +llama_save_session_file = _lib.llama_save_session_file +llama_save_session_file.argtypes = [ + llama_context_p_ctypes, + ctypes.c_char_p, llama_token_p, - c_size_t, + ctypes.c_size_t, ] -_lib.llama_save_session_file.restype = c_size_t +llama_save_session_file.restype = ctypes.c_size_t # // # // Decoding @@ -1555,19 +1600,21 @@ _lib.llama_save_session_file.restype = c_size_t def llama_eval( ctx: llama_context_p, tokens, # type: Array[llama_token] - n_tokens: Union[c_int, int], - n_past: Union[c_int, int], + n_tokens: Union[ctypes.c_int, int], + n_past: Union[ctypes.c_int, int], + / ) -> int: """Run the llama inference to obtain the logits and probabilities for the next token(s). tokens + n_tokens is the provided batch of new tokens to process n_past is the number of tokens to use from previous eval calls Returns 0 on success DEPRECATED: use llama_decode() instead""" - return _lib.llama_eval(ctx, tokens, n_tokens, n_past) + ... -_lib.llama_eval.argtypes = [llama_context_p, llama_token_p, c_int32, c_int32] -_lib.llama_eval.restype = c_int +llama_eval = _lib.llama_eval +llama_eval.argtypes = [llama_context_p_ctypes, llama_token_p, ctypes.c_int32, ctypes.c_int32] +llama_eval.restype = ctypes.c_int # // Same as llama_eval, but use float matrix input directly. @@ -1580,17 +1627,19 @@ _lib.llama_eval.restype = c_int # "use llama_decode() instead"); def llama_eval_embd( ctx: llama_context_p, - embd, # type: Array[c_float] - n_tokens: Union[c_int, int], - n_past: Union[c_int, int], + embd, # type: Array[ctypes.c_float] + n_tokens: Union[ctypes.c_int, int], + n_past: Union[ctypes.c_int, int], + / ) -> int: """Same as llama_eval, but use float matrix input directly. DEPRECATED: use llama_decode() instead""" - return _lib.llama_eval_embd(ctx, embd, n_tokens, n_past) + ... -_lib.llama_eval_embd.argtypes = [llama_context_p, c_float_p, c_int32, c_int32] -_lib.llama_eval_embd.restype = c_int +llama_eval_embd = _lib.llama_eval_embd +llama_eval_embd.argtypes = [llama_context_p_ctypes, c_float_p, ctypes.c_int32, ctypes.c_int32] +llama_eval_embd.restype = ctypes.c_int # // Return batch for single sequence of tokens starting at pos_0 @@ -1604,24 +1653,26 @@ _lib.llama_eval_embd.restype = c_int # llama_seq_id seq_id); def llama_batch_get_one( tokens, # type: Array[llama_token] - n_tokens: Union[c_int, int], + n_tokens: Union[ctypes.c_int, int], pos_0: Union[llama_pos, int], seq_id: llama_seq_id, + / ) -> llama_batch: """Return batch for single sequence of tokens starting at pos_0 NOTE: this is a helper function to facilitate transition to the new batch API - avoid using it """ - return _lib.llama_batch_get_one(tokens, n_tokens, pos_0, seq_id) + ... -_lib.llama_batch_get_one.argtypes = [ +llama_batch_get_one = _lib.llama_batch_get_one +llama_batch_get_one.argtypes = [ llama_token_p, - c_int, + ctypes.c_int, llama_pos, llama_seq_id, ] -_lib.llama_batch_get_one.restype = llama_batch +llama_batch_get_one.restype = llama_batch # // Allocates a batch of tokens on the heap that can hold a maximum of n_tokens @@ -1636,9 +1687,10 @@ _lib.llama_batch_get_one.restype = llama_batch # int32_t embd, # int32_t n_seq_max); def llama_batch_init( - n_tokens: Union[c_int32, int], - embd: Union[c_int32, int], - n_seq_max: Union[c_int32, int], + n_tokens: Union[ctypes.c_int32, int], + embd: Union[ctypes.c_int32, int], + n_seq_max: Union[ctypes.c_int32, int], + / ) -> llama_batch: """Allocates a batch of tokens on the heap that can hold a maximum of n_tokens Each token can be assigned up to n_seq_max sequence ids @@ -1647,22 +1699,24 @@ def llama_batch_init( Otherwise, llama_batch.token will be allocated to store n_tokens llama_token The rest of the llama_batch members are allocated with size n_tokens All members are left uninitialized""" - return _lib.llama_batch_init(n_tokens, embd, n_seq_max) + ... -_lib.llama_batch_init.argtypes = [c_int32, c_int32, c_int32] -_lib.llama_batch_init.restype = llama_batch +llama_batch_init = _lib.llama_batch_init +llama_batch_init.argtypes = [ctypes.c_int32, ctypes.c_int32, ctypes.c_int32] +llama_batch_init.restype = llama_batch # // Frees a batch of tokens allocated with llama_batch_init() # LLAMA_API void llama_batch_free(struct llama_batch batch); -def llama_batch_free(batch: llama_batch): +def llama_batch_free(batch: llama_batch, /): """Frees a batch of tokens allocated with llama_batch_init()""" - return _lib.llama_batch_free(batch) + ... -_lib.llama_batch_free.argtypes = [llama_batch] -_lib.llama_batch_free.restype = None +llama_batch_free = _lib.llama_batch_free +llama_batch_free.argtypes = [llama_batch] +llama_batch_free.restype = None # // Positive return values does not mean a fatal error, but rather a warning. @@ -1672,16 +1726,17 @@ _lib.llama_batch_free.restype = None # LLAMA_API int32_t llama_decode( # struct llama_context * ctx, # struct llama_batch batch); -def llama_decode(ctx: llama_context_p, batch: llama_batch) -> int: +def llama_decode(ctx: llama_context_p, batch: llama_batch, /) -> int: """Positive return values does not mean a fatal error, but rather a warning. 0 - success 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context) < 0 - error""" - return _lib.llama_decode(ctx, batch) + ... -_lib.llama_decode.argtypes = [llama_context_p, llama_batch] -_lib.llama_decode.restype = c_int32 +llama_decode = _lib.llama_decode +llama_decode.argtypes = [llama_context_p_ctypes, llama_batch] +llama_decode.restype = ctypes.c_int32 # // Set the number of threads used for decoding @@ -1690,18 +1745,20 @@ _lib.llama_decode.restype = c_int32 # LLAMA_API void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_t n_threads_batch); def llama_set_n_threads( ctx: llama_context_p, - n_threads: Union[c_uint32, int], - n_threads_batch: Union[c_uint32, int], + n_threads: Union[ctypes.c_uint32, int], + n_threads_batch: Union[ctypes.c_uint32, int], + / ): """Set the number of threads used for decoding n_threads is the number of threads used for generation (single token) n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens) """ - return _lib.llama_set_n_threads(ctx, n_threads, n_threads_batch) + ... -_lib.llama_set_n_threads.argtypes = [llama_context_p, c_uint32, c_uint32] -_lib.llama_set_n_threads.restype = None +llama_set_n_threads = _lib.llama_set_n_threads +llama_set_n_threads.argtypes = [llama_context_p_ctypes, ctypes.c_uint32, ctypes.c_uint32] +llama_set_n_threads.restype = None # // Token logits obtained from the last call to llama_eval() @@ -1712,62 +1769,68 @@ _lib.llama_set_n_threads.restype = None # LLAMA_API float * llama_get_logits(struct llama_context * ctx); def llama_get_logits( ctx: llama_context_p, + / ): # type: (...) -> Array[float] # type: ignore """Token logits obtained from the last call to llama_eval() The logits for the last token are stored in the last row Logits for which llama_batch.logits[i] == 0 are undefined Rows: n_tokens provided with llama_batch Cols: n_vocab""" - return _lib.llama_get_logits(ctx) + ... -_lib.llama_get_logits.argtypes = [llama_context_p] -_lib.llama_get_logits.restype = c_float_p +llama_get_logits = _lib.llama_get_logits +llama_get_logits.argtypes = [llama_context_p_ctypes] +llama_get_logits.restype = c_float_p # // Logits for the ith token. Equivalent to: # // llama_get_logits(ctx) + i*n_vocab # LLAMA_API float * llama_get_logits_ith(struct llama_context * ctx, int32_t i); def llama_get_logits_ith( - ctx: llama_context_p, i: Union[c_int32, int] + ctx: llama_context_p, i: Union[ctypes.c_int32, int] + , / ): # type: (...) -> Array[float] # type: ignore """Logits for the ith token. Equivalent to: llama_get_logits(ctx) + i*n_vocab""" - return _lib.llama_get_logits_ith(ctx, i) + ... -_lib.llama_get_logits_ith.argtypes = [llama_context_p, c_int32] -_lib.llama_get_logits_ith.restype = c_float_p +llama_get_logits_ith = _lib.llama_get_logits_ith +llama_get_logits_ith.argtypes = [llama_context_p_ctypes, ctypes.c_int32] +llama_get_logits_ith.restype = c_float_p # Get the embeddings for the input # shape: [n_embd] (1-dimensional) # LLAMA_API float * llama_get_embeddings(struct llama_context * ctx); def llama_get_embeddings( - ctx: llama_context_p, + ctx: llama_context_p, / ): # type: (...) -> Array[float] # type: ignore """Get the embeddings for the input shape: [n_embd] (1-dimensional)""" - return _lib.llama_get_embeddings(ctx) + ... -_lib.llama_get_embeddings.argtypes = [llama_context_p] -_lib.llama_get_embeddings.restype = c_float_p +llama_get_embeddings = _lib.llama_get_embeddings +llama_get_embeddings.argtypes = [llama_context_p_ctypes] +llama_get_embeddings.restype = c_float_p # // Get the embeddings for the ith sequence # // llama_get_embeddings(ctx) + i*n_embd # LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i); def llama_get_embeddings_ith( - ctx: llama_context_p, i: Union[c_int32, int] + ctx: llama_context_p, i: Union[ctypes.c_int32, int], / ): # type: (...) -> Array[float] # type: ignore """Get the embeddings for the ith sequence llama_get_embeddings(ctx) + i*n_embd""" - return _lib.llama_get_embeddings_ith(ctx, i) + ... -_lib.llama_get_embeddings_ith.argtypes = [llama_context_p, c_int32] -_lib.llama_get_embeddings_ith.restype = c_float_p +llama_get_embeddings_ith = _lib.llama_get_embeddings_ith +llama_get_embeddings_ith.argtypes = [llama_context_p_ctypes, ctypes.c_int32] +llama_get_embeddings_ith.restype = c_float_p # // @@ -1776,125 +1839,137 @@ _lib.llama_get_embeddings_ith.restype = c_float_p # LLAMA_API const char * llama_token_get_text(const struct llama_model * model, llama_token token); -def llama_token_get_text(model: llama_model_p, token: Union[llama_token, int]) -> bytes: - return _lib.llama_token_get_text(model, token) +def llama_token_get_text(model: llama_model_p, token: Union[llama_token, int], /) -> bytes: + ... -_lib.llama_token_get_text.argtypes = [llama_model_p, llama_token] -_lib.llama_token_get_text.restype = c_char_p +llama_token_get_text = _lib.llama_token_get_text +llama_token_get_text.argtypes = [llama_model_p_ctypes, llama_token] +llama_token_get_text.restype = ctypes.c_char_p # LLAMA_API float llama_token_get_score(const struct llama_model * model, llama_token token); def llama_token_get_score( - model: llama_model_p, token: Union[llama_token, int] + model: llama_model_p, token: Union[llama_token, int], / ) -> float: - return _lib.llama_token_get_score(model, token) + ... -_lib.llama_token_get_score.argtypes = [llama_model_p, llama_token] -_lib.llama_token_get_score.restype = c_float +llama_token_get_score = _lib.llama_token_get_score +llama_token_get_score.argtypes = [llama_model_p_ctypes, llama_token] +llama_token_get_score.restype = ctypes.c_float # LLAMA_API enum llama_token_type llama_token_get_type(const struct llama_model * model, llama_token token); -def llama_token_get_type(model: llama_model_p, token: Union[llama_token, int]) -> int: - return _lib.llama_token_get_type(model, token) +def llama_token_get_type(model: llama_model_p, token: Union[llama_token, int], /) -> int: + ... -_lib.llama_token_get_type.argtypes = [llama_model_p, llama_token] -_lib.llama_token_get_type.restype = ctypes.c_int +llama_token_get_type = _lib.llama_token_get_type +llama_token_get_type.argtypes = [llama_model_p_ctypes, llama_token] +llama_token_get_type.restype = ctypes.c_int # // Special tokens # LLAMA_API llama_token llama_token_bos(const struct llama_model * model); // beginning-of-sentence -def llama_token_bos(model: llama_model_p) -> int: +def llama_token_bos(model: llama_model_p, /) -> int: """beginning-of-sentence""" - return _lib.llama_token_bos(model) + ... -_lib.llama_token_bos.argtypes = [llama_model_p] -_lib.llama_token_bos.restype = llama_token +llama_token_bos = _lib.llama_token_bos +llama_token_bos.argtypes = [llama_model_p_ctypes] +llama_token_bos.restype = llama_token # LLAMA_API llama_token llama_token_eos(const struct llama_model * model); // end-of-sentence -def llama_token_eos(model: llama_model_p) -> int: +def llama_token_eos(model: llama_model_p, /) -> int: """end-of-sentence""" - return _lib.llama_token_eos(model) + ... -_lib.llama_token_eos.argtypes = [llama_model_p] -_lib.llama_token_eos.restype = llama_token +llama_token_eos = _lib.llama_token_eos +llama_token_eos.argtypes = [llama_model_p_ctypes] +llama_token_eos.restype = llama_token # LLAMA_API llama_token llama_token_nl (const struct llama_model * model); // next-line -def llama_token_nl(model: llama_model_p) -> int: +def llama_token_nl(model: llama_model_p, /) -> int: """next-line""" - return _lib.llama_token_nl(model) + ... -_lib.llama_token_nl.argtypes = [llama_model_p] -_lib.llama_token_nl.restype = llama_token +llama_token_nl = _lib.llama_token_nl +llama_token_nl.argtypes = [llama_model_p_ctypes] +llama_token_nl.restype = llama_token # // Returns -1 if unknown, 1 for true or 0 for false. # LLAMA_API int32_t llama_add_bos_token(const struct llama_model * model); -def llama_add_bos_token(model: llama_model_p) -> int: +def llama_add_bos_token(model: llama_model_p, /) -> int: """Returns -1 if unknown, 1 for true or 0 for false.""" - return _lib.llama_add_bos_token(model) + ... -_lib.llama_add_bos_token.argtypes = [llama_model_p] -_lib.llama_add_bos_token.restype = c_int32 +llama_add_bos_token = _lib.llama_add_bos_token +llama_add_bos_token.argtypes = [llama_model_p_ctypes] +llama_add_bos_token.restype = ctypes.c_int32 # // Returns -1 if unknown, 1 for true or 0 for false. # LLAMA_API int32_t llama_add_eos_token(const struct llama_model * model); -def llama_add_eos_token(model: llama_model_p) -> int: +def llama_add_eos_token(model: llama_model_p, /) -> int: """Returns -1 if unknown, 1 for true or 0 for false.""" - return _lib.llama_add_eos_token(model) + ... -_lib.llama_add_eos_token.argtypes = [llama_model_p] -_lib.llama_add_eos_token.restype = c_int32 +llama_add_eos_token = _lib.llama_add_eos_token +llama_add_eos_token.argtypes = [llama_model_p_ctypes] +llama_add_eos_token.restype = ctypes.c_int32 # // codellama infill tokens # LLAMA_API llama_token llama_token_prefix(const struct llama_model * model); // Beginning of infill prefix def llama_token_prefix(model: llama_model_p) -> int: """codellama infill tokens""" - return _lib.llama_token_prefix(model) + ... -_lib.llama_token_prefix.argtypes = [llama_model_p] -_lib.llama_token_prefix.restype = llama_token +llama_token_prefix = _lib.llama_token_prefix +llama_token_prefix.argtypes = [llama_model_p_ctypes] +llama_token_prefix.restype = llama_token # LLAMA_API llama_token llama_token_middle(const struct llama_model * model); // Beginning of infill middle -def llama_token_middle(model: llama_model_p) -> int: - return _lib.llama_token_middle(model) +def llama_token_middle(model: llama_model_p, /) -> int: + ... -_lib.llama_token_middle.argtypes = [llama_model_p] -_lib.llama_token_middle.restype = llama_token +llama_token_middle = _lib.llama_token_middle +llama_token_middle.argtypes = [llama_model_p_ctypes] +llama_token_middle.restype = llama_token # LLAMA_API llama_token llama_token_suffix(const struct llama_model * model); // Beginning of infill suffix -def llama_token_suffix(model: llama_model_p) -> int: - return _lib.llama_token_suffix(model) +def llama_token_suffix(model: llama_model_p, /) -> int: + ... -_lib.llama_token_suffix.argtypes = [llama_model_p] -_lib.llama_token_suffix.restype = llama_token +llama_token_suffix = _lib.llama_token_suffix +llama_token_suffix.argtypes = [llama_model_p_ctypes] +llama_token_suffix.restype = llama_token # LLAMA_API llama_token llama_token_eot (const struct llama_model * model); // End of infill middle -def llama_token_eot(model: llama_model_p) -> int: - return _lib.llama_token_eot(model) +def llama_token_eot(model: llama_model_p, /) -> int: + ... -_lib.llama_token_eot.argtypes = [llama_model_p] -_lib.llama_token_eot.restype = llama_token +llama_token_eot = _lib.llama_token_eot +llama_token_eot.argtypes = [llama_model_p_ctypes] +llama_token_eot.restype = llama_token # // @@ -1919,28 +1994,28 @@ _lib.llama_token_eot.restype = llama_token def llama_tokenize( model: llama_model_p, text: bytes, - text_len: Union[c_int, int], + text_len: Union[ctypes.c_int, int], tokens, # type: Array[llama_token] - n_max_tokens: Union[c_int, int], - add_bos: Union[c_bool, bool], - special: Union[c_bool, bool], + n_max_tokens: Union[ctypes.c_int, int], + add_bos: Union[ctypes.c_bool, bool], + special: Union[ctypes.c_bool, bool], + / ) -> int: """Convert the provided text into tokens.""" - return _lib.llama_tokenize( - model, text, text_len, tokens, n_max_tokens, add_bos, special - ) + ... -_lib.llama_tokenize.argtypes = [ - llama_model_p, - c_char_p, - c_int32, +llama_tokenize = _lib.llama_tokenize +llama_tokenize.argtypes = [ + llama_model_p_ctypes, + ctypes.c_char_p, + ctypes.c_int32, llama_token_p, - c_int32, - c_bool, - c_bool, + ctypes.c_int32, + ctypes.c_bool, + ctypes.c_bool, ] -_lib.llama_tokenize.restype = c_int32 +llama_tokenize.restype = ctypes.c_int32 # // Token Id -> Piece. @@ -1955,19 +2030,21 @@ _lib.llama_tokenize.restype = c_int32 def llama_token_to_piece( model: llama_model_p, token: Union[llama_token, int], - buf: Union[c_char_p, bytes], - length: Union[c_int, int], + buf: Union[ctypes.c_char_p, bytes], + length: Union[ctypes.c_int, int], + / ) -> int: """Token Id -> Piece. Uses the vocabulary in the provided context. Does not write null terminator to the buffer. User code is responsible to remove the leading whitespace of the first non-BOS token when decoding multiple tokens. """ - return _lib.llama_token_to_piece(model, token, buf, length) + ... -_lib.llama_token_to_piece.argtypes = [llama_model_p, llama_token, c_char_p, c_int32] -_lib.llama_token_to_piece.restype = c_int32 +llama_token_to_piece = _lib.llama_token_to_piece +llama_token_to_piece.argtypes = [llama_model_p_ctypes, llama_token, ctypes.c_char_p, ctypes.c_int32] +llama_token_to_piece.restype = ctypes.c_int32 # /// Apply chat template. Inspired by hf apply_chat_template() on python. @@ -1991,23 +2068,20 @@ _lib.llama_token_to_piece.restype = c_int32 def llama_chat_apply_template( model: llama_model_p, tmpl: bytes, - chat: "ctypes._Pointer[llama_chat_message]", + chat: "ctypes._Pointer[llama_chat_message]", # type: ignore n_msg: int, + / ) -> int: - return _lib.llama_chat_apply_template( - model, - tmpl, - chat, - n_msg - ) + ... -_lib.llama_chat_apply_template.argtypes = [ +llama_chat_apply_template = _lib.llama_chat_apply_template +llama_chat_apply_template.argtypes = [ ctypes.c_void_p, ctypes.c_char_p, ctypes.POINTER(llama_chat_message), ctypes.c_size_t ] -_lib.llama_chat_apply_template.restype = ctypes.c_int32 +llama_chat_apply_template.restype = ctypes.c_int32 @@ -2022,39 +2096,43 @@ _lib.llama_chat_apply_template.restype = ctypes.c_int32 # size_t start_rule_index); def llama_grammar_init( rules, # type: Array[llama_grammar_element_p] # type: ignore - n_rules: Union[c_size_t, int], - start_rule_index: Union[c_size_t, int], + n_rules: Union[ctypes.c_size_t, int], + start_rule_index: Union[ctypes.c_size_t, int], + / ) -> llama_grammar_p: """Initialize a grammar from a set of rules.""" - return _lib.llama_grammar_init(rules, n_rules, start_rule_index) + ... -_lib.llama_grammar_init.argtypes = [ - POINTER(llama_grammar_element_p), - c_size_t, - c_size_t, +llama_grammar_init = _lib.llama_grammar_init +llama_grammar_init.argtypes = [ + ctypes.POINTER(llama_grammar_element_p), + ctypes.c_size_t, + ctypes.c_size_t, ] -_lib.llama_grammar_init.restype = llama_grammar_p +llama_grammar_init.restype = llama_grammar_p # LLAMA_API void llama_grammar_free(struct llama_grammar * grammar); -def llama_grammar_free(grammar: llama_grammar_p): +def llama_grammar_free(grammar: llama_grammar_p, /): """Free a grammar.""" - return _lib.llama_grammar_free(grammar) + ... -_lib.llama_grammar_free.argtypes = [llama_grammar_p] -_lib.llama_grammar_free.restype = None +llama_grammar_free = _lib.llama_grammar_free +llama_grammar_free.argtypes = [llama_grammar_p] +llama_grammar_free.restype = None # LLAMA_API struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar); -def llama_grammar_copy(grammar: llama_grammar_p) -> llama_grammar_p: +def llama_grammar_copy(grammar: llama_grammar_p, /) -> llama_grammar_p: """Copy a grammar.""" - return _lib.llama_grammar_copy(grammar) + ... -_lib.llama_grammar_copy.argtypes = [llama_grammar_p] -_lib.llama_grammar_copy.restype = llama_grammar_p +llama_grammar_copy = _lib.llama_grammar_copy +llama_grammar_copy.argtypes = [llama_grammar_p] +llama_grammar_copy.restype = llama_grammar_p # // # // Sampling functions @@ -2063,13 +2141,14 @@ _lib.llama_grammar_copy.restype = llama_grammar_p # // Sets the current rng seed. # LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed); -def llama_set_rng_seed(ctx: llama_context_p, seed: Union[c_uint32, int]): +def llama_set_rng_seed(ctx: llama_context_p, seed: Union[ctypes.c_uint32, int], /): """Sets the current rng seed.""" - return _lib.llama_set_rng_seed(ctx, seed) + ... -_lib.llama_set_rng_seed.argtypes = [llama_context_p, c_uint32] -_lib.llama_set_rng_seed.restype = None +llama_set_rng_seed = _lib.llama_set_rng_seed +llama_set_rng_seed.argtypes = [llama_context_p_ctypes, ctypes.c_uint32] +llama_set_rng_seed.restype = None # /// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. @@ -2086,35 +2165,29 @@ def llama_sample_repetition_penalties( ctx: llama_context_p, candidates, # type: _Pointer[llama_token_data_array] last_tokens_data, # type: Array[llama_token] - penalty_last_n: Union[c_size_t, int], - penalty_repeat: Union[c_float, float], - penalty_freq: Union[c_float, float], - penalty_present: Union[c_float, float], + penalty_last_n: Union[ctypes.c_size_t, int], + penalty_repeat: Union[ctypes.c_float, float], + penalty_freq: Union[ctypes.c_float, float], + penalty_present: Union[ctypes.c_float, float], + / ): """Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details. """ - return _lib.llama_sample_repetition_penalties( - ctx, - candidates, - last_tokens_data, - penalty_last_n, - penalty_repeat, - penalty_freq, - penalty_present, - ) + ... -_lib.llama_sample_repetition_penalties.argtypes = [ - llama_context_p, +llama_sample_repetition_penalties = _lib.llama_sample_repetition_penalties +llama_sample_repetition_penalties.argtypes = [ + llama_context_p_ctypes, llama_token_data_array_p, llama_token_p, - c_size_t, - c_float, - c_float, - c_float, + ctypes.c_size_t, + ctypes.c_float, + ctypes.c_float, + ctypes.c_float, ] -_lib.llama_sample_repetition_penalties.restype = None +llama_sample_repetition_penalties.restype = None # /// @details Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806 @@ -2128,21 +2201,23 @@ _lib.llama_sample_repetition_penalties.restype = None # float scale); def llama_sample_apply_guidance( ctx: llama_context_p, - logits, # type: _Pointer[c_float] - logits_guidance, # type: _Pointer[c_float] - scale: Union[c_float, float], + logits, # type: _Pointer[ctypes.c_float] + logits_guidance, # type: _Pointer[ctypes.c_float] + scale: Union[ctypes.c_float, float], + / ): """Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806""" - return _lib.llama_sample_apply_guidance(ctx, logits, logits_guidance, scale) + ... -_lib.llama_sample_apply_guidance.argtypes = [ - llama_context_p, +llama_sample_apply_guidance = _lib.llama_sample_apply_guidance +llama_sample_apply_guidance.argtypes = [ + llama_context_p_ctypes, c_float_p, c_float_p, - c_float, + ctypes.c_float, ] -_lib.llama_sample_apply_guidance.restype = None +llama_sample_apply_guidance.restype = None # LLAMA_API DEPRECATED(void llama_sample_classifier_free_guidance( @@ -2155,21 +2230,21 @@ def llama_sample_classifier_free_guidance( ctx: llama_context_p, candidates, # type: _Pointer[llama_token_data_array] guidance_ctx: llama_context_p, - scale: Union[c_float, float], + scale: Union[ctypes.c_float, float], + / ): """Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806""" - return _lib.llama_sample_classifier_free_guidance( - ctx, candidates, guidance_ctx, scale - ) + ... -_lib.llama_sample_classifier_free_guidance.argtypes = [ - llama_context_p, +llama_sample_classifier_free_guidance = _lib.llama_sample_classifier_free_guidance +llama_sample_classifier_free_guidance.argtypes = [ + llama_context_p_ctypes, llama_token_data_array_p, - llama_context_p, - c_float, + llama_context_p_ctypes, + ctypes.c_float, ] -_lib.llama_sample_classifier_free_guidance.restype = None +llama_sample_classifier_free_guidance.restype = None # /// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits. @@ -2177,17 +2252,19 @@ _lib.llama_sample_classifier_free_guidance.restype = None # struct llama_context * ctx, # llama_token_data_array * candidates); def llama_sample_softmax( - ctx: llama_context_p, candidates # type: _Pointer[llama_token_data] + ctx: llama_context_p, candidates, # type: _Pointer[llama_token_data] + / ): """Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.""" - return _lib.llama_sample_softmax(ctx, candidates) + ... -_lib.llama_sample_softmax.argtypes = [ - llama_context_p, +llama_sample_softmax = _lib.llama_sample_softmax +llama_sample_softmax.argtypes = [ + llama_context_p_ctypes, llama_token_data_array_p, ] -_lib.llama_sample_softmax.restype = None +llama_sample_softmax.restype = None # /// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 @@ -2199,20 +2276,22 @@ _lib.llama_sample_softmax.restype = None def llama_sample_top_k( ctx: llama_context_p, candidates, # type: _Pointer[llama_token_data_array] - k: Union[c_int, int], - min_keep: Union[c_size_t, int], + k: Union[ctypes.c_int, int], + min_keep: Union[ctypes.c_size_t, int], + / ): """Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751""" - return _lib.llama_sample_top_k(ctx, candidates, k, min_keep) + ... -_lib.llama_sample_top_k.argtypes = [ - llama_context_p, +llama_sample_top_k = _lib.llama_sample_top_k +llama_sample_top_k.argtypes = [ + llama_context_p_ctypes, llama_token_data_array_p, - c_int32, - c_size_t, + ctypes.c_int32, + ctypes.c_size_t, ] -_lib.llama_sample_top_k.restype = None +llama_sample_top_k.restype = None # /// @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 @@ -2224,20 +2303,22 @@ _lib.llama_sample_top_k.restype = None def llama_sample_top_p( ctx: llama_context_p, candidates, # type: _Pointer[llama_token_data_array] - p: Union[c_float, float], - min_keep: Union[c_size_t, int], + p: Union[ctypes.c_float, float], + min_keep: Union[ctypes.c_size_t, int], + / ): """Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751""" - return _lib.llama_sample_top_p(ctx, candidates, p, min_keep) + ... -_lib.llama_sample_top_p.argtypes = [ - llama_context_p, +llama_sample_top_p = _lib.llama_sample_top_p +llama_sample_top_p.argtypes = [ + llama_context_p_ctypes, llama_token_data_array_p, - c_float, - c_size_t, + ctypes.c_float, + ctypes.c_size_t, ] -_lib.llama_sample_top_p.restype = None +llama_sample_top_p.restype = None # /// @details Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841 @@ -2249,20 +2330,22 @@ _lib.llama_sample_top_p.restype = None def llama_sample_min_p( ctx: llama_context_p, candidates, # type: _Pointer[llama_token_data_array] - p: Union[c_float, float], - min_keep: Union[c_size_t, int], + p: Union[ctypes.c_float, float], + min_keep: Union[ctypes.c_size_t, int], + / ): """Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841""" - return _lib.llama_sample_min_p(ctx, candidates, p, min_keep) + ... -_lib.llama_sample_min_p.argtypes = [ - llama_context_p, +llama_sample_min_p = _lib.llama_sample_min_p +llama_sample_min_p.argtypes = [ + llama_context_p_ctypes, llama_token_data_array_p, - c_float, - c_size_t, + ctypes.c_float, + ctypes.c_size_t, ] -_lib.llama_sample_min_p.restype = None +llama_sample_min_p.restype = None # /// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/. @@ -2274,20 +2357,22 @@ _lib.llama_sample_min_p.restype = None def llama_sample_tail_free( ctx: llama_context_p, candidates, # type: _Pointer[llama_token_data_array] - z: Union[c_float, float], - min_keep: Union[c_size_t, int], + z: Union[ctypes.c_float, float], + min_keep: Union[ctypes.c_size_t, int], + / ): """Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.""" - return _lib.llama_sample_tail_free(ctx, candidates, z, min_keep) + ... -_lib.llama_sample_tail_free.argtypes = [ - llama_context_p, +llama_sample_tail_free = _lib.llama_sample_tail_free +llama_sample_tail_free.argtypes = [ + llama_context_p_ctypes, llama_token_data_array_p, - c_float, - c_size_t, + ctypes.c_float, + ctypes.c_size_t, ] -_lib.llama_sample_tail_free.restype = None +llama_sample_tail_free.restype = None # /// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666. @@ -2299,20 +2384,22 @@ _lib.llama_sample_tail_free.restype = None def llama_sample_typical( ctx: llama_context_p, candidates, # type: _Pointer[llama_token_data_array] - p: Union[c_float, float], - min_keep: Union[c_size_t, int], + p: Union[ctypes.c_float, float], + min_keep: Union[ctypes.c_size_t, int], + / ): """Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.""" - return _lib.llama_sample_typical(ctx, candidates, p, min_keep) + ... -_lib.llama_sample_typical.argtypes = [ - llama_context_p, +llama_sample_typical = _lib.llama_sample_typical +llama_sample_typical.argtypes = [ + llama_context_p_ctypes, llama_token_data_array_p, - c_float, - c_size_t, + ctypes.c_float, + ctypes.c_size_t, ] -_lib.llama_sample_typical.restype = None +llama_sample_typical.restype = None # /// @details Dynamic temperature implementation described in the paper https://arxiv.org/abs/2309.02772. @@ -2325,22 +2412,24 @@ _lib.llama_sample_typical.restype = None def llama_sample_entropy( ctx: llama_context_p, candidates, # type: _Pointer[llama_token_data_array] - min_temp: Union[c_float, float], - max_temp: Union[c_float, float], - exponent_val: Union[c_float, float], + min_temp: Union[ctypes.c_float, float], + max_temp: Union[ctypes.c_float, float], + exponent_val: Union[ctypes.c_float, float], + / ): """Dynamic temperature implementation described in the paper https://arxiv.org/abs/2309.02772.""" - return _lib.llama_sample_entropy(ctx, candidates, min_temp, max_temp, exponent_val) + ... -_lib.llama_sample_entropy.argtypes = [ - llama_context_p, +llama_sample_entropy = _lib.llama_sample_entropy +llama_sample_entropy.argtypes = [ + llama_context_p_ctypes, llama_token_data_array_p, - c_float, - c_float, - c_float, + ctypes.c_float, + ctypes.c_float, + ctypes.c_float, ] -_lib.llama_sample_entropy.restype = None +llama_sample_entropy.restype = None # LLAMA_API void llama_sample_temp( @@ -2350,7 +2439,8 @@ _lib.llama_sample_entropy.restype = None def llama_sample_temp( ctx: llama_context_p, candidates, # type: _Pointer[llama_token_data_array] - temp: Union[c_float, float], + temp: Union[ctypes.c_float, float], + / ): """Temperature sampling described in academic paper "Generating Long Sequences with Sparse Transformers" https://arxiv.org/abs/1904.10509 @@ -2358,15 +2448,16 @@ def llama_sample_temp( candidates: A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. temp: The temperature value to use for the sampling. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. """ - return _lib.llama_sample_temp(ctx, candidates, temp) + ... -_lib.llama_sample_temp.argtypes = [ - llama_context_p, +llama_sample_temp = _lib.llama_sample_temp +llama_sample_temp.argtypes = [ + llama_context_p_ctypes, llama_token_data_array_p, - c_float, + ctypes.c_float, ] -_lib.llama_sample_temp.restype = None +llama_sample_temp.restype = None # LLAMA_API DEPRECATED(void llama_sample_temperature( @@ -2377,18 +2468,20 @@ _lib.llama_sample_temp.restype = None def llama_sample_temperature( ctx: llama_context_p, candidates, # type: _Pointer[llama_token_data_array] - temp: Union[c_float, float], + temp: Union[ctypes.c_float, float], + / ): """use llama_sample_temp instead""" - return _lib.llama_sample_temperature(ctx, candidates, temp) + ... -_lib.llama_sample_temperature.argtypes = [ - llama_context_p, +llama_sample_temperature = _lib.llama_sample_temperature +llama_sample_temperature.argtypes = [ + llama_context_p_ctypes, llama_token_data_array_p, - c_float, + ctypes.c_float, ] -_lib.llama_sample_temperature.restype = None +llama_sample_temperature.restype = None # /// @details Apply constraints from grammar @@ -2400,6 +2493,7 @@ def llama_sample_grammar( ctx: llama_context_p, candidates, # type: _Pointer[llama_token_data_array] grammar, # type: llama_grammar_p + / ): """Apply constraints from grammar @@ -2407,15 +2501,16 @@ def llama_sample_grammar( candidates: A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. grammar: A grammar object containing the rules and constraints to apply to the generated text. """ - return _lib.llama_sample_grammar(ctx, candidates, grammar) + ... -_lib.llama_sample_grammar.argtypes = [ - llama_context_p, +llama_sample_grammar = _lib.llama_sample_grammar +llama_sample_grammar.argtypes = [ + llama_context_p_ctypes, llama_token_data_array_p, llama_grammar_p, ] -_lib.llama_sample_grammar.restype = None +llama_sample_grammar.restype = None # /// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. @@ -2434,10 +2529,11 @@ _lib.llama_sample_grammar.restype = None def llama_sample_token_mirostat( ctx: llama_context_p, candidates, # type: _Pointer[llama_token_data_array] - tau: Union[c_float, float], - eta: Union[c_float, float], - m: Union[c_int, int], - mu, # type: _Pointer[c_float] + tau: Union[ctypes.c_float, float], + eta: Union[ctypes.c_float, float], + m: Union[ctypes.c_int, int], + mu, # type: _Pointer[ctypes.c_float] + / ) -> int: """Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. @@ -2448,18 +2544,19 @@ def llama_sample_token_mirostat( m: The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm. mu: Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. """ - return _lib.llama_sample_token_mirostat(ctx, candidates, tau, eta, m, mu) + ... -_lib.llama_sample_token_mirostat.argtypes = [ - llama_context_p, +llama_sample_token_mirostat = _lib.llama_sample_token_mirostat +llama_sample_token_mirostat.argtypes = [ + llama_context_p_ctypes, llama_token_data_array_p, - c_float, - c_float, - c_int32, + ctypes.c_float, + ctypes.c_float, + ctypes.c_int32, c_float_p, ] -_lib.llama_sample_token_mirostat.restype = llama_token +llama_sample_token_mirostat.restype = llama_token # /// @details Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. @@ -2476,9 +2573,10 @@ _lib.llama_sample_token_mirostat.restype = llama_token def llama_sample_token_mirostat_v2( ctx: llama_context_p, candidates, # type: _Pointer[llama_token_data_array] - tau: Union[c_float, float], - eta: Union[c_float, float], - mu, # type: _Pointer[c_float] + tau: Union[ctypes.c_float, float], + eta: Union[ctypes.c_float, float], + mu, # type: _Pointer[ctypes.c_float] + / ) -> int: """Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. @@ -2488,17 +2586,18 @@ def llama_sample_token_mirostat_v2( eta: The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. mu: Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. """ - return _lib.llama_sample_token_mirostat_v2(ctx, candidates, tau, eta, mu) + ... -_lib.llama_sample_token_mirostat_v2.argtypes = [ - llama_context_p, +llama_sample_token_mirostat_v2 = _lib.llama_sample_token_mirostat_v2 +llama_sample_token_mirostat_v2.argtypes = [ + llama_context_p_ctypes, llama_token_data_array_p, - c_float, - c_float, + ctypes.c_float, + ctypes.c_float, c_float_p, ] -_lib.llama_sample_token_mirostat_v2.restype = llama_token +llama_sample_token_mirostat_v2.restype = llama_token # /// @details Selects the token with the highest probability. @@ -2509,16 +2608,18 @@ _lib.llama_sample_token_mirostat_v2.restype = llama_token def llama_sample_token_greedy( ctx: llama_context_p, candidates, # type: _Pointer[llama_token_data_array] + / ) -> int: """Selects the token with the highest probability.""" - return _lib.llama_sample_token_greedy(ctx, candidates) + ... -_lib.llama_sample_token_greedy.argtypes = [ - llama_context_p, +llama_sample_token_greedy = _lib.llama_sample_token_greedy +llama_sample_token_greedy.argtypes = [ + llama_context_p_ctypes, llama_token_data_array_p, ] -_lib.llama_sample_token_greedy.restype = llama_token +llama_sample_token_greedy.restype = llama_token # /// @details Randomly selects a token from the candidates based on their probabilities. @@ -2528,16 +2629,18 @@ _lib.llama_sample_token_greedy.restype = llama_token def llama_sample_token( ctx: llama_context_p, candidates, # type: _Pointer[llama_token_data_array] + / ) -> int: """Randomly selects a token from the candidates based on their probabilities.""" - return _lib.llama_sample_token(ctx, candidates) + ... -_lib.llama_sample_token.argtypes = [ - llama_context_p, +llama_sample_token = _lib.llama_sample_token +llama_sample_token.argtypes = [ + llama_context_p_ctypes, llama_token_data_array_p, ] -_lib.llama_sample_token.restype = llama_token +llama_sample_token.restype = llama_token # /// @details Accepts the sampled token into the grammar @@ -2549,17 +2652,19 @@ def llama_grammar_accept_token( ctx: llama_context_p, grammar: llama_grammar_p, token: Union[llama_token, int], + / ) -> None: """Accepts the sampled token into the grammar""" - _lib.llama_grammar_accept_token(ctx, grammar, token) + ... -_lib.llama_grammar_accept_token.argtypes = [ - llama_context_p, +llama_grammar_accept_token = _lib.llama_grammar_accept_token +llama_grammar_accept_token.argtypes = [ + llama_context_p_ctypes, llama_grammar_p, llama_token, ] -_lib.llama_grammar_accept_token.restype = None +llama_grammar_accept_token.restype = None # // @@ -2577,9 +2682,9 @@ _lib.llama_grammar_accept_token.restype = None class llama_beam_view(ctypes.Structure): _fields_ = [ ("tokens", llama_token_p), - ("n_tokens", c_size_t), - ("p", c_float), - ("eob", c_bool), + ("n_tokens", ctypes.c_size_t), + ("p", ctypes.c_float), + ("eob", ctypes.c_bool), ] @@ -2595,10 +2700,10 @@ class llama_beam_view(ctypes.Structure): # }; class llama_beams_state(ctypes.Structure): _fields_ = [ - ("beam_views", POINTER(llama_beam_view)), - ("n_beams", c_size_t), - ("common_prefix_length", c_size_t), - ("last_call", c_bool), + ("beam_views", ctypes.POINTER(llama_beam_view)), + ("n_beams", ctypes.c_size_t), + ("common_prefix_length", ctypes.c_size_t), + ("last_call", ctypes.c_bool), ] @@ -2606,7 +2711,7 @@ class llama_beams_state(ctypes.Structure): # // void* callback_data is any custom data passed to llama_beam_search, that is subsequently # // passed back to beam_search_callback. This avoids having to use global variables in the callback. # typedef void (*llama_beam_search_callback_fn_t)(void * callback_data, struct llama_beams_state); -llama_beam_search_callback_fn_t = ctypes.CFUNCTYPE(None, c_void_p, llama_beams_state) +llama_beam_search_callback_fn_t = ctypes.CFUNCTYPE(None, ctypes.c_void_p, llama_beams_state) # /// @details Deterministically returns entire sentence constructed by a beam search. @@ -2626,70 +2731,74 @@ llama_beam_search_callback_fn_t = ctypes.CFUNCTYPE(None, c_void_p, llama_beams_s # int32_t n_predict); def llama_beam_search( ctx: llama_context_p, - callback: "ctypes._CFuncPtr[None, c_void_p, llama_beams_state]", # type: ignore - callback_data: c_void_p, - n_beams: Union[c_size_t, int], - n_past: Union[c_int, int], - n_predict: Union[c_int, int], + callback: "ctypes._CFuncPtr[None, ctypes.c_void_p, llama_beams_state]", # type: ignore + callback_data: ctypes.c_void_p, + n_beams: Union[ctypes.c_size_t, int], + n_past: Union[ctypes.c_int, int], + n_predict: Union[ctypes.c_int, int], + / ): - return _lib.llama_beam_search( - ctx, callback, callback_data, n_beams, n_past, n_predict - ) + ... -_lib.llama_beam_search.argtypes = [ - llama_context_p, +llama_beam_search = _lib.llama_beam_search +llama_beam_search.argtypes = [ + llama_context_p_ctypes, llama_beam_search_callback_fn_t, - c_void_p, - c_size_t, - c_int32, - c_int32, + ctypes.c_void_p, + ctypes.c_size_t, + ctypes.c_int32, + ctypes.c_int32, ] -_lib.llama_beam_search.restype = None +llama_beam_search.restype = None # Performance information # LLAMA_API struct llama_timings llama_get_timings(struct llama_context * ctx); -def llama_get_timings(ctx: llama_context_p) -> llama_timings: +def llama_get_timings(ctx: llama_context_p, /) -> llama_timings: """Get performance information""" - return _lib.llama_get_timings(ctx) + ... -_lib.llama_get_timings.argtypes = [llama_context_p] -_lib.llama_get_timings.restype = llama_timings +llama_get_timings = _lib.llama_get_timings +llama_get_timings.argtypes = [llama_context_p_ctypes] +llama_get_timings.restype = llama_timings # LLAMA_API void llama_print_timings(struct llama_context * ctx); -def llama_print_timings(ctx: llama_context_p): +def llama_print_timings(ctx: llama_context_p, /): """Print performance information""" - _lib.llama_print_timings(ctx) + ... -_lib.llama_print_timings.argtypes = [llama_context_p] -_lib.llama_print_timings.restype = None +llama_print_timings = _lib.llama_print_timings +llama_print_timings.argtypes = [llama_context_p_ctypes] +llama_print_timings.restype = None # LLAMA_API void llama_reset_timings(struct llama_context * ctx); -def llama_reset_timings(ctx: llama_context_p): +def llama_reset_timings(ctx: llama_context_p, /): """Reset performance information""" - _lib.llama_reset_timings(ctx) + ... -_lib.llama_reset_timings.argtypes = [llama_context_p] -_lib.llama_reset_timings.restype = None +llama_reset_timings = _lib.llama_reset_timings +llama_reset_timings.argtypes = [llama_context_p_ctypes] +llama_reset_timings.restype = None # Print system information # LLAMA_API const char * llama_print_system_info(void); def llama_print_system_info() -> bytes: """Print system information""" - return _lib.llama_print_system_info() + ... -_lib.llama_print_system_info.argtypes = [] -_lib.llama_print_system_info.restype = c_char_p +llama_print_system_info = _lib.llama_print_system_info +llama_print_system_info.argtypes = [] +llama_print_system_info.restype = ctypes.c_char_p # NOTE: THIS IS CURRENTLY BROKEN AS ggml_log_callback IS NOT EXPOSED IN LLAMA.H @@ -2697,22 +2806,25 @@ _lib.llama_print_system_info.restype = c_char_p # // If this is not called, or NULL is supplied, everything is output on stderr. # LLAMA_API void llama_log_set(ggml_log_callback log_callback, void * user_data); def llama_log_set( - log_callback: Union["ctypes._FuncPointer", c_void_p], user_data: c_void_p # type: ignore + log_callback: Union["ctypes._FuncPointer", ctypes.c_void_p], user_data: ctypes.c_void_p, # type: ignore + / ): """Set callback for all future logging events. If this is not called, or NULL is supplied, everything is output on stderr.""" - return _lib.llama_log_set(log_callback, user_data) + ... -_lib.llama_log_set.argtypes = [ctypes.c_void_p, c_void_p] -_lib.llama_log_set.restype = None +llama_log_set = _lib.llama_log_set +llama_log_set.argtypes = [ctypes.c_void_p, ctypes.c_void_p] +llama_log_set.restype = None # LLAMA_API void llama_dump_timing_info_yaml(FILE * stream, const struct llama_context * ctx); -def llama_dump_timing_info_yaml(stream: ctypes.c_void_p, ctx: llama_context_p): - return _lib.llama_dump_timing_info_yaml(stream, ctx) +def llama_dump_timing_info_yaml(stream: ctypes.c_void_p, ctx: llama_context_p, /): + ... -_lib.llama_dump_timing_info_yaml.argtypes = [ctypes.c_void_p, llama_context_p] -_lib.llama_dump_timing_info_yaml.restype = None +llama_dump_timing_info_yaml = _lib.llama_dump_timing_info_yaml +llama_dump_timing_info_yaml.argtypes = [ctypes.c_void_p, llama_context_p_ctypes] +llama_dump_timing_info_yaml.restype = None diff --git a/llama_cpp/llava_cpp.py b/llama_cpp/llava_cpp.py index 8195bd4..4eaa9e5 100644 --- a/llama_cpp/llava_cpp.py +++ b/llama_cpp/llava_cpp.py @@ -5,21 +5,15 @@ from ctypes import ( c_bool, c_char_p, c_int, - c_int8, - c_int32, c_uint8, - c_uint32, - c_size_t, c_float, - c_double, c_void_p, POINTER, _Pointer, # type: ignore Structure, - Array, ) import pathlib -from typing import List, Union +from typing import List, Union, NewType, Optional import llama_cpp.llama_cpp as llama_cpp @@ -67,7 +61,7 @@ def _load_shared_library(lib_base_name: str): for _lib_path in _lib_paths: if _lib_path.exists(): try: - return ctypes.CDLL(str(_lib_path), **cdll_args) + return ctypes.CDLL(str(_lib_path), **cdll_args) # type: ignore except Exception as e: raise RuntimeError(f"Failed to load shared library '{_lib_path}': {e}") @@ -88,7 +82,8 @@ _libllava = _load_shared_library(_libllava_base_name) ################################################ # struct clip_ctx; -clip_ctx_p = c_void_p +clip_ctx_p = NewType("clip_ctx_p", int) +clip_ctx_p_ctypes = c_void_p # struct llava_image_embed { # float * embed; @@ -102,43 +97,48 @@ class llava_image_embed(Structure): # /** sanity check for clip <-> llava embed size match */ # LLAVA_API bool llava_validate_embed_size(const llama_context * ctx_llama, const clip_ctx * ctx_clip); -def llava_validate_embed_size(ctx_llama: llama_cpp.llama_context_p, ctx_clip: clip_ctx_p) -> bool: - return _libllava.llava_validate_embed_size(ctx_llama, ctx_clip) +def llava_validate_embed_size(ctx_llama: llama_cpp.llama_context_p, ctx_clip: clip_ctx_p, /) -> bool: + ... -_libllava.llava_validate_embed_size.argtypes = [llama_cpp.llama_context_p, clip_ctx_p] -_libllava.llava_validate_embed_size.restype = c_bool +llava_validate_embed_size = _libllava.llava_validate_embed_size +llava_validate_embed_size.argtypes = [llama_cpp.llama_context_p_ctypes, clip_ctx_p_ctypes] +llava_validate_embed_size.restype = c_bool # /** build an image embed from image file bytes */ # LLAVA_API struct llava_image_embed * llava_image_embed_make_with_bytes(struct clip_ctx * ctx_clip, int n_threads, const unsigned char * image_bytes, int image_bytes_length); -def llava_image_embed_make_with_bytes(ctx_clip: clip_ctx_p, n_threads: Union[c_int, int], image_bytes: bytes, image_bytes_length: Union[c_int, int]) -> "_Pointer[llava_image_embed]": - return _libllava.llava_image_embed_make_with_bytes(ctx_clip, n_threads, image_bytes, image_bytes_length) +def llava_image_embed_make_with_bytes(ctx_clip: clip_ctx_p, n_threads: Union[c_int, int], image_bytes: bytes, image_bytes_length: Union[c_int, int], /) -> "_Pointer[llava_image_embed]": + ... -_libllava.llava_image_embed_make_with_bytes.argtypes = [clip_ctx_p, c_int, POINTER(c_uint8), c_int] -_libllava.llava_image_embed_make_with_bytes.restype = POINTER(llava_image_embed) +llava_image_embed_make_with_bytes = _libllava.llava_image_embed_make_with_bytes +llava_image_embed_make_with_bytes.argtypes = [clip_ctx_p_ctypes, c_int, POINTER(c_uint8), c_int] +llava_image_embed_make_with_bytes.restype = POINTER(llava_image_embed) # /** build an image embed from a path to an image filename */ # LLAVA_API struct llava_image_embed * llava_image_embed_make_with_filename(struct clip_ctx * ctx_clip, int n_threads, const char * image_path); -def llava_image_embed_make_with_filename(ctx_clip: clip_ctx_p, n_threads: Union[c_int, int], image_path: bytes) -> "_Pointer[llava_image_embed]": - return _libllava.llava_image_embed_make_with_filename(ctx_clip, n_threads, image_path) +def llava_image_embed_make_with_filename(ctx_clip: clip_ctx_p, n_threads: Union[c_int, int], image_path: bytes, /) -> "_Pointer[llava_image_embed]": + ... -_libllava.llava_image_embed_make_with_filename.argtypes = [clip_ctx_p, c_int, c_char_p] -_libllava.llava_image_embed_make_with_filename.restype = POINTER(llava_image_embed) +llava_image_embed_make_with_filename = _libllava.llava_image_embed_make_with_filename +llava_image_embed_make_with_filename.argtypes = [clip_ctx_p_ctypes, c_int, c_char_p] +llava_image_embed_make_with_filename.restype = POINTER(llava_image_embed) # LLAVA_API void llava_image_embed_free(struct llava_image_embed * embed); # /** free an embedding made with llava_image_embed_make_* */ -def llava_image_embed_free(embed: "_Pointer[llava_image_embed]"): - return _libllava.llava_image_embed_free(embed) +def llava_image_embed_free(embed: "_Pointer[llava_image_embed]", /): + ... -_libllava.llava_image_embed_free.argtypes = [POINTER(llava_image_embed)] -_libllava.llava_image_embed_free.restype = None +llava_image_embed_free = _libllava.llava_image_embed_free +llava_image_embed_free.argtypes = [POINTER(llava_image_embed)] +llava_image_embed_free.restype = None # /** write the image represented by embed into the llama context with batch size n_batch, starting at context pos n_past. on completion, n_past points to the next position in the context after the image embed. */ # LLAVA_API bool llava_eval_image_embed(struct llama_context * ctx_llama, const struct llava_image_embed * embed, int n_batch, int * n_past); -def llava_eval_image_embed(ctx_llama: llama_cpp.llama_context_p, embed: "_Pointer[llava_image_embed]", n_batch: Union[c_int, int], n_past: "_Pointer[c_int]") -> bool: - return _libllava.llava_eval_image_embed(ctx_llama, embed, n_batch, n_past) +def llava_eval_image_embed(ctx_llama: llama_cpp.llama_context_p, embed: "_Pointer[llava_image_embed]", n_batch: Union[c_int, int], n_past: "_Pointer[c_int]", /) -> bool: + ... -_libllava.llava_eval_image_embed.argtypes = [llama_cpp.llama_context_p, POINTER(llava_image_embed), c_int, POINTER(c_int)] -_libllava.llava_eval_image_embed.restype = c_bool +llava_eval_image_embed = _libllava.llava_eval_image_embed +llava_eval_image_embed.argtypes = [llama_cpp.llama_context_p_ctypes, POINTER(llava_image_embed), c_int, POINTER(c_int)] +llava_eval_image_embed.restype = c_bool ################################################ @@ -148,16 +148,18 @@ _libllava.llava_eval_image_embed.restype = c_bool # /** load mmproj model */ # CLIP_API struct clip_ctx * clip_model_load (const char * fname, int verbosity); -def clip_model_load(fname: bytes, verbosity: Union[c_int, int]) -> clip_ctx_p: - return _libllava.clip_model_load(fname, verbosity) +def clip_model_load(fname: bytes, verbosity: Union[c_int, int], /) -> Optional[clip_ctx_p]: + ... -_libllava.clip_model_load.argtypes = [c_char_p, c_int] -_libllava.clip_model_load.restype = clip_ctx_p +clip_model_load = _libllava.clip_model_load +clip_model_load.argtypes = [c_char_p, c_int] +clip_model_load.restype = clip_ctx_p_ctypes # /** free mmproj model */ # CLIP_API void clip_free(struct clip_ctx * ctx); -def clip_free(ctx: clip_ctx_p): - return _libllava.clip_free(ctx) +def clip_free(ctx: clip_ctx_p, /): + ... -_libllava.clip_free.argtypes = [clip_ctx_p] -_libllava.clip_free.restype = None +clip_free = _libllava.clip_free +clip_free.argtypes = [clip_ctx_p_ctypes] +clip_free.restype = None diff --git a/tests/test_llama.py b/tests/test_llama.py index dac33b7..5cf421b 100644 --- a/tests/test_llama.py +++ b/tests/test_llama.py @@ -54,7 +54,7 @@ def mock_llama(monkeypatch): output_tokens = llama.tokenize( output_text.encode("utf-8"), add_bos=True, special=True ) - logits = (llama_cpp.c_float * (n_vocab * n_ctx))(-100.0) + logits = (ctypes.c_float * (n_vocab * n_ctx))(-100.0) for i in range(n_ctx): output_idx = i + 1 # logits for first tokens predict second token if output_idx < len(output_tokens): @@ -90,9 +90,9 @@ def mock_llama(monkeypatch): assert n > 0, "mock_llama_decode not called" assert last_n_tokens > 0, "mock_llama_decode not called" # Return view of logits for last_n_tokens - return (llama_cpp.c_float * (last_n_tokens * n_vocab)).from_address( + return (ctypes.c_float * (last_n_tokens * n_vocab)).from_address( ctypes.addressof(logits) - + (n - last_n_tokens) * n_vocab * ctypes.sizeof(llama_cpp.c_float) + + (n - last_n_tokens) * n_vocab * ctypes.sizeof(ctypes.c_float) ) monkeypatch.setattr("llama_cpp.llama_cpp.llama_decode", mock_decode)