diff --git a/llama_cpp/llama_cpp.py b/llama_cpp/llama_cpp.py index 3da131e..f4d523b 100644 --- a/llama_cpp/llama_cpp.py +++ b/llama_cpp/llama_cpp.py @@ -3,12 +3,12 @@ from __future__ import annotations import sys import os import ctypes -from ctypes import ( - _Pointer, # type: ignore - Array, -) +import functools import pathlib + from typing import ( + Any, + Callable, List, Union, NewType, @@ -110,6 +110,36 @@ if TYPE_CHECKING: CtypesFuncPointer: TypeAlias = ctypes._FuncPointer # type: ignore +def ctypes_function_for_shared_library(lib: ctypes.CDLL): + def ctypes_function( + name: str, argtypes: List[Any], restype: Any, enabled: bool = True + ): + def decorator(f: Callable[..., Any]): + if enabled: + func = getattr(lib, name) + func.argtypes = argtypes + func.restype = restype + functools.wraps(f)(func) + return func + else: + return f + + return decorator + + return ctypes_function + + +ctypes_function = ctypes_function_for_shared_library(_lib) + + +def byref(obj: CtypesCData, offset: Optional[int] = None) -> CtypesRef[CtypesCData]: + """Type-annotated version of ctypes.byref""" + ... + + +byref = ctypes.byref # type: ignore + + # from ggml-backend.h # typedef bool (*ggml_backend_sched_eval_callback)(struct ggml_tensor * t, bool ask, void * user_data); ggml_backend_sched_eval_callback = ctypes.CFUNCTYPE( @@ -707,43 +737,48 @@ class llama_chat_message(ctypes.Structure): # // Helpers for getting default parameters # LLAMA_API struct llama_model_params llama_model_default_params(void); +@ctypes_function( + "llama_model_default_params", + [], + llama_model_params, +) def llama_model_default_params() -> llama_model_params: """Get default parameters for llama_model""" ... -llama_model_default_params = _lib.llama_model_default_params -llama_model_default_params.argtypes = [] -llama_model_default_params.restype = llama_model_params - - # LLAMA_API struct llama_context_params llama_context_default_params(void); +@ctypes_function( + "llama_context_default_params", + [], + llama_context_params, +) def llama_context_default_params() -> llama_context_params: """Get default parameters for llama_context""" ... -llama_context_default_params = _lib.llama_context_default_params -llama_context_default_params.argtypes = [] -llama_context_default_params.restype = llama_context_params - - # LLAMA_API struct llama_model_quantize_params llama_model_quantize_default_params(void); +@ctypes_function( + "llama_model_quantize_default_params", + [], + llama_model_quantize_params, +) def llama_model_quantize_default_params() -> llama_model_quantize_params: """Get default parameters for llama_model_quantize""" ... -llama_model_quantize_default_params = _lib.llama_model_quantize_default_params -llama_model_quantize_default_params.argtypes = [] -llama_model_quantize_default_params.restype = llama_model_quantize_params - - # // Initialize the llama + ggml backend # // If numa is true, use NUMA optimizations # // Call once at the start of the program # LLAMA_API void llama_backend_init(bool numa); # LLAMA_API void llama_backend_init(void); +@ctypes_function( + "llama_backend_init", + [], + None, +) def llama_backend_init(): """Initialize the llama + ggml backend If numa is true, use NUMA optimizations @@ -751,11 +786,6 @@ def llama_backend_init(): ... -llama_backend_init = _lib.llama_backend_init -llama_backend_init.argtypes = [] -llama_backend_init.restype = None - - # // numa strategies # enum ggml_numa_strategy { # GGML_NUMA_STRATEGY_DISABLED = 0, @@ -775,228 +805,201 @@ GGML_NUMA_STRATEGY_COUNT = 5 # //optional: # LLAMA_API void llama_numa_init(enum ggml_numa_strategy numa); +@ctypes_function( + "llama_numa_init", + [ctypes.c_int], + None, +) def llama_numa_init(numa: int, /): ... -llama_numa_init = _lib.llama_numa_init -llama_numa_init.argtypes = [ctypes.c_int] -llama_numa_init.restype = None - - # // Call once at the end of the program - currently only used for MPI # LLAMA_API void llama_backend_free(void); +@ctypes_function( + "llama_backend_free", + [], + None, +) def llama_backend_free(): """Call once at the end of the program - currently only used for MPI""" ... -llama_backend_free = _lib.llama_backend_free -llama_backend_free.argtypes = [] -llama_backend_free.restype = None - - # LLAMA_API struct llama_model * llama_load_model_from_file( # const char * path_model, # struct llama_model_params params); +@ctypes_function( + "llama_load_model_from_file", + [ctypes.c_char_p, llama_model_params], + llama_model_p_ctypes, +) def llama_load_model_from_file( path_model: bytes, params: llama_model_params, / ) -> Optional[llama_model_p]: ... -llama_load_model_from_file = _lib.llama_load_model_from_file -llama_load_model_from_file.argtypes = [ctypes.c_char_p, llama_model_params] -llama_load_model_from_file.restype = llama_model_p_ctypes - - # LLAMA_API void llama_free_model(struct llama_model * model); +@ctypes_function( + "llama_free_model", + [llama_model_p_ctypes], + None, +) def llama_free_model(model: llama_model_p, /): ... -llama_free_model = _lib.llama_free_model -llama_free_model.argtypes = [llama_model_p_ctypes] -llama_free_model.restype = None - - # LLAMA_API struct llama_context * llama_new_context_with_model( # struct llama_model * model, # struct llama_context_params params); +@ctypes_function( + "llama_new_context_with_model", + [llama_model_p_ctypes, llama_context_params], + llama_context_p_ctypes, +) def llama_new_context_with_model( model: llama_model_p, params: llama_context_params, / ) -> Optional[llama_context_p]: ... -llama_new_context_with_model = _lib.llama_new_context_with_model -llama_new_context_with_model.argtypes = [llama_model_p_ctypes, llama_context_params] -llama_new_context_with_model.restype = llama_context_p_ctypes - - # // Frees all allocated memory # LLAMA_API void llama_free(struct llama_context * ctx); +@ctypes_function( + "llama_free", + [llama_context_p_ctypes], + None, +) def llama_free(ctx: llama_context_p, /): """Frees all allocated memory""" ... -llama_free = _lib.llama_free -llama_free.argtypes = [llama_context_p_ctypes] -llama_free.restype = None - - # LLAMA_API int64_t llama_time_us(void); +@ctypes_function( + "llama_time_us", + [], + ctypes.c_int64, +) def llama_time_us() -> int: ... -llama_time_us = _lib.llama_time_us -llama_time_us.argtypes = [] -llama_time_us.restype = ctypes.c_int64 - - # LLAMA_API size_t llama_max_devices(void); + + +@ctypes_function("llama_max_devices", [], ctypes.c_size_t) def llama_max_devices() -> int: ... -llama_max_devices = _lib.llama_max_devices -llama_max_devices.argtypes = [] -llama_max_devices.restype = ctypes.c_size_t - - # LLAMA_API bool llama_supports_mmap (void); + + +@ctypes_function("llama_supports_mmap", [], ctypes.c_bool) def llama_supports_mmap() -> bool: ... -llama_supports_mmap = _lib.llama_supports_mmap -llama_supports_mmap.argtypes = [] -llama_supports_mmap.restype = ctypes.c_bool - - # LLAMA_API bool llama_supports_mlock (void); + + +@ctypes_function("llama_supports_mlock", [], ctypes.c_bool) def llama_supports_mlock() -> bool: ... -llama_supports_mlock = _lib.llama_supports_mlock -llama_supports_mlock.argtypes = [] -llama_supports_mlock.restype = ctypes.c_bool - - # LLAMA_API bool llama_supports_gpu_offload(void); + + +@ctypes_function("llama_supports_gpu_offload", [], ctypes.c_bool) def llama_supports_gpu_offload() -> bool: ... -llama_supports_gpu_offload = _lib.llama_supports_gpu_offload -llama_supports_gpu_offload.argtypes = [] -llama_supports_gpu_offload.restype = ctypes.c_bool - - # LLAMA_API DEPRECATED(bool llama_mmap_supported (void), "use llama_supports_mmap() instead"); + + +@ctypes_function("llama_mmap_supported", [], ctypes.c_bool) def llama_mmap_supported() -> bool: ... -llama_mmap_supported = _lib.llama_mmap_supported -llama_mmap_supported.argtypes = [] -llama_mmap_supported.restype = ctypes.c_bool - - # LLAMA_API DEPRECATED(bool llama_mlock_supported(void), "use llama_supports_mlock() instead"); + + +@ctypes_function("llama_mlock_supported", [], ctypes.c_bool) def llama_mlock_supported() -> bool: ... -llama_mlock_supported = _lib.llama_mlock_supported -llama_mlock_supported.argtypes = [] -llama_mlock_supported.restype = ctypes.c_bool - - # LLAMA_API const struct llama_model * llama_get_model(const struct llama_context * ctx); + + +@ctypes_function("llama_get_model", [llama_context_p_ctypes], llama_model_p_ctypes) def llama_get_model(ctx: llama_context_p, /) -> Optional[llama_model_p]: ... -llama_get_model = _lib.llama_get_model -llama_get_model.argtypes = [llama_context_p_ctypes] -llama_get_model.restype = llama_model_p_ctypes - - # LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx); + + +@ctypes_function("llama_n_ctx", [llama_context_p_ctypes], ctypes.c_uint32) def llama_n_ctx(ctx: llama_context_p, /) -> int: ... -llama_n_ctx = _lib.llama_n_ctx -llama_n_ctx.argtypes = [llama_context_p_ctypes] -llama_n_ctx.restype = ctypes.c_uint32 - - # LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx); + + +@ctypes_function("llama_n_batch", [llama_context_p_ctypes], ctypes.c_uint32) def llama_n_batch(ctx: llama_context_p, /) -> int: ... -llama_n_batch = _lib.llama_n_batch -llama_n_batch.argtypes = [llama_context_p_ctypes] -llama_n_batch.restype = ctypes.c_uint32 - - # LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_model * model); + + +@ctypes_function("llama_vocab_type", [llama_model_p_ctypes], ctypes.c_int) def llama_vocab_type(model: llama_model_p, /) -> int: ... -llama_vocab_type = _lib.llama_vocab_type -llama_vocab_type.argtypes = [llama_model_p_ctypes] -llama_vocab_type.restype = ctypes.c_int - - # LLAMA_API int32_t llama_n_vocab (const struct llama_model * model); + + +@ctypes_function("llama_n_vocab", [llama_model_p_ctypes], ctypes.c_int32) def llama_n_vocab(model: llama_model_p, /) -> int: ... -llama_n_vocab = _lib.llama_n_vocab -llama_n_vocab.argtypes = [llama_model_p_ctypes] -llama_n_vocab.restype = ctypes.c_int32 - - # LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model); + + +@ctypes_function("llama_n_ctx_train", [llama_model_p_ctypes], ctypes.c_int32) def llama_n_ctx_train(model: llama_model_p, /) -> int: ... -llama_n_ctx_train = _lib.llama_n_ctx_train -llama_n_ctx_train.argtypes = [llama_model_p_ctypes] -llama_n_ctx_train.restype = ctypes.c_int32 - - # LLAMA_API int32_t llama_n_embd (const struct llama_model * model); + + +@ctypes_function("llama_n_embd", [llama_model_p_ctypes], ctypes.c_int32) def llama_n_embd(model: llama_model_p, /) -> int: ... -llama_n_embd = _lib.llama_n_embd -llama_n_embd.argtypes = [llama_model_p_ctypes] -llama_n_embd.restype = ctypes.c_int32 - - # // Get the model's RoPE frequency scaling factor # LLAMA_API float llama_rope_freq_scale_train(const struct llama_model * model); + + +@ctypes_function("llama_rope_freq_scale_train", [llama_model_p_ctypes], ctypes.c_float) def llama_rope_freq_scale_train(model: llama_model_p, /) -> float: """Get the model's RoPE frequency scaling factor""" ... -llama_rope_freq_scale_train = _lib.llama_rope_freq_scale_train -llama_rope_freq_scale_train.argtypes = [llama_model_p_ctypes] -llama_rope_freq_scale_train.restype = ctypes.c_float - # // Functions to access the model's GGUF metadata scalar values # // - The functions return the length of the string on success, or -1 on failure # // - The output string is always null-terminated and cleared on failure @@ -1005,6 +1008,18 @@ llama_rope_freq_scale_train.restype = ctypes.c_float # // Get metadata value as a string by key name # LLAMA_API int32_t llama_model_meta_val_str(const struct llama_model * model, const char * key, char * buf, size_t buf_size); + + +@ctypes_function( + "llama_model_meta_val_str", + [ + llama_model_p_ctypes, + ctypes.c_char_p, + ctypes.c_char_p, + ctypes.c_size_t, + ], + ctypes.c_int32, +) def llama_model_meta_val_str( model: llama_model_p, key: Union[ctypes.c_char_p, bytes], @@ -1016,106 +1031,112 @@ def llama_model_meta_val_str( ... -llama_model_meta_val_str = _lib.llama_model_meta_val_str -llama_model_meta_val_str.argtypes = [ - llama_model_p_ctypes, - ctypes.c_char_p, - ctypes.c_char_p, - ctypes.c_size_t, -] -llama_model_meta_val_str.restype = ctypes.c_int32 - - # // Get the number of metadata key/value pairs # LLAMA_API int32_t llama_model_meta_count(const struct llama_model * model); + + +@ctypes_function("llama_model_meta_count", [llama_model_p_ctypes], ctypes.c_int32) def llama_model_meta_count(model: llama_model_p, /) -> int: """Get the number of metadata key/value pairs""" ... -llama_model_meta_count = _lib.llama_model_meta_count -llama_model_meta_count.argtypes = [llama_model_p_ctypes] -llama_model_meta_count.restype = ctypes.c_int32 - - # // Get metadata key name by index # LLAMA_API int32_t llama_model_meta_key_by_index(const struct llama_model * model, int32_t i, char * buf, size_t buf_size); + + +@ctypes_function( + "llama_model_meta_key_by_index", + [ + llama_model_p_ctypes, + ctypes.c_int32, + ctypes.c_char_p, + ctypes.c_size_t, + ], + ctypes.c_int32, +) def llama_model_meta_key_by_index( - model: llama_model_p, i: Union[ctypes.c_int, int], buf: bytes, buf_size: int, / + model: llama_model_p, + i: Union[ctypes.c_int, int], + buf: Union[bytes, CtypesArray[ctypes.c_char]], + buf_size: int, + /, ) -> int: """Get metadata key name by index""" ... -llama_model_meta_key_by_index = _lib.llama_model_meta_key_by_index -llama_model_meta_key_by_index.argtypes = [ - llama_model_p_ctypes, - ctypes.c_int32, - ctypes.c_char_p, - ctypes.c_size_t, -] -llama_model_meta_key_by_index.restype = ctypes.c_int32 - - # // Get metadata value as a string by index # LLAMA_API int32_t llama_model_meta_val_str_by_index(const struct llama_model * model, int32_t i, char * buf, size_t buf_size); + + +@ctypes_function( + "llama_model_meta_val_str_by_index", + [ + llama_model_p_ctypes, + ctypes.c_int32, + ctypes.c_char_p, + ctypes.c_size_t, + ], + ctypes.c_int32, +) def llama_model_meta_val_str_by_index( - model: llama_model_p, i: Union[ctypes.c_int, int], buf: bytes, buf_size: int, / + model: llama_model_p, + i: Union[ctypes.c_int, int], + buf: Union[bytes, CtypesArray[ctypes.c_char]], + buf_size: int, + /, ) -> int: """Get metadata value as a string by index""" ... -llama_model_meta_val_str_by_index = _lib.llama_model_meta_val_str_by_index -llama_model_meta_val_str_by_index.argtypes = [ - llama_model_p_ctypes, - ctypes.c_int32, - ctypes.c_char_p, - ctypes.c_size_t, -] -llama_model_meta_val_str_by_index.restype = ctypes.c_int32 - - # // Get a string describing the model type # LLAMA_API int32_t llama_model_desc(const struct llama_model * model, char * buf, size_t buf_size); + + +@ctypes_function( + "llama_model_desc", + [llama_model_p_ctypes, ctypes.c_char_p, ctypes.c_size_t], + ctypes.c_int32, +) def llama_model_desc( - model: llama_model_p, buf: bytes, buf_size: Union[ctypes.c_size_t, int], / + model: llama_model_p, + buf: Union[bytes, CtypesArray[ctypes.c_char]], + buf_size: Union[ctypes.c_size_t, int], + /, ) -> int: """Get a string describing the model type""" ... -llama_model_desc = _lib.llama_model_desc -llama_model_desc.argtypes = [llama_model_p_ctypes, ctypes.c_char_p, ctypes.c_size_t] -llama_model_desc.restype = ctypes.c_int32 - - # // Returns the total size of all the tensors in the model in bytes # LLAMA_API uint64_t llama_model_size(const struct llama_model * model); + + +@ctypes_function("llama_model_size", [llama_model_p_ctypes], ctypes.c_uint64) def llama_model_size(model: llama_model_p, /) -> int: """Returns the total size of all the tensors in the model in bytes""" ... -llama_model_size = _lib.llama_model_size -llama_model_size.argtypes = [llama_model_p_ctypes] -llama_model_size.restype = ctypes.c_uint64 - - # // Returns the total number of parameters in the model # LLAMA_API uint64_t llama_model_n_params(const struct llama_model * model); + + +@ctypes_function("llama_model_n_params", [llama_model_p_ctypes], ctypes.c_uint64) def llama_model_n_params(model: llama_model_p, /) -> int: """Returns the total number of parameters in the model""" ... -llama_model_n_params = _lib.llama_model_n_params -llama_model_n_params.argtypes = [llama_model_p_ctypes] -llama_model_n_params.restype = ctypes.c_uint64 - - # // Get a llama model tensor # LLAMA_API struct ggml_tensor * llama_get_model_tensor(struct llama_model * model, const char * name); + + +@ctypes_function( + "llama_get_model_tensor", [llama_model_p_ctypes, ctypes.c_char_p], ctypes.c_void_p +) def llama_get_model_tensor( model: llama_model_p, name: Union[ctypes.c_char_p, bytes], / ) -> ctypes.c_void_p: @@ -1123,16 +1144,22 @@ def llama_get_model_tensor( ... -llama_get_model_tensor = _lib.llama_get_model_tensor -llama_get_model_tensor.argtypes = [llama_model_p_ctypes, ctypes.c_char_p] -llama_get_model_tensor.restype = ctypes.c_void_p - - # // Returns 0 on success # LLAMA_API uint32_t llama_model_quantize( # const char * fname_inp, # const char * fname_out, # const llama_model_quantize_params * params); + + +@ctypes_function( + "llama_model_quantize", + [ + ctypes.c_char_p, + ctypes.c_char_p, + ctypes.POINTER(llama_model_quantize_params), + ], + ctypes.c_uint32, +) def llama_model_quantize( fname_inp: bytes, fname_out: bytes, @@ -1143,15 +1170,6 @@ def llama_model_quantize( ... -llama_model_quantize = _lib.llama_model_quantize -llama_model_quantize.argtypes = [ - ctypes.c_char_p, - ctypes.c_char_p, - ctypes.POINTER(llama_model_quantize_params), -] -llama_model_quantize.restype = ctypes.c_uint32 - - # // Apply a LoRA adapter to a loaded model # // path_base_model is the path to a higher quality model to use as a base for # // the layers modified by the adapter. Can be NULL to use the current loaded model. @@ -1165,6 +1183,19 @@ llama_model_quantize.restype = ctypes.c_uint32 # const char * path_base_model, # int32_t n_threads), # "use llama_model_apply_lora_from_file instead"); + + +@ctypes_function( + "llama_apply_lora_from_file", + [ + llama_context_p_ctypes, + ctypes.c_char_p, + ctypes.c_float, + ctypes.c_char_p, + ctypes.c_int32, + ], + ctypes.c_int32, +) def llama_apply_lora_from_file( ctx: llama_context_p, path_lora: Union[ctypes.c_char_p, bytes], @@ -1182,23 +1213,25 @@ def llama_apply_lora_from_file( ... -llama_apply_lora_from_file = _lib.llama_apply_lora_from_file -llama_apply_lora_from_file.argtypes = [ - llama_context_p_ctypes, - ctypes.c_char_p, - ctypes.c_float, - ctypes.c_char_p, - ctypes.c_int32, -] -llama_apply_lora_from_file.restype = ctypes.c_int32 - - # LLAMA_API int32_t llama_model_apply_lora_from_file( # const struct llama_model * model, # const char * path_lora, # float scale, # const char * path_base_model, # int32_t n_threads); + + +@ctypes_function( + "llama_model_apply_lora_from_file", + [ + llama_model_p_ctypes, + ctypes.c_char_p, + ctypes.c_float, + ctypes.c_char_p, + ctypes.c_int32, + ], + ctypes.c_int32, +) def llama_model_apply_lora_from_file( model: llama_model_p, path_lora: Union[ctypes.c_char_p, bytes], @@ -1210,16 +1243,6 @@ def llama_model_apply_lora_from_file( ... -llama_model_apply_lora_from_file = _lib.llama_model_apply_lora_from_file -llama_model_apply_lora_from_file.argtypes = [ - llama_model_p_ctypes, - ctypes.c_char_p, - ctypes.c_float, - ctypes.c_char_p, - ctypes.c_int32, -] -llama_model_apply_lora_from_file.restype = ctypes.c_int32 - # // # // KV cache # // @@ -1285,6 +1308,13 @@ llama_kv_cache_view_p = ctypes.POINTER(llama_kv_cache_view) # // Create an empty KV cache view. (use only for debugging purposes) # LLAMA_API struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_context * ctx, int32_t n_max_seq); + + +@ctypes_function( + "llama_kv_cache_view_init", + [llama_context_p_ctypes, ctypes.c_int32], + llama_kv_cache_view, +) def llama_kv_cache_view_init( ctx: llama_context_p, n_max_seq: Union[ctypes.c_int32, int], / ) -> llama_kv_cache_view: @@ -1292,38 +1322,36 @@ def llama_kv_cache_view_init( ... -llama_kv_cache_view_init = _lib.llama_kv_cache_view_init -llama_kv_cache_view_init.argtypes = [llama_context_p_ctypes, ctypes.c_int32] -llama_kv_cache_view_init.restype = llama_kv_cache_view - - # // Free a KV cache view. (use only for debugging purposes) # LLAMA_API void llama_kv_cache_view_free(struct llama_kv_cache_view * view); + + +@ctypes_function("llama_kv_cache_view_free", [llama_kv_cache_view_p], None) def llama_kv_cache_view_free(view: "ctypes.pointer[llama_kv_cache_view]", /): # type: ignore """Free a KV cache view. (use only for debugging purposes)""" ... -llama_kv_cache_view_free = _lib.llama_kv_cache_view_free -llama_kv_cache_view_free.argtypes = [llama_kv_cache_view_p] -llama_kv_cache_view_free.restype = None - - # // Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes) # LLAMA_API void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_kv_cache_view * view); + + +@ctypes_function( + "llama_kv_cache_view_update", [llama_context_p_ctypes, llama_kv_cache_view_p], None +) def llama_kv_cache_view_update(ctx: llama_context_p, view: CtypesPointerOrRef[llama_kv_cache_view], /): # type: ignore """Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes)""" ... -llama_kv_cache_view_update = _lib.llama_kv_cache_view_update -llama_kv_cache_view_update.argtypes = [llama_context_p_ctypes, llama_kv_cache_view_p] -llama_kv_cache_view_update.restype = None - - # // Returns the number of tokens in the KV cache (slow, use only for debug) # // If a KV cell has multiple sequences assigned to it, it will be counted multiple times # LLAMA_API int32_t llama_get_kv_cache_token_count(const struct llama_context * ctx); + + +@ctypes_function( + "llama_get_kv_cache_token_count", [llama_context_p_ctypes], ctypes.c_int32 +) def llama_get_kv_cache_token_count(ctx: llama_context_p, /) -> int: """Returns the number of tokens in the KV cache (slow, use only for debug) If a KV cell has multiple sequences assigned to it, it will be counted multiple times @@ -1331,36 +1359,29 @@ def llama_get_kv_cache_token_count(ctx: llama_context_p, /) -> int: ... -llama_get_kv_cache_token_count = _lib.llama_get_kv_cache_token_count -llama_get_kv_cache_token_count.argtypes = [llama_context_p_ctypes] -llama_get_kv_cache_token_count.restype = ctypes.c_int32 - - # // Returns the number of used KV cells (i.e. have at least one sequence assigned to them) # LLAMA_API int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx); + + +@ctypes_function( + "llama_get_kv_cache_used_cells", [llama_context_p_ctypes], ctypes.c_int32 +) def llama_get_kv_cache_used_cells(ctx: llama_context_p, /) -> int: """Returns the number of used KV cells (i.e. have at least one sequence assigned to them)""" ... -llama_get_kv_cache_used_cells = _lib.llama_get_kv_cache_used_cells -llama_get_kv_cache_used_cells.argtypes = [llama_context_p_ctypes] -llama_get_kv_cache_used_cells.restype = ctypes.c_int32 - - # // Clear the KV cache # LLAMA_API void llama_kv_cache_clear( # struct llama_context * ctx); + + +@ctypes_function("llama_kv_cache_clear", [llama_context_p_ctypes], None) def llama_kv_cache_clear(ctx: llama_context_p, /): """Clear the KV cache""" ... -llama_kv_cache_clear = _lib.llama_kv_cache_clear -llama_kv_cache_clear.argtypes = [llama_context_p_ctypes] -llama_kv_cache_clear.restype = None - - # // Removes all tokens that belong to the specified sequence and have positions in [p0, p1) # // seq_id < 0 : match any sequence # // p0 < 0 : [0, p1] @@ -1370,6 +1391,18 @@ llama_kv_cache_clear.restype = None # llama_seq_id seq_id, # llama_pos p0, # llama_pos p1); + + +@ctypes_function( + "llama_kv_cache_seq_rm", + [ + llama_context_p_ctypes, + llama_seq_id, + llama_pos, + llama_pos, + ], + None, +) def llama_kv_cache_seq_rm( ctx: llama_context_p, seq_id: Union[llama_seq_id, int], @@ -1384,16 +1417,6 @@ def llama_kv_cache_seq_rm( ... -llama_kv_cache_seq_rm = _lib.llama_kv_cache_seq_rm -llama_kv_cache_seq_rm.argtypes = [ - llama_context_p_ctypes, - llama_seq_id, - llama_pos, - llama_pos, -] -llama_kv_cache_seq_rm.restype = None - - # // Copy all tokens that belong to the specified sequence to another sequence # // Note that this does not allocate extra KV cache memory - it simply assigns the tokens to the new sequence # // p0 < 0 : [0, p1] @@ -1404,6 +1427,19 @@ llama_kv_cache_seq_rm.restype = None # llama_seq_id seq_id_dst, # llama_pos p0, # llama_pos p1); + + +@ctypes_function( + "llama_kv_cache_seq_cp", + [ + llama_context_p_ctypes, + llama_seq_id, + llama_seq_id, + llama_pos, + llama_pos, + ], + None, +) def llama_kv_cache_seq_cp( ctx: llama_context_p, seq_id_src: Union[llama_seq_id, int], @@ -1419,31 +1455,20 @@ def llama_kv_cache_seq_cp( ... -llama_kv_cache_seq_cp = _lib.llama_kv_cache_seq_cp -llama_kv_cache_seq_cp.argtypes = [ - llama_context_p_ctypes, - llama_seq_id, - llama_seq_id, - llama_pos, - llama_pos, -] -llama_kv_cache_seq_cp.restype = None - - # // Removes all tokens that do not belong to the specified sequence # LLAMA_API void llama_kv_cache_seq_keep( # struct llama_context * ctx, # llama_seq_id seq_id); + + +@ctypes_function( + "llama_kv_cache_seq_keep", [llama_context_p_ctypes, llama_seq_id], None +) def llama_kv_cache_seq_keep(ctx: llama_context_p, seq_id: Union[llama_seq_id, int], /): """Removes all tokens that do not belong to the specified sequence""" ... -llama_kv_cache_seq_keep = _lib.llama_kv_cache_seq_keep -llama_kv_cache_seq_keep.argtypes = [llama_context_p_ctypes, llama_seq_id] -llama_kv_cache_seq_keep.restype = None - - # // Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1) # // If the KV cache is RoPEd, the KV data is updated accordingly # // p0 < 0 : [0, p1] @@ -1454,6 +1479,19 @@ llama_kv_cache_seq_keep.restype = None # llama_pos p0, # llama_pos p1, # llama_pos delta); + + +@ctypes_function( + "llama_kv_cache_seq_shift", + [ + llama_context_p_ctypes, + llama_seq_id, + llama_pos, + llama_pos, + llama_pos, + ], + None, +) def llama_kv_cache_seq_shift( ctx: llama_context_p, seq_id: Union[llama_seq_id, int], @@ -1469,17 +1507,6 @@ def llama_kv_cache_seq_shift( ... -llama_kv_cache_seq_shift = _lib.llama_kv_cache_seq_shift -llama_kv_cache_seq_shift.argtypes = [ - llama_context_p_ctypes, - llama_seq_id, - llama_pos, - llama_pos, - llama_pos, -] -llama_kv_cache_seq_shift.restype = None - - # // Integer division of the positions by factor of `d > 1` # // If the KV cache is RoPEd, the KV data is updated accordingly # // p0 < 0 : [0, p1] @@ -1490,6 +1517,19 @@ llama_kv_cache_seq_shift.restype = None # llama_pos p0, # llama_pos p1, # int d); + + +@ctypes_function( + "llama_kv_cache_seq_div", + [ + llama_context_p_ctypes, + llama_seq_id, + llama_pos, + llama_pos, + ctypes.c_int, + ], + None, +) def llama_kv_cache_seq_div( ctx: llama_context_p, seq_id: Union[llama_seq_id, int], @@ -1505,16 +1545,6 @@ def llama_kv_cache_seq_div( ... -llama_kv_cache_seq_div = _lib.llama_kv_cache_seq_div -llama_kv_cache_seq_div.argtypes = [ - llama_context_p_ctypes, - llama_seq_id, - llama_pos, - llama_pos, - ctypes.c_int, -] -llama_kv_cache_seq_div.restype = None - # // # // State / sessions # // @@ -1523,23 +1553,31 @@ llama_kv_cache_seq_div.restype = None # Returns the maximum size in bytes of the state (rng, logits, embedding # and kv_cache) - will often be smaller after compacting tokens # LLAMA_API size_t llama_get_state_size(const struct llama_context * ctx); + + +@ctypes_function("llama_get_state_size", [llama_context_p_ctypes], ctypes.c_size_t) def llama_get_state_size(ctx: llama_context_p, /) -> int: """Returns the maximum size in bytes of the state (rng, logits, embedding and kv_cache) - will often be smaller after compacting tokens""" ... -llama_get_state_size = _lib.llama_get_state_size -llama_get_state_size.argtypes = [llama_context_p_ctypes] -llama_get_state_size.restype = ctypes.c_size_t - - # Copies the state to the specified destination address. # Destination needs to have allocated enough memory. # Returns the number of bytes copied # LLAMA_API size_t llama_copy_state_data( # struct llama_context * ctx, # uint8_t * dst); + + +@ctypes_function( + "llama_copy_state_data", + [ + llama_context_p_ctypes, + ctypes.POINTER(ctypes.c_uint8), + ], + ctypes.c_size_t, +) def llama_copy_state_data( ctx: llama_context_p, dst: CtypesArray[ctypes.c_uint8], / ) -> int: @@ -1549,19 +1587,18 @@ def llama_copy_state_data( ... -llama_copy_state_data = _lib.llama_copy_state_data -llama_copy_state_data.argtypes = [ - llama_context_p_ctypes, - ctypes.POINTER(ctypes.c_uint8), -] -llama_copy_state_data.restype = ctypes.c_size_t - - # Set the state reading from the specified address # Returns the number of bytes read # LLAMA_API size_t llama_set_state_data( # struct llama_context * ctx, # uint8_t * src); + + +@ctypes_function( + "llama_set_state_data", + [llama_context_p_ctypes, ctypes.POINTER(ctypes.c_uint8)], + ctypes.c_size_t, +) def llama_set_state_data( ctx: llama_context_p, src: CtypesArray[ctypes.c_uint8], / ) -> int: @@ -1569,11 +1606,6 @@ def llama_set_state_data( ... -llama_set_state_data = _lib.llama_set_state_data -llama_set_state_data.argtypes = [llama_context_p_ctypes, ctypes.POINTER(ctypes.c_uint8)] -llama_set_state_data.restype = ctypes.c_size_t - - # Save/load session file # LLAMA_API bool llama_load_session_file( # struct llama_context * ctx, @@ -1581,6 +1613,19 @@ llama_set_state_data.restype = ctypes.c_size_t # llama_token * tokens_out, # size_t n_token_capacity, # size_t * n_token_count_out); + + +@ctypes_function( + "llama_load_session_file", + [ + llama_context_p_ctypes, + ctypes.c_char_p, + llama_token_p, + ctypes.c_size_t, + ctypes.POINTER(ctypes.c_size_t), + ], + ctypes.c_size_t, +) def llama_load_session_file( ctx: llama_context_p, path_session: bytes, @@ -1592,22 +1637,23 @@ def llama_load_session_file( ... -llama_load_session_file = _lib.llama_load_session_file -llama_load_session_file.argtypes = [ - llama_context_p_ctypes, - ctypes.c_char_p, - llama_token_p, - ctypes.c_size_t, - ctypes.POINTER(ctypes.c_size_t), -] -llama_load_session_file.restype = ctypes.c_size_t - - # LLAMA_API bool llama_save_session_file( # struct llama_context * ctx, # const char * path_session, # const llama_token * tokens, # size_t n_token_count); + + +@ctypes_function( + "llama_save_session_file", + [ + llama_context_p_ctypes, + ctypes.c_char_p, + llama_token_p, + ctypes.c_size_t, + ], + ctypes.c_size_t, +) def llama_save_session_file( ctx: llama_context_p, path_session: bytes, @@ -1618,15 +1664,6 @@ def llama_save_session_file( ... -llama_save_session_file = _lib.llama_save_session_file -llama_save_session_file.argtypes = [ - llama_context_p_ctypes, - ctypes.c_char_p, - llama_token_p, - ctypes.c_size_t, -] -llama_save_session_file.restype = ctypes.c_size_t - # // # // Decoding # // @@ -1643,6 +1680,18 @@ llama_save_session_file.restype = ctypes.c_size_t # int32_t n_tokens, # int32_t n_past), # "use llama_decode() instead"); + + +@ctypes_function( + "llama_eval", + [ + llama_context_p_ctypes, + llama_token_p, + ctypes.c_int32, + ctypes.c_int32, + ], + ctypes.c_int, +) def llama_eval( ctx: llama_context_p, tokens: CtypesArray[llama_token], @@ -1658,16 +1707,6 @@ def llama_eval( ... -llama_eval = _lib.llama_eval -llama_eval.argtypes = [ - llama_context_p_ctypes, - llama_token_p, - ctypes.c_int32, - ctypes.c_int32, -] -llama_eval.restype = ctypes.c_int - - # // Same as llama_eval, but use float matrix input directly. # // DEPRECATED: use llama_decode() instead # LLAMA_API DEPRECATED(int llama_eval_embd( @@ -1676,6 +1715,18 @@ llama_eval.restype = ctypes.c_int # int32_t n_tokens, # int32_t n_past), # "use llama_decode() instead"); + + +@ctypes_function( + "llama_eval_embd", + [ + llama_context_p_ctypes, + ctypes.POINTER(ctypes.c_float), + ctypes.c_int32, + ctypes.c_int32, + ], + ctypes.c_int, +) def llama_eval_embd( ctx: llama_context_p, embd: CtypesArray[ctypes.c_float], @@ -1688,16 +1739,6 @@ def llama_eval_embd( ... -llama_eval_embd = _lib.llama_eval_embd -llama_eval_embd.argtypes = [ - llama_context_p_ctypes, - ctypes.POINTER(ctypes.c_float), - ctypes.c_int32, - ctypes.c_int32, -] -llama_eval_embd.restype = ctypes.c_int - - # // Return batch for single sequence of tokens starting at pos_0 # // # // NOTE: this is a helper function to facilitate transition to the new batch API - avoid using it @@ -1707,6 +1748,18 @@ llama_eval_embd.restype = ctypes.c_int # int32_t n_tokens, # llama_pos pos_0, # llama_seq_id seq_id); + + +@ctypes_function( + "llama_batch_get_one", + [ + llama_token_p, + ctypes.c_int, + llama_pos, + llama_seq_id, + ], + llama_batch, +) def llama_batch_get_one( tokens: CtypesArray[llama_token], n_tokens: Union[ctypes.c_int, int], @@ -1721,16 +1774,6 @@ def llama_batch_get_one( ... -llama_batch_get_one = _lib.llama_batch_get_one -llama_batch_get_one.argtypes = [ - llama_token_p, - ctypes.c_int, - llama_pos, - llama_seq_id, -] -llama_batch_get_one.restype = llama_batch - - # // Allocates a batch of tokens on the heap that can hold a maximum of n_tokens # // Each token can be assigned up to n_seq_max sequence ids # // The batch has to be freed with llama_batch_free() @@ -1742,6 +1785,11 @@ llama_batch_get_one.restype = llama_batch # int32_t n_tokens, # int32_t embd, # int32_t n_seq_max); + + +@ctypes_function( + "llama_batch_init", [ctypes.c_int32, ctypes.c_int32, ctypes.c_int32], llama_batch +) def llama_batch_init( n_tokens: Union[ctypes.c_int32, int], embd: Union[ctypes.c_int32, int], @@ -1758,23 +1806,16 @@ def llama_batch_init( ... -llama_batch_init = _lib.llama_batch_init -llama_batch_init.argtypes = [ctypes.c_int32, ctypes.c_int32, ctypes.c_int32] -llama_batch_init.restype = llama_batch - - # // Frees a batch of tokens allocated with llama_batch_init() # LLAMA_API void llama_batch_free(struct llama_batch batch); + + +@ctypes_function("llama_batch_free", [llama_batch], None) def llama_batch_free(batch: llama_batch, /): """Frees a batch of tokens allocated with llama_batch_init()""" ... -llama_batch_free = _lib.llama_batch_free -llama_batch_free.argtypes = [llama_batch] -llama_batch_free.restype = None - - # // Positive return values does not mean a fatal error, but rather a warning. # // 0 - success # // 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context) @@ -1782,6 +1823,9 @@ llama_batch_free.restype = None # LLAMA_API int32_t llama_decode( # struct llama_context * ctx, # struct llama_batch batch); + + +@ctypes_function("llama_decode", [llama_context_p_ctypes, llama_batch], ctypes.c_int32) def llama_decode(ctx: llama_context_p, batch: llama_batch, /) -> int: """Positive return values does not mean a fatal error, but rather a warning. 0 - success @@ -1790,15 +1834,21 @@ def llama_decode(ctx: llama_context_p, batch: llama_batch, /) -> int: ... -llama_decode = _lib.llama_decode -llama_decode.argtypes = [llama_context_p_ctypes, llama_batch] -llama_decode.restype = ctypes.c_int32 - - # // Set the number of threads used for decoding # // n_threads is the number of threads used for generation (single token) # // n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens) # LLAMA_API void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_t n_threads_batch); + + +@ctypes_function( + "llama_set_n_threads", + [ + llama_context_p_ctypes, + ctypes.c_uint32, + ctypes.c_uint32, + ], + None, +) def llama_set_n_threads( ctx: llama_context_p, n_threads: Union[ctypes.c_uint32, int], @@ -1812,24 +1862,18 @@ def llama_set_n_threads( ... -llama_set_n_threads = _lib.llama_set_n_threads -llama_set_n_threads.argtypes = [ - llama_context_p_ctypes, - ctypes.c_uint32, - ctypes.c_uint32, -] -llama_set_n_threads.restype = None - - # // Token logits obtained from the last call to llama_eval() # // The logits for the last token are stored in the last row # // Logits for which llama_batch.logits[i] == 0 are undefined # // Rows: n_tokens provided with llama_batch # // Cols: n_vocab # LLAMA_API float * llama_get_logits(struct llama_context * ctx); -def llama_get_logits( - ctx: llama_context_p, / -): # type: (...) -> Array[float] # type: ignore + + +@ctypes_function( + "llama_get_logits", [llama_context_p_ctypes], ctypes.POINTER(ctypes.c_float) +) +def llama_get_logits(ctx: llama_context_p, /) -> CtypesArray[ctypes.c_float]: """Token logits obtained from the last call to llama_eval() The logits for the last token are stored in the last row Logits for which llama_batch.logits[i] == 0 are undefined @@ -1838,202 +1882,181 @@ def llama_get_logits( ... -llama_get_logits = _lib.llama_get_logits -llama_get_logits.argtypes = [llama_context_p_ctypes] -llama_get_logits.restype = ctypes.POINTER(ctypes.c_float) - - # // Logits for the ith token. Equivalent to: # // llama_get_logits(ctx) + i*n_vocab # LLAMA_API float * llama_get_logits_ith(struct llama_context * ctx, int32_t i); + + +@ctypes_function( + "llama_get_logits_ith", + [llama_context_p_ctypes, ctypes.c_int32], + ctypes.POINTER(ctypes.c_float), +) def llama_get_logits_ith( ctx: llama_context_p, i: Union[ctypes.c_int32, int], / -): # type: (...) -> Array[float] # type: ignore +) -> CtypesArray[ctypes.c_float]: """Logits for the ith token. Equivalent to: llama_get_logits(ctx) + i*n_vocab""" ... -llama_get_logits_ith = _lib.llama_get_logits_ith -llama_get_logits_ith.argtypes = [llama_context_p_ctypes, ctypes.c_int32] -llama_get_logits_ith.restype = ctypes.POINTER(ctypes.c_float) - - # Get the embeddings for the input # shape: [n_embd] (1-dimensional) # LLAMA_API float * llama_get_embeddings(struct llama_context * ctx); -def llama_get_embeddings( - ctx: llama_context_p, / -): # type: (...) -> Array[float] # type: ignore + + +@ctypes_function( + "llama_get_embeddings", [llama_context_p_ctypes], ctypes.POINTER(ctypes.c_float) +) +def llama_get_embeddings(ctx: llama_context_p, /) -> CtypesArray[ctypes.c_float]: """Get the embeddings for the input shape: [n_embd] (1-dimensional)""" ... -llama_get_embeddings = _lib.llama_get_embeddings -llama_get_embeddings.argtypes = [llama_context_p_ctypes] -llama_get_embeddings.restype = ctypes.POINTER(ctypes.c_float) - - # // Get the embeddings for the ith sequence # // llama_get_embeddings(ctx) + i*n_embd # LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i); + + +@ctypes_function( + "llama_get_embeddings_ith", + [llama_context_p_ctypes, ctypes.c_int32], + ctypes.POINTER(ctypes.c_float), +) def llama_get_embeddings_ith( ctx: llama_context_p, i: Union[ctypes.c_int32, int], / -): # type: (...) -> Array[float] # type: ignore +) -> CtypesArray[ctypes.c_float]: """Get the embeddings for the ith sequence llama_get_embeddings(ctx) + i*n_embd""" ... -llama_get_embeddings_ith = _lib.llama_get_embeddings_ith -llama_get_embeddings_ith.argtypes = [llama_context_p_ctypes, ctypes.c_int32] -llama_get_embeddings_ith.restype = ctypes.POINTER(ctypes.c_float) - - # // # // Vocab # // # LLAMA_API const char * llama_token_get_text(const struct llama_model * model, llama_token token); + + +@ctypes_function( + "llama_token_get_text", [llama_model_p_ctypes, llama_token], ctypes.c_char_p +) def llama_token_get_text( model: llama_model_p, token: Union[llama_token, int], / ) -> bytes: ... -llama_token_get_text = _lib.llama_token_get_text -llama_token_get_text.argtypes = [llama_model_p_ctypes, llama_token] -llama_token_get_text.restype = ctypes.c_char_p - - # LLAMA_API float llama_token_get_score(const struct llama_model * model, llama_token token); + + +@ctypes_function( + "llama_token_get_score", [llama_model_p_ctypes, llama_token], ctypes.c_float +) def llama_token_get_score( model: llama_model_p, token: Union[llama_token, int], / ) -> float: ... -llama_token_get_score = _lib.llama_token_get_score -llama_token_get_score.argtypes = [llama_model_p_ctypes, llama_token] -llama_token_get_score.restype = ctypes.c_float - - # LLAMA_API enum llama_token_type llama_token_get_type(const struct llama_model * model, llama_token token); + + +@ctypes_function( + "llama_token_get_type", [llama_model_p_ctypes, llama_token], ctypes.c_int +) def llama_token_get_type( model: llama_model_p, token: Union[llama_token, int], / ) -> int: ... -llama_token_get_type = _lib.llama_token_get_type -llama_token_get_type.argtypes = [llama_model_p_ctypes, llama_token] -llama_token_get_type.restype = ctypes.c_int - - # // Special tokens # LLAMA_API llama_token llama_token_bos(const struct llama_model * model); // beginning-of-sentence + + +@ctypes_function("llama_token_bos", [llama_model_p_ctypes], llama_token) def llama_token_bos(model: llama_model_p, /) -> int: """beginning-of-sentence""" ... -llama_token_bos = _lib.llama_token_bos -llama_token_bos.argtypes = [llama_model_p_ctypes] -llama_token_bos.restype = llama_token - - # LLAMA_API llama_token llama_token_eos(const struct llama_model * model); // end-of-sentence + + +@ctypes_function("llama_token_eos", [llama_model_p_ctypes], llama_token) def llama_token_eos(model: llama_model_p, /) -> int: """end-of-sentence""" ... -llama_token_eos = _lib.llama_token_eos -llama_token_eos.argtypes = [llama_model_p_ctypes] -llama_token_eos.restype = llama_token - - # LLAMA_API llama_token llama_token_nl (const struct llama_model * model); // next-line + + +@ctypes_function("llama_token_nl", [llama_model_p_ctypes], llama_token) def llama_token_nl(model: llama_model_p, /) -> int: """next-line""" ... -llama_token_nl = _lib.llama_token_nl -llama_token_nl.argtypes = [llama_model_p_ctypes] -llama_token_nl.restype = llama_token - - # // Returns -1 if unknown, 1 for true or 0 for false. # LLAMA_API int32_t llama_add_bos_token(const struct llama_model * model); + + +@ctypes_function("llama_add_bos_token", [llama_model_p_ctypes], ctypes.c_int32) def llama_add_bos_token(model: llama_model_p, /) -> int: """Returns -1 if unknown, 1 for true or 0 for false.""" ... -llama_add_bos_token = _lib.llama_add_bos_token -llama_add_bos_token.argtypes = [llama_model_p_ctypes] -llama_add_bos_token.restype = ctypes.c_int32 - - # // Returns -1 if unknown, 1 for true or 0 for false. # LLAMA_API int32_t llama_add_eos_token(const struct llama_model * model); + + +@ctypes_function("llama_add_eos_token", [llama_model_p_ctypes], ctypes.c_int32) def llama_add_eos_token(model: llama_model_p, /) -> int: """Returns -1 if unknown, 1 for true or 0 for false.""" ... -llama_add_eos_token = _lib.llama_add_eos_token -llama_add_eos_token.argtypes = [llama_model_p_ctypes] -llama_add_eos_token.restype = ctypes.c_int32 - - # // codellama infill tokens # LLAMA_API llama_token llama_token_prefix(const struct llama_model * model); // Beginning of infill prefix + + +@ctypes_function("llama_token_prefix", [llama_model_p_ctypes], llama_token) def llama_token_prefix(model: llama_model_p) -> int: """codellama infill tokens""" ... -llama_token_prefix = _lib.llama_token_prefix -llama_token_prefix.argtypes = [llama_model_p_ctypes] -llama_token_prefix.restype = llama_token - - # LLAMA_API llama_token llama_token_middle(const struct llama_model * model); // Beginning of infill middle + + +@ctypes_function("llama_token_middle", [llama_model_p_ctypes], llama_token) def llama_token_middle(model: llama_model_p, /) -> int: ... -llama_token_middle = _lib.llama_token_middle -llama_token_middle.argtypes = [llama_model_p_ctypes] -llama_token_middle.restype = llama_token - - # LLAMA_API llama_token llama_token_suffix(const struct llama_model * model); // Beginning of infill suffix + + +@ctypes_function("llama_token_suffix", [llama_model_p_ctypes], llama_token) def llama_token_suffix(model: llama_model_p, /) -> int: ... -llama_token_suffix = _lib.llama_token_suffix -llama_token_suffix.argtypes = [llama_model_p_ctypes] -llama_token_suffix.restype = llama_token - - # LLAMA_API llama_token llama_token_eot (const struct llama_model * model); // End of infill middle + + +@ctypes_function("llama_token_eot", [llama_model_p_ctypes], llama_token) def llama_token_eot(model: llama_model_p, /) -> int: ... -llama_token_eot = _lib.llama_token_eot -llama_token_eot.argtypes = [llama_model_p_ctypes] -llama_token_eot.restype = llama_token - - # // # // Tokenization # // @@ -2053,6 +2076,21 @@ llama_token_eot.restype = llama_token # int32_t n_max_tokens, # bool add_bos, # bool special); + + +@ctypes_function( + "llama_tokenize", + [ + llama_model_p_ctypes, + ctypes.c_char_p, + ctypes.c_int32, + llama_token_p, + ctypes.c_int32, + ctypes.c_bool, + ctypes.c_bool, + ], + ctypes.c_int32, +) def llama_tokenize( model: llama_model_p, text: bytes, @@ -2067,19 +2105,6 @@ def llama_tokenize( ... -llama_tokenize = _lib.llama_tokenize -llama_tokenize.argtypes = [ - llama_model_p_ctypes, - ctypes.c_char_p, - ctypes.c_int32, - llama_token_p, - ctypes.c_int32, - ctypes.c_bool, - ctypes.c_bool, -] -llama_tokenize.restype = ctypes.c_int32 - - # // Token Id -> Piece. # // Uses the vocabulary in the provided context. # // Does not write null terminator to the buffer. @@ -2089,10 +2114,22 @@ llama_tokenize.restype = ctypes.c_int32 # llama_token token, # char * buf, # int32_t length); + + +@ctypes_function( + "llama_token_to_piece", + [ + llama_model_p_ctypes, + llama_token, + ctypes.c_char_p, + ctypes.c_int32, + ], + ctypes.c_int32, +) def llama_token_to_piece( model: llama_model_p, token: Union[llama_token, int], - buf: Union[ctypes.c_char_p, bytes], + buf: Union[ctypes.c_char_p, bytes, CtypesArray[ctypes.c_char]], length: Union[ctypes.c_int, int], /, ) -> int: @@ -2104,16 +2141,6 @@ def llama_token_to_piece( ... -llama_token_to_piece = _lib.llama_token_to_piece -llama_token_to_piece.argtypes = [ - llama_model_p_ctypes, - llama_token, - ctypes.c_char_p, - ctypes.c_int32, -] -llama_token_to_piece.restype = ctypes.c_int32 - - # /// Apply chat template. Inspired by hf apply_chat_template() on python. # /// Both "model" and "custom_template" are optional, but at least one is required. "custom_template" has higher precedence than "model" # /// NOTE: This function does not use a jinja parser. It only support a pre-defined list of template. See more: https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template @@ -2132,6 +2159,18 @@ llama_token_to_piece.restype = ctypes.c_int32 # bool add_ass, # char * buf, # int32_t length); + + +@ctypes_function( + "llama_chat_apply_template", + [ + ctypes.c_void_p, + ctypes.c_char_p, + ctypes.POINTER(llama_chat_message), + ctypes.c_size_t, + ], + ctypes.c_int32, +) def llama_chat_apply_template( model: llama_model_p, tmpl: bytes, @@ -2142,16 +2181,6 @@ def llama_chat_apply_template( ... -llama_chat_apply_template = _lib.llama_chat_apply_template -llama_chat_apply_template.argtypes = [ - ctypes.c_void_p, - ctypes.c_char_p, - ctypes.POINTER(llama_chat_message), - ctypes.c_size_t, -] -llama_chat_apply_template.restype = ctypes.c_int32 - - # // # // Grammar # // @@ -2161,6 +2190,17 @@ llama_chat_apply_template.restype = ctypes.c_int32 # const llama_grammar_element ** rules, # size_t n_rules, # size_t start_rule_index); + + +@ctypes_function( + "llama_grammar_init", + [ + ctypes.POINTER(llama_grammar_element_p), + ctypes.c_size_t, + ctypes.c_size_t, + ], + llama_grammar_p, +) def llama_grammar_init( rules: CtypesArray[ CtypesPointer[llama_grammar_element] @@ -2173,36 +2213,28 @@ def llama_grammar_init( ... -llama_grammar_init = _lib.llama_grammar_init -llama_grammar_init.argtypes = [ - ctypes.POINTER(llama_grammar_element_p), - ctypes.c_size_t, - ctypes.c_size_t, -] -llama_grammar_init.restype = llama_grammar_p - - # LLAMA_API void llama_grammar_free(struct llama_grammar * grammar); +@ctypes_function( + "llama_grammar_free", + [llama_grammar_p], + None, +) def llama_grammar_free(grammar: llama_grammar_p, /): """Free a grammar.""" ... -llama_grammar_free = _lib.llama_grammar_free -llama_grammar_free.argtypes = [llama_grammar_p] -llama_grammar_free.restype = None - - # LLAMA_API struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar); +@ctypes_function( + "llama_grammar_copy", + [llama_grammar_p], + llama_grammar_p, +) def llama_grammar_copy(grammar: llama_grammar_p, /) -> llama_grammar_p: """Copy a grammar.""" ... -llama_grammar_copy = _lib.llama_grammar_copy -llama_grammar_copy.argtypes = [llama_grammar_p] -llama_grammar_copy.restype = llama_grammar_p - # // # // Sampling functions # // @@ -2210,16 +2242,16 @@ llama_grammar_copy.restype = llama_grammar_p # // Sets the current rng seed. # LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed); +@ctypes_function( + "llama_set_rng_seed", + [llama_context_p_ctypes, ctypes.c_uint32], + None, +) def llama_set_rng_seed(ctx: llama_context_p, seed: Union[ctypes.c_uint32, int], /): """Sets the current rng seed.""" ... -llama_set_rng_seed = _lib.llama_set_rng_seed -llama_set_rng_seed.argtypes = [llama_context_p_ctypes, ctypes.c_uint32] -llama_set_rng_seed.restype = None - - # /// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. # /// @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details. # LLAMA_API void llama_sample_repetition_penalties( @@ -2230,6 +2262,19 @@ llama_set_rng_seed.restype = None # float penalty_repeat, # float penalty_freq, # float penalty_present); +@ctypes_function( + "llama_sample_repetition_penalties", + [ + llama_context_p_ctypes, + llama_token_data_array_p, + llama_token_p, + ctypes.c_size_t, + ctypes.c_float, + ctypes.c_float, + ctypes.c_float, + ], + None, +) def llama_sample_repetition_penalties( ctx: llama_context_p, candidates: Union[ @@ -2248,19 +2293,6 @@ def llama_sample_repetition_penalties( ... -llama_sample_repetition_penalties = _lib.llama_sample_repetition_penalties -llama_sample_repetition_penalties.argtypes = [ - llama_context_p_ctypes, - llama_token_data_array_p, - llama_token_p, - ctypes.c_size_t, - ctypes.c_float, - ctypes.c_float, - ctypes.c_float, -] -llama_sample_repetition_penalties.restype = None - - # /// @details Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806 # /// @param logits Logits extracted from the original generation context. # /// @param logits_guidance Logits extracted from a separate context from the same model. Other than a negative prompt at the beginning, it should have all generated and user input tokens copied from the main context. @@ -2270,6 +2302,16 @@ llama_sample_repetition_penalties.restype = None # float * logits, # float * logits_guidance, # float scale); +@ctypes_function( + "llama_sample_apply_guidance", + [ + llama_context_p_ctypes, + ctypes.POINTER(ctypes.c_float), + ctypes.POINTER(ctypes.c_float), + ctypes.c_float, + ], + None, +) def llama_sample_apply_guidance( ctx: llama_context_p, logits: CtypesArray[ctypes.c_float], @@ -2281,22 +2323,22 @@ def llama_sample_apply_guidance( ... -llama_sample_apply_guidance = _lib.llama_sample_apply_guidance -llama_sample_apply_guidance.argtypes = [ - llama_context_p_ctypes, - ctypes.POINTER(ctypes.c_float), - ctypes.POINTER(ctypes.c_float), - ctypes.c_float, -] -llama_sample_apply_guidance.restype = None - - # LLAMA_API DEPRECATED(void llama_sample_classifier_free_guidance( # struct llama_context * ctx, # llama_token_data_array * candidates, # struct llama_context * guidance_ctx, # float scale), # "use llama_sample_apply_guidance() instead"); +@ctypes_function( + "llama_sample_classifier_free_guidance", + [ + llama_context_p_ctypes, + llama_token_data_array_p, + llama_context_p_ctypes, + ctypes.c_float, + ], + None, +) def llama_sample_classifier_free_guidance( ctx: llama_context_p, candidates: Union[ @@ -2310,20 +2352,15 @@ def llama_sample_classifier_free_guidance( ... -llama_sample_classifier_free_guidance = _lib.llama_sample_classifier_free_guidance -llama_sample_classifier_free_guidance.argtypes = [ - llama_context_p_ctypes, - llama_token_data_array_p, - llama_context_p_ctypes, - ctypes.c_float, -] -llama_sample_classifier_free_guidance.restype = None - - # /// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits. # LLAMA_API void llama_sample_softmax( # struct llama_context * ctx, # llama_token_data_array * candidates); +@ctypes_function( + "llama_sample_softmax", + [llama_context_p_ctypes, llama_token_data_array_p], + None, +) def llama_sample_softmax( ctx: llama_context_p, candidates: Union[ @@ -2335,20 +2372,17 @@ def llama_sample_softmax( ... -llama_sample_softmax = _lib.llama_sample_softmax -llama_sample_softmax.argtypes = [ - llama_context_p_ctypes, - llama_token_data_array_p, -] -llama_sample_softmax.restype = None - - # /// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 # LLAMA_API void llama_sample_top_k( # struct llama_context * ctx, # llama_token_data_array * candidates, # int32_t k, # size_t min_keep); +@ctypes_function( + "llama_sample_top_k", + [llama_context_p_ctypes, llama_token_data_array_p, ctypes.c_int32, ctypes.c_size_t], + None, +) def llama_sample_top_k( ctx: llama_context_p, candidates: Union[ @@ -2362,22 +2396,17 @@ def llama_sample_top_k( ... -llama_sample_top_k = _lib.llama_sample_top_k -llama_sample_top_k.argtypes = [ - llama_context_p_ctypes, - llama_token_data_array_p, - ctypes.c_int32, - ctypes.c_size_t, -] -llama_sample_top_k.restype = None - - # /// @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 # LLAMA_API void llama_sample_top_p( # struct llama_context * ctx, # llama_token_data_array * candidates, # float p, # size_t min_keep); +@ctypes_function( + "llama_sample_top_p", + [llama_context_p_ctypes, llama_token_data_array_p, ctypes.c_float, ctypes.c_size_t], + None, +) def llama_sample_top_p( ctx: llama_context_p, candidates: Union[ @@ -2391,22 +2420,17 @@ def llama_sample_top_p( ... -llama_sample_top_p = _lib.llama_sample_top_p -llama_sample_top_p.argtypes = [ - llama_context_p_ctypes, - llama_token_data_array_p, - ctypes.c_float, - ctypes.c_size_t, -] -llama_sample_top_p.restype = None - - # /// @details Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841 # LLAMA_API void llama_sample_min_p( # struct llama_context * ctx, # llama_token_data_array * candidates, # float p, # size_t min_keep); +@ctypes_function( + "llama_sample_min_p", + [llama_context_p_ctypes, llama_token_data_array_p, ctypes.c_float, ctypes.c_size_t], + None, +) def llama_sample_min_p( ctx: llama_context_p, candidates: Union[ @@ -2420,22 +2444,17 @@ def llama_sample_min_p( ... -llama_sample_min_p = _lib.llama_sample_min_p -llama_sample_min_p.argtypes = [ - llama_context_p_ctypes, - llama_token_data_array_p, - ctypes.c_float, - ctypes.c_size_t, -] -llama_sample_min_p.restype = None - - # /// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/. # LLAMA_API void llama_sample_tail_free( # struct llama_context * ctx, # llama_token_data_array * candidates, # float z, # size_t min_keep); +@ctypes_function( + "llama_sample_tail_free", + [llama_context_p_ctypes, llama_token_data_array_p, ctypes.c_float, ctypes.c_size_t], + None, +) def llama_sample_tail_free( ctx: llama_context_p, candidates: Union[ @@ -2449,22 +2468,17 @@ def llama_sample_tail_free( ... -llama_sample_tail_free = _lib.llama_sample_tail_free -llama_sample_tail_free.argtypes = [ - llama_context_p_ctypes, - llama_token_data_array_p, - ctypes.c_float, - ctypes.c_size_t, -] -llama_sample_tail_free.restype = None - - # /// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666. # LLAMA_API void llama_sample_typical( # struct llama_context * ctx, # llama_token_data_array * candidates, # float p, # size_t min_keep); +@ctypes_function( + "llama_sample_typical", + [llama_context_p_ctypes, llama_token_data_array_p, ctypes.c_float, ctypes.c_size_t], + None, +) def llama_sample_typical( ctx: llama_context_p, candidates: Union[ @@ -2478,16 +2492,6 @@ def llama_sample_typical( ... -llama_sample_typical = _lib.llama_sample_typical -llama_sample_typical.argtypes = [ - llama_context_p_ctypes, - llama_token_data_array_p, - ctypes.c_float, - ctypes.c_size_t, -] -llama_sample_typical.restype = None - - # /// @details Dynamic temperature implementation described in the paper https://arxiv.org/abs/2309.02772. # LLAMA_API void llama_sample_entropy( # struct llama_context * ctx, @@ -2495,6 +2499,17 @@ llama_sample_typical.restype = None # float min_temp, # float max_temp, # float exponent_val); +@ctypes_function( + "llama_sample_entropy", + [ + llama_context_p_ctypes, + llama_token_data_array_p, + ctypes.c_float, + ctypes.c_float, + ctypes.c_float, + ], + None, +) def llama_sample_entropy( ctx: llama_context_p, candidates: Union[ @@ -2509,21 +2524,15 @@ def llama_sample_entropy( ... -llama_sample_entropy = _lib.llama_sample_entropy -llama_sample_entropy.argtypes = [ - llama_context_p_ctypes, - llama_token_data_array_p, - ctypes.c_float, - ctypes.c_float, - ctypes.c_float, -] -llama_sample_entropy.restype = None - - # LLAMA_API void llama_sample_temp( # struct llama_context * ctx, # llama_token_data_array * candidates, # float temp); +@ctypes_function( + "llama_sample_temp", + [llama_context_p_ctypes, llama_token_data_array_p, ctypes.c_float], + None, +) def llama_sample_temp( ctx: llama_context_p, candidates: Union[ @@ -2541,20 +2550,16 @@ def llama_sample_temp( ... -llama_sample_temp = _lib.llama_sample_temp -llama_sample_temp.argtypes = [ - llama_context_p_ctypes, - llama_token_data_array_p, - ctypes.c_float, -] -llama_sample_temp.restype = None - - # LLAMA_API DEPRECATED(void llama_sample_temperature( # struct llama_context * ctx, # llama_token_data_array * candidates, # float temp), # "use llama_sample_temp instead"); +@ctypes_function( + "llama_sample_temperature", + [llama_context_p_ctypes, llama_token_data_array_p, ctypes.c_float], + None, +) def llama_sample_temperature( ctx: llama_context_p, candidates: Union[ @@ -2567,20 +2572,16 @@ def llama_sample_temperature( ... -llama_sample_temperature = _lib.llama_sample_temperature -llama_sample_temperature.argtypes = [ - llama_context_p_ctypes, - llama_token_data_array_p, - ctypes.c_float, -] -llama_sample_temperature.restype = None - - # /// @details Apply constraints from grammar # LLAMA_API void llama_sample_grammar( # struct llama_context * ctx, # llama_token_data_array * candidates, # const struct llama_grammar * grammar); +@ctypes_function( + "llama_sample_grammar", + [llama_context_p_ctypes, llama_token_data_array_p, llama_grammar_p], + None, +) def llama_sample_grammar( ctx: llama_context_p, candidates: Union[ @@ -2598,15 +2599,6 @@ def llama_sample_grammar( ... -llama_sample_grammar = _lib.llama_sample_grammar -llama_sample_grammar.argtypes = [ - llama_context_p_ctypes, - llama_token_data_array_p, - llama_grammar_p, -] -llama_sample_grammar.restype = None - - # /// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. # /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. # /// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. @@ -2620,6 +2612,18 @@ llama_sample_grammar.restype = None # float eta, # int32_t m, # float * mu); +@ctypes_function( + "llama_sample_token_mirostat", + [ + llama_context_p_ctypes, + llama_token_data_array_p, + ctypes.c_float, + ctypes.c_float, + ctypes.c_int32, + ctypes.POINTER(ctypes.c_float), + ], + llama_token, +) def llama_sample_token_mirostat( ctx: llama_context_p, candidates: Union[ @@ -2643,18 +2647,6 @@ def llama_sample_token_mirostat( ... -llama_sample_token_mirostat = _lib.llama_sample_token_mirostat -llama_sample_token_mirostat.argtypes = [ - llama_context_p_ctypes, - llama_token_data_array_p, - ctypes.c_float, - ctypes.c_float, - ctypes.c_int32, - ctypes.POINTER(ctypes.c_float), -] -llama_sample_token_mirostat.restype = llama_token - - # /// @details Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. # /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. # /// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. @@ -2666,6 +2658,17 @@ llama_sample_token_mirostat.restype = llama_token # float tau, # float eta, # float * mu); +@ctypes_function( + "llama_sample_token_mirostat_v2", + [ + llama_context_p_ctypes, + llama_token_data_array_p, + ctypes.c_float, + ctypes.c_float, + ctypes.POINTER(ctypes.c_float), + ], + llama_token, +) def llama_sample_token_mirostat_v2( ctx: llama_context_p, candidates: Union[ @@ -2673,7 +2676,7 @@ def llama_sample_token_mirostat_v2( ], tau: Union[ctypes.c_float, float], eta: Union[ctypes.c_float, float], - mu, # type: _Pointer[ctypes.c_float] + mu: CtypesPointerOrRef[ctypes.c_float], /, ) -> int: """Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. @@ -2687,22 +2690,16 @@ def llama_sample_token_mirostat_v2( ... -llama_sample_token_mirostat_v2 = _lib.llama_sample_token_mirostat_v2 -llama_sample_token_mirostat_v2.argtypes = [ - llama_context_p_ctypes, - llama_token_data_array_p, - ctypes.c_float, - ctypes.c_float, - ctypes.POINTER(ctypes.c_float), -] -llama_sample_token_mirostat_v2.restype = llama_token - - # /// @details Selects the token with the highest probability. # /// Does not compute the token probabilities. Use llama_sample_softmax() instead. # LLAMA_API llama_token llama_sample_token_greedy( # struct llama_context * ctx, # llama_token_data_array * candidates); +@ctypes_function( + "llama_sample_token_greedy", + [llama_context_p_ctypes, llama_token_data_array_p], + llama_token, +) def llama_sample_token_greedy( ctx: llama_context_p, candidates: Union[ @@ -2714,18 +2711,15 @@ def llama_sample_token_greedy( ... -llama_sample_token_greedy = _lib.llama_sample_token_greedy -llama_sample_token_greedy.argtypes = [ - llama_context_p_ctypes, - llama_token_data_array_p, -] -llama_sample_token_greedy.restype = llama_token - - # /// @details Randomly selects a token from the candidates based on their probabilities. # LLAMA_API llama_token llama_sample_token( # struct llama_context * ctx, # llama_token_data_array * candidates); +@ctypes_function( + "llama_sample_token", + [llama_context_p_ctypes, llama_token_data_array_p], + llama_token, +) def llama_sample_token( ctx: llama_context_p, candidates: Union[ @@ -2737,19 +2731,16 @@ def llama_sample_token( ... -llama_sample_token = _lib.llama_sample_token -llama_sample_token.argtypes = [ - llama_context_p_ctypes, - llama_token_data_array_p, -] -llama_sample_token.restype = llama_token - - # /// @details Accepts the sampled token into the grammar # LLAMA_API void llama_grammar_accept_token( # struct llama_context * ctx, # struct llama_grammar * grammar, # llama_token token); +@ctypes_function( + "llama_grammar_accept_token", + [llama_context_p_ctypes, llama_grammar_p, llama_token], + None, +) def llama_grammar_accept_token( ctx: llama_context_p, grammar: llama_grammar_p, token: Union[llama_token, int], / ) -> None: @@ -2757,15 +2748,6 @@ def llama_grammar_accept_token( ... -llama_grammar_accept_token = _lib.llama_grammar_accept_token -llama_grammar_accept_token.argtypes = [ - llama_context_p_ctypes, - llama_grammar_p, - llama_token, -] -llama_grammar_accept_token.restype = None - - # // # // Beam search # // @@ -2830,6 +2812,18 @@ llama_beam_search_callback_fn_t = ctypes.CFUNCTYPE( # size_t n_beams, # int32_t n_past, # int32_t n_predict); +@ctypes_function( + "llama_beam_search", + [ + llama_context_p_ctypes, + llama_beam_search_callback_fn_t, + ctypes.c_void_p, + ctypes.c_size_t, + ctypes.c_int32, + ctypes.c_int32, + ], + None, +) def llama_beam_search( ctx: llama_context_p, callback: CtypesFuncPointer, @@ -2842,73 +2836,66 @@ def llama_beam_search( ... -llama_beam_search = _lib.llama_beam_search -llama_beam_search.argtypes = [ - llama_context_p_ctypes, - llama_beam_search_callback_fn_t, - ctypes.c_void_p, - ctypes.c_size_t, - ctypes.c_int32, - ctypes.c_int32, -] -llama_beam_search.restype = None - - # Performance information # LLAMA_API struct llama_timings llama_get_timings(struct llama_context * ctx); +@ctypes_function( + "llama_get_timings", + [llama_context_p_ctypes], + llama_timings, +) def llama_get_timings(ctx: llama_context_p, /) -> llama_timings: """Get performance information""" ... -llama_get_timings = _lib.llama_get_timings -llama_get_timings.argtypes = [llama_context_p_ctypes] -llama_get_timings.restype = llama_timings - - # LLAMA_API void llama_print_timings(struct llama_context * ctx); +@ctypes_function( + "llama_print_timings", + [llama_context_p_ctypes], + None, +) def llama_print_timings(ctx: llama_context_p, /): """Print performance information""" ... -llama_print_timings = _lib.llama_print_timings -llama_print_timings.argtypes = [llama_context_p_ctypes] -llama_print_timings.restype = None - - # LLAMA_API void llama_reset_timings(struct llama_context * ctx); +@ctypes_function( + "llama_reset_timings", + [llama_context_p_ctypes], + None, +) def llama_reset_timings(ctx: llama_context_p, /): """Reset performance information""" ... -llama_reset_timings = _lib.llama_reset_timings -llama_reset_timings.argtypes = [llama_context_p_ctypes] -llama_reset_timings.restype = None - - # Print system information # LLAMA_API const char * llama_print_system_info(void); +@ctypes_function( + "llama_print_system_info", + [], + ctypes.c_char_p, +) def llama_print_system_info() -> bytes: """Print system information""" ... -llama_print_system_info = _lib.llama_print_system_info -llama_print_system_info.argtypes = [] -llama_print_system_info.restype = ctypes.c_char_p - - # NOTE: THIS IS CURRENTLY BROKEN AS ggml_log_callback IS NOT EXPOSED IN LLAMA.H # // Set callback for all future logging events. # // If this is not called, or NULL is supplied, everything is output on stderr. # LLAMA_API void llama_log_set(ggml_log_callback log_callback, void * user_data); +@ctypes_function( + "llama_log_set", + [ctypes.c_void_p, ctypes.c_void_p], + None, +) def llama_log_set( log_callback: Optional[CtypesFuncPointer], - user_data: ctypes.c_void_p, # type: ignore + user_data: ctypes.c_void_p, /, ): """Set callback for all future logging events. @@ -2917,16 +2904,11 @@ def llama_log_set( ... -llama_log_set = _lib.llama_log_set -llama_log_set.argtypes = [ctypes.c_void_p, ctypes.c_void_p] -llama_log_set.restype = None - - # LLAMA_API void llama_dump_timing_info_yaml(FILE * stream, const struct llama_context * ctx); +@ctypes_function( + "llama_dump_timing_info_yaml", + [ctypes.c_void_p, llama_context_p_ctypes], + None, +) def llama_dump_timing_info_yaml(stream: ctypes.c_void_p, ctx: llama_context_p, /): ... - - -llama_dump_timing_info_yaml = _lib.llama_dump_timing_info_yaml -llama_dump_timing_info_yaml.argtypes = [ctypes.c_void_p, llama_context_p_ctypes] -llama_dump_timing_info_yaml.restype = None