Fix: types
This commit is contained in:
parent
66e28eb548
commit
40501435c1
1 changed files with 13 additions and 8 deletions
|
@ -141,6 +141,11 @@ LLAMA_FTYPE_MOSTLY_Q8_0 = ctypes.c_int(7) # except 1d tensors
|
|||
LLAMA_FTYPE_MOSTLY_Q5_0 = ctypes.c_int(8) # except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q5_1 = ctypes.c_int(9) # except 1d tensors
|
||||
|
||||
# Misc
|
||||
c_float_p = POINTER(c_float)
|
||||
c_uint8_p = POINTER(c_uint8)
|
||||
c_size_t_p = POINTER(c_size_t)
|
||||
|
||||
# Functions
|
||||
|
||||
|
||||
|
@ -257,7 +262,7 @@ def llama_copy_state_data(ctx: llama_context_p, dest: Array[c_uint8]) -> c_size_
|
|||
return _lib.llama_copy_state_data(ctx, dest)
|
||||
|
||||
|
||||
_lib.llama_copy_state_data.argtypes = [llama_context_p, POINTER(c_uint8)]
|
||||
_lib.llama_copy_state_data.argtypes = [llama_context_p, c_uint8_p]
|
||||
_lib.llama_copy_state_data.restype = c_size_t
|
||||
|
||||
|
||||
|
@ -269,7 +274,7 @@ def llama_set_state_data(
|
|||
return _lib.llama_set_state_data(ctx, src)
|
||||
|
||||
|
||||
_lib.llama_set_state_data.argtypes = [llama_context_p, POINTER(c_uint8)]
|
||||
_lib.llama_set_state_data.argtypes = [llama_context_p, c_uint8_p]
|
||||
_lib.llama_set_state_data.restype = c_size_t
|
||||
|
||||
|
||||
|
@ -291,7 +296,7 @@ _lib.llama_load_session_file.argtypes = [
|
|||
c_char_p,
|
||||
llama_token_p,
|
||||
c_size_t,
|
||||
POINTER(c_size_t),
|
||||
c_size_t_p,
|
||||
]
|
||||
_lib.llama_load_session_file.restype = c_size_t
|
||||
|
||||
|
@ -340,7 +345,7 @@ _lib.llama_eval.restype = c_int
|
|||
def llama_tokenize(
|
||||
ctx: llama_context_p,
|
||||
text: bytes,
|
||||
tokens, # type: Array[llama_token]
|
||||
tokens: Array[llama_token],
|
||||
n_max_tokens: c_int,
|
||||
add_bos: c_bool,
|
||||
) -> c_int:
|
||||
|
@ -385,7 +390,7 @@ def llama_get_logits(ctx: llama_context_p):
|
|||
|
||||
|
||||
_lib.llama_get_logits.argtypes = [llama_context_p]
|
||||
_lib.llama_get_logits.restype = POINTER(c_float)
|
||||
_lib.llama_get_logits.restype = c_float_p
|
||||
|
||||
|
||||
# Get the embeddings for the input
|
||||
|
@ -395,7 +400,7 @@ def llama_get_embeddings(ctx: llama_context_p):
|
|||
|
||||
|
||||
_lib.llama_get_embeddings.argtypes = [llama_context_p]
|
||||
_lib.llama_get_embeddings.restype = POINTER(c_float)
|
||||
_lib.llama_get_embeddings.restype = c_float_p
|
||||
|
||||
|
||||
# Token Id -> String. Uses the vocabulary in the provided context
|
||||
|
@ -614,7 +619,7 @@ _lib.llama_sample_token_mirostat.argtypes = [
|
|||
c_float,
|
||||
c_float,
|
||||
c_int,
|
||||
POINTER(c_float),
|
||||
c_float_p,
|
||||
]
|
||||
_lib.llama_sample_token_mirostat.restype = llama_token
|
||||
|
||||
|
@ -639,7 +644,7 @@ _lib.llama_sample_token_mirostat_v2.argtypes = [
|
|||
llama_token_data_array_p,
|
||||
c_float,
|
||||
c_float,
|
||||
POINTER(c_float),
|
||||
c_float_p,
|
||||
]
|
||||
_lib.llama_sample_token_mirostat_v2.restype = llama_token
|
||||
|
||||
|
|
Loading…
Reference in a new issue