misc: additional type annotations for low level api

This commit is contained in:
Andrei Betlen 2024-02-22 02:00:09 -05:00
parent 3632241e98
commit aefcb8f71a

View file

@ -1,3 +1,5 @@
from __future__ import annotations
import sys import sys
import os import os
import ctypes import ctypes
@ -6,7 +8,16 @@ from ctypes import (
Array, Array,
) )
import pathlib import pathlib
from typing import List, Union, NewType, Optional from typing import (
List,
Union,
NewType,
Optional,
TYPE_CHECKING,
TypeVar,
TypeAlias,
Generic,
)
# Load the library # Load the library
@ -56,7 +67,7 @@ def _load_shared_library(lib_base_name: str):
for _lib_path in _lib_paths: for _lib_path in _lib_paths:
if _lib_path.exists(): if _lib_path.exists():
try: try:
return ctypes.CDLL(str(_lib_path), **cdll_args) # type: ignore return ctypes.CDLL(str(_lib_path), **cdll_args) # type: ignore
except Exception as e: except Exception as e:
raise RuntimeError(f"Failed to load shared library '{_lib_path}': {e}") raise RuntimeError(f"Failed to load shared library '{_lib_path}': {e}")
@ -71,14 +82,39 @@ _lib_base_name = "llama"
# Load the library # Load the library
_lib = _load_shared_library(_lib_base_name) _lib = _load_shared_library(_lib_base_name)
# Misc
c_float_p = ctypes.POINTER(ctypes.c_float) # ctypes sane type hint helpers
c_uint8_p = ctypes.POINTER(ctypes.c_uint8) #
c_size_t_p = ctypes.POINTER(ctypes.c_size_t) # - Generic Pointer and Array types
# - PointerOrRef type with a type hinted byref function
#
# NOTE: Only use these for static type checking not for runtime checks
# no good will come of that
if TYPE_CHECKING:
CtypesCData = TypeVar("CtypesCData", bound=ctypes._CData) # type: ignore
CtypesArray: TypeAlias = ctypes.Array[CtypesCData] # type: ignore
CtypesPointer: TypeAlias = ctypes._Pointer[CtypesCData] # type: ignore
CtypesVoidPointer: TypeAlias = ctypes.c_void_p
class CtypesRef(Generic[CtypesCData]):
pass
CtypesPointerOrRef: TypeAlias = Union[
CtypesPointer[CtypesCData], CtypesRef[CtypesCData]
]
CtypesFuncPointer: TypeAlias = ctypes._FuncPointer # type: ignore
# from ggml-backend.h # from ggml-backend.h
# typedef bool (*ggml_backend_sched_eval_callback)(struct ggml_tensor * t, bool ask, void * user_data); # typedef bool (*ggml_backend_sched_eval_callback)(struct ggml_tensor * t, bool ask, void * user_data);
ggml_backend_sched_eval_callback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_void_p, ctypes.c_bool, ctypes.c_void_p) ggml_backend_sched_eval_callback = ctypes.CFUNCTYPE(
ctypes.c_bool, ctypes.c_void_p, ctypes.c_bool, ctypes.c_void_p
)
# llama.h bindings # llama.h bindings
@ -286,7 +322,9 @@ class llama_token_data_array(ctypes.Structure):
llama_token_data_array_p = ctypes.POINTER(llama_token_data_array) llama_token_data_array_p = ctypes.POINTER(llama_token_data_array)
# typedef bool (*llama_progress_callback)(float progress, void *ctx); # typedef bool (*llama_progress_callback)(float progress, void *ctx);
llama_progress_callback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_float, ctypes.c_void_p) llama_progress_callback = ctypes.CFUNCTYPE(
ctypes.c_bool, ctypes.c_float, ctypes.c_void_p
)
# // Input data for llama_decode # // Input data for llama_decode
@ -336,7 +374,7 @@ class llama_batch(ctypes.Structure):
_fields_ = [ _fields_ = [
("n_tokens", ctypes.c_int32), ("n_tokens", ctypes.c_int32),
("token", ctypes.POINTER(llama_token)), ("token", ctypes.POINTER(llama_token)),
("embd", c_float_p), ("embd", ctypes.POINTER(ctypes.c_float)),
("pos", ctypes.POINTER(llama_pos)), ("pos", ctypes.POINTER(llama_pos)),
("n_seq_id", ctypes.POINTER(ctypes.c_int32)), ("n_seq_id", ctypes.POINTER(ctypes.c_int32)),
("seq_id", ctypes.POINTER(ctypes.POINTER(llama_seq_id))), ("seq_id", ctypes.POINTER(ctypes.POINTER(llama_seq_id))),
@ -431,7 +469,7 @@ class llama_model_params(ctypes.Structure):
("n_gpu_layers", ctypes.c_int32), ("n_gpu_layers", ctypes.c_int32),
("split_mode", ctypes.c_int), ("split_mode", ctypes.c_int),
("main_gpu", ctypes.c_int32), ("main_gpu", ctypes.c_int32),
("tensor_split", c_float_p), ("tensor_split", ctypes.POINTER(ctypes.c_float)),
("progress_callback", llama_progress_callback), ("progress_callback", llama_progress_callback),
("progress_callback_user_data", ctypes.c_void_p), ("progress_callback_user_data", ctypes.c_void_p),
("kv_overrides", ctypes.POINTER(llama_model_kv_override)), ("kv_overrides", ctypes.POINTER(llama_model_kv_override)),
@ -532,7 +570,9 @@ class llama_context_params(ctypes.Structure):
# // if it exists. # // if it exists.
# // It might not exist for progress report where '.' is output repeatedly. # // It might not exist for progress report where '.' is output repeatedly.
# typedef void (*llama_log_callback)(enum llama_log_level level, const char * text, void * user_data); # typedef void (*llama_log_callback)(enum llama_log_level level, const char * text, void * user_data);
llama_log_callback = ctypes.CFUNCTYPE(None, ctypes.c_int, ctypes.c_char_p, ctypes.c_void_p) llama_log_callback = ctypes.CFUNCTYPE(
None, ctypes.c_int, ctypes.c_char_p, ctypes.c_void_p
)
"""Signature for logging events """Signature for logging events
Note that text includes the new line character at the end for most events. Note that text includes the new line character at the end for most events.
If your logging mechanism cannot handle that, check if the last character is '\n' and strip it If your logging mechanism cannot handle that, check if the last character is '\n' and strip it
@ -966,14 +1006,23 @@ llama_rope_freq_scale_train.restype = ctypes.c_float
# // Get metadata value as a string by key name # // Get metadata value as a string by key name
# LLAMA_API int32_t llama_model_meta_val_str(const struct llama_model * model, const char * key, char * buf, size_t buf_size); # LLAMA_API int32_t llama_model_meta_val_str(const struct llama_model * model, const char * key, char * buf, size_t buf_size);
def llama_model_meta_val_str( def llama_model_meta_val_str(
model: llama_model_p, key: Union[ctypes.c_char_p, bytes], buf: bytes, buf_size: int, / model: llama_model_p,
key: Union[ctypes.c_char_p, bytes],
buf: bytes,
buf_size: int,
/,
) -> int: ) -> int:
"""Get metadata value as a string by key name""" """Get metadata value as a string by key name"""
... ...
llama_model_meta_val_str = _lib.llama_model_meta_val_str llama_model_meta_val_str = _lib.llama_model_meta_val_str
llama_model_meta_val_str.argtypes = [llama_model_p_ctypes, ctypes.c_char_p, ctypes.c_char_p, ctypes.c_size_t] llama_model_meta_val_str.argtypes = [
llama_model_p_ctypes,
ctypes.c_char_p,
ctypes.c_char_p,
ctypes.c_size_t,
]
llama_model_meta_val_str.restype = ctypes.c_int32 llama_model_meta_val_str.restype = ctypes.c_int32
@ -1087,8 +1136,8 @@ llama_get_model_tensor.restype = ctypes.c_void_p
def llama_model_quantize( def llama_model_quantize(
fname_inp: bytes, fname_inp: bytes,
fname_out: bytes, fname_out: bytes,
params, # type: ctypes.POINTER(llama_model_quantize_params) # type: ignore params: CtypesPointerOrRef[llama_model_quantize_params],
/ /,
) -> int: ) -> int:
"""Returns 0 on success""" """Returns 0 on success"""
... ...
@ -1121,8 +1170,8 @@ def llama_apply_lora_from_file(
path_lora: Union[ctypes.c_char_p, bytes], path_lora: Union[ctypes.c_char_p, bytes],
scale: Union[ctypes.c_float, float], scale: Union[ctypes.c_float, float],
path_base_model: Union[ctypes.c_char_p, bytes], path_base_model: Union[ctypes.c_char_p, bytes],
n_threads: Union[ctypes.c_int, int], n_threads: Union[ctypes.c_int32, int],
/ /,
) -> int: ) -> int:
"""Apply a LoRA adapter to a loaded model """Apply a LoRA adapter to a loaded model
path_base_model is the path to a higher quality model to use as a base for path_base_model is the path to a higher quality model to use as a base for
@ -1155,8 +1204,8 @@ def llama_model_apply_lora_from_file(
path_lora: Union[ctypes.c_char_p, bytes], path_lora: Union[ctypes.c_char_p, bytes],
scale: Union[ctypes.c_float, float], scale: Union[ctypes.c_float, float],
path_base_model: Union[ctypes.c_char_p, bytes], path_base_model: Union[ctypes.c_char_p, bytes],
n_threads: Union[ctypes.c_int, int], n_threads: Union[ctypes.c_int32, int],
/ /,
) -> int: ) -> int:
... ...
@ -1262,7 +1311,7 @@ llama_kv_cache_view_free.restype = None
# // Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes) # // Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes)
# LLAMA_API void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_kv_cache_view * view); # LLAMA_API void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_kv_cache_view * view);
def llama_kv_cache_view_update(ctx: llama_context_p, view: "ctypes.pointer[llama_kv_cache_view]", /): # type: ignore def llama_kv_cache_view_update(ctx: llama_context_p, view: CtypesPointerOrRef[llama_kv_cache_view], /): # type: ignore
"""Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes)""" """Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes)"""
... ...
@ -1326,7 +1375,7 @@ def llama_kv_cache_seq_rm(
seq_id: Union[llama_seq_id, int], seq_id: Union[llama_seq_id, int],
p0: Union[llama_pos, int], p0: Union[llama_pos, int],
p1: Union[llama_pos, int], p1: Union[llama_pos, int],
/ /,
): ):
"""Removes all tokens that belong to the specified sequence and have positions in [p0, p1) """Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
seq_id < 0 : match any sequence seq_id < 0 : match any sequence
@ -1361,7 +1410,7 @@ def llama_kv_cache_seq_cp(
seq_id_dst: Union[llama_seq_id, int], seq_id_dst: Union[llama_seq_id, int],
p0: Union[llama_pos, int], p0: Union[llama_pos, int],
p1: Union[llama_pos, int], p1: Union[llama_pos, int],
/ /,
): ):
"""Copy all tokens that belong to the specified sequence to another sequence """Copy all tokens that belong to the specified sequence to another sequence
Note that this does not allocate extra KV cache memory - it simply assigns the tokens to the new sequence Note that this does not allocate extra KV cache memory - it simply assigns the tokens to the new sequence
@ -1385,11 +1434,7 @@ llama_kv_cache_seq_cp.restype = None
# LLAMA_API void llama_kv_cache_seq_keep( # LLAMA_API void llama_kv_cache_seq_keep(
# struct llama_context * ctx, # struct llama_context * ctx,
# llama_seq_id seq_id); # llama_seq_id seq_id);
def llama_kv_cache_seq_keep( def llama_kv_cache_seq_keep(ctx: llama_context_p, seq_id: Union[llama_seq_id, int], /):
ctx: llama_context_p,
seq_id: Union[llama_seq_id, int],
/
):
"""Removes all tokens that do not belong to the specified sequence""" """Removes all tokens that do not belong to the specified sequence"""
... ...
@ -1415,7 +1460,7 @@ def llama_kv_cache_seq_shift(
p0: Union[llama_pos, int], p0: Union[llama_pos, int],
p1: Union[llama_pos, int], p1: Union[llama_pos, int],
delta: Union[llama_pos, int], delta: Union[llama_pos, int],
/ /,
): ):
"""Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1) """Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
If the KV cache is RoPEd, the KV data is updated accordingly If the KV cache is RoPEd, the KV data is updated accordingly
@ -1451,7 +1496,7 @@ def llama_kv_cache_seq_div(
p0: Union[llama_pos, int], p0: Union[llama_pos, int],
p1: Union[llama_pos, int], p1: Union[llama_pos, int],
d: Union[ctypes.c_int, int], d: Union[ctypes.c_int, int],
/ /,
): ):
"""Integer division of the positions by factor of `d > 1` """Integer division of the positions by factor of `d > 1`
If the KV cache is RoPEd, the KV data is updated accordingly If the KV cache is RoPEd, the KV data is updated accordingly
@ -1496,8 +1541,7 @@ llama_get_state_size.restype = ctypes.c_size_t
# struct llama_context * ctx, # struct llama_context * ctx,
# uint8_t * dst); # uint8_t * dst);
def llama_copy_state_data( def llama_copy_state_data(
ctx: llama_context_p, dst, # type: Array[ctypes.c_uint8] ctx: llama_context_p, dst: CtypesArray[ctypes.c_uint8], /
/
) -> int: ) -> int:
"""Copies the state to the specified destination address. """Copies the state to the specified destination address.
Destination needs to have allocated enough memory. Destination needs to have allocated enough memory.
@ -1506,7 +1550,10 @@ def llama_copy_state_data(
llama_copy_state_data = _lib.llama_copy_state_data llama_copy_state_data = _lib.llama_copy_state_data
llama_copy_state_data.argtypes = [llama_context_p_ctypes, c_uint8_p] llama_copy_state_data.argtypes = [
llama_context_p_ctypes,
ctypes.POINTER(ctypes.c_uint8),
]
llama_copy_state_data.restype = ctypes.c_size_t llama_copy_state_data.restype = ctypes.c_size_t
@ -1516,15 +1563,14 @@ llama_copy_state_data.restype = ctypes.c_size_t
# struct llama_context * ctx, # struct llama_context * ctx,
# uint8_t * src); # uint8_t * src);
def llama_set_state_data( def llama_set_state_data(
ctx: llama_context_p, src, # type: Array[ctypes.c_uint8] ctx: llama_context_p, src: CtypesArray[ctypes.c_uint8], /
/
) -> int: ) -> int:
"""Set the state reading from the specified address""" """Set the state reading from the specified address"""
... ...
llama_set_state_data = _lib.llama_set_state_data llama_set_state_data = _lib.llama_set_state_data
llama_set_state_data.argtypes = [llama_context_p_ctypes, c_uint8_p] llama_set_state_data.argtypes = [llama_context_p_ctypes, ctypes.POINTER(ctypes.c_uint8)]
llama_set_state_data.restype = ctypes.c_size_t llama_set_state_data.restype = ctypes.c_size_t
@ -1538,10 +1584,10 @@ llama_set_state_data.restype = ctypes.c_size_t
def llama_load_session_file( def llama_load_session_file(
ctx: llama_context_p, ctx: llama_context_p,
path_session: bytes, path_session: bytes,
tokens_out, # type: Array[llama_token] tokens_out: CtypesArray[llama_token],
n_token_capacity: Union[ctypes.c_size_t, int], n_token_capacity: Union[ctypes.c_size_t, int],
n_token_count_out, # type: _Pointer[ctypes.c_size_t] n_token_count_out: CtypesPointerOrRef[ctypes.c_size_t],
/ /,
) -> int: ) -> int:
... ...
@ -1552,7 +1598,7 @@ llama_load_session_file.argtypes = [
ctypes.c_char_p, ctypes.c_char_p,
llama_token_p, llama_token_p,
ctypes.c_size_t, ctypes.c_size_t,
c_size_t_p, ctypes.POINTER(ctypes.c_size_t),
] ]
llama_load_session_file.restype = ctypes.c_size_t llama_load_session_file.restype = ctypes.c_size_t
@ -1565,9 +1611,9 @@ llama_load_session_file.restype = ctypes.c_size_t
def llama_save_session_file( def llama_save_session_file(
ctx: llama_context_p, ctx: llama_context_p,
path_session: bytes, path_session: bytes,
tokens, # type: Array[llama_token] tokens: CtypesArray[llama_token],
n_token_count: Union[ctypes.c_size_t, int], n_token_count: Union[ctypes.c_size_t, int],
/ /,
) -> int: ) -> int:
... ...
@ -1599,10 +1645,10 @@ llama_save_session_file.restype = ctypes.c_size_t
# "use llama_decode() instead"); # "use llama_decode() instead");
def llama_eval( def llama_eval(
ctx: llama_context_p, ctx: llama_context_p,
tokens, # type: Array[llama_token] tokens: CtypesArray[llama_token],
n_tokens: Union[ctypes.c_int, int], n_tokens: Union[ctypes.c_int, int],
n_past: Union[ctypes.c_int, int], n_past: Union[ctypes.c_int, int],
/ /,
) -> int: ) -> int:
"""Run the llama inference to obtain the logits and probabilities for the next token(s). """Run the llama inference to obtain the logits and probabilities for the next token(s).
tokens + n_tokens is the provided batch of new tokens to process tokens + n_tokens is the provided batch of new tokens to process
@ -1613,7 +1659,12 @@ def llama_eval(
llama_eval = _lib.llama_eval llama_eval = _lib.llama_eval
llama_eval.argtypes = [llama_context_p_ctypes, llama_token_p, ctypes.c_int32, ctypes.c_int32] llama_eval.argtypes = [
llama_context_p_ctypes,
llama_token_p,
ctypes.c_int32,
ctypes.c_int32,
]
llama_eval.restype = ctypes.c_int llama_eval.restype = ctypes.c_int
@ -1627,10 +1678,10 @@ llama_eval.restype = ctypes.c_int
# "use llama_decode() instead"); # "use llama_decode() instead");
def llama_eval_embd( def llama_eval_embd(
ctx: llama_context_p, ctx: llama_context_p,
embd, # type: Array[ctypes.c_float] embd: CtypesArray[ctypes.c_float],
n_tokens: Union[ctypes.c_int, int], n_tokens: Union[ctypes.c_int, int],
n_past: Union[ctypes.c_int, int], n_past: Union[ctypes.c_int, int],
/ /,
) -> int: ) -> int:
"""Same as llama_eval, but use float matrix input directly. """Same as llama_eval, but use float matrix input directly.
DEPRECATED: use llama_decode() instead""" DEPRECATED: use llama_decode() instead"""
@ -1638,7 +1689,12 @@ def llama_eval_embd(
llama_eval_embd = _lib.llama_eval_embd llama_eval_embd = _lib.llama_eval_embd
llama_eval_embd.argtypes = [llama_context_p_ctypes, c_float_p, ctypes.c_int32, ctypes.c_int32] llama_eval_embd.argtypes = [
llama_context_p_ctypes,
ctypes.POINTER(ctypes.c_float),
ctypes.c_int32,
ctypes.c_int32,
]
llama_eval_embd.restype = ctypes.c_int llama_eval_embd.restype = ctypes.c_int
@ -1652,11 +1708,11 @@ llama_eval_embd.restype = ctypes.c_int
# llama_pos pos_0, # llama_pos pos_0,
# llama_seq_id seq_id); # llama_seq_id seq_id);
def llama_batch_get_one( def llama_batch_get_one(
tokens, # type: Array[llama_token] tokens: CtypesArray[llama_token],
n_tokens: Union[ctypes.c_int, int], n_tokens: Union[ctypes.c_int, int],
pos_0: Union[llama_pos, int], pos_0: Union[llama_pos, int],
seq_id: llama_seq_id, seq_id: llama_seq_id,
/ /,
) -> llama_batch: ) -> llama_batch:
"""Return batch for single sequence of tokens starting at pos_0 """Return batch for single sequence of tokens starting at pos_0
@ -1690,7 +1746,7 @@ def llama_batch_init(
n_tokens: Union[ctypes.c_int32, int], n_tokens: Union[ctypes.c_int32, int],
embd: Union[ctypes.c_int32, int], embd: Union[ctypes.c_int32, int],
n_seq_max: Union[ctypes.c_int32, int], n_seq_max: Union[ctypes.c_int32, int],
/ /,
) -> llama_batch: ) -> llama_batch:
"""Allocates a batch of tokens on the heap that can hold a maximum of n_tokens """Allocates a batch of tokens on the heap that can hold a maximum of n_tokens
Each token can be assigned up to n_seq_max sequence ids Each token can be assigned up to n_seq_max sequence ids
@ -1747,7 +1803,7 @@ def llama_set_n_threads(
ctx: llama_context_p, ctx: llama_context_p,
n_threads: Union[ctypes.c_uint32, int], n_threads: Union[ctypes.c_uint32, int],
n_threads_batch: Union[ctypes.c_uint32, int], n_threads_batch: Union[ctypes.c_uint32, int],
/ /,
): ):
"""Set the number of threads used for decoding """Set the number of threads used for decoding
n_threads is the number of threads used for generation (single token) n_threads is the number of threads used for generation (single token)
@ -1757,7 +1813,11 @@ def llama_set_n_threads(
llama_set_n_threads = _lib.llama_set_n_threads llama_set_n_threads = _lib.llama_set_n_threads
llama_set_n_threads.argtypes = [llama_context_p_ctypes, ctypes.c_uint32, ctypes.c_uint32] llama_set_n_threads.argtypes = [
llama_context_p_ctypes,
ctypes.c_uint32,
ctypes.c_uint32,
]
llama_set_n_threads.restype = None llama_set_n_threads.restype = None
@ -1768,8 +1828,7 @@ llama_set_n_threads.restype = None
# // Cols: n_vocab # // Cols: n_vocab
# LLAMA_API float * llama_get_logits(struct llama_context * ctx); # LLAMA_API float * llama_get_logits(struct llama_context * ctx);
def llama_get_logits( def llama_get_logits(
ctx: llama_context_p, ctx: llama_context_p, /
/
): # type: (...) -> Array[float] # type: ignore ): # type: (...) -> Array[float] # type: ignore
"""Token logits obtained from the last call to llama_eval() """Token logits obtained from the last call to llama_eval()
The logits for the last token are stored in the last row The logits for the last token are stored in the last row
@ -1781,15 +1840,14 @@ def llama_get_logits(
llama_get_logits = _lib.llama_get_logits llama_get_logits = _lib.llama_get_logits
llama_get_logits.argtypes = [llama_context_p_ctypes] llama_get_logits.argtypes = [llama_context_p_ctypes]
llama_get_logits.restype = c_float_p llama_get_logits.restype = ctypes.POINTER(ctypes.c_float)
# // Logits for the ith token. Equivalent to: # // Logits for the ith token. Equivalent to:
# // llama_get_logits(ctx) + i*n_vocab # // llama_get_logits(ctx) + i*n_vocab
# LLAMA_API float * llama_get_logits_ith(struct llama_context * ctx, int32_t i); # LLAMA_API float * llama_get_logits_ith(struct llama_context * ctx, int32_t i);
def llama_get_logits_ith( def llama_get_logits_ith(
ctx: llama_context_p, i: Union[ctypes.c_int32, int] ctx: llama_context_p, i: Union[ctypes.c_int32, int], /
, /
): # type: (...) -> Array[float] # type: ignore ): # type: (...) -> Array[float] # type: ignore
"""Logits for the ith token. Equivalent to: """Logits for the ith token. Equivalent to:
llama_get_logits(ctx) + i*n_vocab""" llama_get_logits(ctx) + i*n_vocab"""
@ -1798,7 +1856,7 @@ def llama_get_logits_ith(
llama_get_logits_ith = _lib.llama_get_logits_ith llama_get_logits_ith = _lib.llama_get_logits_ith
llama_get_logits_ith.argtypes = [llama_context_p_ctypes, ctypes.c_int32] llama_get_logits_ith.argtypes = [llama_context_p_ctypes, ctypes.c_int32]
llama_get_logits_ith.restype = c_float_p llama_get_logits_ith.restype = ctypes.POINTER(ctypes.c_float)
# Get the embeddings for the input # Get the embeddings for the input
@ -1814,7 +1872,7 @@ def llama_get_embeddings(
llama_get_embeddings = _lib.llama_get_embeddings llama_get_embeddings = _lib.llama_get_embeddings
llama_get_embeddings.argtypes = [llama_context_p_ctypes] llama_get_embeddings.argtypes = [llama_context_p_ctypes]
llama_get_embeddings.restype = c_float_p llama_get_embeddings.restype = ctypes.POINTER(ctypes.c_float)
# // Get the embeddings for the ith sequence # // Get the embeddings for the ith sequence
@ -1830,7 +1888,7 @@ def llama_get_embeddings_ith(
llama_get_embeddings_ith = _lib.llama_get_embeddings_ith llama_get_embeddings_ith = _lib.llama_get_embeddings_ith
llama_get_embeddings_ith.argtypes = [llama_context_p_ctypes, ctypes.c_int32] llama_get_embeddings_ith.argtypes = [llama_context_p_ctypes, ctypes.c_int32]
llama_get_embeddings_ith.restype = c_float_p llama_get_embeddings_ith.restype = ctypes.POINTER(ctypes.c_float)
# // # //
@ -1839,7 +1897,9 @@ llama_get_embeddings_ith.restype = c_float_p
# LLAMA_API const char * llama_token_get_text(const struct llama_model * model, llama_token token); # LLAMA_API const char * llama_token_get_text(const struct llama_model * model, llama_token token);
def llama_token_get_text(model: llama_model_p, token: Union[llama_token, int], /) -> bytes: def llama_token_get_text(
model: llama_model_p, token: Union[llama_token, int], /
) -> bytes:
... ...
@ -1861,7 +1921,9 @@ llama_token_get_score.restype = ctypes.c_float
# LLAMA_API enum llama_token_type llama_token_get_type(const struct llama_model * model, llama_token token); # LLAMA_API enum llama_token_type llama_token_get_type(const struct llama_model * model, llama_token token);
def llama_token_get_type(model: llama_model_p, token: Union[llama_token, int], /) -> int: def llama_token_get_type(
model: llama_model_p, token: Union[llama_token, int], /
) -> int:
... ...
@ -1995,11 +2057,11 @@ def llama_tokenize(
model: llama_model_p, model: llama_model_p,
text: bytes, text: bytes,
text_len: Union[ctypes.c_int, int], text_len: Union[ctypes.c_int, int],
tokens, # type: Array[llama_token] tokens: CtypesArray[llama_token],
n_max_tokens: Union[ctypes.c_int, int], n_max_tokens: Union[ctypes.c_int, int],
add_bos: Union[ctypes.c_bool, bool], add_bos: Union[ctypes.c_bool, bool],
special: Union[ctypes.c_bool, bool], special: Union[ctypes.c_bool, bool],
/ /,
) -> int: ) -> int:
"""Convert the provided text into tokens.""" """Convert the provided text into tokens."""
... ...
@ -2032,7 +2094,7 @@ def llama_token_to_piece(
token: Union[llama_token, int], token: Union[llama_token, int],
buf: Union[ctypes.c_char_p, bytes], buf: Union[ctypes.c_char_p, bytes],
length: Union[ctypes.c_int, int], length: Union[ctypes.c_int, int],
/ /,
) -> int: ) -> int:
"""Token Id -> Piece. """Token Id -> Piece.
Uses the vocabulary in the provided context. Uses the vocabulary in the provided context.
@ -2043,7 +2105,12 @@ def llama_token_to_piece(
llama_token_to_piece = _lib.llama_token_to_piece llama_token_to_piece = _lib.llama_token_to_piece
llama_token_to_piece.argtypes = [llama_model_p_ctypes, llama_token, ctypes.c_char_p, ctypes.c_int32] llama_token_to_piece.argtypes = [
llama_model_p_ctypes,
llama_token,
ctypes.c_char_p,
ctypes.c_int32,
]
llama_token_to_piece.restype = ctypes.c_int32 llama_token_to_piece.restype = ctypes.c_int32
@ -2066,25 +2133,25 @@ llama_token_to_piece.restype = ctypes.c_int32
# char * buf, # char * buf,
# int32_t length); # int32_t length);
def llama_chat_apply_template( def llama_chat_apply_template(
model: llama_model_p, model: llama_model_p,
tmpl: bytes, tmpl: bytes,
chat: "ctypes._Pointer[llama_chat_message]", # type: ignore chat: CtypesArray[llama_chat_message],
n_msg: int, n_msg: int,
/ /,
) -> int: ) -> int:
... ...
llama_chat_apply_template = _lib.llama_chat_apply_template llama_chat_apply_template = _lib.llama_chat_apply_template
llama_chat_apply_template.argtypes = [ llama_chat_apply_template.argtypes = [
ctypes.c_void_p, ctypes.c_void_p,
ctypes.c_char_p, ctypes.c_char_p,
ctypes.POINTER(llama_chat_message), ctypes.POINTER(llama_chat_message),
ctypes.c_size_t ctypes.c_size_t,
] ]
llama_chat_apply_template.restype = ctypes.c_int32 llama_chat_apply_template.restype = ctypes.c_int32
# // # //
# // Grammar # // Grammar
# // # //
@ -2095,10 +2162,12 @@ llama_chat_apply_template.restype = ctypes.c_int32
# size_t n_rules, # size_t n_rules,
# size_t start_rule_index); # size_t start_rule_index);
def llama_grammar_init( def llama_grammar_init(
rules, # type: Array[llama_grammar_element_p] # type: ignore rules: CtypesArray[
CtypesPointer[llama_grammar_element]
], # NOTE: This might be wrong type sig
n_rules: Union[ctypes.c_size_t, int], n_rules: Union[ctypes.c_size_t, int],
start_rule_index: Union[ctypes.c_size_t, int], start_rule_index: Union[ctypes.c_size_t, int],
/ /,
) -> llama_grammar_p: ) -> llama_grammar_p:
"""Initialize a grammar from a set of rules.""" """Initialize a grammar from a set of rules."""
... ...
@ -2163,13 +2232,15 @@ llama_set_rng_seed.restype = None
# float penalty_present); # float penalty_present);
def llama_sample_repetition_penalties( def llama_sample_repetition_penalties(
ctx: llama_context_p, ctx: llama_context_p,
candidates, # type: _Pointer[llama_token_data_array] candidates: Union[
last_tokens_data, # type: Array[llama_token] CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array]
],
last_tokens_data: CtypesArray[llama_token],
penalty_last_n: Union[ctypes.c_size_t, int], penalty_last_n: Union[ctypes.c_size_t, int],
penalty_repeat: Union[ctypes.c_float, float], penalty_repeat: Union[ctypes.c_float, float],
penalty_freq: Union[ctypes.c_float, float], penalty_freq: Union[ctypes.c_float, float],
penalty_present: Union[ctypes.c_float, float], penalty_present: Union[ctypes.c_float, float],
/ /,
): ):
"""Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. """Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details. Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
@ -2201,10 +2272,10 @@ llama_sample_repetition_penalties.restype = None
# float scale); # float scale);
def llama_sample_apply_guidance( def llama_sample_apply_guidance(
ctx: llama_context_p, ctx: llama_context_p,
logits, # type: _Pointer[ctypes.c_float] logits: CtypesArray[ctypes.c_float],
logits_guidance, # type: _Pointer[ctypes.c_float] logits_guidance: CtypesArray[ctypes.c_float],
scale: Union[ctypes.c_float, float], scale: Union[ctypes.c_float, float],
/ /,
): ):
"""Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806""" """Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806"""
... ...
@ -2213,8 +2284,8 @@ def llama_sample_apply_guidance(
llama_sample_apply_guidance = _lib.llama_sample_apply_guidance llama_sample_apply_guidance = _lib.llama_sample_apply_guidance
llama_sample_apply_guidance.argtypes = [ llama_sample_apply_guidance.argtypes = [
llama_context_p_ctypes, llama_context_p_ctypes,
c_float_p, ctypes.POINTER(ctypes.c_float),
c_float_p, ctypes.POINTER(ctypes.c_float),
ctypes.c_float, ctypes.c_float,
] ]
llama_sample_apply_guidance.restype = None llama_sample_apply_guidance.restype = None
@ -2228,10 +2299,12 @@ llama_sample_apply_guidance.restype = None
# "use llama_sample_apply_guidance() instead"); # "use llama_sample_apply_guidance() instead");
def llama_sample_classifier_free_guidance( def llama_sample_classifier_free_guidance(
ctx: llama_context_p, ctx: llama_context_p,
candidates, # type: _Pointer[llama_token_data_array] candidates: Union[
CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array]
],
guidance_ctx: llama_context_p, guidance_ctx: llama_context_p,
scale: Union[ctypes.c_float, float], scale: Union[ctypes.c_float, float],
/ /,
): ):
"""Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806""" """Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806"""
... ...
@ -2252,8 +2325,11 @@ llama_sample_classifier_free_guidance.restype = None
# struct llama_context * ctx, # struct llama_context * ctx,
# llama_token_data_array * candidates); # llama_token_data_array * candidates);
def llama_sample_softmax( def llama_sample_softmax(
ctx: llama_context_p, candidates, # type: _Pointer[llama_token_data] ctx: llama_context_p,
/ candidates: Union[
CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array]
],
/,
): ):
"""Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.""" """Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits."""
... ...
@ -2275,10 +2351,12 @@ llama_sample_softmax.restype = None
# size_t min_keep); # size_t min_keep);
def llama_sample_top_k( def llama_sample_top_k(
ctx: llama_context_p, ctx: llama_context_p,
candidates, # type: _Pointer[llama_token_data_array] candidates: Union[
CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array]
],
k: Union[ctypes.c_int, int], k: Union[ctypes.c_int, int],
min_keep: Union[ctypes.c_size_t, int], min_keep: Union[ctypes.c_size_t, int],
/ /,
): ):
"""Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751""" """Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751"""
... ...
@ -2302,10 +2380,12 @@ llama_sample_top_k.restype = None
# size_t min_keep); # size_t min_keep);
def llama_sample_top_p( def llama_sample_top_p(
ctx: llama_context_p, ctx: llama_context_p,
candidates, # type: _Pointer[llama_token_data_array] candidates: Union[
CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array]
],
p: Union[ctypes.c_float, float], p: Union[ctypes.c_float, float],
min_keep: Union[ctypes.c_size_t, int], min_keep: Union[ctypes.c_size_t, int],
/ /,
): ):
"""Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751""" """Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751"""
... ...
@ -2329,10 +2409,12 @@ llama_sample_top_p.restype = None
# size_t min_keep); # size_t min_keep);
def llama_sample_min_p( def llama_sample_min_p(
ctx: llama_context_p, ctx: llama_context_p,
candidates, # type: _Pointer[llama_token_data_array] candidates: Union[
CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array]
],
p: Union[ctypes.c_float, float], p: Union[ctypes.c_float, float],
min_keep: Union[ctypes.c_size_t, int], min_keep: Union[ctypes.c_size_t, int],
/ /,
): ):
"""Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841""" """Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841"""
... ...
@ -2356,10 +2438,12 @@ llama_sample_min_p.restype = None
# size_t min_keep); # size_t min_keep);
def llama_sample_tail_free( def llama_sample_tail_free(
ctx: llama_context_p, ctx: llama_context_p,
candidates, # type: _Pointer[llama_token_data_array] candidates: Union[
CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array]
],
z: Union[ctypes.c_float, float], z: Union[ctypes.c_float, float],
min_keep: Union[ctypes.c_size_t, int], min_keep: Union[ctypes.c_size_t, int],
/ /,
): ):
"""Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.""" """Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/."""
... ...
@ -2383,10 +2467,12 @@ llama_sample_tail_free.restype = None
# size_t min_keep); # size_t min_keep);
def llama_sample_typical( def llama_sample_typical(
ctx: llama_context_p, ctx: llama_context_p,
candidates, # type: _Pointer[llama_token_data_array] candidates: Union[
CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array]
],
p: Union[ctypes.c_float, float], p: Union[ctypes.c_float, float],
min_keep: Union[ctypes.c_size_t, int], min_keep: Union[ctypes.c_size_t, int],
/ /,
): ):
"""Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.""" """Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666."""
... ...
@ -2411,11 +2497,13 @@ llama_sample_typical.restype = None
# float exponent_val); # float exponent_val);
def llama_sample_entropy( def llama_sample_entropy(
ctx: llama_context_p, ctx: llama_context_p,
candidates, # type: _Pointer[llama_token_data_array] candidates: Union[
CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array]
],
min_temp: Union[ctypes.c_float, float], min_temp: Union[ctypes.c_float, float],
max_temp: Union[ctypes.c_float, float], max_temp: Union[ctypes.c_float, float],
exponent_val: Union[ctypes.c_float, float], exponent_val: Union[ctypes.c_float, float],
/ /,
): ):
"""Dynamic temperature implementation described in the paper https://arxiv.org/abs/2309.02772.""" """Dynamic temperature implementation described in the paper https://arxiv.org/abs/2309.02772."""
... ...
@ -2438,9 +2526,11 @@ llama_sample_entropy.restype = None
# float temp); # float temp);
def llama_sample_temp( def llama_sample_temp(
ctx: llama_context_p, ctx: llama_context_p,
candidates, # type: _Pointer[llama_token_data_array] candidates: Union[
CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array]
],
temp: Union[ctypes.c_float, float], temp: Union[ctypes.c_float, float],
/ /,
): ):
"""Temperature sampling described in academic paper "Generating Long Sequences with Sparse Transformers" https://arxiv.org/abs/1904.10509 """Temperature sampling described in academic paper "Generating Long Sequences with Sparse Transformers" https://arxiv.org/abs/1904.10509
@ -2467,9 +2557,11 @@ llama_sample_temp.restype = None
# "use llama_sample_temp instead"); # "use llama_sample_temp instead");
def llama_sample_temperature( def llama_sample_temperature(
ctx: llama_context_p, ctx: llama_context_p,
candidates, # type: _Pointer[llama_token_data_array] candidates: Union[
CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array]
],
temp: Union[ctypes.c_float, float], temp: Union[ctypes.c_float, float],
/ /,
): ):
"""use llama_sample_temp instead""" """use llama_sample_temp instead"""
... ...
@ -2491,9 +2583,11 @@ llama_sample_temperature.restype = None
# const struct llama_grammar * grammar); # const struct llama_grammar * grammar);
def llama_sample_grammar( def llama_sample_grammar(
ctx: llama_context_p, ctx: llama_context_p,
candidates, # type: _Pointer[llama_token_data_array] candidates: Union[
CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array]
],
grammar, # type: llama_grammar_p grammar, # type: llama_grammar_p
/ /,
): ):
"""Apply constraints from grammar """Apply constraints from grammar
@ -2528,12 +2622,14 @@ llama_sample_grammar.restype = None
# float * mu); # float * mu);
def llama_sample_token_mirostat( def llama_sample_token_mirostat(
ctx: llama_context_p, ctx: llama_context_p,
candidates, # type: _Pointer[llama_token_data_array] candidates: Union[
CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array]
],
tau: Union[ctypes.c_float, float], tau: Union[ctypes.c_float, float],
eta: Union[ctypes.c_float, float], eta: Union[ctypes.c_float, float],
m: Union[ctypes.c_int, int], m: Union[ctypes.c_int, int],
mu, # type: _Pointer[ctypes.c_float] mu: CtypesPointerOrRef[ctypes.c_float],
/ /,
) -> int: ) -> int:
"""Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. """Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
@ -2554,7 +2650,7 @@ llama_sample_token_mirostat.argtypes = [
ctypes.c_float, ctypes.c_float,
ctypes.c_float, ctypes.c_float,
ctypes.c_int32, ctypes.c_int32,
c_float_p, ctypes.POINTER(ctypes.c_float),
] ]
llama_sample_token_mirostat.restype = llama_token llama_sample_token_mirostat.restype = llama_token
@ -2572,11 +2668,13 @@ llama_sample_token_mirostat.restype = llama_token
# float * mu); # float * mu);
def llama_sample_token_mirostat_v2( def llama_sample_token_mirostat_v2(
ctx: llama_context_p, ctx: llama_context_p,
candidates, # type: _Pointer[llama_token_data_array] candidates: Union[
CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array]
],
tau: Union[ctypes.c_float, float], tau: Union[ctypes.c_float, float],
eta: Union[ctypes.c_float, float], eta: Union[ctypes.c_float, float],
mu, # type: _Pointer[ctypes.c_float] mu, # type: _Pointer[ctypes.c_float]
/ /,
) -> int: ) -> int:
"""Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. """Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
@ -2595,7 +2693,7 @@ llama_sample_token_mirostat_v2.argtypes = [
llama_token_data_array_p, llama_token_data_array_p,
ctypes.c_float, ctypes.c_float,
ctypes.c_float, ctypes.c_float,
c_float_p, ctypes.POINTER(ctypes.c_float),
] ]
llama_sample_token_mirostat_v2.restype = llama_token llama_sample_token_mirostat_v2.restype = llama_token
@ -2607,8 +2705,10 @@ llama_sample_token_mirostat_v2.restype = llama_token
# llama_token_data_array * candidates); # llama_token_data_array * candidates);
def llama_sample_token_greedy( def llama_sample_token_greedy(
ctx: llama_context_p, ctx: llama_context_p,
candidates, # type: _Pointer[llama_token_data_array] candidates: Union[
/ CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array]
],
/,
) -> int: ) -> int:
"""Selects the token with the highest probability.""" """Selects the token with the highest probability."""
... ...
@ -2628,8 +2728,10 @@ llama_sample_token_greedy.restype = llama_token
# llama_token_data_array * candidates); # llama_token_data_array * candidates);
def llama_sample_token( def llama_sample_token(
ctx: llama_context_p, ctx: llama_context_p,
candidates, # type: _Pointer[llama_token_data_array] candidates: Union[
/ CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array]
],
/,
) -> int: ) -> int:
"""Randomly selects a token from the candidates based on their probabilities.""" """Randomly selects a token from the candidates based on their probabilities."""
... ...
@ -2649,10 +2751,7 @@ llama_sample_token.restype = llama_token
# struct llama_grammar * grammar, # struct llama_grammar * grammar,
# llama_token token); # llama_token token);
def llama_grammar_accept_token( def llama_grammar_accept_token(
ctx: llama_context_p, ctx: llama_context_p, grammar: llama_grammar_p, token: Union[llama_token, int], /
grammar: llama_grammar_p,
token: Union[llama_token, int],
/
) -> None: ) -> None:
"""Accepts the sampled token into the grammar""" """Accepts the sampled token into the grammar"""
... ...
@ -2711,7 +2810,9 @@ class llama_beams_state(ctypes.Structure):
# // void* callback_data is any custom data passed to llama_beam_search, that is subsequently # // void* callback_data is any custom data passed to llama_beam_search, that is subsequently
# // passed back to beam_search_callback. This avoids having to use global variables in the callback. # // passed back to beam_search_callback. This avoids having to use global variables in the callback.
# typedef void (*llama_beam_search_callback_fn_t)(void * callback_data, struct llama_beams_state); # typedef void (*llama_beam_search_callback_fn_t)(void * callback_data, struct llama_beams_state);
llama_beam_search_callback_fn_t = ctypes.CFUNCTYPE(None, ctypes.c_void_p, llama_beams_state) llama_beam_search_callback_fn_t = ctypes.CFUNCTYPE(
None, ctypes.c_void_p, llama_beams_state
)
# /// @details Deterministically returns entire sentence constructed by a beam search. # /// @details Deterministically returns entire sentence constructed by a beam search.
@ -2731,12 +2832,12 @@ llama_beam_search_callback_fn_t = ctypes.CFUNCTYPE(None, ctypes.c_void_p, llama_
# int32_t n_predict); # int32_t n_predict);
def llama_beam_search( def llama_beam_search(
ctx: llama_context_p, ctx: llama_context_p,
callback: "ctypes._CFuncPtr[None, ctypes.c_void_p, llama_beams_state]", # type: ignore callback: CtypesFuncPointer,
callback_data: ctypes.c_void_p, callback_data: ctypes.c_void_p,
n_beams: Union[ctypes.c_size_t, int], n_beams: Union[ctypes.c_size_t, int],
n_past: Union[ctypes.c_int, int], n_past: Union[ctypes.c_int, int],
n_predict: Union[ctypes.c_int, int], n_predict: Union[ctypes.c_int, int],
/ /,
): ):
... ...
@ -2806,8 +2907,9 @@ llama_print_system_info.restype = ctypes.c_char_p
# // If this is not called, or NULL is supplied, everything is output on stderr. # // If this is not called, or NULL is supplied, everything is output on stderr.
# LLAMA_API void llama_log_set(ggml_log_callback log_callback, void * user_data); # LLAMA_API void llama_log_set(ggml_log_callback log_callback, void * user_data);
def llama_log_set( def llama_log_set(
log_callback: Union["ctypes._FuncPointer", ctypes.c_void_p], user_data: ctypes.c_void_p, # type: ignore log_callback: Optional[CtypesFuncPointer],
/ user_data: ctypes.c_void_p, # type: ignore
/,
): ):
"""Set callback for all future logging events. """Set callback for all future logging events.