misc: additional type annotations for low level api
This commit is contained in:
parent
3632241e98
commit
aefcb8f71a
1 changed files with 227 additions and 125 deletions
|
@ -1,3 +1,5 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import os
|
||||
import ctypes
|
||||
|
@ -6,7 +8,16 @@ from ctypes import (
|
|||
Array,
|
||||
)
|
||||
import pathlib
|
||||
from typing import List, Union, NewType, Optional
|
||||
from typing import (
|
||||
List,
|
||||
Union,
|
||||
NewType,
|
||||
Optional,
|
||||
TYPE_CHECKING,
|
||||
TypeVar,
|
||||
TypeAlias,
|
||||
Generic,
|
||||
)
|
||||
|
||||
|
||||
# Load the library
|
||||
|
@ -56,7 +67,7 @@ def _load_shared_library(lib_base_name: str):
|
|||
for _lib_path in _lib_paths:
|
||||
if _lib_path.exists():
|
||||
try:
|
||||
return ctypes.CDLL(str(_lib_path), **cdll_args) # type: ignore
|
||||
return ctypes.CDLL(str(_lib_path), **cdll_args) # type: ignore
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load shared library '{_lib_path}': {e}")
|
||||
|
||||
|
@ -71,14 +82,39 @@ _lib_base_name = "llama"
|
|||
# Load the library
|
||||
_lib = _load_shared_library(_lib_base_name)
|
||||
|
||||
# Misc
|
||||
c_float_p = ctypes.POINTER(ctypes.c_float)
|
||||
c_uint8_p = ctypes.POINTER(ctypes.c_uint8)
|
||||
c_size_t_p = ctypes.POINTER(ctypes.c_size_t)
|
||||
|
||||
# ctypes sane type hint helpers
|
||||
#
|
||||
# - Generic Pointer and Array types
|
||||
# - PointerOrRef type with a type hinted byref function
|
||||
#
|
||||
# NOTE: Only use these for static type checking not for runtime checks
|
||||
# no good will come of that
|
||||
|
||||
if TYPE_CHECKING:
|
||||
CtypesCData = TypeVar("CtypesCData", bound=ctypes._CData) # type: ignore
|
||||
|
||||
CtypesArray: TypeAlias = ctypes.Array[CtypesCData] # type: ignore
|
||||
|
||||
CtypesPointer: TypeAlias = ctypes._Pointer[CtypesCData] # type: ignore
|
||||
|
||||
CtypesVoidPointer: TypeAlias = ctypes.c_void_p
|
||||
|
||||
class CtypesRef(Generic[CtypesCData]):
|
||||
pass
|
||||
|
||||
CtypesPointerOrRef: TypeAlias = Union[
|
||||
CtypesPointer[CtypesCData], CtypesRef[CtypesCData]
|
||||
]
|
||||
|
||||
CtypesFuncPointer: TypeAlias = ctypes._FuncPointer # type: ignore
|
||||
|
||||
|
||||
# from ggml-backend.h
|
||||
# typedef bool (*ggml_backend_sched_eval_callback)(struct ggml_tensor * t, bool ask, void * user_data);
|
||||
ggml_backend_sched_eval_callback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_void_p, ctypes.c_bool, ctypes.c_void_p)
|
||||
ggml_backend_sched_eval_callback = ctypes.CFUNCTYPE(
|
||||
ctypes.c_bool, ctypes.c_void_p, ctypes.c_bool, ctypes.c_void_p
|
||||
)
|
||||
|
||||
# llama.h bindings
|
||||
|
||||
|
@ -286,7 +322,9 @@ class llama_token_data_array(ctypes.Structure):
|
|||
llama_token_data_array_p = ctypes.POINTER(llama_token_data_array)
|
||||
|
||||
# typedef bool (*llama_progress_callback)(float progress, void *ctx);
|
||||
llama_progress_callback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_float, ctypes.c_void_p)
|
||||
llama_progress_callback = ctypes.CFUNCTYPE(
|
||||
ctypes.c_bool, ctypes.c_float, ctypes.c_void_p
|
||||
)
|
||||
|
||||
|
||||
# // Input data for llama_decode
|
||||
|
@ -336,7 +374,7 @@ class llama_batch(ctypes.Structure):
|
|||
_fields_ = [
|
||||
("n_tokens", ctypes.c_int32),
|
||||
("token", ctypes.POINTER(llama_token)),
|
||||
("embd", c_float_p),
|
||||
("embd", ctypes.POINTER(ctypes.c_float)),
|
||||
("pos", ctypes.POINTER(llama_pos)),
|
||||
("n_seq_id", ctypes.POINTER(ctypes.c_int32)),
|
||||
("seq_id", ctypes.POINTER(ctypes.POINTER(llama_seq_id))),
|
||||
|
@ -431,7 +469,7 @@ class llama_model_params(ctypes.Structure):
|
|||
("n_gpu_layers", ctypes.c_int32),
|
||||
("split_mode", ctypes.c_int),
|
||||
("main_gpu", ctypes.c_int32),
|
||||
("tensor_split", c_float_p),
|
||||
("tensor_split", ctypes.POINTER(ctypes.c_float)),
|
||||
("progress_callback", llama_progress_callback),
|
||||
("progress_callback_user_data", ctypes.c_void_p),
|
||||
("kv_overrides", ctypes.POINTER(llama_model_kv_override)),
|
||||
|
@ -532,7 +570,9 @@ class llama_context_params(ctypes.Structure):
|
|||
# // if it exists.
|
||||
# // It might not exist for progress report where '.' is output repeatedly.
|
||||
# typedef void (*llama_log_callback)(enum llama_log_level level, const char * text, void * user_data);
|
||||
llama_log_callback = ctypes.CFUNCTYPE(None, ctypes.c_int, ctypes.c_char_p, ctypes.c_void_p)
|
||||
llama_log_callback = ctypes.CFUNCTYPE(
|
||||
None, ctypes.c_int, ctypes.c_char_p, ctypes.c_void_p
|
||||
)
|
||||
"""Signature for logging events
|
||||
Note that text includes the new line character at the end for most events.
|
||||
If your logging mechanism cannot handle that, check if the last character is '\n' and strip it
|
||||
|
@ -966,14 +1006,23 @@ llama_rope_freq_scale_train.restype = ctypes.c_float
|
|||
# // Get metadata value as a string by key name
|
||||
# LLAMA_API int32_t llama_model_meta_val_str(const struct llama_model * model, const char * key, char * buf, size_t buf_size);
|
||||
def llama_model_meta_val_str(
|
||||
model: llama_model_p, key: Union[ctypes.c_char_p, bytes], buf: bytes, buf_size: int, /
|
||||
model: llama_model_p,
|
||||
key: Union[ctypes.c_char_p, bytes],
|
||||
buf: bytes,
|
||||
buf_size: int,
|
||||
/,
|
||||
) -> int:
|
||||
"""Get metadata value as a string by key name"""
|
||||
...
|
||||
|
||||
|
||||
llama_model_meta_val_str = _lib.llama_model_meta_val_str
|
||||
llama_model_meta_val_str.argtypes = [llama_model_p_ctypes, ctypes.c_char_p, ctypes.c_char_p, ctypes.c_size_t]
|
||||
llama_model_meta_val_str.argtypes = [
|
||||
llama_model_p_ctypes,
|
||||
ctypes.c_char_p,
|
||||
ctypes.c_char_p,
|
||||
ctypes.c_size_t,
|
||||
]
|
||||
llama_model_meta_val_str.restype = ctypes.c_int32
|
||||
|
||||
|
||||
|
@ -1087,8 +1136,8 @@ llama_get_model_tensor.restype = ctypes.c_void_p
|
|||
def llama_model_quantize(
|
||||
fname_inp: bytes,
|
||||
fname_out: bytes,
|
||||
params, # type: ctypes.POINTER(llama_model_quantize_params) # type: ignore
|
||||
/
|
||||
params: CtypesPointerOrRef[llama_model_quantize_params],
|
||||
/,
|
||||
) -> int:
|
||||
"""Returns 0 on success"""
|
||||
...
|
||||
|
@ -1121,8 +1170,8 @@ def llama_apply_lora_from_file(
|
|||
path_lora: Union[ctypes.c_char_p, bytes],
|
||||
scale: Union[ctypes.c_float, float],
|
||||
path_base_model: Union[ctypes.c_char_p, bytes],
|
||||
n_threads: Union[ctypes.c_int, int],
|
||||
/
|
||||
n_threads: Union[ctypes.c_int32, int],
|
||||
/,
|
||||
) -> int:
|
||||
"""Apply a LoRA adapter to a loaded model
|
||||
path_base_model is the path to a higher quality model to use as a base for
|
||||
|
@ -1155,8 +1204,8 @@ def llama_model_apply_lora_from_file(
|
|||
path_lora: Union[ctypes.c_char_p, bytes],
|
||||
scale: Union[ctypes.c_float, float],
|
||||
path_base_model: Union[ctypes.c_char_p, bytes],
|
||||
n_threads: Union[ctypes.c_int, int],
|
||||
/
|
||||
n_threads: Union[ctypes.c_int32, int],
|
||||
/,
|
||||
) -> int:
|
||||
...
|
||||
|
||||
|
@ -1262,7 +1311,7 @@ llama_kv_cache_view_free.restype = None
|
|||
|
||||
# // Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes)
|
||||
# LLAMA_API void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_kv_cache_view * view);
|
||||
def llama_kv_cache_view_update(ctx: llama_context_p, view: "ctypes.pointer[llama_kv_cache_view]", /): # type: ignore
|
||||
def llama_kv_cache_view_update(ctx: llama_context_p, view: CtypesPointerOrRef[llama_kv_cache_view], /): # type: ignore
|
||||
"""Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes)"""
|
||||
...
|
||||
|
||||
|
@ -1326,7 +1375,7 @@ def llama_kv_cache_seq_rm(
|
|||
seq_id: Union[llama_seq_id, int],
|
||||
p0: Union[llama_pos, int],
|
||||
p1: Union[llama_pos, int],
|
||||
/
|
||||
/,
|
||||
):
|
||||
"""Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
|
||||
seq_id < 0 : match any sequence
|
||||
|
@ -1361,7 +1410,7 @@ def llama_kv_cache_seq_cp(
|
|||
seq_id_dst: Union[llama_seq_id, int],
|
||||
p0: Union[llama_pos, int],
|
||||
p1: Union[llama_pos, int],
|
||||
/
|
||||
/,
|
||||
):
|
||||
"""Copy all tokens that belong to the specified sequence to another sequence
|
||||
Note that this does not allocate extra KV cache memory - it simply assigns the tokens to the new sequence
|
||||
|
@ -1385,11 +1434,7 @@ llama_kv_cache_seq_cp.restype = None
|
|||
# LLAMA_API void llama_kv_cache_seq_keep(
|
||||
# struct llama_context * ctx,
|
||||
# llama_seq_id seq_id);
|
||||
def llama_kv_cache_seq_keep(
|
||||
ctx: llama_context_p,
|
||||
seq_id: Union[llama_seq_id, int],
|
||||
/
|
||||
):
|
||||
def llama_kv_cache_seq_keep(ctx: llama_context_p, seq_id: Union[llama_seq_id, int], /):
|
||||
"""Removes all tokens that do not belong to the specified sequence"""
|
||||
...
|
||||
|
||||
|
@ -1415,7 +1460,7 @@ def llama_kv_cache_seq_shift(
|
|||
p0: Union[llama_pos, int],
|
||||
p1: Union[llama_pos, int],
|
||||
delta: Union[llama_pos, int],
|
||||
/
|
||||
/,
|
||||
):
|
||||
"""Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
|
||||
If the KV cache is RoPEd, the KV data is updated accordingly
|
||||
|
@ -1451,7 +1496,7 @@ def llama_kv_cache_seq_div(
|
|||
p0: Union[llama_pos, int],
|
||||
p1: Union[llama_pos, int],
|
||||
d: Union[ctypes.c_int, int],
|
||||
/
|
||||
/,
|
||||
):
|
||||
"""Integer division of the positions by factor of `d > 1`
|
||||
If the KV cache is RoPEd, the KV data is updated accordingly
|
||||
|
@ -1496,8 +1541,7 @@ llama_get_state_size.restype = ctypes.c_size_t
|
|||
# struct llama_context * ctx,
|
||||
# uint8_t * dst);
|
||||
def llama_copy_state_data(
|
||||
ctx: llama_context_p, dst, # type: Array[ctypes.c_uint8]
|
||||
/
|
||||
ctx: llama_context_p, dst: CtypesArray[ctypes.c_uint8], /
|
||||
) -> int:
|
||||
"""Copies the state to the specified destination address.
|
||||
Destination needs to have allocated enough memory.
|
||||
|
@ -1506,7 +1550,10 @@ def llama_copy_state_data(
|
|||
|
||||
|
||||
llama_copy_state_data = _lib.llama_copy_state_data
|
||||
llama_copy_state_data.argtypes = [llama_context_p_ctypes, c_uint8_p]
|
||||
llama_copy_state_data.argtypes = [
|
||||
llama_context_p_ctypes,
|
||||
ctypes.POINTER(ctypes.c_uint8),
|
||||
]
|
||||
llama_copy_state_data.restype = ctypes.c_size_t
|
||||
|
||||
|
||||
|
@ -1516,15 +1563,14 @@ llama_copy_state_data.restype = ctypes.c_size_t
|
|||
# struct llama_context * ctx,
|
||||
# uint8_t * src);
|
||||
def llama_set_state_data(
|
||||
ctx: llama_context_p, src, # type: Array[ctypes.c_uint8]
|
||||
/
|
||||
ctx: llama_context_p, src: CtypesArray[ctypes.c_uint8], /
|
||||
) -> int:
|
||||
"""Set the state reading from the specified address"""
|
||||
...
|
||||
|
||||
|
||||
llama_set_state_data = _lib.llama_set_state_data
|
||||
llama_set_state_data.argtypes = [llama_context_p_ctypes, c_uint8_p]
|
||||
llama_set_state_data.argtypes = [llama_context_p_ctypes, ctypes.POINTER(ctypes.c_uint8)]
|
||||
llama_set_state_data.restype = ctypes.c_size_t
|
||||
|
||||
|
||||
|
@ -1538,10 +1584,10 @@ llama_set_state_data.restype = ctypes.c_size_t
|
|||
def llama_load_session_file(
|
||||
ctx: llama_context_p,
|
||||
path_session: bytes,
|
||||
tokens_out, # type: Array[llama_token]
|
||||
tokens_out: CtypesArray[llama_token],
|
||||
n_token_capacity: Union[ctypes.c_size_t, int],
|
||||
n_token_count_out, # type: _Pointer[ctypes.c_size_t]
|
||||
/
|
||||
n_token_count_out: CtypesPointerOrRef[ctypes.c_size_t],
|
||||
/,
|
||||
) -> int:
|
||||
...
|
||||
|
||||
|
@ -1552,7 +1598,7 @@ llama_load_session_file.argtypes = [
|
|||
ctypes.c_char_p,
|
||||
llama_token_p,
|
||||
ctypes.c_size_t,
|
||||
c_size_t_p,
|
||||
ctypes.POINTER(ctypes.c_size_t),
|
||||
]
|
||||
llama_load_session_file.restype = ctypes.c_size_t
|
||||
|
||||
|
@ -1565,9 +1611,9 @@ llama_load_session_file.restype = ctypes.c_size_t
|
|||
def llama_save_session_file(
|
||||
ctx: llama_context_p,
|
||||
path_session: bytes,
|
||||
tokens, # type: Array[llama_token]
|
||||
tokens: CtypesArray[llama_token],
|
||||
n_token_count: Union[ctypes.c_size_t, int],
|
||||
/
|
||||
/,
|
||||
) -> int:
|
||||
...
|
||||
|
||||
|
@ -1599,10 +1645,10 @@ llama_save_session_file.restype = ctypes.c_size_t
|
|||
# "use llama_decode() instead");
|
||||
def llama_eval(
|
||||
ctx: llama_context_p,
|
||||
tokens, # type: Array[llama_token]
|
||||
tokens: CtypesArray[llama_token],
|
||||
n_tokens: Union[ctypes.c_int, int],
|
||||
n_past: Union[ctypes.c_int, int],
|
||||
/
|
||||
/,
|
||||
) -> int:
|
||||
"""Run the llama inference to obtain the logits and probabilities for the next token(s).
|
||||
tokens + n_tokens is the provided batch of new tokens to process
|
||||
|
@ -1613,7 +1659,12 @@ def llama_eval(
|
|||
|
||||
|
||||
llama_eval = _lib.llama_eval
|
||||
llama_eval.argtypes = [llama_context_p_ctypes, llama_token_p, ctypes.c_int32, ctypes.c_int32]
|
||||
llama_eval.argtypes = [
|
||||
llama_context_p_ctypes,
|
||||
llama_token_p,
|
||||
ctypes.c_int32,
|
||||
ctypes.c_int32,
|
||||
]
|
||||
llama_eval.restype = ctypes.c_int
|
||||
|
||||
|
||||
|
@ -1627,10 +1678,10 @@ llama_eval.restype = ctypes.c_int
|
|||
# "use llama_decode() instead");
|
||||
def llama_eval_embd(
|
||||
ctx: llama_context_p,
|
||||
embd, # type: Array[ctypes.c_float]
|
||||
embd: CtypesArray[ctypes.c_float],
|
||||
n_tokens: Union[ctypes.c_int, int],
|
||||
n_past: Union[ctypes.c_int, int],
|
||||
/
|
||||
/,
|
||||
) -> int:
|
||||
"""Same as llama_eval, but use float matrix input directly.
|
||||
DEPRECATED: use llama_decode() instead"""
|
||||
|
@ -1638,7 +1689,12 @@ def llama_eval_embd(
|
|||
|
||||
|
||||
llama_eval_embd = _lib.llama_eval_embd
|
||||
llama_eval_embd.argtypes = [llama_context_p_ctypes, c_float_p, ctypes.c_int32, ctypes.c_int32]
|
||||
llama_eval_embd.argtypes = [
|
||||
llama_context_p_ctypes,
|
||||
ctypes.POINTER(ctypes.c_float),
|
||||
ctypes.c_int32,
|
||||
ctypes.c_int32,
|
||||
]
|
||||
llama_eval_embd.restype = ctypes.c_int
|
||||
|
||||
|
||||
|
@ -1652,11 +1708,11 @@ llama_eval_embd.restype = ctypes.c_int
|
|||
# llama_pos pos_0,
|
||||
# llama_seq_id seq_id);
|
||||
def llama_batch_get_one(
|
||||
tokens, # type: Array[llama_token]
|
||||
tokens: CtypesArray[llama_token],
|
||||
n_tokens: Union[ctypes.c_int, int],
|
||||
pos_0: Union[llama_pos, int],
|
||||
seq_id: llama_seq_id,
|
||||
/
|
||||
/,
|
||||
) -> llama_batch:
|
||||
"""Return batch for single sequence of tokens starting at pos_0
|
||||
|
||||
|
@ -1690,7 +1746,7 @@ def llama_batch_init(
|
|||
n_tokens: Union[ctypes.c_int32, int],
|
||||
embd: Union[ctypes.c_int32, int],
|
||||
n_seq_max: Union[ctypes.c_int32, int],
|
||||
/
|
||||
/,
|
||||
) -> llama_batch:
|
||||
"""Allocates a batch of tokens on the heap that can hold a maximum of n_tokens
|
||||
Each token can be assigned up to n_seq_max sequence ids
|
||||
|
@ -1747,7 +1803,7 @@ def llama_set_n_threads(
|
|||
ctx: llama_context_p,
|
||||
n_threads: Union[ctypes.c_uint32, int],
|
||||
n_threads_batch: Union[ctypes.c_uint32, int],
|
||||
/
|
||||
/,
|
||||
):
|
||||
"""Set the number of threads used for decoding
|
||||
n_threads is the number of threads used for generation (single token)
|
||||
|
@ -1757,7 +1813,11 @@ def llama_set_n_threads(
|
|||
|
||||
|
||||
llama_set_n_threads = _lib.llama_set_n_threads
|
||||
llama_set_n_threads.argtypes = [llama_context_p_ctypes, ctypes.c_uint32, ctypes.c_uint32]
|
||||
llama_set_n_threads.argtypes = [
|
||||
llama_context_p_ctypes,
|
||||
ctypes.c_uint32,
|
||||
ctypes.c_uint32,
|
||||
]
|
||||
llama_set_n_threads.restype = None
|
||||
|
||||
|
||||
|
@ -1768,8 +1828,7 @@ llama_set_n_threads.restype = None
|
|||
# // Cols: n_vocab
|
||||
# LLAMA_API float * llama_get_logits(struct llama_context * ctx);
|
||||
def llama_get_logits(
|
||||
ctx: llama_context_p,
|
||||
/
|
||||
ctx: llama_context_p, /
|
||||
): # type: (...) -> Array[float] # type: ignore
|
||||
"""Token logits obtained from the last call to llama_eval()
|
||||
The logits for the last token are stored in the last row
|
||||
|
@ -1781,15 +1840,14 @@ def llama_get_logits(
|
|||
|
||||
llama_get_logits = _lib.llama_get_logits
|
||||
llama_get_logits.argtypes = [llama_context_p_ctypes]
|
||||
llama_get_logits.restype = c_float_p
|
||||
llama_get_logits.restype = ctypes.POINTER(ctypes.c_float)
|
||||
|
||||
|
||||
# // Logits for the ith token. Equivalent to:
|
||||
# // llama_get_logits(ctx) + i*n_vocab
|
||||
# LLAMA_API float * llama_get_logits_ith(struct llama_context * ctx, int32_t i);
|
||||
def llama_get_logits_ith(
|
||||
ctx: llama_context_p, i: Union[ctypes.c_int32, int]
|
||||
, /
|
||||
ctx: llama_context_p, i: Union[ctypes.c_int32, int], /
|
||||
): # type: (...) -> Array[float] # type: ignore
|
||||
"""Logits for the ith token. Equivalent to:
|
||||
llama_get_logits(ctx) + i*n_vocab"""
|
||||
|
@ -1798,7 +1856,7 @@ def llama_get_logits_ith(
|
|||
|
||||
llama_get_logits_ith = _lib.llama_get_logits_ith
|
||||
llama_get_logits_ith.argtypes = [llama_context_p_ctypes, ctypes.c_int32]
|
||||
llama_get_logits_ith.restype = c_float_p
|
||||
llama_get_logits_ith.restype = ctypes.POINTER(ctypes.c_float)
|
||||
|
||||
|
||||
# Get the embeddings for the input
|
||||
|
@ -1814,7 +1872,7 @@ def llama_get_embeddings(
|
|||
|
||||
llama_get_embeddings = _lib.llama_get_embeddings
|
||||
llama_get_embeddings.argtypes = [llama_context_p_ctypes]
|
||||
llama_get_embeddings.restype = c_float_p
|
||||
llama_get_embeddings.restype = ctypes.POINTER(ctypes.c_float)
|
||||
|
||||
|
||||
# // Get the embeddings for the ith sequence
|
||||
|
@ -1830,7 +1888,7 @@ def llama_get_embeddings_ith(
|
|||
|
||||
llama_get_embeddings_ith = _lib.llama_get_embeddings_ith
|
||||
llama_get_embeddings_ith.argtypes = [llama_context_p_ctypes, ctypes.c_int32]
|
||||
llama_get_embeddings_ith.restype = c_float_p
|
||||
llama_get_embeddings_ith.restype = ctypes.POINTER(ctypes.c_float)
|
||||
|
||||
|
||||
# //
|
||||
|
@ -1839,7 +1897,9 @@ llama_get_embeddings_ith.restype = c_float_p
|
|||
|
||||
|
||||
# LLAMA_API const char * llama_token_get_text(const struct llama_model * model, llama_token token);
|
||||
def llama_token_get_text(model: llama_model_p, token: Union[llama_token, int], /) -> bytes:
|
||||
def llama_token_get_text(
|
||||
model: llama_model_p, token: Union[llama_token, int], /
|
||||
) -> bytes:
|
||||
...
|
||||
|
||||
|
||||
|
@ -1861,7 +1921,9 @@ llama_token_get_score.restype = ctypes.c_float
|
|||
|
||||
|
||||
# LLAMA_API enum llama_token_type llama_token_get_type(const struct llama_model * model, llama_token token);
|
||||
def llama_token_get_type(model: llama_model_p, token: Union[llama_token, int], /) -> int:
|
||||
def llama_token_get_type(
|
||||
model: llama_model_p, token: Union[llama_token, int], /
|
||||
) -> int:
|
||||
...
|
||||
|
||||
|
||||
|
@ -1995,11 +2057,11 @@ def llama_tokenize(
|
|||
model: llama_model_p,
|
||||
text: bytes,
|
||||
text_len: Union[ctypes.c_int, int],
|
||||
tokens, # type: Array[llama_token]
|
||||
tokens: CtypesArray[llama_token],
|
||||
n_max_tokens: Union[ctypes.c_int, int],
|
||||
add_bos: Union[ctypes.c_bool, bool],
|
||||
special: Union[ctypes.c_bool, bool],
|
||||
/
|
||||
/,
|
||||
) -> int:
|
||||
"""Convert the provided text into tokens."""
|
||||
...
|
||||
|
@ -2032,7 +2094,7 @@ def llama_token_to_piece(
|
|||
token: Union[llama_token, int],
|
||||
buf: Union[ctypes.c_char_p, bytes],
|
||||
length: Union[ctypes.c_int, int],
|
||||
/
|
||||
/,
|
||||
) -> int:
|
||||
"""Token Id -> Piece.
|
||||
Uses the vocabulary in the provided context.
|
||||
|
@ -2043,7 +2105,12 @@ def llama_token_to_piece(
|
|||
|
||||
|
||||
llama_token_to_piece = _lib.llama_token_to_piece
|
||||
llama_token_to_piece.argtypes = [llama_model_p_ctypes, llama_token, ctypes.c_char_p, ctypes.c_int32]
|
||||
llama_token_to_piece.argtypes = [
|
||||
llama_model_p_ctypes,
|
||||
llama_token,
|
||||
ctypes.c_char_p,
|
||||
ctypes.c_int32,
|
||||
]
|
||||
llama_token_to_piece.restype = ctypes.c_int32
|
||||
|
||||
|
||||
|
@ -2066,25 +2133,25 @@ llama_token_to_piece.restype = ctypes.c_int32
|
|||
# char * buf,
|
||||
# int32_t length);
|
||||
def llama_chat_apply_template(
|
||||
model: llama_model_p,
|
||||
tmpl: bytes,
|
||||
chat: "ctypes._Pointer[llama_chat_message]", # type: ignore
|
||||
n_msg: int,
|
||||
/
|
||||
model: llama_model_p,
|
||||
tmpl: bytes,
|
||||
chat: CtypesArray[llama_chat_message],
|
||||
n_msg: int,
|
||||
/,
|
||||
) -> int:
|
||||
...
|
||||
|
||||
|
||||
llama_chat_apply_template = _lib.llama_chat_apply_template
|
||||
llama_chat_apply_template.argtypes = [
|
||||
ctypes.c_void_p,
|
||||
ctypes.c_char_p,
|
||||
ctypes.POINTER(llama_chat_message),
|
||||
ctypes.c_size_t
|
||||
ctypes.c_size_t,
|
||||
]
|
||||
llama_chat_apply_template.restype = ctypes.c_int32
|
||||
|
||||
|
||||
|
||||
# //
|
||||
# // Grammar
|
||||
# //
|
||||
|
@ -2095,10 +2162,12 @@ llama_chat_apply_template.restype = ctypes.c_int32
|
|||
# size_t n_rules,
|
||||
# size_t start_rule_index);
|
||||
def llama_grammar_init(
|
||||
rules, # type: Array[llama_grammar_element_p] # type: ignore
|
||||
rules: CtypesArray[
|
||||
CtypesPointer[llama_grammar_element]
|
||||
], # NOTE: This might be wrong type sig
|
||||
n_rules: Union[ctypes.c_size_t, int],
|
||||
start_rule_index: Union[ctypes.c_size_t, int],
|
||||
/
|
||||
/,
|
||||
) -> llama_grammar_p:
|
||||
"""Initialize a grammar from a set of rules."""
|
||||
...
|
||||
|
@ -2163,13 +2232,15 @@ llama_set_rng_seed.restype = None
|
|||
# float penalty_present);
|
||||
def llama_sample_repetition_penalties(
|
||||
ctx: llama_context_p,
|
||||
candidates, # type: _Pointer[llama_token_data_array]
|
||||
last_tokens_data, # type: Array[llama_token]
|
||||
candidates: Union[
|
||||
CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array]
|
||||
],
|
||||
last_tokens_data: CtypesArray[llama_token],
|
||||
penalty_last_n: Union[ctypes.c_size_t, int],
|
||||
penalty_repeat: Union[ctypes.c_float, float],
|
||||
penalty_freq: Union[ctypes.c_float, float],
|
||||
penalty_present: Union[ctypes.c_float, float],
|
||||
/
|
||||
/,
|
||||
):
|
||||
"""Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
|
||||
Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
|
||||
|
@ -2201,10 +2272,10 @@ llama_sample_repetition_penalties.restype = None
|
|||
# float scale);
|
||||
def llama_sample_apply_guidance(
|
||||
ctx: llama_context_p,
|
||||
logits, # type: _Pointer[ctypes.c_float]
|
||||
logits_guidance, # type: _Pointer[ctypes.c_float]
|
||||
logits: CtypesArray[ctypes.c_float],
|
||||
logits_guidance: CtypesArray[ctypes.c_float],
|
||||
scale: Union[ctypes.c_float, float],
|
||||
/
|
||||
/,
|
||||
):
|
||||
"""Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806"""
|
||||
...
|
||||
|
@ -2213,8 +2284,8 @@ def llama_sample_apply_guidance(
|
|||
llama_sample_apply_guidance = _lib.llama_sample_apply_guidance
|
||||
llama_sample_apply_guidance.argtypes = [
|
||||
llama_context_p_ctypes,
|
||||
c_float_p,
|
||||
c_float_p,
|
||||
ctypes.POINTER(ctypes.c_float),
|
||||
ctypes.POINTER(ctypes.c_float),
|
||||
ctypes.c_float,
|
||||
]
|
||||
llama_sample_apply_guidance.restype = None
|
||||
|
@ -2228,10 +2299,12 @@ llama_sample_apply_guidance.restype = None
|
|||
# "use llama_sample_apply_guidance() instead");
|
||||
def llama_sample_classifier_free_guidance(
|
||||
ctx: llama_context_p,
|
||||
candidates, # type: _Pointer[llama_token_data_array]
|
||||
candidates: Union[
|
||||
CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array]
|
||||
],
|
||||
guidance_ctx: llama_context_p,
|
||||
scale: Union[ctypes.c_float, float],
|
||||
/
|
||||
/,
|
||||
):
|
||||
"""Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806"""
|
||||
...
|
||||
|
@ -2252,8 +2325,11 @@ llama_sample_classifier_free_guidance.restype = None
|
|||
# struct llama_context * ctx,
|
||||
# llama_token_data_array * candidates);
|
||||
def llama_sample_softmax(
|
||||
ctx: llama_context_p, candidates, # type: _Pointer[llama_token_data]
|
||||
/
|
||||
ctx: llama_context_p,
|
||||
candidates: Union[
|
||||
CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array]
|
||||
],
|
||||
/,
|
||||
):
|
||||
"""Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits."""
|
||||
...
|
||||
|
@ -2275,10 +2351,12 @@ llama_sample_softmax.restype = None
|
|||
# size_t min_keep);
|
||||
def llama_sample_top_k(
|
||||
ctx: llama_context_p,
|
||||
candidates, # type: _Pointer[llama_token_data_array]
|
||||
candidates: Union[
|
||||
CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array]
|
||||
],
|
||||
k: Union[ctypes.c_int, int],
|
||||
min_keep: Union[ctypes.c_size_t, int],
|
||||
/
|
||||
/,
|
||||
):
|
||||
"""Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751"""
|
||||
...
|
||||
|
@ -2302,10 +2380,12 @@ llama_sample_top_k.restype = None
|
|||
# size_t min_keep);
|
||||
def llama_sample_top_p(
|
||||
ctx: llama_context_p,
|
||||
candidates, # type: _Pointer[llama_token_data_array]
|
||||
candidates: Union[
|
||||
CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array]
|
||||
],
|
||||
p: Union[ctypes.c_float, float],
|
||||
min_keep: Union[ctypes.c_size_t, int],
|
||||
/
|
||||
/,
|
||||
):
|
||||
"""Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751"""
|
||||
...
|
||||
|
@ -2329,10 +2409,12 @@ llama_sample_top_p.restype = None
|
|||
# size_t min_keep);
|
||||
def llama_sample_min_p(
|
||||
ctx: llama_context_p,
|
||||
candidates, # type: _Pointer[llama_token_data_array]
|
||||
candidates: Union[
|
||||
CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array]
|
||||
],
|
||||
p: Union[ctypes.c_float, float],
|
||||
min_keep: Union[ctypes.c_size_t, int],
|
||||
/
|
||||
/,
|
||||
):
|
||||
"""Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841"""
|
||||
...
|
||||
|
@ -2356,10 +2438,12 @@ llama_sample_min_p.restype = None
|
|||
# size_t min_keep);
|
||||
def llama_sample_tail_free(
|
||||
ctx: llama_context_p,
|
||||
candidates, # type: _Pointer[llama_token_data_array]
|
||||
candidates: Union[
|
||||
CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array]
|
||||
],
|
||||
z: Union[ctypes.c_float, float],
|
||||
min_keep: Union[ctypes.c_size_t, int],
|
||||
/
|
||||
/,
|
||||
):
|
||||
"""Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/."""
|
||||
...
|
||||
|
@ -2383,10 +2467,12 @@ llama_sample_tail_free.restype = None
|
|||
# size_t min_keep);
|
||||
def llama_sample_typical(
|
||||
ctx: llama_context_p,
|
||||
candidates, # type: _Pointer[llama_token_data_array]
|
||||
candidates: Union[
|
||||
CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array]
|
||||
],
|
||||
p: Union[ctypes.c_float, float],
|
||||
min_keep: Union[ctypes.c_size_t, int],
|
||||
/
|
||||
/,
|
||||
):
|
||||
"""Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666."""
|
||||
...
|
||||
|
@ -2411,11 +2497,13 @@ llama_sample_typical.restype = None
|
|||
# float exponent_val);
|
||||
def llama_sample_entropy(
|
||||
ctx: llama_context_p,
|
||||
candidates, # type: _Pointer[llama_token_data_array]
|
||||
candidates: Union[
|
||||
CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array]
|
||||
],
|
||||
min_temp: Union[ctypes.c_float, float],
|
||||
max_temp: Union[ctypes.c_float, float],
|
||||
exponent_val: Union[ctypes.c_float, float],
|
||||
/
|
||||
/,
|
||||
):
|
||||
"""Dynamic temperature implementation described in the paper https://arxiv.org/abs/2309.02772."""
|
||||
...
|
||||
|
@ -2438,9 +2526,11 @@ llama_sample_entropy.restype = None
|
|||
# float temp);
|
||||
def llama_sample_temp(
|
||||
ctx: llama_context_p,
|
||||
candidates, # type: _Pointer[llama_token_data_array]
|
||||
candidates: Union[
|
||||
CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array]
|
||||
],
|
||||
temp: Union[ctypes.c_float, float],
|
||||
/
|
||||
/,
|
||||
):
|
||||
"""Temperature sampling described in academic paper "Generating Long Sequences with Sparse Transformers" https://arxiv.org/abs/1904.10509
|
||||
|
||||
|
@ -2467,9 +2557,11 @@ llama_sample_temp.restype = None
|
|||
# "use llama_sample_temp instead");
|
||||
def llama_sample_temperature(
|
||||
ctx: llama_context_p,
|
||||
candidates, # type: _Pointer[llama_token_data_array]
|
||||
candidates: Union[
|
||||
CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array]
|
||||
],
|
||||
temp: Union[ctypes.c_float, float],
|
||||
/
|
||||
/,
|
||||
):
|
||||
"""use llama_sample_temp instead"""
|
||||
...
|
||||
|
@ -2491,9 +2583,11 @@ llama_sample_temperature.restype = None
|
|||
# const struct llama_grammar * grammar);
|
||||
def llama_sample_grammar(
|
||||
ctx: llama_context_p,
|
||||
candidates, # type: _Pointer[llama_token_data_array]
|
||||
candidates: Union[
|
||||
CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array]
|
||||
],
|
||||
grammar, # type: llama_grammar_p
|
||||
/
|
||||
/,
|
||||
):
|
||||
"""Apply constraints from grammar
|
||||
|
||||
|
@ -2528,12 +2622,14 @@ llama_sample_grammar.restype = None
|
|||
# float * mu);
|
||||
def llama_sample_token_mirostat(
|
||||
ctx: llama_context_p,
|
||||
candidates, # type: _Pointer[llama_token_data_array]
|
||||
candidates: Union[
|
||||
CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array]
|
||||
],
|
||||
tau: Union[ctypes.c_float, float],
|
||||
eta: Union[ctypes.c_float, float],
|
||||
m: Union[ctypes.c_int, int],
|
||||
mu, # type: _Pointer[ctypes.c_float]
|
||||
/
|
||||
mu: CtypesPointerOrRef[ctypes.c_float],
|
||||
/,
|
||||
) -> int:
|
||||
"""Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
|
||||
|
||||
|
@ -2554,7 +2650,7 @@ llama_sample_token_mirostat.argtypes = [
|
|||
ctypes.c_float,
|
||||
ctypes.c_float,
|
||||
ctypes.c_int32,
|
||||
c_float_p,
|
||||
ctypes.POINTER(ctypes.c_float),
|
||||
]
|
||||
llama_sample_token_mirostat.restype = llama_token
|
||||
|
||||
|
@ -2572,11 +2668,13 @@ llama_sample_token_mirostat.restype = llama_token
|
|||
# float * mu);
|
||||
def llama_sample_token_mirostat_v2(
|
||||
ctx: llama_context_p,
|
||||
candidates, # type: _Pointer[llama_token_data_array]
|
||||
candidates: Union[
|
||||
CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array]
|
||||
],
|
||||
tau: Union[ctypes.c_float, float],
|
||||
eta: Union[ctypes.c_float, float],
|
||||
mu, # type: _Pointer[ctypes.c_float]
|
||||
/
|
||||
/,
|
||||
) -> int:
|
||||
"""Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
|
||||
|
||||
|
@ -2595,7 +2693,7 @@ llama_sample_token_mirostat_v2.argtypes = [
|
|||
llama_token_data_array_p,
|
||||
ctypes.c_float,
|
||||
ctypes.c_float,
|
||||
c_float_p,
|
||||
ctypes.POINTER(ctypes.c_float),
|
||||
]
|
||||
llama_sample_token_mirostat_v2.restype = llama_token
|
||||
|
||||
|
@ -2607,8 +2705,10 @@ llama_sample_token_mirostat_v2.restype = llama_token
|
|||
# llama_token_data_array * candidates);
|
||||
def llama_sample_token_greedy(
|
||||
ctx: llama_context_p,
|
||||
candidates, # type: _Pointer[llama_token_data_array]
|
||||
/
|
||||
candidates: Union[
|
||||
CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array]
|
||||
],
|
||||
/,
|
||||
) -> int:
|
||||
"""Selects the token with the highest probability."""
|
||||
...
|
||||
|
@ -2628,8 +2728,10 @@ llama_sample_token_greedy.restype = llama_token
|
|||
# llama_token_data_array * candidates);
|
||||
def llama_sample_token(
|
||||
ctx: llama_context_p,
|
||||
candidates, # type: _Pointer[llama_token_data_array]
|
||||
/
|
||||
candidates: Union[
|
||||
CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array]
|
||||
],
|
||||
/,
|
||||
) -> int:
|
||||
"""Randomly selects a token from the candidates based on their probabilities."""
|
||||
...
|
||||
|
@ -2649,10 +2751,7 @@ llama_sample_token.restype = llama_token
|
|||
# struct llama_grammar * grammar,
|
||||
# llama_token token);
|
||||
def llama_grammar_accept_token(
|
||||
ctx: llama_context_p,
|
||||
grammar: llama_grammar_p,
|
||||
token: Union[llama_token, int],
|
||||
/
|
||||
ctx: llama_context_p, grammar: llama_grammar_p, token: Union[llama_token, int], /
|
||||
) -> None:
|
||||
"""Accepts the sampled token into the grammar"""
|
||||
...
|
||||
|
@ -2711,7 +2810,9 @@ class llama_beams_state(ctypes.Structure):
|
|||
# // void* callback_data is any custom data passed to llama_beam_search, that is subsequently
|
||||
# // passed back to beam_search_callback. This avoids having to use global variables in the callback.
|
||||
# typedef void (*llama_beam_search_callback_fn_t)(void * callback_data, struct llama_beams_state);
|
||||
llama_beam_search_callback_fn_t = ctypes.CFUNCTYPE(None, ctypes.c_void_p, llama_beams_state)
|
||||
llama_beam_search_callback_fn_t = ctypes.CFUNCTYPE(
|
||||
None, ctypes.c_void_p, llama_beams_state
|
||||
)
|
||||
|
||||
|
||||
# /// @details Deterministically returns entire sentence constructed by a beam search.
|
||||
|
@ -2731,12 +2832,12 @@ llama_beam_search_callback_fn_t = ctypes.CFUNCTYPE(None, ctypes.c_void_p, llama_
|
|||
# int32_t n_predict);
|
||||
def llama_beam_search(
|
||||
ctx: llama_context_p,
|
||||
callback: "ctypes._CFuncPtr[None, ctypes.c_void_p, llama_beams_state]", # type: ignore
|
||||
callback: CtypesFuncPointer,
|
||||
callback_data: ctypes.c_void_p,
|
||||
n_beams: Union[ctypes.c_size_t, int],
|
||||
n_past: Union[ctypes.c_int, int],
|
||||
n_predict: Union[ctypes.c_int, int],
|
||||
/
|
||||
/,
|
||||
):
|
||||
...
|
||||
|
||||
|
@ -2806,8 +2907,9 @@ llama_print_system_info.restype = ctypes.c_char_p
|
|||
# // If this is not called, or NULL is supplied, everything is output on stderr.
|
||||
# LLAMA_API void llama_log_set(ggml_log_callback log_callback, void * user_data);
|
||||
def llama_log_set(
|
||||
log_callback: Union["ctypes._FuncPointer", ctypes.c_void_p], user_data: ctypes.c_void_p, # type: ignore
|
||||
/
|
||||
log_callback: Optional[CtypesFuncPointer],
|
||||
user_data: ctypes.c_void_p, # type: ignore
|
||||
/,
|
||||
):
|
||||
"""Set callback for all future logging events.
|
||||
|
||||
|
|
Loading…
Reference in a new issue