Fix: types

This commit is contained in:
Andrei Betlen 2023-05-05 14:04:12 -04:00
parent 66e28eb548
commit 40501435c1

View file

@ -141,6 +141,11 @@ LLAMA_FTYPE_MOSTLY_Q8_0 = ctypes.c_int(7) # except 1d tensors
LLAMA_FTYPE_MOSTLY_Q5_0 = ctypes.c_int(8) # except 1d tensors LLAMA_FTYPE_MOSTLY_Q5_0 = ctypes.c_int(8) # except 1d tensors
LLAMA_FTYPE_MOSTLY_Q5_1 = ctypes.c_int(9) # except 1d tensors LLAMA_FTYPE_MOSTLY_Q5_1 = ctypes.c_int(9) # except 1d tensors
# Misc
c_float_p = POINTER(c_float)
c_uint8_p = POINTER(c_uint8)
c_size_t_p = POINTER(c_size_t)
# Functions # Functions
@ -257,7 +262,7 @@ def llama_copy_state_data(ctx: llama_context_p, dest: Array[c_uint8]) -> c_size_
return _lib.llama_copy_state_data(ctx, dest) return _lib.llama_copy_state_data(ctx, dest)
_lib.llama_copy_state_data.argtypes = [llama_context_p, POINTER(c_uint8)] _lib.llama_copy_state_data.argtypes = [llama_context_p, c_uint8_p]
_lib.llama_copy_state_data.restype = c_size_t _lib.llama_copy_state_data.restype = c_size_t
@ -269,7 +274,7 @@ def llama_set_state_data(
return _lib.llama_set_state_data(ctx, src) return _lib.llama_set_state_data(ctx, src)
_lib.llama_set_state_data.argtypes = [llama_context_p, POINTER(c_uint8)] _lib.llama_set_state_data.argtypes = [llama_context_p, c_uint8_p]
_lib.llama_set_state_data.restype = c_size_t _lib.llama_set_state_data.restype = c_size_t
@ -291,7 +296,7 @@ _lib.llama_load_session_file.argtypes = [
c_char_p, c_char_p,
llama_token_p, llama_token_p,
c_size_t, c_size_t,
POINTER(c_size_t), c_size_t_p,
] ]
_lib.llama_load_session_file.restype = c_size_t _lib.llama_load_session_file.restype = c_size_t
@ -340,7 +345,7 @@ _lib.llama_eval.restype = c_int
def llama_tokenize( def llama_tokenize(
ctx: llama_context_p, ctx: llama_context_p,
text: bytes, text: bytes,
tokens, # type: Array[llama_token] tokens: Array[llama_token],
n_max_tokens: c_int, n_max_tokens: c_int,
add_bos: c_bool, add_bos: c_bool,
) -> c_int: ) -> c_int:
@ -385,7 +390,7 @@ def llama_get_logits(ctx: llama_context_p):
_lib.llama_get_logits.argtypes = [llama_context_p] _lib.llama_get_logits.argtypes = [llama_context_p]
_lib.llama_get_logits.restype = POINTER(c_float) _lib.llama_get_logits.restype = c_float_p
# Get the embeddings for the input # Get the embeddings for the input
@ -395,7 +400,7 @@ def llama_get_embeddings(ctx: llama_context_p):
_lib.llama_get_embeddings.argtypes = [llama_context_p] _lib.llama_get_embeddings.argtypes = [llama_context_p]
_lib.llama_get_embeddings.restype = POINTER(c_float) _lib.llama_get_embeddings.restype = c_float_p
# Token Id -> String. Uses the vocabulary in the provided context # Token Id -> String. Uses the vocabulary in the provided context
@ -614,7 +619,7 @@ _lib.llama_sample_token_mirostat.argtypes = [
c_float, c_float,
c_float, c_float,
c_int, c_int,
POINTER(c_float), c_float_p,
] ]
_lib.llama_sample_token_mirostat.restype = llama_token _lib.llama_sample_token_mirostat.restype = llama_token
@ -639,7 +644,7 @@ _lib.llama_sample_token_mirostat_v2.argtypes = [
llama_token_data_array_p, llama_token_data_array_p,
c_float, c_float,
c_float, c_float,
POINTER(c_float), c_float_p,
] ]
_lib.llama_sample_token_mirostat_v2.restype = llama_token _lib.llama_sample_token_mirostat_v2.restype = llama_token