Fix return types and import comments
This commit is contained in:
parent
55d6308537
commit
7837c3fdc7
1 changed files with 38 additions and 34 deletions
|
@ -427,13 +427,16 @@ _lib.llama_token_nl.restype = llama_token
|
|||
|
||||
|
||||
# Sampling functions
|
||||
|
||||
|
||||
# @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
|
||||
def llama_sample_repetition_penalty(
|
||||
ctx: llama_context_p,
|
||||
candidates,
|
||||
last_tokens_data,
|
||||
last_tokens_size: c_int,
|
||||
penalty: c_float,
|
||||
) -> llama_token:
|
||||
):
|
||||
return _lib.llama_sample_repetition_penalty(
|
||||
ctx, candidates, last_tokens_data, last_tokens_size, penalty
|
||||
)
|
||||
|
@ -446,10 +449,10 @@ _lib.llama_sample_repetition_penalty.argtypes = [
|
|||
c_int,
|
||||
c_float,
|
||||
]
|
||||
_lib.llama_sample_repetition_penalty.restype = llama_token
|
||||
_lib.llama_sample_repetition_penalty.restype = None
|
||||
|
||||
|
||||
# LLAMA_API void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, llama_token_data_array * candidates, llama_token * last_tokens, size_t last_tokens_size, float alpha_frequency, float alpha_presence);
|
||||
# @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
|
||||
def llama_sample_frequency_and_presence_penalties(
|
||||
ctx: llama_context_p,
|
||||
candidates,
|
||||
|
@ -457,7 +460,7 @@ def llama_sample_frequency_and_presence_penalties(
|
|||
last_tokens_size: c_int,
|
||||
alpha_frequency: c_float,
|
||||
alpha_presence: c_float,
|
||||
) -> llama_token:
|
||||
):
|
||||
return _lib.llama_sample_frequency_and_presence_penalties(
|
||||
ctx,
|
||||
candidates,
|
||||
|
@ -476,11 +479,11 @@ _lib.llama_sample_frequency_and_presence_penalties.argtypes = [
|
|||
c_float,
|
||||
c_float,
|
||||
]
|
||||
_lib.llama_sample_frequency_and_presence_penalties.restype = llama_token
|
||||
_lib.llama_sample_frequency_and_presence_penalties.restype = None
|
||||
|
||||
|
||||
# LLAMA_API void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates);
|
||||
def llama_sample_softmax(ctx: llama_context_p, candidates) -> llama_token:
|
||||
# @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
|
||||
def llama_sample_softmax(ctx: llama_context_p, candidates):
|
||||
return _lib.llama_sample_softmax(ctx, candidates)
|
||||
|
||||
|
||||
|
@ -488,13 +491,11 @@ _lib.llama_sample_softmax.argtypes = [
|
|||
llama_context_p,
|
||||
llama_token_data_array_p,
|
||||
]
|
||||
_lib.llama_sample_softmax.restype = llama_token
|
||||
_lib.llama_sample_softmax.restype = None
|
||||
|
||||
|
||||
# LLAMA_API void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * candidates, int k, size_t min_keep = 1);
|
||||
def llama_sample_top_k(
|
||||
ctx: llama_context_p, candidates, k: c_int, min_keep: c_int
|
||||
) -> llama_token:
|
||||
# @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
|
||||
def llama_sample_top_k(ctx: llama_context_p, candidates, k: c_int, min_keep: c_int):
|
||||
return _lib.llama_sample_top_k(ctx, candidates, k, min_keep)
|
||||
|
||||
|
||||
|
@ -504,12 +505,11 @@ _lib.llama_sample_top_k.argtypes = [
|
|||
c_int,
|
||||
c_int,
|
||||
]
|
||||
_lib.llama_sample_top_k.restype = None
|
||||
|
||||
|
||||
# LLAMA_API void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep = 1);
|
||||
def llama_sample_top_p(
|
||||
ctx: llama_context_p, candidates, p: c_float, min_keep: c_int
|
||||
) -> llama_token:
|
||||
# @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
|
||||
def llama_sample_top_p(ctx: llama_context_p, candidates, p: c_float, min_keep: c_int):
|
||||
return _lib.llama_sample_top_p(ctx, candidates, p, min_keep)
|
||||
|
||||
|
||||
|
@ -519,13 +519,13 @@ _lib.llama_sample_top_p.argtypes = [
|
|||
c_float,
|
||||
c_int,
|
||||
]
|
||||
_lib.llama_sample_top_p.restype = llama_token
|
||||
_lib.llama_sample_top_p.restype = None
|
||||
|
||||
|
||||
# LLAMA_API void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep = 1);
|
||||
# @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
|
||||
def llama_sample_tail_free(
|
||||
ctx: llama_context_p, candidates, z: c_float, min_keep: c_int
|
||||
) -> llama_token:
|
||||
):
|
||||
return _lib.llama_sample_tail_free(ctx, candidates, z, min_keep)
|
||||
|
||||
|
||||
|
@ -535,13 +535,11 @@ _lib.llama_sample_tail_free.argtypes = [
|
|||
c_float,
|
||||
c_int,
|
||||
]
|
||||
_lib.llama_sample_tail_free.restype = llama_token
|
||||
_lib.llama_sample_tail_free.restype = None
|
||||
|
||||
|
||||
# LLAMA_API void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep = 1);
|
||||
def llama_sample_typical(
|
||||
ctx: llama_context_p, candidates, p: c_float, min_keep: c_int
|
||||
) -> llama_token:
|
||||
# @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
|
||||
def llama_sample_typical(ctx: llama_context_p, candidates, p: c_float, min_keep: c_int):
|
||||
return _lib.llama_sample_typical(ctx, candidates, p, min_keep)
|
||||
|
||||
|
||||
|
@ -551,13 +549,10 @@ _lib.llama_sample_typical.argtypes = [
|
|||
c_float,
|
||||
c_int,
|
||||
]
|
||||
_lib.llama_sample_typical.restype = llama_token
|
||||
_lib.llama_sample_typical.restype = None
|
||||
|
||||
|
||||
# LLAMA_API void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array * candidates, float temp);
|
||||
def llama_sample_temperature(
|
||||
ctx: llama_context_p, candidates, temp: c_float
|
||||
) -> llama_token:
|
||||
def llama_sample_temperature(ctx: llama_context_p, candidates, temp: c_float):
|
||||
return _lib.llama_sample_temperature(ctx, candidates, temp)
|
||||
|
||||
|
||||
|
@ -566,10 +561,15 @@ _lib.llama_sample_temperature.argtypes = [
|
|||
llama_token_data_array_p,
|
||||
c_float,
|
||||
]
|
||||
_lib.llama_sample_temperature.restype = llama_token
|
||||
_lib.llama_sample_temperature.restype = None
|
||||
|
||||
|
||||
# LLAMA_API llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int m, float * mu);
|
||||
# @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
|
||||
# @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
|
||||
# @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
|
||||
# @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
|
||||
# @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm.
|
||||
# @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
|
||||
def llama_sample_token_mirostat(
|
||||
ctx: llama_context_p, candidates, tau: c_float, eta: c_float, m: c_int, mu
|
||||
) -> llama_token:
|
||||
|
@ -587,7 +587,11 @@ _lib.llama_sample_token_mirostat.argtypes = [
|
|||
_lib.llama_sample_token_mirostat.restype = llama_token
|
||||
|
||||
|
||||
# LLAMA_API llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, float * mu);
|
||||
# @details Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
|
||||
# @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
|
||||
# @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
|
||||
# @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
|
||||
# @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
|
||||
def llama_sample_token_mirostat_v2(
|
||||
ctx: llama_context_p, candidates, tau: c_float, eta: c_float, mu
|
||||
) -> llama_token:
|
||||
|
@ -604,7 +608,7 @@ _lib.llama_sample_token_mirostat_v2.argtypes = [
|
|||
_lib.llama_sample_token_mirostat_v2.restype = llama_token
|
||||
|
||||
|
||||
# LLAMA_API llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_data_array * candidates);
|
||||
# @details Selects the token with the highest probability.
|
||||
def llama_sample_token_greedy(ctx: llama_context_p, candidates) -> llama_token:
|
||||
return _lib.llama_sample_token_greedy(ctx, candidates)
|
||||
|
||||
|
@ -616,7 +620,7 @@ _lib.llama_sample_token_greedy.argtypes = [
|
|||
_lib.llama_sample_token_greedy.restype = llama_token
|
||||
|
||||
|
||||
# LLAMA_API llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates);
|
||||
# @details Randomly selects a token from the candidates based on their probabilities.
|
||||
def llama_sample_token(ctx: llama_context_p, candidates) -> llama_token:
|
||||
return _lib.llama_sample_token(ctx, candidates)
|
||||
|
||||
|
|
Loading…
Reference in a new issue