From 1545b22727f53abcf42b8737e1eea21fca34a50a Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Fri, 31 Mar 2023 02:08:20 -0400 Subject: [PATCH] Fix array type signatures --- llama_cpp/llama_cpp.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/llama_cpp/llama_cpp.py b/llama_cpp/llama_cpp.py index 66c60af..5980430 100644 --- a/llama_cpp/llama_cpp.py +++ b/llama_cpp/llama_cpp.py @@ -116,7 +116,7 @@ _lib.llama_model_quantize.restype = c_int # Returns 0 on success def llama_eval( ctx: llama_context_p, - tokens: llama_token_p, + tokens: ctypes.Array[llama_token], n_tokens: c_int, n_past: c_int, n_threads: c_int, @@ -136,7 +136,7 @@ _lib.llama_eval.restype = c_int def llama_tokenize( ctx: llama_context_p, text: bytes, - tokens: llama_token_p, + tokens: ctypes.Array[llama_token], n_max_tokens: c_int, add_bos: c_bool, ) -> c_int: @@ -176,7 +176,7 @@ _lib.llama_n_embd.restype = c_int # Can be mutated in order to change the probabilities of the next token # Rows: n_tokens # Cols: n_vocab -def llama_get_logits(ctx: llama_context_p): +def llama_get_logits(ctx: llama_context_p) -> ctypes.Array[c_float]: return _lib.llama_get_logits(ctx) @@ -186,7 +186,7 @@ _lib.llama_get_logits.restype = POINTER(c_float) # Get the embeddings for the input # shape: [n_embd] (1-dimensional) -def llama_get_embeddings(ctx: llama_context_p): +def llama_get_embeddings(ctx: llama_context_p) -> ctypes.Array[c_float]: return _lib.llama_get_embeddings(ctx) @@ -224,7 +224,7 @@ _lib.llama_token_eos.restype = llama_token # TODO: improve the last_n_tokens interface ? def llama_sample_top_p_top_k( ctx: llama_context_p, - last_n_tokens_data: llama_token_p, + last_n_tokens_data: ctypes.Array[llama_token], last_n_tokens_size: c_int, top_k: c_int, top_p: c_float,