Fix: runtime type errors

This commit is contained in:
Andrei Betlen 2023-05-05 14:12:26 -04:00
parent e24c3d7447
commit 3e28e0e50c

View file

@ -258,7 +258,9 @@ _lib.llama_get_state_size.restype = c_size_t
# Copies the state to the specified destination address. # Copies the state to the specified destination address.
# Destination needs to have allocated enough memory. # Destination needs to have allocated enough memory.
# Returns the number of bytes copied # Returns the number of bytes copied
def llama_copy_state_data(ctx: llama_context_p, dest: Array[c_uint8]) -> c_size_t: def llama_copy_state_data(
ctx: llama_context_p, dest # type: Array[c_uint8]
) -> c_size_t:
return _lib.llama_copy_state_data(ctx, dest) return _lib.llama_copy_state_data(ctx, dest)
@ -282,9 +284,9 @@ _lib.llama_set_state_data.restype = c_size_t
def llama_load_session_file( def llama_load_session_file(
ctx: llama_context_p, ctx: llama_context_p,
path_session: bytes, path_session: bytes,
tokens_out: Array[llama_token], tokens_out, # type: Array[llama_token]
n_token_capacity: c_size_t, n_token_capacity: c_size_t,
n_token_count_out: _Pointer[c_size_t], n_token_count_out, # type: _Pointer[c_size_t]
) -> c_size_t: ) -> c_size_t:
return _lib.llama_load_session_file( return _lib.llama_load_session_file(
ctx, path_session, tokens_out, n_token_capacity, n_token_count_out ctx, path_session, tokens_out, n_token_capacity, n_token_count_out
@ -304,7 +306,7 @@ _lib.llama_load_session_file.restype = c_size_t
def llama_save_session_file( def llama_save_session_file(
ctx: llama_context_p, ctx: llama_context_p,
path_session: bytes, path_session: bytes,
tokens: Array[llama_token], tokens, # type: Array[llama_token]
n_token_count: c_size_t, n_token_count: c_size_t,
) -> c_size_t: ) -> c_size_t:
return _lib.llama_save_session_file(ctx, path_session, tokens, n_token_count) return _lib.llama_save_session_file(ctx, path_session, tokens, n_token_count)
@ -325,7 +327,7 @@ _lib.llama_save_session_file.restype = c_size_t
# Returns 0 on success # Returns 0 on success
def llama_eval( def llama_eval(
ctx: llama_context_p, ctx: llama_context_p,
tokens: Array[llama_token], tokens, # type: Array[llama_token]
n_tokens: c_int, n_tokens: c_int,
n_past: c_int, n_past: c_int,
n_threads: c_int, n_threads: c_int,
@ -345,7 +347,7 @@ _lib.llama_eval.restype = c_int
def llama_tokenize( def llama_tokenize(
ctx: llama_context_p, ctx: llama_context_p,
text: bytes, text: bytes,
tokens: Array[llama_token], tokens, # type: Array[llama_token]
n_max_tokens: c_int, n_max_tokens: c_int,
add_bos: c_bool, add_bos: c_bool,
) -> c_int: ) -> c_int:
@ -444,8 +446,8 @@ _lib.llama_token_nl.restype = llama_token
# @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. # @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
def llama_sample_repetition_penalty( def llama_sample_repetition_penalty(
ctx: llama_context_p, ctx: llama_context_p,
candidates: _Pointer[llama_token_data_array], candidates, # type: _Pointer[llama_token_data_array]
last_tokens_data: Array[llama_token], last_tokens_data, # type: Array[llama_token]
last_tokens_size: c_int, last_tokens_size: c_int,
penalty: c_float, penalty: c_float,
): ):
@ -467,8 +469,8 @@ _lib.llama_sample_repetition_penalty.restype = None
# @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details. # @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
def llama_sample_frequency_and_presence_penalties( def llama_sample_frequency_and_presence_penalties(
ctx: llama_context_p, ctx: llama_context_p,
candidates: _Pointer[llama_token_data_array], candidates, # type: _Pointer[llama_token_data_array]
last_tokens_data: Array[llama_token], last_tokens_data, # type: Array[llama_token]
last_tokens_size: c_int, last_tokens_size: c_int,
alpha_frequency: c_float, alpha_frequency: c_float,
alpha_presence: c_float, alpha_presence: c_float,
@ -495,7 +497,9 @@ _lib.llama_sample_frequency_and_presence_penalties.restype = None
# @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits. # @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
def llama_sample_softmax(ctx: llama_context_p, candidates: _Pointer[llama_token_data]): def llama_sample_softmax(
ctx: llama_context_p, candidates # type: _Pointer[llama_token_data]
):
return _lib.llama_sample_softmax(ctx, candidates) return _lib.llama_sample_softmax(ctx, candidates)
@ -509,7 +513,7 @@ _lib.llama_sample_softmax.restype = None
# @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 # @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
def llama_sample_top_k( def llama_sample_top_k(
ctx: llama_context_p, ctx: llama_context_p,
candidates: _Pointer[llama_token_data_array], candidates, # type: _Pointer[llama_token_data_array]
k: c_int, k: c_int,
min_keep: c_size_t = c_size_t(1), min_keep: c_size_t = c_size_t(1),
): ):
@ -528,7 +532,7 @@ _lib.llama_sample_top_k.restype = None
# @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 # @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
def llama_sample_top_p( def llama_sample_top_p(
ctx: llama_context_p, ctx: llama_context_p,
candidates: _Pointer[llama_token_data_array], candidates, # type: _Pointer[llama_token_data_array]
p: c_float, p: c_float,
min_keep: c_size_t = c_size_t(1), min_keep: c_size_t = c_size_t(1),
): ):
@ -547,7 +551,7 @@ _lib.llama_sample_top_p.restype = None
# @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/. # @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
def llama_sample_tail_free( def llama_sample_tail_free(
ctx: llama_context_p, ctx: llama_context_p,
candidates: _Pointer[llama_token_data_array], candidates, # type: _Pointer[llama_token_data_array]
z: c_float, z: c_float,
min_keep: c_size_t = c_size_t(1), min_keep: c_size_t = c_size_t(1),
): ):
@ -566,7 +570,7 @@ _lib.llama_sample_tail_free.restype = None
# @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666. # @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
def llama_sample_typical( def llama_sample_typical(
ctx: llama_context_p, ctx: llama_context_p,
candidates: _Pointer[llama_token_data_array], candidates, # type: _Pointer[llama_token_data_array]
p: c_float, p: c_float,
min_keep: c_size_t = c_size_t(1), min_keep: c_size_t = c_size_t(1),
): ):
@ -583,7 +587,9 @@ _lib.llama_sample_typical.restype = None
def llama_sample_temperature( def llama_sample_temperature(
ctx: llama_context_p, candidates: _Pointer[llama_token_data_array], temp: c_float ctx: llama_context_p,
candidates, # type: _Pointer[llama_token_data_array]
temp: c_float,
): ):
return _lib.llama_sample_temperature(ctx, candidates, temp) return _lib.llama_sample_temperature(ctx, candidates, temp)
@ -604,11 +610,11 @@ _lib.llama_sample_temperature.restype = None
# @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. # @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
def llama_sample_token_mirostat( def llama_sample_token_mirostat(
ctx: llama_context_p, ctx: llama_context_p,
candidates: _Pointer[llama_token_data_array], candidates, # type: _Pointer[llama_token_data_array]
tau: c_float, tau: c_float,
eta: c_float, eta: c_float,
m: c_int, m: c_int,
mu: _Pointer[c_float], mu, # type: _Pointer[c_float]
) -> llama_token: ) -> llama_token:
return _lib.llama_sample_token_mirostat(ctx, candidates, tau, eta, m, mu) return _lib.llama_sample_token_mirostat(ctx, candidates, tau, eta, m, mu)
@ -631,10 +637,10 @@ _lib.llama_sample_token_mirostat.restype = llama_token
# @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. # @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
def llama_sample_token_mirostat_v2( def llama_sample_token_mirostat_v2(
ctx: llama_context_p, ctx: llama_context_p,
candidates: _Pointer[llama_token_data_array], candidates, # type: _Pointer[llama_token_data_array]
tau: c_float, tau: c_float,
eta: c_float, eta: c_float,
mu: _Pointer[c_float], mu, # type: _Pointer[c_float]
) -> llama_token: ) -> llama_token:
return _lib.llama_sample_token_mirostat_v2(ctx, candidates, tau, eta, mu) return _lib.llama_sample_token_mirostat_v2(ctx, candidates, tau, eta, mu)
@ -651,7 +657,8 @@ _lib.llama_sample_token_mirostat_v2.restype = llama_token
# @details Selects the token with the highest probability. # @details Selects the token with the highest probability.
def llama_sample_token_greedy( def llama_sample_token_greedy(
ctx: llama_context_p, candidates: _Pointer[llama_token_data_array] ctx: llama_context_p,
candidates, # type: _Pointer[llama_token_data_array]
) -> llama_token: ) -> llama_token:
return _lib.llama_sample_token_greedy(ctx, candidates) return _lib.llama_sample_token_greedy(ctx, candidates)
@ -665,7 +672,8 @@ _lib.llama_sample_token_greedy.restype = llama_token
# @details Randomly selects a token from the candidates based on their probabilities. # @details Randomly selects a token from the candidates based on their probabilities.
def llama_sample_token( def llama_sample_token(
ctx: llama_context_p, candidates: _Pointer[llama_token_data_array] ctx: llama_context_p,
candidates, # type: _Pointer[llama_token_data_array]
) -> llama_token: ) -> llama_token:
return _lib.llama_sample_token(ctx, candidates) return _lib.llama_sample_token(ctx, candidates)