docs: Improve low-level docstrings
This commit is contained in:
parent
9c68b1804a
commit
396dbf0b2b
1 changed files with 95 additions and 2 deletions
|
@ -212,6 +212,12 @@ LLAMA_ROPE_SCALING_MAX_VALUE = LLAMA_ROPE_SCALING_YARN
|
|||
# float p; // probability of the token
|
||||
# } llama_token_data;
|
||||
class llama_token_data(Structure):
|
||||
"""Used to store token data
|
||||
|
||||
Attributes:
|
||||
id (llama_token): token id
|
||||
logit (float): log-odds of the token
|
||||
p (float): probability of the token"""
|
||||
_fields_ = [
|
||||
("id", llama_token),
|
||||
("logit", c_float),
|
||||
|
@ -228,6 +234,12 @@ llama_token_data_p = POINTER(llama_token_data)
|
|||
# bool sorted;
|
||||
# } llama_token_data_array;
|
||||
class llama_token_data_array(Structure):
|
||||
"""Used to sample tokens given logits
|
||||
|
||||
Attributes:
|
||||
data (ctypes.Array[llama_token_data]): token data
|
||||
size (int): size of the array
|
||||
sorted (bool): whether the array is sorted"""
|
||||
_fields_ = [
|
||||
("data", llama_token_data_p),
|
||||
("size", c_size_t),
|
||||
|
@ -282,8 +294,7 @@ class llama_batch(Structure):
|
|||
token (ctypes.Array[llama_token]): the token ids of the input (used when embd is NULL)
|
||||
embd (ctypes.Array[ctypes.c_float]): token embeddings (i.e. float vector of size n_embd) (used when token is NULL)
|
||||
pos (ctypes.Array[ctypes.Array[llama_pos]]): the positions of the respective token in the sequence
|
||||
seq_id (ctypes.Array[ctypes.Array[llama_seq_id]]): the sequence to which the respective token belongs
|
||||
"""
|
||||
seq_id (ctypes.Array[ctypes.Array[llama_seq_id]]): the sequence to which the respective token belongs"""
|
||||
|
||||
_fields_ = [
|
||||
("n_tokens", c_int32),
|
||||
|
@ -316,6 +327,17 @@ class llama_batch(Structure):
|
|||
# bool use_mlock; // force system to keep model in RAM
|
||||
# };
|
||||
class llama_model_params(Structure):
|
||||
"""Parameters for llama_model
|
||||
|
||||
Attributes:
|
||||
n_gpu_layers (int): number of layers to store in VRAM
|
||||
main_gpu (int): the GPU that is used for scratch and small tensors
|
||||
tensor_split (ctypes.Array[ctypes.c_float]): how to split layers across multiple GPUs (size: LLAMA_MAX_DEVICES)
|
||||
progress_callback (llama_progress_callback): called with a progress value between 0 and 1, pass NULL to disable
|
||||
progress_callback_user_data (ctypes.c_void_p): context pointer passed to the progress callback
|
||||
vocab_only (bool): only load the vocabulary, no weights
|
||||
use_mmap (bool): use mmap if possible
|
||||
use_mlock (bool): force system to keep model in RAM"""
|
||||
_fields_ = [
|
||||
("n_gpu_layers", c_int32),
|
||||
("main_gpu", c_int32),
|
||||
|
@ -353,6 +375,26 @@ class llama_model_params(Structure):
|
|||
# bool embedding; // embedding mode only
|
||||
# };
|
||||
class llama_context_params(Structure):
|
||||
"""Parameters for llama_context
|
||||
|
||||
Attributes:
|
||||
seed (int): RNG seed, -1 for random
|
||||
n_ctx (int): text context, 0 = from model
|
||||
n_batch (int): prompt processing maximum batch size
|
||||
n_threads (int): number of threads to use for generation
|
||||
n_threads_batch (int): number of threads to use for batch processing
|
||||
rope_scaling_type (int): RoPE scaling type, from `enum llama_rope_scaling_type`
|
||||
rope_freq_base (float): RoPE base frequency, 0 = from model
|
||||
rope_freq_scale (float): RoPE frequency scaling factor, 0 = from model
|
||||
yarn_ext_factor (float): YaRN extrapolation mix factor, negative = from model
|
||||
yarn_attn_factor (float): YaRN magnitude scaling factor
|
||||
yarn_beta_fast (float): YaRN low correction dim
|
||||
yarn_beta_slow (float): YaRN high correction dim
|
||||
yarn_orig_ctx (int): YaRN original context size
|
||||
mul_mat_q (bool): if true, use experimental mul_mat_q kernels (DEPRECATED - always true)
|
||||
f16_kv (bool): use fp16 for KV cache, fp32 otherwise
|
||||
logits_all (bool): the llama_eval() call computes all logits, not just the last one
|
||||
embedding (bool): embedding mode only"""
|
||||
_fields_ = [
|
||||
("seed", c_uint32),
|
||||
("n_ctx", c_uint32),
|
||||
|
@ -398,6 +440,15 @@ It might not exist for progress report where '.' is output repeatedly."""
|
|||
# bool pure; // disable k-quant mixtures and quantize all tensors to the same type
|
||||
# } llama_model_quantize_params;
|
||||
class llama_model_quantize_params(Structure):
|
||||
"""Parameters for llama_model_quantize
|
||||
|
||||
Attributes:
|
||||
nthread (int): number of threads to use for quantizing, if <=0 will use std::thread::hardware_concurrency()
|
||||
ftype (int): quantize to this llama_ftype
|
||||
allow_requantize (bool): allow quantizing non-f32/f16 tensors
|
||||
quantize_output_tensor (bool): quantize output.weight
|
||||
only_copy (bool): only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored
|
||||
pure (bool): disable k-quant mixtures and quantize all tensors to the same type"""
|
||||
_fields_ = [
|
||||
("nthread", c_int),
|
||||
("ftype", c_int),
|
||||
|
@ -489,6 +540,7 @@ class llama_timings(Structure):
|
|||
# // Helpers for getting default parameters
|
||||
# LLAMA_API struct llama_model_params llama_model_default_params(void);
|
||||
def llama_model_default_params() -> llama_model_params:
|
||||
"""Get default parameters for llama_model"""
|
||||
return _lib.llama_model_default_params()
|
||||
|
||||
|
||||
|
@ -498,6 +550,7 @@ _lib.llama_model_default_params.restype = llama_model_params
|
|||
|
||||
# LLAMA_API struct llama_context_params llama_context_default_params(void);
|
||||
def llama_context_default_params() -> llama_context_params:
|
||||
"""Get default parameters for llama_context"""
|
||||
return _lib.llama_context_default_params()
|
||||
|
||||
|
||||
|
@ -507,6 +560,7 @@ _lib.llama_context_default_params.restype = llama_context_params
|
|||
|
||||
# LLAMA_API struct llama_model_quantize_params llama_model_quantize_default_params(void);
|
||||
def llama_model_quantize_default_params() -> llama_model_quantize_params:
|
||||
"""Get default parameters for llama_model_quantize"""
|
||||
return _lib.llama_model_quantize_default_params()
|
||||
|
||||
|
||||
|
@ -1668,6 +1722,7 @@ def llama_grammar_init(
|
|||
n_rules: Union[c_size_t, int],
|
||||
start_rule_index: Union[c_size_t, int],
|
||||
) -> llama_grammar_p:
|
||||
"""Initialize a grammar from a set of rules."""
|
||||
return _lib.llama_grammar_init(rules, n_rules, start_rule_index)
|
||||
|
||||
|
||||
|
@ -1681,6 +1736,7 @@ _lib.llama_grammar_init.restype = llama_grammar_p
|
|||
|
||||
# LLAMA_API void llama_grammar_free(struct llama_grammar * grammar);
|
||||
def llama_grammar_free(grammar: llama_grammar_p):
|
||||
"""Free a grammar."""
|
||||
return _lib.llama_grammar_free(grammar)
|
||||
|
||||
|
||||
|
@ -1690,6 +1746,7 @@ _lib.llama_grammar_free.restype = None
|
|||
|
||||
# LLAMA_API struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar);
|
||||
def llama_grammar_copy(grammar: llama_grammar_p) -> llama_grammar_p:
|
||||
"""Copy a grammar."""
|
||||
return _lib.llama_grammar_copy(grammar)
|
||||
|
||||
|
||||
|
@ -1939,6 +1996,11 @@ def llama_sample_temp(
|
|||
candidates, # type: _Pointer[llama_token_data_array]
|
||||
temp: Union[c_float, float],
|
||||
):
|
||||
"""Temperature sampling described in academic paper "Generating Long Sequences with Sparse Transformers" https://arxiv.org/abs/1904.10509
|
||||
|
||||
Parameters:
|
||||
candidates: A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
|
||||
temp: The temperature value to use for the sampling. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text."""
|
||||
return _lib.llama_sample_temp(ctx, candidates, temp)
|
||||
|
||||
|
||||
|
@ -1960,6 +2022,7 @@ def llama_sample_temperature(
|
|||
candidates, # type: _Pointer[llama_token_data_array]
|
||||
temp: Union[c_float, float],
|
||||
):
|
||||
"""use llama_sample_temp instead"""
|
||||
return _lib.llama_sample_temperature(ctx, candidates, temp)
|
||||
|
||||
|
||||
|
@ -1981,6 +2044,11 @@ def llama_sample_grammar(
|
|||
candidates, # type: _Pointer[llama_token_data_array]
|
||||
grammar, # type: llama_grammar_p
|
||||
):
|
||||
"""Apply constraints from grammar
|
||||
|
||||
Parameters:
|
||||
candidates: A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
|
||||
grammar: A grammar object containing the rules and constraints to apply to the generated text."""
|
||||
return _lib.llama_sample_grammar(ctx, candidates, grammar)
|
||||
|
||||
|
||||
|
@ -2013,6 +2081,14 @@ def llama_sample_token_mirostat(
|
|||
m: Union[c_int, int],
|
||||
mu, # type: _Pointer[c_float]
|
||||
) -> int:
|
||||
"""Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
|
||||
|
||||
Parameters:
|
||||
candidates: A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
|
||||
tau: The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
|
||||
eta: The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
|
||||
m: The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm.
|
||||
mu: Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal."""
|
||||
return _lib.llama_sample_token_mirostat(ctx, candidates, tau, eta, m, mu)
|
||||
|
||||
|
||||
|
@ -2045,6 +2121,13 @@ def llama_sample_token_mirostat_v2(
|
|||
eta: Union[c_float, float],
|
||||
mu, # type: _Pointer[c_float]
|
||||
) -> int:
|
||||
"""Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
|
||||
|
||||
Parameters:
|
||||
candidates: A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
|
||||
tau: The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
|
||||
eta: The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
|
||||
mu: Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal."""
|
||||
return _lib.llama_sample_token_mirostat_v2(ctx, candidates, tau, eta, mu)
|
||||
|
||||
|
||||
|
@ -2067,6 +2150,7 @@ def llama_sample_token_greedy(
|
|||
ctx: llama_context_p,
|
||||
candidates, # type: _Pointer[llama_token_data_array]
|
||||
) -> int:
|
||||
"""Selects the token with the highest probability."""
|
||||
return _lib.llama_sample_token_greedy(ctx, candidates)
|
||||
|
||||
|
||||
|
@ -2085,6 +2169,7 @@ def llama_sample_token(
|
|||
ctx: llama_context_p,
|
||||
candidates, # type: _Pointer[llama_token_data_array]
|
||||
) -> int:
|
||||
"""Randomly selects a token from the candidates based on their probabilities."""
|
||||
return _lib.llama_sample_token(ctx, candidates)
|
||||
|
||||
|
||||
|
@ -2105,6 +2190,7 @@ def llama_grammar_accept_token(
|
|||
grammar: llama_grammar_p,
|
||||
token: Union[llama_token, int],
|
||||
) -> None:
|
||||
"""Accepts the sampled token into the grammar"""
|
||||
_lib.llama_grammar_accept_token(ctx, grammar, token)
|
||||
|
||||
|
||||
|
@ -2207,6 +2293,7 @@ _lib.llama_beam_search.restype = None
|
|||
|
||||
# LLAMA_API struct llama_timings llama_get_timings(struct llama_context * ctx);
|
||||
def llama_get_timings(ctx: llama_context_p) -> llama_timings:
|
||||
"""Get performance information"""
|
||||
return _lib.llama_get_timings(ctx)
|
||||
|
||||
|
||||
|
@ -2216,6 +2303,7 @@ _lib.llama_get_timings.restype = llama_timings
|
|||
|
||||
# LLAMA_API void llama_print_timings(struct llama_context * ctx);
|
||||
def llama_print_timings(ctx: llama_context_p):
|
||||
"""Print performance information"""
|
||||
_lib.llama_print_timings(ctx)
|
||||
|
||||
|
||||
|
@ -2225,6 +2313,7 @@ _lib.llama_print_timings.restype = None
|
|||
|
||||
# LLAMA_API void llama_reset_timings(struct llama_context * ctx);
|
||||
def llama_reset_timings(ctx: llama_context_p):
|
||||
"""Reset performance information"""
|
||||
_lib.llama_reset_timings(ctx)
|
||||
|
||||
|
||||
|
@ -2235,6 +2324,7 @@ _lib.llama_reset_timings.restype = None
|
|||
# Print system information
|
||||
# LLAMA_API const char * llama_print_system_info(void);
|
||||
def llama_print_system_info() -> bytes:
|
||||
"""Print system information"""
|
||||
return _lib.llama_print_system_info()
|
||||
|
||||
|
||||
|
@ -2249,6 +2339,9 @@ _lib.llama_print_system_info.restype = c_char_p
|
|||
def llama_log_set(
|
||||
log_callback: "ctypes._FuncPointer", user_data: c_void_p # type: ignore
|
||||
):
|
||||
"""Set callback for all future logging events.
|
||||
|
||||
If this is not called, or NULL is supplied, everything is output on stderr."""
|
||||
return _lib.llama_log_set(log_callback, user_data)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in a new issue