misc: additional type annotations for low level api

This commit is contained in:
Andrei Betlen 2024-02-22 02:00:09 -05:00
parent 3632241e98
commit aefcb8f71a

View file

@ -1,3 +1,5 @@
from __future__ import annotations
import sys
import os
import ctypes
@ -6,7 +8,16 @@ from ctypes import (
Array,
)
import pathlib
from typing import List, Union, NewType, Optional
from typing import (
List,
Union,
NewType,
Optional,
TYPE_CHECKING,
TypeVar,
TypeAlias,
Generic,
)
# Load the library
@ -56,7 +67,7 @@ def _load_shared_library(lib_base_name: str):
for _lib_path in _lib_paths:
if _lib_path.exists():
try:
return ctypes.CDLL(str(_lib_path), **cdll_args) # type: ignore
return ctypes.CDLL(str(_lib_path), **cdll_args) # type: ignore
except Exception as e:
raise RuntimeError(f"Failed to load shared library '{_lib_path}': {e}")
@ -71,14 +82,39 @@ _lib_base_name = "llama"
# Load the library
_lib = _load_shared_library(_lib_base_name)
# Misc
c_float_p = ctypes.POINTER(ctypes.c_float)
c_uint8_p = ctypes.POINTER(ctypes.c_uint8)
c_size_t_p = ctypes.POINTER(ctypes.c_size_t)
# ctypes sane type hint helpers
#
# - Generic Pointer and Array types
# - PointerOrRef type with a type hinted byref function
#
# NOTE: Only use these for static type checking not for runtime checks
# no good will come of that
if TYPE_CHECKING:
CtypesCData = TypeVar("CtypesCData", bound=ctypes._CData) # type: ignore
CtypesArray: TypeAlias = ctypes.Array[CtypesCData] # type: ignore
CtypesPointer: TypeAlias = ctypes._Pointer[CtypesCData] # type: ignore
CtypesVoidPointer: TypeAlias = ctypes.c_void_p
class CtypesRef(Generic[CtypesCData]):
pass
CtypesPointerOrRef: TypeAlias = Union[
CtypesPointer[CtypesCData], CtypesRef[CtypesCData]
]
CtypesFuncPointer: TypeAlias = ctypes._FuncPointer # type: ignore
# from ggml-backend.h
# typedef bool (*ggml_backend_sched_eval_callback)(struct ggml_tensor * t, bool ask, void * user_data);
ggml_backend_sched_eval_callback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_void_p, ctypes.c_bool, ctypes.c_void_p)
ggml_backend_sched_eval_callback = ctypes.CFUNCTYPE(
ctypes.c_bool, ctypes.c_void_p, ctypes.c_bool, ctypes.c_void_p
)
# llama.h bindings
@ -286,7 +322,9 @@ class llama_token_data_array(ctypes.Structure):
llama_token_data_array_p = ctypes.POINTER(llama_token_data_array)
# typedef bool (*llama_progress_callback)(float progress, void *ctx);
llama_progress_callback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_float, ctypes.c_void_p)
llama_progress_callback = ctypes.CFUNCTYPE(
ctypes.c_bool, ctypes.c_float, ctypes.c_void_p
)
# // Input data for llama_decode
@ -336,7 +374,7 @@ class llama_batch(ctypes.Structure):
_fields_ = [
("n_tokens", ctypes.c_int32),
("token", ctypes.POINTER(llama_token)),
("embd", c_float_p),
("embd", ctypes.POINTER(ctypes.c_float)),
("pos", ctypes.POINTER(llama_pos)),
("n_seq_id", ctypes.POINTER(ctypes.c_int32)),
("seq_id", ctypes.POINTER(ctypes.POINTER(llama_seq_id))),
@ -431,7 +469,7 @@ class llama_model_params(ctypes.Structure):
("n_gpu_layers", ctypes.c_int32),
("split_mode", ctypes.c_int),
("main_gpu", ctypes.c_int32),
("tensor_split", c_float_p),
("tensor_split", ctypes.POINTER(ctypes.c_float)),
("progress_callback", llama_progress_callback),
("progress_callback_user_data", ctypes.c_void_p),
("kv_overrides", ctypes.POINTER(llama_model_kv_override)),
@ -532,7 +570,9 @@ class llama_context_params(ctypes.Structure):
# // if it exists.
# // It might not exist for progress report where '.' is output repeatedly.
# typedef void (*llama_log_callback)(enum llama_log_level level, const char * text, void * user_data);
llama_log_callback = ctypes.CFUNCTYPE(None, ctypes.c_int, ctypes.c_char_p, ctypes.c_void_p)
llama_log_callback = ctypes.CFUNCTYPE(
None, ctypes.c_int, ctypes.c_char_p, ctypes.c_void_p
)
"""Signature for logging events
Note that text includes the new line character at the end for most events.
If your logging mechanism cannot handle that, check if the last character is '\n' and strip it
@ -966,14 +1006,23 @@ llama_rope_freq_scale_train.restype = ctypes.c_float
# // Get metadata value as a string by key name
# LLAMA_API int32_t llama_model_meta_val_str(const struct llama_model * model, const char * key, char * buf, size_t buf_size);
def llama_model_meta_val_str(
model: llama_model_p, key: Union[ctypes.c_char_p, bytes], buf: bytes, buf_size: int, /
model: llama_model_p,
key: Union[ctypes.c_char_p, bytes],
buf: bytes,
buf_size: int,
/,
) -> int:
"""Get metadata value as a string by key name"""
...
llama_model_meta_val_str = _lib.llama_model_meta_val_str
llama_model_meta_val_str.argtypes = [llama_model_p_ctypes, ctypes.c_char_p, ctypes.c_char_p, ctypes.c_size_t]
llama_model_meta_val_str.argtypes = [
llama_model_p_ctypes,
ctypes.c_char_p,
ctypes.c_char_p,
ctypes.c_size_t,
]
llama_model_meta_val_str.restype = ctypes.c_int32
@ -1087,8 +1136,8 @@ llama_get_model_tensor.restype = ctypes.c_void_p
def llama_model_quantize(
fname_inp: bytes,
fname_out: bytes,
params, # type: ctypes.POINTER(llama_model_quantize_params) # type: ignore
/
params: CtypesPointerOrRef[llama_model_quantize_params],
/,
) -> int:
"""Returns 0 on success"""
...
@ -1121,8 +1170,8 @@ def llama_apply_lora_from_file(
path_lora: Union[ctypes.c_char_p, bytes],
scale: Union[ctypes.c_float, float],
path_base_model: Union[ctypes.c_char_p, bytes],
n_threads: Union[ctypes.c_int, int],
/
n_threads: Union[ctypes.c_int32, int],
/,
) -> int:
"""Apply a LoRA adapter to a loaded model
path_base_model is the path to a higher quality model to use as a base for
@ -1155,8 +1204,8 @@ def llama_model_apply_lora_from_file(
path_lora: Union[ctypes.c_char_p, bytes],
scale: Union[ctypes.c_float, float],
path_base_model: Union[ctypes.c_char_p, bytes],
n_threads: Union[ctypes.c_int, int],
/
n_threads: Union[ctypes.c_int32, int],
/,
) -> int:
...
@ -1262,7 +1311,7 @@ llama_kv_cache_view_free.restype = None
# // Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes)
# LLAMA_API void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_kv_cache_view * view);
def llama_kv_cache_view_update(ctx: llama_context_p, view: "ctypes.pointer[llama_kv_cache_view]", /): # type: ignore
def llama_kv_cache_view_update(ctx: llama_context_p, view: CtypesPointerOrRef[llama_kv_cache_view], /): # type: ignore
"""Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes)"""
...
@ -1326,7 +1375,7 @@ def llama_kv_cache_seq_rm(
seq_id: Union[llama_seq_id, int],
p0: Union[llama_pos, int],
p1: Union[llama_pos, int],
/
/,
):
"""Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
seq_id < 0 : match any sequence
@ -1361,7 +1410,7 @@ def llama_kv_cache_seq_cp(
seq_id_dst: Union[llama_seq_id, int],
p0: Union[llama_pos, int],
p1: Union[llama_pos, int],
/
/,
):
"""Copy all tokens that belong to the specified sequence to another sequence
Note that this does not allocate extra KV cache memory - it simply assigns the tokens to the new sequence
@ -1385,11 +1434,7 @@ llama_kv_cache_seq_cp.restype = None
# LLAMA_API void llama_kv_cache_seq_keep(
# struct llama_context * ctx,
# llama_seq_id seq_id);
def llama_kv_cache_seq_keep(
ctx: llama_context_p,
seq_id: Union[llama_seq_id, int],
/
):
def llama_kv_cache_seq_keep(ctx: llama_context_p, seq_id: Union[llama_seq_id, int], /):
"""Removes all tokens that do not belong to the specified sequence"""
...
@ -1415,7 +1460,7 @@ def llama_kv_cache_seq_shift(
p0: Union[llama_pos, int],
p1: Union[llama_pos, int],
delta: Union[llama_pos, int],
/
/,
):
"""Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
If the KV cache is RoPEd, the KV data is updated accordingly
@ -1451,7 +1496,7 @@ def llama_kv_cache_seq_div(
p0: Union[llama_pos, int],
p1: Union[llama_pos, int],
d: Union[ctypes.c_int, int],
/
/,
):
"""Integer division of the positions by factor of `d > 1`
If the KV cache is RoPEd, the KV data is updated accordingly
@ -1496,8 +1541,7 @@ llama_get_state_size.restype = ctypes.c_size_t
# struct llama_context * ctx,
# uint8_t * dst);
def llama_copy_state_data(
ctx: llama_context_p, dst, # type: Array[ctypes.c_uint8]
/
ctx: llama_context_p, dst: CtypesArray[ctypes.c_uint8], /
) -> int:
"""Copies the state to the specified destination address.
Destination needs to have allocated enough memory.
@ -1506,7 +1550,10 @@ def llama_copy_state_data(
llama_copy_state_data = _lib.llama_copy_state_data
llama_copy_state_data.argtypes = [llama_context_p_ctypes, c_uint8_p]
llama_copy_state_data.argtypes = [
llama_context_p_ctypes,
ctypes.POINTER(ctypes.c_uint8),
]
llama_copy_state_data.restype = ctypes.c_size_t
@ -1516,15 +1563,14 @@ llama_copy_state_data.restype = ctypes.c_size_t
# struct llama_context * ctx,
# uint8_t * src);
def llama_set_state_data(
ctx: llama_context_p, src, # type: Array[ctypes.c_uint8]
/
ctx: llama_context_p, src: CtypesArray[ctypes.c_uint8], /
) -> int:
"""Set the state reading from the specified address"""
...
llama_set_state_data = _lib.llama_set_state_data
llama_set_state_data.argtypes = [llama_context_p_ctypes, c_uint8_p]
llama_set_state_data.argtypes = [llama_context_p_ctypes, ctypes.POINTER(ctypes.c_uint8)]
llama_set_state_data.restype = ctypes.c_size_t
@ -1538,10 +1584,10 @@ llama_set_state_data.restype = ctypes.c_size_t
def llama_load_session_file(
ctx: llama_context_p,
path_session: bytes,
tokens_out, # type: Array[llama_token]
tokens_out: CtypesArray[llama_token],
n_token_capacity: Union[ctypes.c_size_t, int],
n_token_count_out, # type: _Pointer[ctypes.c_size_t]
/
n_token_count_out: CtypesPointerOrRef[ctypes.c_size_t],
/,
) -> int:
...
@ -1552,7 +1598,7 @@ llama_load_session_file.argtypes = [
ctypes.c_char_p,
llama_token_p,
ctypes.c_size_t,
c_size_t_p,
ctypes.POINTER(ctypes.c_size_t),
]
llama_load_session_file.restype = ctypes.c_size_t
@ -1565,9 +1611,9 @@ llama_load_session_file.restype = ctypes.c_size_t
def llama_save_session_file(
ctx: llama_context_p,
path_session: bytes,
tokens, # type: Array[llama_token]
tokens: CtypesArray[llama_token],
n_token_count: Union[ctypes.c_size_t, int],
/
/,
) -> int:
...
@ -1599,10 +1645,10 @@ llama_save_session_file.restype = ctypes.c_size_t
# "use llama_decode() instead");
def llama_eval(
ctx: llama_context_p,
tokens, # type: Array[llama_token]
tokens: CtypesArray[llama_token],
n_tokens: Union[ctypes.c_int, int],
n_past: Union[ctypes.c_int, int],
/
/,
) -> int:
"""Run the llama inference to obtain the logits and probabilities for the next token(s).
tokens + n_tokens is the provided batch of new tokens to process
@ -1613,7 +1659,12 @@ def llama_eval(
llama_eval = _lib.llama_eval
llama_eval.argtypes = [llama_context_p_ctypes, llama_token_p, ctypes.c_int32, ctypes.c_int32]
llama_eval.argtypes = [
llama_context_p_ctypes,
llama_token_p,
ctypes.c_int32,
ctypes.c_int32,
]
llama_eval.restype = ctypes.c_int
@ -1627,10 +1678,10 @@ llama_eval.restype = ctypes.c_int
# "use llama_decode() instead");
def llama_eval_embd(
ctx: llama_context_p,
embd, # type: Array[ctypes.c_float]
embd: CtypesArray[ctypes.c_float],
n_tokens: Union[ctypes.c_int, int],
n_past: Union[ctypes.c_int, int],
/
/,
) -> int:
"""Same as llama_eval, but use float matrix input directly.
DEPRECATED: use llama_decode() instead"""
@ -1638,7 +1689,12 @@ def llama_eval_embd(
llama_eval_embd = _lib.llama_eval_embd
llama_eval_embd.argtypes = [llama_context_p_ctypes, c_float_p, ctypes.c_int32, ctypes.c_int32]
llama_eval_embd.argtypes = [
llama_context_p_ctypes,
ctypes.POINTER(ctypes.c_float),
ctypes.c_int32,
ctypes.c_int32,
]
llama_eval_embd.restype = ctypes.c_int
@ -1652,11 +1708,11 @@ llama_eval_embd.restype = ctypes.c_int
# llama_pos pos_0,
# llama_seq_id seq_id);
def llama_batch_get_one(
tokens, # type: Array[llama_token]
tokens: CtypesArray[llama_token],
n_tokens: Union[ctypes.c_int, int],
pos_0: Union[llama_pos, int],
seq_id: llama_seq_id,
/
/,
) -> llama_batch:
"""Return batch for single sequence of tokens starting at pos_0
@ -1690,7 +1746,7 @@ def llama_batch_init(
n_tokens: Union[ctypes.c_int32, int],
embd: Union[ctypes.c_int32, int],
n_seq_max: Union[ctypes.c_int32, int],
/
/,
) -> llama_batch:
"""Allocates a batch of tokens on the heap that can hold a maximum of n_tokens
Each token can be assigned up to n_seq_max sequence ids
@ -1747,7 +1803,7 @@ def llama_set_n_threads(
ctx: llama_context_p,
n_threads: Union[ctypes.c_uint32, int],
n_threads_batch: Union[ctypes.c_uint32, int],
/
/,
):
"""Set the number of threads used for decoding
n_threads is the number of threads used for generation (single token)
@ -1757,7 +1813,11 @@ def llama_set_n_threads(
llama_set_n_threads = _lib.llama_set_n_threads
llama_set_n_threads.argtypes = [llama_context_p_ctypes, ctypes.c_uint32, ctypes.c_uint32]
llama_set_n_threads.argtypes = [
llama_context_p_ctypes,
ctypes.c_uint32,
ctypes.c_uint32,
]
llama_set_n_threads.restype = None
@ -1768,8 +1828,7 @@ llama_set_n_threads.restype = None
# // Cols: n_vocab
# LLAMA_API float * llama_get_logits(struct llama_context * ctx);
def llama_get_logits(
ctx: llama_context_p,
/
ctx: llama_context_p, /
): # type: (...) -> Array[float] # type: ignore
"""Token logits obtained from the last call to llama_eval()
The logits for the last token are stored in the last row
@ -1781,15 +1840,14 @@ def llama_get_logits(
llama_get_logits = _lib.llama_get_logits
llama_get_logits.argtypes = [llama_context_p_ctypes]
llama_get_logits.restype = c_float_p
llama_get_logits.restype = ctypes.POINTER(ctypes.c_float)
# // Logits for the ith token. Equivalent to:
# // llama_get_logits(ctx) + i*n_vocab
# LLAMA_API float * llama_get_logits_ith(struct llama_context * ctx, int32_t i);
def llama_get_logits_ith(
ctx: llama_context_p, i: Union[ctypes.c_int32, int]
, /
ctx: llama_context_p, i: Union[ctypes.c_int32, int], /
): # type: (...) -> Array[float] # type: ignore
"""Logits for the ith token. Equivalent to:
llama_get_logits(ctx) + i*n_vocab"""
@ -1798,7 +1856,7 @@ def llama_get_logits_ith(
llama_get_logits_ith = _lib.llama_get_logits_ith
llama_get_logits_ith.argtypes = [llama_context_p_ctypes, ctypes.c_int32]
llama_get_logits_ith.restype = c_float_p
llama_get_logits_ith.restype = ctypes.POINTER(ctypes.c_float)
# Get the embeddings for the input
@ -1814,7 +1872,7 @@ def llama_get_embeddings(
llama_get_embeddings = _lib.llama_get_embeddings
llama_get_embeddings.argtypes = [llama_context_p_ctypes]
llama_get_embeddings.restype = c_float_p
llama_get_embeddings.restype = ctypes.POINTER(ctypes.c_float)
# // Get the embeddings for the ith sequence
@ -1830,7 +1888,7 @@ def llama_get_embeddings_ith(
llama_get_embeddings_ith = _lib.llama_get_embeddings_ith
llama_get_embeddings_ith.argtypes = [llama_context_p_ctypes, ctypes.c_int32]
llama_get_embeddings_ith.restype = c_float_p
llama_get_embeddings_ith.restype = ctypes.POINTER(ctypes.c_float)
# //
@ -1839,7 +1897,9 @@ llama_get_embeddings_ith.restype = c_float_p
# LLAMA_API const char * llama_token_get_text(const struct llama_model * model, llama_token token);
def llama_token_get_text(model: llama_model_p, token: Union[llama_token, int], /) -> bytes:
def llama_token_get_text(
model: llama_model_p, token: Union[llama_token, int], /
) -> bytes:
...
@ -1861,7 +1921,9 @@ llama_token_get_score.restype = ctypes.c_float
# LLAMA_API enum llama_token_type llama_token_get_type(const struct llama_model * model, llama_token token);
def llama_token_get_type(model: llama_model_p, token: Union[llama_token, int], /) -> int:
def llama_token_get_type(
model: llama_model_p, token: Union[llama_token, int], /
) -> int:
...
@ -1995,11 +2057,11 @@ def llama_tokenize(
model: llama_model_p,
text: bytes,
text_len: Union[ctypes.c_int, int],
tokens, # type: Array[llama_token]
tokens: CtypesArray[llama_token],
n_max_tokens: Union[ctypes.c_int, int],
add_bos: Union[ctypes.c_bool, bool],
special: Union[ctypes.c_bool, bool],
/
/,
) -> int:
"""Convert the provided text into tokens."""
...
@ -2032,7 +2094,7 @@ def llama_token_to_piece(
token: Union[llama_token, int],
buf: Union[ctypes.c_char_p, bytes],
length: Union[ctypes.c_int, int],
/
/,
) -> int:
"""Token Id -> Piece.
Uses the vocabulary in the provided context.
@ -2043,7 +2105,12 @@ def llama_token_to_piece(
llama_token_to_piece = _lib.llama_token_to_piece
llama_token_to_piece.argtypes = [llama_model_p_ctypes, llama_token, ctypes.c_char_p, ctypes.c_int32]
llama_token_to_piece.argtypes = [
llama_model_p_ctypes,
llama_token,
ctypes.c_char_p,
ctypes.c_int32,
]
llama_token_to_piece.restype = ctypes.c_int32
@ -2066,25 +2133,25 @@ llama_token_to_piece.restype = ctypes.c_int32
# char * buf,
# int32_t length);
def llama_chat_apply_template(
model: llama_model_p,
tmpl: bytes,
chat: "ctypes._Pointer[llama_chat_message]", # type: ignore
n_msg: int,
/
model: llama_model_p,
tmpl: bytes,
chat: CtypesArray[llama_chat_message],
n_msg: int,
/,
) -> int:
...
llama_chat_apply_template = _lib.llama_chat_apply_template
llama_chat_apply_template.argtypes = [
ctypes.c_void_p,
ctypes.c_char_p,
ctypes.POINTER(llama_chat_message),
ctypes.c_size_t
ctypes.c_size_t,
]
llama_chat_apply_template.restype = ctypes.c_int32
# //
# // Grammar
# //
@ -2095,10 +2162,12 @@ llama_chat_apply_template.restype = ctypes.c_int32
# size_t n_rules,
# size_t start_rule_index);
def llama_grammar_init(
rules, # type: Array[llama_grammar_element_p] # type: ignore
rules: CtypesArray[
CtypesPointer[llama_grammar_element]
], # NOTE: This might be wrong type sig
n_rules: Union[ctypes.c_size_t, int],
start_rule_index: Union[ctypes.c_size_t, int],
/
/,
) -> llama_grammar_p:
"""Initialize a grammar from a set of rules."""
...
@ -2163,13 +2232,15 @@ llama_set_rng_seed.restype = None
# float penalty_present);
def llama_sample_repetition_penalties(
ctx: llama_context_p,
candidates, # type: _Pointer[llama_token_data_array]
last_tokens_data, # type: Array[llama_token]
candidates: Union[
CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array]
],
last_tokens_data: CtypesArray[llama_token],
penalty_last_n: Union[ctypes.c_size_t, int],
penalty_repeat: Union[ctypes.c_float, float],
penalty_freq: Union[ctypes.c_float, float],
penalty_present: Union[ctypes.c_float, float],
/
/,
):
"""Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
@ -2201,10 +2272,10 @@ llama_sample_repetition_penalties.restype = None
# float scale);
def llama_sample_apply_guidance(
ctx: llama_context_p,
logits, # type: _Pointer[ctypes.c_float]
logits_guidance, # type: _Pointer[ctypes.c_float]
logits: CtypesArray[ctypes.c_float],
logits_guidance: CtypesArray[ctypes.c_float],
scale: Union[ctypes.c_float, float],
/
/,
):
"""Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806"""
...
@ -2213,8 +2284,8 @@ def llama_sample_apply_guidance(
llama_sample_apply_guidance = _lib.llama_sample_apply_guidance
llama_sample_apply_guidance.argtypes = [
llama_context_p_ctypes,
c_float_p,
c_float_p,
ctypes.POINTER(ctypes.c_float),
ctypes.POINTER(ctypes.c_float),
ctypes.c_float,
]
llama_sample_apply_guidance.restype = None
@ -2228,10 +2299,12 @@ llama_sample_apply_guidance.restype = None
# "use llama_sample_apply_guidance() instead");
def llama_sample_classifier_free_guidance(
ctx: llama_context_p,
candidates, # type: _Pointer[llama_token_data_array]
candidates: Union[
CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array]
],
guidance_ctx: llama_context_p,
scale: Union[ctypes.c_float, float],
/
/,
):
"""Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806"""
...
@ -2252,8 +2325,11 @@ llama_sample_classifier_free_guidance.restype = None
# struct llama_context * ctx,
# llama_token_data_array * candidates);
def llama_sample_softmax(
ctx: llama_context_p, candidates, # type: _Pointer[llama_token_data]
/
ctx: llama_context_p,
candidates: Union[
CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array]
],
/,
):
"""Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits."""
...
@ -2275,10 +2351,12 @@ llama_sample_softmax.restype = None
# size_t min_keep);
def llama_sample_top_k(
ctx: llama_context_p,
candidates, # type: _Pointer[llama_token_data_array]
candidates: Union[
CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array]
],
k: Union[ctypes.c_int, int],
min_keep: Union[ctypes.c_size_t, int],
/
/,
):
"""Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751"""
...
@ -2302,10 +2380,12 @@ llama_sample_top_k.restype = None
# size_t min_keep);
def llama_sample_top_p(
ctx: llama_context_p,
candidates, # type: _Pointer[llama_token_data_array]
candidates: Union[
CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array]
],
p: Union[ctypes.c_float, float],
min_keep: Union[ctypes.c_size_t, int],
/
/,
):
"""Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751"""
...
@ -2329,10 +2409,12 @@ llama_sample_top_p.restype = None
# size_t min_keep);
def llama_sample_min_p(
ctx: llama_context_p,
candidates, # type: _Pointer[llama_token_data_array]
candidates: Union[
CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array]
],
p: Union[ctypes.c_float, float],
min_keep: Union[ctypes.c_size_t, int],
/
/,
):
"""Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841"""
...
@ -2356,10 +2438,12 @@ llama_sample_min_p.restype = None
# size_t min_keep);
def llama_sample_tail_free(
ctx: llama_context_p,
candidates, # type: _Pointer[llama_token_data_array]
candidates: Union[
CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array]
],
z: Union[ctypes.c_float, float],
min_keep: Union[ctypes.c_size_t, int],
/
/,
):
"""Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/."""
...
@ -2383,10 +2467,12 @@ llama_sample_tail_free.restype = None
# size_t min_keep);
def llama_sample_typical(
ctx: llama_context_p,
candidates, # type: _Pointer[llama_token_data_array]
candidates: Union[
CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array]
],
p: Union[ctypes.c_float, float],
min_keep: Union[ctypes.c_size_t, int],
/
/,
):
"""Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666."""
...
@ -2411,11 +2497,13 @@ llama_sample_typical.restype = None
# float exponent_val);
def llama_sample_entropy(
ctx: llama_context_p,
candidates, # type: _Pointer[llama_token_data_array]
candidates: Union[
CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array]
],
min_temp: Union[ctypes.c_float, float],
max_temp: Union[ctypes.c_float, float],
exponent_val: Union[ctypes.c_float, float],
/
/,
):
"""Dynamic temperature implementation described in the paper https://arxiv.org/abs/2309.02772."""
...
@ -2438,9 +2526,11 @@ llama_sample_entropy.restype = None
# float temp);
def llama_sample_temp(
ctx: llama_context_p,
candidates, # type: _Pointer[llama_token_data_array]
candidates: Union[
CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array]
],
temp: Union[ctypes.c_float, float],
/
/,
):
"""Temperature sampling described in academic paper "Generating Long Sequences with Sparse Transformers" https://arxiv.org/abs/1904.10509
@ -2467,9 +2557,11 @@ llama_sample_temp.restype = None
# "use llama_sample_temp instead");
def llama_sample_temperature(
ctx: llama_context_p,
candidates, # type: _Pointer[llama_token_data_array]
candidates: Union[
CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array]
],
temp: Union[ctypes.c_float, float],
/
/,
):
"""use llama_sample_temp instead"""
...
@ -2491,9 +2583,11 @@ llama_sample_temperature.restype = None
# const struct llama_grammar * grammar);
def llama_sample_grammar(
ctx: llama_context_p,
candidates, # type: _Pointer[llama_token_data_array]
candidates: Union[
CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array]
],
grammar, # type: llama_grammar_p
/
/,
):
"""Apply constraints from grammar
@ -2528,12 +2622,14 @@ llama_sample_grammar.restype = None
# float * mu);
def llama_sample_token_mirostat(
ctx: llama_context_p,
candidates, # type: _Pointer[llama_token_data_array]
candidates: Union[
CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array]
],
tau: Union[ctypes.c_float, float],
eta: Union[ctypes.c_float, float],
m: Union[ctypes.c_int, int],
mu, # type: _Pointer[ctypes.c_float]
/
mu: CtypesPointerOrRef[ctypes.c_float],
/,
) -> int:
"""Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
@ -2554,7 +2650,7 @@ llama_sample_token_mirostat.argtypes = [
ctypes.c_float,
ctypes.c_float,
ctypes.c_int32,
c_float_p,
ctypes.POINTER(ctypes.c_float),
]
llama_sample_token_mirostat.restype = llama_token
@ -2572,11 +2668,13 @@ llama_sample_token_mirostat.restype = llama_token
# float * mu);
def llama_sample_token_mirostat_v2(
ctx: llama_context_p,
candidates, # type: _Pointer[llama_token_data_array]
candidates: Union[
CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array]
],
tau: Union[ctypes.c_float, float],
eta: Union[ctypes.c_float, float],
mu, # type: _Pointer[ctypes.c_float]
/
/,
) -> int:
"""Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
@ -2595,7 +2693,7 @@ llama_sample_token_mirostat_v2.argtypes = [
llama_token_data_array_p,
ctypes.c_float,
ctypes.c_float,
c_float_p,
ctypes.POINTER(ctypes.c_float),
]
llama_sample_token_mirostat_v2.restype = llama_token
@ -2607,8 +2705,10 @@ llama_sample_token_mirostat_v2.restype = llama_token
# llama_token_data_array * candidates);
def llama_sample_token_greedy(
ctx: llama_context_p,
candidates, # type: _Pointer[llama_token_data_array]
/
candidates: Union[
CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array]
],
/,
) -> int:
"""Selects the token with the highest probability."""
...
@ -2628,8 +2728,10 @@ llama_sample_token_greedy.restype = llama_token
# llama_token_data_array * candidates);
def llama_sample_token(
ctx: llama_context_p,
candidates, # type: _Pointer[llama_token_data_array]
/
candidates: Union[
CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array]
],
/,
) -> int:
"""Randomly selects a token from the candidates based on their probabilities."""
...
@ -2649,10 +2751,7 @@ llama_sample_token.restype = llama_token
# struct llama_grammar * grammar,
# llama_token token);
def llama_grammar_accept_token(
ctx: llama_context_p,
grammar: llama_grammar_p,
token: Union[llama_token, int],
/
ctx: llama_context_p, grammar: llama_grammar_p, token: Union[llama_token, int], /
) -> None:
"""Accepts the sampled token into the grammar"""
...
@ -2711,7 +2810,9 @@ class llama_beams_state(ctypes.Structure):
# // void* callback_data is any custom data passed to llama_beam_search, that is subsequently
# // passed back to beam_search_callback. This avoids having to use global variables in the callback.
# typedef void (*llama_beam_search_callback_fn_t)(void * callback_data, struct llama_beams_state);
llama_beam_search_callback_fn_t = ctypes.CFUNCTYPE(None, ctypes.c_void_p, llama_beams_state)
llama_beam_search_callback_fn_t = ctypes.CFUNCTYPE(
None, ctypes.c_void_p, llama_beams_state
)
# /// @details Deterministically returns entire sentence constructed by a beam search.
@ -2731,12 +2832,12 @@ llama_beam_search_callback_fn_t = ctypes.CFUNCTYPE(None, ctypes.c_void_p, llama_
# int32_t n_predict);
def llama_beam_search(
ctx: llama_context_p,
callback: "ctypes._CFuncPtr[None, ctypes.c_void_p, llama_beams_state]", # type: ignore
callback: CtypesFuncPointer,
callback_data: ctypes.c_void_p,
n_beams: Union[ctypes.c_size_t, int],
n_past: Union[ctypes.c_int, int],
n_predict: Union[ctypes.c_int, int],
/
/,
):
...
@ -2806,8 +2907,9 @@ llama_print_system_info.restype = ctypes.c_char_p
# // If this is not called, or NULL is supplied, everything is output on stderr.
# LLAMA_API void llama_log_set(ggml_log_callback log_callback, void * user_data);
def llama_log_set(
log_callback: Union["ctypes._FuncPointer", ctypes.c_void_p], user_data: ctypes.c_void_p, # type: ignore
/
log_callback: Optional[CtypesFuncPointer],
user_data: ctypes.c_void_p, # type: ignore
/,
):
"""Set callback for all future logging events.