feat: Add typechecking for ctypes structure attributes

This commit is contained in:
Andrei Betlen 2024-04-10 02:40:41 -04:00
parent 889d0e8981
commit 1347e1d050

View file

@ -237,7 +237,7 @@ LLAMA_FILE_MAGIC_GGLA = 0x67676C61
# define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn'
LLAMA_FILE_MAGIC_GGSN = 0x6767736E
#define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq'
# define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq'
LLAMA_FILE_MAGIC_GGSQ = 0x67677371
# define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
@ -245,9 +245,9 @@ LLAMA_SESSION_MAGIC = LLAMA_FILE_MAGIC_GGSN
# define LLAMA_SESSION_VERSION 5
LLAMA_SESSION_VERSION = 5
#define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ
# define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ
LLAMA_STATE_SEQ_MAGIC = LLAMA_FILE_MAGIC_GGSQ
#define LLAMA_STATE_SEQ_VERSION 1
# define LLAMA_STATE_SEQ_VERSION 1
LLAMA_STATE_SEQ_VERSION = 1
# struct llama_model;
@ -431,6 +431,11 @@ class llama_token_data(ctypes.Structure):
logit (float): log-odds of the token
p (float): probability of the token"""
if TYPE_CHECKING:
id: llama_token
logit: float
p: float
_fields_ = [
("id", llama_token),
("logit", ctypes.c_float),
@ -454,6 +459,11 @@ class llama_token_data_array(ctypes.Structure):
size (int): size of the array
sorted (bool): whether the array is sorted"""
if TYPE_CHECKING:
data: CtypesArray[llama_token_data]
size: int
sorted: bool
_fields_ = [
("data", llama_token_data_p),
("size", ctypes.c_size_t),
@ -515,6 +525,15 @@ class llama_batch(ctypes.Structure):
logits (ctypes.Array[ctypes.ctypes.c_int8]): if zero, the logits for the respective token will not be output
"""
if TYPE_CHECKING:
n_tokens: int
token: CtypesArray[llama_token]
embd: CtypesArray[ctypes.c_float]
pos: CtypesArray[CtypesArray[llama_pos]]
n_seq_id: CtypesArray[ctypes.c_int]
seq_id: CtypesArray[CtypesArray[llama_seq_id]]
logits: CtypesArray[ctypes.c_int8]
_fields_ = [
("n_tokens", ctypes.c_int32),
("token", ctypes.POINTER(llama_token)),
@ -609,6 +628,18 @@ class llama_model_params(ctypes.Structure):
use_mmap (bool): use mmap if possible
use_mlock (bool): force system to keep model in RAM"""
if TYPE_CHECKING:
n_gpu_layers: int
split_mode: int
main_gpu: int
tensor_split: CtypesArray[ctypes.c_float]
progress_callback: Callable[[float, ctypes.c_void_p], bool]
progress_callback_user_data: ctypes.c_void_p
kv_overrides: CtypesArray[llama_model_kv_override]
vocab_only: bool
use_mmap: bool
use_mlock: bool
_fields_ = [
("n_gpu_layers", ctypes.c_int32),
("split_mode", ctypes.c_int),
@ -696,6 +727,34 @@ class llama_context_params(ctypes.Structure):
abort_callback_data (ctypes.ctypes.c_void_p): data for abort_callback
"""
if TYPE_CHECKING:
seed: int
n_ctx: int
n_batch: int
n_ubatch: int
n_seq_max: int
n_threads: int
n_threads_batch: int
rope_scaling_type: int
pooling_type: int
rope_freq_base: float
rope_freq_scale: float
yarn_ext_factor: float
yarn_attn_factor: float
yarn_beta_fast: float
yarn_beta_slow: float
yarn_orig_ctx: int
defrag_thold: float
cb_eval: Callable[[ctypes.c_void_p, bool], bool]
cb_eval_user_data: ctypes.c_void_p
type_k: int
type_v: int
logits_all: bool
embeddings: bool
offload_kqv: bool
abort_callback: Callable[[ctypes.c_void_p], bool]
abort_callback_data: ctypes.c_void_p
_fields_ = [
("seed", ctypes.c_uint32),
("n_ctx", ctypes.c_uint32),
@ -771,6 +830,18 @@ class llama_model_quantize_params(ctypes.Structure):
kv_overrides (ctypes.c_void_p): pointer to vector containing overrides
"""
if TYPE_CHECKING:
nthread: int
ftype: int
output_tensor_type: int
token_embedding_type: int
allow_requantize: bool
quantize_output_tensor: bool
only_copy: bool
pure: bool
imatrix: ctypes.c_void_p
kv_overrides: ctypes.c_void_p
_fields_ = [
("nthread", ctypes.c_int32),
("ftype", ctypes.c_int),
@ -828,6 +899,10 @@ LLAMA_GRETYPE_CHAR_ALT = 6
# uint32_t value; // Unicode code point or rule ID
# } llama_grammar_element;
class llama_grammar_element(ctypes.Structure):
if TYPE_CHECKING:
type: int
value: int
_fields_ = [
("type", ctypes.c_int),
("value", ctypes.c_uint32),
@ -851,6 +926,17 @@ llama_grammar_element_p = ctypes.POINTER(llama_grammar_element)
# int32_t n_eval;
# };
class llama_timings(ctypes.Structure):
if TYPE_CHECKING:
t_start_ms: float
t_end_ms: float
t_load_ms: float
t_sample_ms: float
t_p_eval_ms: float
t_eval_ms: float
n_sample: int
n_p_eval: int
n_eval: int
_fields_ = [
("t_start_ms", ctypes.c_double),
("t_end_ms", ctypes.c_double),
@ -951,7 +1037,8 @@ GGML_NUMA_STRATEGY_COUNT = 5
[ctypes.c_int],
None,
)
def llama_numa_init(numa: int, /): ...
def llama_numa_init(numa: int, /):
...
# // Call once at the end of the program - currently only used for MPI
@ -976,7 +1063,8 @@ def llama_backend_free():
)
def llama_load_model_from_file(
path_model: bytes, params: llama_model_params, /
) -> Optional[llama_model_p]: ...
) -> Optional[llama_model_p]:
...
# LLAMA_API void llama_free_model(struct llama_model * model);
@ -985,7 +1073,8 @@ def llama_load_model_from_file(
[llama_model_p_ctypes],
None,
)
def llama_free_model(model: llama_model_p, /): ...
def llama_free_model(model: llama_model_p, /):
...
# LLAMA_API struct llama_context * llama_new_context_with_model(
@ -998,7 +1087,8 @@ def llama_free_model(model: llama_model_p, /): ...
)
def llama_new_context_with_model(
model: llama_model_p, params: llama_context_params, /
) -> Optional[llama_context_p]: ...
) -> Optional[llama_context_p]:
...
# // Frees all allocated memory
@ -1019,82 +1109,98 @@ def llama_free(ctx: llama_context_p, /):
[],
ctypes.c_int64,
)
def llama_time_us() -> int: ...
def llama_time_us() -> int:
...
# LLAMA_API size_t llama_max_devices(void);
@ctypes_function("llama_max_devices", [], ctypes.c_size_t)
def llama_max_devices() -> int: ...
def llama_max_devices() -> int:
...
# LLAMA_API bool llama_supports_mmap (void);
@ctypes_function("llama_supports_mmap", [], ctypes.c_bool)
def llama_supports_mmap() -> bool: ...
def llama_supports_mmap() -> bool:
...
# LLAMA_API bool llama_supports_mlock (void);
@ctypes_function("llama_supports_mlock", [], ctypes.c_bool)
def llama_supports_mlock() -> bool: ...
def llama_supports_mlock() -> bool:
...
# LLAMA_API bool llama_supports_gpu_offload(void);
@ctypes_function("llama_supports_gpu_offload", [], ctypes.c_bool)
def llama_supports_gpu_offload() -> bool: ...
def llama_supports_gpu_offload() -> bool:
...
# LLAMA_API const struct llama_model * llama_get_model(const struct llama_context * ctx);
@ctypes_function("llama_get_model", [llama_context_p_ctypes], llama_model_p_ctypes)
def llama_get_model(ctx: llama_context_p, /) -> Optional[llama_model_p]: ...
def llama_get_model(ctx: llama_context_p, /) -> Optional[llama_model_p]:
...
# LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx);
@ctypes_function("llama_n_ctx", [llama_context_p_ctypes], ctypes.c_uint32)
def llama_n_ctx(ctx: llama_context_p, /) -> int: ...
def llama_n_ctx(ctx: llama_context_p, /) -> int:
...
# LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx);
@ctypes_function("llama_n_batch", [llama_context_p_ctypes], ctypes.c_uint32)
def llama_n_batch(ctx: llama_context_p, /) -> int: ...
def llama_n_batch(ctx: llama_context_p, /) -> int:
...
# LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx);
@ctypes_function("llama_n_ubatch", [llama_context_p_ctypes], ctypes.c_uint32)
def llama_n_ubatch(ctx: llama_context_p, /) -> int: ...
def llama_n_ubatch(ctx: llama_context_p, /) -> int:
...
# LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx);
@ctypes_function("llama_n_seq_max", [llama_context_p_ctypes], ctypes.c_uint32)
def llama_n_seq_max(ctx: llama_context_p, /) -> int: ...
def llama_n_seq_max(ctx: llama_context_p, /) -> int:
...
# LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_model * model);
@ctypes_function("llama_vocab_type", [llama_model_p_ctypes], ctypes.c_int)
def llama_vocab_type(model: llama_model_p, /) -> int: ...
def llama_vocab_type(model: llama_model_p, /) -> int:
...
# LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model);
@ctypes_function("llama_rope_type", [llama_model_p_ctypes], ctypes.c_int)
def llama_rope_type(model: llama_model_p, /) -> int: ...
def llama_rope_type(model: llama_model_p, /) -> int:
...
# LLAMA_API int32_t llama_n_vocab (const struct llama_model * model);
@ctypes_function("llama_n_vocab", [llama_model_p_ctypes], ctypes.c_int32)
def llama_n_vocab(model: llama_model_p, /) -> int: ...
def llama_n_vocab(model: llama_model_p, /) -> int:
...
# LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model);
@ctypes_function("llama_n_ctx_train", [llama_model_p_ctypes], ctypes.c_int32)
def llama_n_ctx_train(model: llama_model_p, /) -> int: ...
def llama_n_ctx_train(model: llama_model_p, /) -> int:
...
# LLAMA_API int32_t llama_n_embd (const struct llama_model * model);
@ctypes_function("llama_n_embd", [llama_model_p_ctypes], ctypes.c_int32)
def llama_n_embd(model: llama_model_p, /) -> int: ...
def llama_n_embd(model: llama_model_p, /) -> int:
...
# LLAMA_API int32_t llama_n_layer (const struct llama_model * model);
@ctypes_function("llama_n_layer", [llama_model_p_ctypes], ctypes.c_int32)
def llama_n_layer(model: llama_model_p, /) -> int: ...
def llama_n_layer(model: llama_model_p, /) -> int:
...
# // Get the model's RoPE frequency scaling factor
@ -1358,6 +1464,9 @@ class llama_kv_cache_view_cell(ctypes.Structure):
pos (llama_pos): The position for this cell. Takes KV cache shifts into account.
May be negative if the cell is not populated."""
if TYPE_CHECKING:
pos: llama_pos
_fields_ = [("pos", llama_pos)]
@ -1394,6 +1503,16 @@ class llama_kv_cache_view_cell(ctypes.Structure):
# llama_seq_id * cells_sequences;
# };
class llama_kv_cache_view(ctypes.Structure):
if TYPE_CHECKING:
n_cells: int
n_max_seq: int
token_count: int
used_cells: int
max_contiguous: int
max_contiguous_idx: int
cells: CtypesArray[llama_kv_cache_view_cell]
cells_sequences: CtypesArray[llama_seq_id]
_fields_ = [
("n_cells", ctypes.c_int32),
("n_max_seq", ctypes.c_int32),
@ -1783,7 +1902,8 @@ def llama_state_load_file(
n_token_capacity: Union[ctypes.c_size_t, int],
n_token_count_out: CtypesPointerOrRef[ctypes.c_size_t],
/,
) -> bool: ...
) -> bool:
...
# LLAMA_API DEPRECATED(bool llama_load_session_file(
@ -1811,7 +1931,8 @@ def llama_load_session_file(
n_token_capacity: Union[ctypes.c_size_t, int],
n_token_count_out: CtypesPointerOrRef[ctypes.c_size_t],
/,
) -> int: ...
) -> int:
...
# LLAMA_API bool llama_state_save_file(
@ -1835,7 +1956,8 @@ def llama_state_save_file(
tokens: CtypesArray[llama_token],
n_token_count: Union[ctypes.c_size_t, int],
/,
) -> bool: ...
) -> bool:
...
# LLAMA_API DEPRECATED(bool llama_save_session_file(
@ -1860,7 +1982,8 @@ def llama_save_session_file(
tokens: CtypesArray[llama_token],
n_token_count: Union[ctypes.c_size_t, int],
/,
) -> int: ...
) -> int:
...
# // Get the exact size needed to copy the KV cache of a single sequence
@ -2233,7 +2356,8 @@ def llama_get_embeddings_seq(
)
def llama_token_get_text(
model: llama_model_p, token: Union[llama_token, int], /
) -> bytes: ...
) -> bytes:
...
# LLAMA_API float llama_token_get_score(const struct llama_model * model, llama_token token);
@ -2242,7 +2366,8 @@ def llama_token_get_text(
)
def llama_token_get_score(
model: llama_model_p, token: Union[llama_token, int], /
) -> float: ...
) -> float:
...
# LLAMA_API enum llama_token_type llama_token_get_type(const struct llama_model * model, llama_token token);
@ -2251,7 +2376,8 @@ def llama_token_get_score(
)
def llama_token_get_type(
model: llama_model_p, token: Union[llama_token, int], /
) -> int: ...
) -> int:
...
# // Special tokens
@ -2318,17 +2444,20 @@ def llama_token_prefix(model: llama_model_p) -> int:
# LLAMA_API llama_token llama_token_middle(const struct llama_model * model); // Beginning of infill middle
@ctypes_function("llama_token_middle", [llama_model_p_ctypes], llama_token)
def llama_token_middle(model: llama_model_p, /) -> int: ...
def llama_token_middle(model: llama_model_p, /) -> int:
...
# LLAMA_API llama_token llama_token_suffix(const struct llama_model * model); // Beginning of infill suffix
@ctypes_function("llama_token_suffix", [llama_model_p_ctypes], llama_token)
def llama_token_suffix(model: llama_model_p, /) -> int: ...
def llama_token_suffix(model: llama_model_p, /) -> int:
...
# LLAMA_API llama_token llama_token_eot (const struct llama_model * model); // End of infill middle
@ctypes_function("llama_token_eot", [llama_model_p_ctypes], llama_token)
def llama_token_eot(model: llama_model_p, /) -> int: ...
def llama_token_eot(model: llama_model_p, /) -> int:
...
# //
@ -2459,7 +2588,8 @@ def llama_chat_apply_template(
chat: CtypesArray[llama_chat_message],
n_msg: int,
/,
) -> int: ...
) -> int:
...
# //
@ -2989,6 +3119,12 @@ def llama_grammar_accept_token(
# bool eob; // Callback should set this to true when a beam is at end-of-beam.
# };
class llama_beam_view(ctypes.Structure):
if TYPE_CHECKING:
tokens: CtypesArray[llama_token]
n_tokens: int
p: float
eob: bool
_fields_ = [
("tokens", llama_token_p),
("n_tokens", ctypes.c_size_t),
@ -3008,6 +3144,12 @@ class llama_beam_view(ctypes.Structure):
# bool last_call; // True iff this is the last callback invocation.
# };
class llama_beams_state(ctypes.Structure):
if TYPE_CHECKING:
beam_views: CtypesArray[llama_beam_view]
n_beams: int
common_prefix_length: int
last_call: bool
_fields_ = [
("beam_views", ctypes.POINTER(llama_beam_view)),
("n_beams", ctypes.c_size_t),
@ -3060,7 +3202,8 @@ def llama_beam_search(
n_past: Union[ctypes.c_int, int],
n_predict: Union[ctypes.c_int, int],
/,
): ...
):
...
# /// @details Build a split GGUF final path for this chunk.
@ -3179,4 +3322,5 @@ def llama_log_set(
[ctypes.c_void_p, llama_context_p_ctypes],
None,
)
def llama_dump_timing_info_yaml(stream: ctypes.c_void_p, ctx: llama_context_p, /): ...
def llama_dump_timing_info_yaml(stream: ctypes.c_void_p, ctx: llama_context_p, /):
...