diff --git a/llama_cpp/llama_cpp.py b/llama_cpp/llama_cpp.py index 61b40f8..8ce3c89 100644 --- a/llama_cpp/llama_cpp.py +++ b/llama_cpp/llama_cpp.py @@ -141,6 +141,11 @@ LLAMA_FTYPE_MOSTLY_Q8_0 = ctypes.c_int(7) # except 1d tensors LLAMA_FTYPE_MOSTLY_Q5_0 = ctypes.c_int(8) # except 1d tensors LLAMA_FTYPE_MOSTLY_Q5_1 = ctypes.c_int(9) # except 1d tensors +# Misc +c_float_p = POINTER(c_float) +c_uint8_p = POINTER(c_uint8) +c_size_t_p = POINTER(c_size_t) + # Functions @@ -257,7 +262,7 @@ def llama_copy_state_data(ctx: llama_context_p, dest: Array[c_uint8]) -> c_size_ return _lib.llama_copy_state_data(ctx, dest) -_lib.llama_copy_state_data.argtypes = [llama_context_p, POINTER(c_uint8)] +_lib.llama_copy_state_data.argtypes = [llama_context_p, c_uint8_p] _lib.llama_copy_state_data.restype = c_size_t @@ -269,7 +274,7 @@ def llama_set_state_data( return _lib.llama_set_state_data(ctx, src) -_lib.llama_set_state_data.argtypes = [llama_context_p, POINTER(c_uint8)] +_lib.llama_set_state_data.argtypes = [llama_context_p, c_uint8_p] _lib.llama_set_state_data.restype = c_size_t @@ -291,7 +296,7 @@ _lib.llama_load_session_file.argtypes = [ c_char_p, llama_token_p, c_size_t, - POINTER(c_size_t), + c_size_t_p, ] _lib.llama_load_session_file.restype = c_size_t @@ -340,7 +345,7 @@ _lib.llama_eval.restype = c_int def llama_tokenize( ctx: llama_context_p, text: bytes, - tokens, # type: Array[llama_token] + tokens: Array[llama_token], n_max_tokens: c_int, add_bos: c_bool, ) -> c_int: @@ -385,7 +390,7 @@ def llama_get_logits(ctx: llama_context_p): _lib.llama_get_logits.argtypes = [llama_context_p] -_lib.llama_get_logits.restype = POINTER(c_float) +_lib.llama_get_logits.restype = c_float_p # Get the embeddings for the input @@ -395,7 +400,7 @@ def llama_get_embeddings(ctx: llama_context_p): _lib.llama_get_embeddings.argtypes = [llama_context_p] -_lib.llama_get_embeddings.restype = POINTER(c_float) +_lib.llama_get_embeddings.restype = c_float_p # Token Id -> String. Uses the vocabulary in the provided context @@ -614,7 +619,7 @@ _lib.llama_sample_token_mirostat.argtypes = [ c_float, c_float, c_int, - POINTER(c_float), + c_float_p, ] _lib.llama_sample_token_mirostat.restype = llama_token @@ -639,7 +644,7 @@ _lib.llama_sample_token_mirostat_v2.argtypes = [ llama_token_data_array_p, c_float, c_float, - POINTER(c_float), + c_float_p, ] _lib.llama_sample_token_mirostat_v2.restype = llama_token