Fix array type signatures

This commit is contained in:
Andrei Betlen 2023-03-31 02:08:20 -04:00
parent 4b9eb5c19e
commit 1545b22727

View file

@ -116,7 +116,7 @@ _lib.llama_model_quantize.restype = c_int
# Returns 0 on success # Returns 0 on success
def llama_eval( def llama_eval(
ctx: llama_context_p, ctx: llama_context_p,
tokens: llama_token_p, tokens: ctypes.Array[llama_token],
n_tokens: c_int, n_tokens: c_int,
n_past: c_int, n_past: c_int,
n_threads: c_int, n_threads: c_int,
@ -136,7 +136,7 @@ _lib.llama_eval.restype = c_int
def llama_tokenize( def llama_tokenize(
ctx: llama_context_p, ctx: llama_context_p,
text: bytes, text: bytes,
tokens: llama_token_p, tokens: ctypes.Array[llama_token],
n_max_tokens: c_int, n_max_tokens: c_int,
add_bos: c_bool, add_bos: c_bool,
) -> c_int: ) -> c_int:
@ -176,7 +176,7 @@ _lib.llama_n_embd.restype = c_int
# Can be mutated in order to change the probabilities of the next token # Can be mutated in order to change the probabilities of the next token
# Rows: n_tokens # Rows: n_tokens
# Cols: n_vocab # Cols: n_vocab
def llama_get_logits(ctx: llama_context_p): def llama_get_logits(ctx: llama_context_p) -> ctypes.Array[c_float]:
return _lib.llama_get_logits(ctx) return _lib.llama_get_logits(ctx)
@ -186,7 +186,7 @@ _lib.llama_get_logits.restype = POINTER(c_float)
# Get the embeddings for the input # Get the embeddings for the input
# shape: [n_embd] (1-dimensional) # shape: [n_embd] (1-dimensional)
def llama_get_embeddings(ctx: llama_context_p): def llama_get_embeddings(ctx: llama_context_p) -> ctypes.Array[c_float]:
return _lib.llama_get_embeddings(ctx) return _lib.llama_get_embeddings(ctx)
@ -224,7 +224,7 @@ _lib.llama_token_eos.restype = llama_token
# TODO: improve the last_n_tokens interface ? # TODO: improve the last_n_tokens interface ?
def llama_sample_top_p_top_k( def llama_sample_top_p_top_k(
ctx: llama_context_p, ctx: llama_context_p,
last_n_tokens_data: llama_token_p, last_n_tokens_data: ctypes.Array[llama_token],
last_n_tokens_size: c_int, last_n_tokens_size: c_int,
top_k: c_int, top_k: c_int,
top_p: c_float, top_p: c_float,