feat: Add typechecking for ctypes structure attributes

This commit is contained in:
Andrei Betlen 2024-04-10 02:40:41 -04:00
parent 889d0e8981
commit 1347e1d050

View file

@ -431,6 +431,11 @@ class llama_token_data(ctypes.Structure):
logit (float): log-odds of the token logit (float): log-odds of the token
p (float): probability of the token""" p (float): probability of the token"""
if TYPE_CHECKING:
id: llama_token
logit: float
p: float
_fields_ = [ _fields_ = [
("id", llama_token), ("id", llama_token),
("logit", ctypes.c_float), ("logit", ctypes.c_float),
@ -454,6 +459,11 @@ class llama_token_data_array(ctypes.Structure):
size (int): size of the array size (int): size of the array
sorted (bool): whether the array is sorted""" sorted (bool): whether the array is sorted"""
if TYPE_CHECKING:
data: CtypesArray[llama_token_data]
size: int
sorted: bool
_fields_ = [ _fields_ = [
("data", llama_token_data_p), ("data", llama_token_data_p),
("size", ctypes.c_size_t), ("size", ctypes.c_size_t),
@ -515,6 +525,15 @@ class llama_batch(ctypes.Structure):
logits (ctypes.Array[ctypes.ctypes.c_int8]): if zero, the logits for the respective token will not be output logits (ctypes.Array[ctypes.ctypes.c_int8]): if zero, the logits for the respective token will not be output
""" """
if TYPE_CHECKING:
n_tokens: int
token: CtypesArray[llama_token]
embd: CtypesArray[ctypes.c_float]
pos: CtypesArray[CtypesArray[llama_pos]]
n_seq_id: CtypesArray[ctypes.c_int]
seq_id: CtypesArray[CtypesArray[llama_seq_id]]
logits: CtypesArray[ctypes.c_int8]
_fields_ = [ _fields_ = [
("n_tokens", ctypes.c_int32), ("n_tokens", ctypes.c_int32),
("token", ctypes.POINTER(llama_token)), ("token", ctypes.POINTER(llama_token)),
@ -609,6 +628,18 @@ class llama_model_params(ctypes.Structure):
use_mmap (bool): use mmap if possible use_mmap (bool): use mmap if possible
use_mlock (bool): force system to keep model in RAM""" use_mlock (bool): force system to keep model in RAM"""
if TYPE_CHECKING:
n_gpu_layers: int
split_mode: int
main_gpu: int
tensor_split: CtypesArray[ctypes.c_float]
progress_callback: Callable[[float, ctypes.c_void_p], bool]
progress_callback_user_data: ctypes.c_void_p
kv_overrides: CtypesArray[llama_model_kv_override]
vocab_only: bool
use_mmap: bool
use_mlock: bool
_fields_ = [ _fields_ = [
("n_gpu_layers", ctypes.c_int32), ("n_gpu_layers", ctypes.c_int32),
("split_mode", ctypes.c_int), ("split_mode", ctypes.c_int),
@ -696,6 +727,34 @@ class llama_context_params(ctypes.Structure):
abort_callback_data (ctypes.ctypes.c_void_p): data for abort_callback abort_callback_data (ctypes.ctypes.c_void_p): data for abort_callback
""" """
if TYPE_CHECKING:
seed: int
n_ctx: int
n_batch: int
n_ubatch: int
n_seq_max: int
n_threads: int
n_threads_batch: int
rope_scaling_type: int
pooling_type: int
rope_freq_base: float
rope_freq_scale: float
yarn_ext_factor: float
yarn_attn_factor: float
yarn_beta_fast: float
yarn_beta_slow: float
yarn_orig_ctx: int
defrag_thold: float
cb_eval: Callable[[ctypes.c_void_p, bool], bool]
cb_eval_user_data: ctypes.c_void_p
type_k: int
type_v: int
logits_all: bool
embeddings: bool
offload_kqv: bool
abort_callback: Callable[[ctypes.c_void_p], bool]
abort_callback_data: ctypes.c_void_p
_fields_ = [ _fields_ = [
("seed", ctypes.c_uint32), ("seed", ctypes.c_uint32),
("n_ctx", ctypes.c_uint32), ("n_ctx", ctypes.c_uint32),
@ -771,6 +830,18 @@ class llama_model_quantize_params(ctypes.Structure):
kv_overrides (ctypes.c_void_p): pointer to vector containing overrides kv_overrides (ctypes.c_void_p): pointer to vector containing overrides
""" """
if TYPE_CHECKING:
nthread: int
ftype: int
output_tensor_type: int
token_embedding_type: int
allow_requantize: bool
quantize_output_tensor: bool
only_copy: bool
pure: bool
imatrix: ctypes.c_void_p
kv_overrides: ctypes.c_void_p
_fields_ = [ _fields_ = [
("nthread", ctypes.c_int32), ("nthread", ctypes.c_int32),
("ftype", ctypes.c_int), ("ftype", ctypes.c_int),
@ -828,6 +899,10 @@ LLAMA_GRETYPE_CHAR_ALT = 6
# uint32_t value; // Unicode code point or rule ID # uint32_t value; // Unicode code point or rule ID
# } llama_grammar_element; # } llama_grammar_element;
class llama_grammar_element(ctypes.Structure): class llama_grammar_element(ctypes.Structure):
if TYPE_CHECKING:
type: int
value: int
_fields_ = [ _fields_ = [
("type", ctypes.c_int), ("type", ctypes.c_int),
("value", ctypes.c_uint32), ("value", ctypes.c_uint32),
@ -851,6 +926,17 @@ llama_grammar_element_p = ctypes.POINTER(llama_grammar_element)
# int32_t n_eval; # int32_t n_eval;
# }; # };
class llama_timings(ctypes.Structure): class llama_timings(ctypes.Structure):
if TYPE_CHECKING:
t_start_ms: float
t_end_ms: float
t_load_ms: float
t_sample_ms: float
t_p_eval_ms: float
t_eval_ms: float
n_sample: int
n_p_eval: int
n_eval: int
_fields_ = [ _fields_ = [
("t_start_ms", ctypes.c_double), ("t_start_ms", ctypes.c_double),
("t_end_ms", ctypes.c_double), ("t_end_ms", ctypes.c_double),
@ -951,7 +1037,8 @@ GGML_NUMA_STRATEGY_COUNT = 5
[ctypes.c_int], [ctypes.c_int],
None, None,
) )
def llama_numa_init(numa: int, /): ... def llama_numa_init(numa: int, /):
...
# // Call once at the end of the program - currently only used for MPI # // Call once at the end of the program - currently only used for MPI
@ -976,7 +1063,8 @@ def llama_backend_free():
) )
def llama_load_model_from_file( def llama_load_model_from_file(
path_model: bytes, params: llama_model_params, / path_model: bytes, params: llama_model_params, /
) -> Optional[llama_model_p]: ... ) -> Optional[llama_model_p]:
...
# LLAMA_API void llama_free_model(struct llama_model * model); # LLAMA_API void llama_free_model(struct llama_model * model);
@ -985,7 +1073,8 @@ def llama_load_model_from_file(
[llama_model_p_ctypes], [llama_model_p_ctypes],
None, None,
) )
def llama_free_model(model: llama_model_p, /): ... def llama_free_model(model: llama_model_p, /):
...
# LLAMA_API struct llama_context * llama_new_context_with_model( # LLAMA_API struct llama_context * llama_new_context_with_model(
@ -998,7 +1087,8 @@ def llama_free_model(model: llama_model_p, /): ...
) )
def llama_new_context_with_model( def llama_new_context_with_model(
model: llama_model_p, params: llama_context_params, / model: llama_model_p, params: llama_context_params, /
) -> Optional[llama_context_p]: ... ) -> Optional[llama_context_p]:
...
# // Frees all allocated memory # // Frees all allocated memory
@ -1019,82 +1109,98 @@ def llama_free(ctx: llama_context_p, /):
[], [],
ctypes.c_int64, ctypes.c_int64,
) )
def llama_time_us() -> int: ... def llama_time_us() -> int:
...
# LLAMA_API size_t llama_max_devices(void); # LLAMA_API size_t llama_max_devices(void);
@ctypes_function("llama_max_devices", [], ctypes.c_size_t) @ctypes_function("llama_max_devices", [], ctypes.c_size_t)
def llama_max_devices() -> int: ... def llama_max_devices() -> int:
...
# LLAMA_API bool llama_supports_mmap (void); # LLAMA_API bool llama_supports_mmap (void);
@ctypes_function("llama_supports_mmap", [], ctypes.c_bool) @ctypes_function("llama_supports_mmap", [], ctypes.c_bool)
def llama_supports_mmap() -> bool: ... def llama_supports_mmap() -> bool:
...
# LLAMA_API bool llama_supports_mlock (void); # LLAMA_API bool llama_supports_mlock (void);
@ctypes_function("llama_supports_mlock", [], ctypes.c_bool) @ctypes_function("llama_supports_mlock", [], ctypes.c_bool)
def llama_supports_mlock() -> bool: ... def llama_supports_mlock() -> bool:
...
# LLAMA_API bool llama_supports_gpu_offload(void); # LLAMA_API bool llama_supports_gpu_offload(void);
@ctypes_function("llama_supports_gpu_offload", [], ctypes.c_bool) @ctypes_function("llama_supports_gpu_offload", [], ctypes.c_bool)
def llama_supports_gpu_offload() -> bool: ... def llama_supports_gpu_offload() -> bool:
...
# LLAMA_API const struct llama_model * llama_get_model(const struct llama_context * ctx); # LLAMA_API const struct llama_model * llama_get_model(const struct llama_context * ctx);
@ctypes_function("llama_get_model", [llama_context_p_ctypes], llama_model_p_ctypes) @ctypes_function("llama_get_model", [llama_context_p_ctypes], llama_model_p_ctypes)
def llama_get_model(ctx: llama_context_p, /) -> Optional[llama_model_p]: ... def llama_get_model(ctx: llama_context_p, /) -> Optional[llama_model_p]:
...
# LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx); # LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx);
@ctypes_function("llama_n_ctx", [llama_context_p_ctypes], ctypes.c_uint32) @ctypes_function("llama_n_ctx", [llama_context_p_ctypes], ctypes.c_uint32)
def llama_n_ctx(ctx: llama_context_p, /) -> int: ... def llama_n_ctx(ctx: llama_context_p, /) -> int:
...
# LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx); # LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx);
@ctypes_function("llama_n_batch", [llama_context_p_ctypes], ctypes.c_uint32) @ctypes_function("llama_n_batch", [llama_context_p_ctypes], ctypes.c_uint32)
def llama_n_batch(ctx: llama_context_p, /) -> int: ... def llama_n_batch(ctx: llama_context_p, /) -> int:
...
# LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx); # LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx);
@ctypes_function("llama_n_ubatch", [llama_context_p_ctypes], ctypes.c_uint32) @ctypes_function("llama_n_ubatch", [llama_context_p_ctypes], ctypes.c_uint32)
def llama_n_ubatch(ctx: llama_context_p, /) -> int: ... def llama_n_ubatch(ctx: llama_context_p, /) -> int:
...
# LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx); # LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx);
@ctypes_function("llama_n_seq_max", [llama_context_p_ctypes], ctypes.c_uint32) @ctypes_function("llama_n_seq_max", [llama_context_p_ctypes], ctypes.c_uint32)
def llama_n_seq_max(ctx: llama_context_p, /) -> int: ... def llama_n_seq_max(ctx: llama_context_p, /) -> int:
...
# LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_model * model); # LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_model * model);
@ctypes_function("llama_vocab_type", [llama_model_p_ctypes], ctypes.c_int) @ctypes_function("llama_vocab_type", [llama_model_p_ctypes], ctypes.c_int)
def llama_vocab_type(model: llama_model_p, /) -> int: ... def llama_vocab_type(model: llama_model_p, /) -> int:
...
# LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model); # LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model);
@ctypes_function("llama_rope_type", [llama_model_p_ctypes], ctypes.c_int) @ctypes_function("llama_rope_type", [llama_model_p_ctypes], ctypes.c_int)
def llama_rope_type(model: llama_model_p, /) -> int: ... def llama_rope_type(model: llama_model_p, /) -> int:
...
# LLAMA_API int32_t llama_n_vocab (const struct llama_model * model); # LLAMA_API int32_t llama_n_vocab (const struct llama_model * model);
@ctypes_function("llama_n_vocab", [llama_model_p_ctypes], ctypes.c_int32) @ctypes_function("llama_n_vocab", [llama_model_p_ctypes], ctypes.c_int32)
def llama_n_vocab(model: llama_model_p, /) -> int: ... def llama_n_vocab(model: llama_model_p, /) -> int:
...
# LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model); # LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model);
@ctypes_function("llama_n_ctx_train", [llama_model_p_ctypes], ctypes.c_int32) @ctypes_function("llama_n_ctx_train", [llama_model_p_ctypes], ctypes.c_int32)
def llama_n_ctx_train(model: llama_model_p, /) -> int: ... def llama_n_ctx_train(model: llama_model_p, /) -> int:
...
# LLAMA_API int32_t llama_n_embd (const struct llama_model * model); # LLAMA_API int32_t llama_n_embd (const struct llama_model * model);
@ctypes_function("llama_n_embd", [llama_model_p_ctypes], ctypes.c_int32) @ctypes_function("llama_n_embd", [llama_model_p_ctypes], ctypes.c_int32)
def llama_n_embd(model: llama_model_p, /) -> int: ... def llama_n_embd(model: llama_model_p, /) -> int:
...
# LLAMA_API int32_t llama_n_layer (const struct llama_model * model); # LLAMA_API int32_t llama_n_layer (const struct llama_model * model);
@ctypes_function("llama_n_layer", [llama_model_p_ctypes], ctypes.c_int32) @ctypes_function("llama_n_layer", [llama_model_p_ctypes], ctypes.c_int32)
def llama_n_layer(model: llama_model_p, /) -> int: ... def llama_n_layer(model: llama_model_p, /) -> int:
...
# // Get the model's RoPE frequency scaling factor # // Get the model's RoPE frequency scaling factor
@ -1358,6 +1464,9 @@ class llama_kv_cache_view_cell(ctypes.Structure):
pos (llama_pos): The position for this cell. Takes KV cache shifts into account. pos (llama_pos): The position for this cell. Takes KV cache shifts into account.
May be negative if the cell is not populated.""" May be negative if the cell is not populated."""
if TYPE_CHECKING:
pos: llama_pos
_fields_ = [("pos", llama_pos)] _fields_ = [("pos", llama_pos)]
@ -1394,6 +1503,16 @@ class llama_kv_cache_view_cell(ctypes.Structure):
# llama_seq_id * cells_sequences; # llama_seq_id * cells_sequences;
# }; # };
class llama_kv_cache_view(ctypes.Structure): class llama_kv_cache_view(ctypes.Structure):
if TYPE_CHECKING:
n_cells: int
n_max_seq: int
token_count: int
used_cells: int
max_contiguous: int
max_contiguous_idx: int
cells: CtypesArray[llama_kv_cache_view_cell]
cells_sequences: CtypesArray[llama_seq_id]
_fields_ = [ _fields_ = [
("n_cells", ctypes.c_int32), ("n_cells", ctypes.c_int32),
("n_max_seq", ctypes.c_int32), ("n_max_seq", ctypes.c_int32),
@ -1783,7 +1902,8 @@ def llama_state_load_file(
n_token_capacity: Union[ctypes.c_size_t, int], n_token_capacity: Union[ctypes.c_size_t, int],
n_token_count_out: CtypesPointerOrRef[ctypes.c_size_t], n_token_count_out: CtypesPointerOrRef[ctypes.c_size_t],
/, /,
) -> bool: ... ) -> bool:
...
# LLAMA_API DEPRECATED(bool llama_load_session_file( # LLAMA_API DEPRECATED(bool llama_load_session_file(
@ -1811,7 +1931,8 @@ def llama_load_session_file(
n_token_capacity: Union[ctypes.c_size_t, int], n_token_capacity: Union[ctypes.c_size_t, int],
n_token_count_out: CtypesPointerOrRef[ctypes.c_size_t], n_token_count_out: CtypesPointerOrRef[ctypes.c_size_t],
/, /,
) -> int: ... ) -> int:
...
# LLAMA_API bool llama_state_save_file( # LLAMA_API bool llama_state_save_file(
@ -1835,7 +1956,8 @@ def llama_state_save_file(
tokens: CtypesArray[llama_token], tokens: CtypesArray[llama_token],
n_token_count: Union[ctypes.c_size_t, int], n_token_count: Union[ctypes.c_size_t, int],
/, /,
) -> bool: ... ) -> bool:
...
# LLAMA_API DEPRECATED(bool llama_save_session_file( # LLAMA_API DEPRECATED(bool llama_save_session_file(
@ -1860,7 +1982,8 @@ def llama_save_session_file(
tokens: CtypesArray[llama_token], tokens: CtypesArray[llama_token],
n_token_count: Union[ctypes.c_size_t, int], n_token_count: Union[ctypes.c_size_t, int],
/, /,
) -> int: ... ) -> int:
...
# // Get the exact size needed to copy the KV cache of a single sequence # // Get the exact size needed to copy the KV cache of a single sequence
@ -2233,7 +2356,8 @@ def llama_get_embeddings_seq(
) )
def llama_token_get_text( def llama_token_get_text(
model: llama_model_p, token: Union[llama_token, int], / model: llama_model_p, token: Union[llama_token, int], /
) -> bytes: ... ) -> bytes:
...
# LLAMA_API float llama_token_get_score(const struct llama_model * model, llama_token token); # LLAMA_API float llama_token_get_score(const struct llama_model * model, llama_token token);
@ -2242,7 +2366,8 @@ def llama_token_get_text(
) )
def llama_token_get_score( def llama_token_get_score(
model: llama_model_p, token: Union[llama_token, int], / model: llama_model_p, token: Union[llama_token, int], /
) -> float: ... ) -> float:
...
# LLAMA_API enum llama_token_type llama_token_get_type(const struct llama_model * model, llama_token token); # LLAMA_API enum llama_token_type llama_token_get_type(const struct llama_model * model, llama_token token);
@ -2251,7 +2376,8 @@ def llama_token_get_score(
) )
def llama_token_get_type( def llama_token_get_type(
model: llama_model_p, token: Union[llama_token, int], / model: llama_model_p, token: Union[llama_token, int], /
) -> int: ... ) -> int:
...
# // Special tokens # // Special tokens
@ -2318,17 +2444,20 @@ def llama_token_prefix(model: llama_model_p) -> int:
# LLAMA_API llama_token llama_token_middle(const struct llama_model * model); // Beginning of infill middle # LLAMA_API llama_token llama_token_middle(const struct llama_model * model); // Beginning of infill middle
@ctypes_function("llama_token_middle", [llama_model_p_ctypes], llama_token) @ctypes_function("llama_token_middle", [llama_model_p_ctypes], llama_token)
def llama_token_middle(model: llama_model_p, /) -> int: ... def llama_token_middle(model: llama_model_p, /) -> int:
...
# LLAMA_API llama_token llama_token_suffix(const struct llama_model * model); // Beginning of infill suffix # LLAMA_API llama_token llama_token_suffix(const struct llama_model * model); // Beginning of infill suffix
@ctypes_function("llama_token_suffix", [llama_model_p_ctypes], llama_token) @ctypes_function("llama_token_suffix", [llama_model_p_ctypes], llama_token)
def llama_token_suffix(model: llama_model_p, /) -> int: ... def llama_token_suffix(model: llama_model_p, /) -> int:
...
# LLAMA_API llama_token llama_token_eot (const struct llama_model * model); // End of infill middle # LLAMA_API llama_token llama_token_eot (const struct llama_model * model); // End of infill middle
@ctypes_function("llama_token_eot", [llama_model_p_ctypes], llama_token) @ctypes_function("llama_token_eot", [llama_model_p_ctypes], llama_token)
def llama_token_eot(model: llama_model_p, /) -> int: ... def llama_token_eot(model: llama_model_p, /) -> int:
...
# // # //
@ -2459,7 +2588,8 @@ def llama_chat_apply_template(
chat: CtypesArray[llama_chat_message], chat: CtypesArray[llama_chat_message],
n_msg: int, n_msg: int,
/, /,
) -> int: ... ) -> int:
...
# // # //
@ -2989,6 +3119,12 @@ def llama_grammar_accept_token(
# bool eob; // Callback should set this to true when a beam is at end-of-beam. # bool eob; // Callback should set this to true when a beam is at end-of-beam.
# }; # };
class llama_beam_view(ctypes.Structure): class llama_beam_view(ctypes.Structure):
if TYPE_CHECKING:
tokens: CtypesArray[llama_token]
n_tokens: int
p: float
eob: bool
_fields_ = [ _fields_ = [
("tokens", llama_token_p), ("tokens", llama_token_p),
("n_tokens", ctypes.c_size_t), ("n_tokens", ctypes.c_size_t),
@ -3008,6 +3144,12 @@ class llama_beam_view(ctypes.Structure):
# bool last_call; // True iff this is the last callback invocation. # bool last_call; // True iff this is the last callback invocation.
# }; # };
class llama_beams_state(ctypes.Structure): class llama_beams_state(ctypes.Structure):
if TYPE_CHECKING:
beam_views: CtypesArray[llama_beam_view]
n_beams: int
common_prefix_length: int
last_call: bool
_fields_ = [ _fields_ = [
("beam_views", ctypes.POINTER(llama_beam_view)), ("beam_views", ctypes.POINTER(llama_beam_view)),
("n_beams", ctypes.c_size_t), ("n_beams", ctypes.c_size_t),
@ -3060,7 +3202,8 @@ def llama_beam_search(
n_past: Union[ctypes.c_int, int], n_past: Union[ctypes.c_int, int],
n_predict: Union[ctypes.c_int, int], n_predict: Union[ctypes.c_int, int],
/, /,
): ... ):
...
# /// @details Build a split GGUF final path for this chunk. # /// @details Build a split GGUF final path for this chunk.
@ -3179,4 +3322,5 @@ def llama_log_set(
[ctypes.c_void_p, llama_context_p_ctypes], [ctypes.c_void_p, llama_context_p_ctypes],
None, None,
) )
def llama_dump_timing_info_yaml(stream: ctypes.c_void_p, ctx: llama_context_p, /): ... def llama_dump_timing_info_yaml(stream: ctypes.c_void_p, ctx: llama_context_p, /):
...