Add types for all low-level api functions
This commit is contained in:
parent
5be0efa5f8
commit
b6a9a0b6ba
2 changed files with 62 additions and 21 deletions
|
@ -52,7 +52,7 @@ class LlamaState:
|
|||
self,
|
||||
eval_tokens: Deque[llama_cpp.llama_token],
|
||||
eval_logits: Deque[List[llama_cpp.c_float]],
|
||||
llama_state,
|
||||
llama_state, # type: llama_cpp.Array[llama_cpp.c_uint8]
|
||||
llama_state_size: llama_cpp.c_size_t,
|
||||
):
|
||||
self.eval_tokens = eval_tokens
|
||||
|
|
|
@ -17,7 +17,7 @@ import pathlib
|
|||
|
||||
|
||||
# Load the library
|
||||
def _load_shared_library(lib_base_name):
|
||||
def _load_shared_library(lib_base_name: str):
|
||||
# Determine the file extension based on the platform
|
||||
if sys.platform.startswith("linux"):
|
||||
lib_ext = ".so"
|
||||
|
@ -252,7 +252,9 @@ _lib.llama_get_state_size.restype = c_size_t
|
|||
# Copies the state to the specified destination address.
|
||||
# Destination needs to have allocated enough memory.
|
||||
# Returns the number of bytes copied
|
||||
def llama_copy_state_data(ctx: llama_context_p, dest) -> c_size_t:
|
||||
def llama_copy_state_data(
|
||||
ctx: llama_context_p, dest # type: Array[c_uint8]
|
||||
) -> c_size_t:
|
||||
return _lib.llama_copy_state_data(ctx, dest)
|
||||
|
||||
|
||||
|
@ -262,7 +264,9 @@ _lib.llama_copy_state_data.restype = c_size_t
|
|||
|
||||
# Set the state reading from the specified address
|
||||
# Returns the number of bytes read
|
||||
def llama_set_state_data(ctx: llama_context_p, src) -> c_size_t:
|
||||
def llama_set_state_data(
|
||||
ctx: llama_context_p, src # type: Array[c_uint8]
|
||||
) -> c_size_t:
|
||||
return _lib.llama_set_state_data(ctx, src)
|
||||
|
||||
|
||||
|
@ -274,9 +278,9 @@ _lib.llama_set_state_data.restype = c_size_t
|
|||
def llama_load_session_file(
|
||||
ctx: llama_context_p,
|
||||
path_session: bytes,
|
||||
tokens_out,
|
||||
tokens_out, # type: Array[llama_token]
|
||||
n_token_capacity: c_size_t,
|
||||
n_token_count_out,
|
||||
n_token_count_out, # type: Array[c_size_t]
|
||||
) -> c_size_t:
|
||||
return _lib.llama_load_session_file(
|
||||
ctx, path_session, tokens_out, n_token_capacity, n_token_count_out
|
||||
|
@ -294,7 +298,10 @@ _lib.llama_load_session_file.restype = c_size_t
|
|||
|
||||
|
||||
def llama_save_session_file(
|
||||
ctx: llama_context_p, path_session: bytes, tokens, n_token_count: c_size_t
|
||||
ctx: llama_context_p,
|
||||
path_session: bytes,
|
||||
tokens, # type: Array[llama_token]
|
||||
n_token_count: c_size_t,
|
||||
) -> c_size_t:
|
||||
return _lib.llama_save_session_file(ctx, path_session, tokens, n_token_count)
|
||||
|
||||
|
@ -433,8 +440,8 @@ _lib.llama_token_nl.restype = llama_token
|
|||
# @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
|
||||
def llama_sample_repetition_penalty(
|
||||
ctx: llama_context_p,
|
||||
candidates,
|
||||
last_tokens_data,
|
||||
candidates, # type: Array[llama_token_data]
|
||||
last_tokens_data, # type: Array[llama_token]
|
||||
last_tokens_size: c_int,
|
||||
penalty: c_float,
|
||||
):
|
||||
|
@ -456,8 +463,8 @@ _lib.llama_sample_repetition_penalty.restype = None
|
|||
# @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
|
||||
def llama_sample_frequency_and_presence_penalties(
|
||||
ctx: llama_context_p,
|
||||
candidates,
|
||||
last_tokens_data,
|
||||
candidates, # type: Array[llama_token_data]
|
||||
last_tokens_data, # type: Array[llama_token]
|
||||
last_tokens_size: c_int,
|
||||
alpha_frequency: c_float,
|
||||
alpha_presence: c_float,
|
||||
|
@ -484,7 +491,10 @@ _lib.llama_sample_frequency_and_presence_penalties.restype = None
|
|||
|
||||
|
||||
# @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
|
||||
def llama_sample_softmax(ctx: llama_context_p, candidates):
|
||||
def llama_sample_softmax(
|
||||
ctx: llama_context_p,
|
||||
candidates # type: Array[llama_token_data]
|
||||
):
|
||||
return _lib.llama_sample_softmax(ctx, candidates)
|
||||
|
||||
|
||||
|
@ -497,7 +507,10 @@ _lib.llama_sample_softmax.restype = None
|
|||
|
||||
# @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
|
||||
def llama_sample_top_k(
|
||||
ctx: llama_context_p, candidates, k: c_int, min_keep: c_size_t = c_size_t(1)
|
||||
ctx: llama_context_p,
|
||||
candidates, # type: Array[llama_token_data]
|
||||
k: c_int,
|
||||
min_keep: c_size_t = c_size_t(1)
|
||||
):
|
||||
return _lib.llama_sample_top_k(ctx, candidates, k, min_keep)
|
||||
|
||||
|
@ -513,7 +526,10 @@ _lib.llama_sample_top_k.restype = None
|
|||
|
||||
# @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
|
||||
def llama_sample_top_p(
|
||||
ctx: llama_context_p, candidates, p: c_float, min_keep: c_size_t = c_size_t(1)
|
||||
ctx: llama_context_p,
|
||||
candidates, # type: Array[llama_token_data]
|
||||
p: c_float,
|
||||
min_keep: c_size_t = c_size_t(1)
|
||||
):
|
||||
return _lib.llama_sample_top_p(ctx, candidates, p, min_keep)
|
||||
|
||||
|
@ -529,7 +545,10 @@ _lib.llama_sample_top_p.restype = None
|
|||
|
||||
# @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
|
||||
def llama_sample_tail_free(
|
||||
ctx: llama_context_p, candidates, z: c_float, min_keep: c_size_t = c_size_t(1)
|
||||
ctx: llama_context_p,
|
||||
candidates, # type: Array[llama_token_data]
|
||||
z: c_float,
|
||||
min_keep: c_size_t = c_size_t(1)
|
||||
):
|
||||
return _lib.llama_sample_tail_free(ctx, candidates, z, min_keep)
|
||||
|
||||
|
@ -545,7 +564,10 @@ _lib.llama_sample_tail_free.restype = None
|
|||
|
||||
# @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
|
||||
def llama_sample_typical(
|
||||
ctx: llama_context_p, candidates, p: c_float, min_keep: c_size_t = c_size_t(1)
|
||||
ctx: llama_context_p,
|
||||
candidates, # type: Array[llama_token_data]
|
||||
p: c_float,
|
||||
min_keep: c_size_t = c_size_t(1)
|
||||
):
|
||||
return _lib.llama_sample_typical(ctx, candidates, p, min_keep)
|
||||
|
||||
|
@ -559,7 +581,11 @@ _lib.llama_sample_typical.argtypes = [
|
|||
_lib.llama_sample_typical.restype = None
|
||||
|
||||
|
||||
def llama_sample_temperature(ctx: llama_context_p, candidates, temp: c_float):
|
||||
def llama_sample_temperature(
|
||||
ctx: llama_context_p,
|
||||
candidates, # type: Array[llama_token_data]
|
||||
temp: c_float
|
||||
):
|
||||
return _lib.llama_sample_temperature(ctx, candidates, temp)
|
||||
|
||||
|
||||
|
@ -578,7 +604,12 @@ _lib.llama_sample_temperature.restype = None
|
|||
# @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm.
|
||||
# @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
|
||||
def llama_sample_token_mirostat(
|
||||
ctx: llama_context_p, candidates, tau: c_float, eta: c_float, m: c_int, mu
|
||||
ctx: llama_context_p,
|
||||
candidates, # type: Array[llama_token_data]
|
||||
tau: c_float,
|
||||
eta: c_float,
|
||||
m: c_int,
|
||||
mu # type: Array[c_float]
|
||||
) -> llama_token:
|
||||
return _lib.llama_sample_token_mirostat(ctx, candidates, tau, eta, m, mu)
|
||||
|
||||
|
@ -600,7 +631,11 @@ _lib.llama_sample_token_mirostat.restype = llama_token
|
|||
# @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
|
||||
# @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
|
||||
def llama_sample_token_mirostat_v2(
|
||||
ctx: llama_context_p, candidates, tau: c_float, eta: c_float, mu
|
||||
ctx: llama_context_p,
|
||||
candidates, # type: Array[llama_token_data]
|
||||
tau: c_float,
|
||||
eta: c_float,
|
||||
mu # type: Array[c_float]
|
||||
) -> llama_token:
|
||||
return _lib.llama_sample_token_mirostat_v2(ctx, candidates, tau, eta, mu)
|
||||
|
||||
|
@ -616,7 +651,10 @@ _lib.llama_sample_token_mirostat_v2.restype = llama_token
|
|||
|
||||
|
||||
# @details Selects the token with the highest probability.
|
||||
def llama_sample_token_greedy(ctx: llama_context_p, candidates) -> llama_token:
|
||||
def llama_sample_token_greedy(
|
||||
ctx: llama_context_p,
|
||||
candidates # type: Array[llama_token_data]
|
||||
) -> llama_token:
|
||||
return _lib.llama_sample_token_greedy(ctx, candidates)
|
||||
|
||||
|
||||
|
@ -628,7 +666,10 @@ _lib.llama_sample_token_greedy.restype = llama_token
|
|||
|
||||
|
||||
# @details Randomly selects a token from the candidates based on their probabilities.
|
||||
def llama_sample_token(ctx: llama_context_p, candidates) -> llama_token:
|
||||
def llama_sample_token(
|
||||
ctx: llama_context_p,
|
||||
candidates # type: Array[llama_token_data]
|
||||
) -> llama_token:
|
||||
return _lib.llama_sample_token(ctx, candidates)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in a new issue