feat: Update llama.cpp

This commit is contained in:
Andrei Betlen 2024-03-08 20:58:50 -05:00
parent 93dc56ace8
commit 40c6b54f68
2 changed files with 13 additions and 5 deletions

View file

@ -429,10 +429,12 @@ class llama_batch(ctypes.Structure):
The provided arrays (i.e. token, embd, pos, etc.) must have size of n_tokens The provided arrays (i.e. token, embd, pos, etc.) must have size of n_tokens
Attributes: Attributes:
n_tokens (int): number of tokens
token (ctypes.Array[llama_token]): the token ids of the input (used when embd is NULL) token (ctypes.Array[llama_token]): the token ids of the input (used when embd is NULL)
embd (ctypes.Array[ctypes.ctypes.c_float]): token embeddings (i.e. float vector of size n_embd) (used when token is NULL) embd (ctypes.Array[ctypes.ctypes.c_float]): token embeddings (i.e. float vector of size n_embd) (used when token is NULL)
pos (ctypes.Array[ctypes.Array[llama_pos]]): the positions of the respective token in the sequence pos (ctypes.Array[ctypes.Array[llama_pos]]): the positions of the respective token in the sequence
seq_id (ctypes.Array[ctypes.Array[llama_seq_id]]): the sequence to which the respective token belongs seq_id (ctypes.Array[ctypes.Array[llama_seq_id]]): the sequence to which the respective token belongs
logits (ctypes.Array[ctypes.ctypes.c_int8]): if zero, the logits for the respective token will not be output
""" """
_fields_ = [ _fields_ = [
@ -547,6 +549,7 @@ class llama_model_params(ctypes.Structure):
# uint32_t seed; // RNG seed, -1 for random # uint32_t seed; // RNG seed, -1 for random
# uint32_t n_ctx; // text context, 0 = from model # uint32_t n_ctx; // text context, 0 = from model
# uint32_t n_batch; // prompt processing maximum batch size # uint32_t n_batch; // prompt processing maximum batch size
# uint32_t n_parallel; // number of parallel sequences (i.e. distinct states for recurrent models)
# uint32_t n_threads; // number of threads to use for generation # uint32_t n_threads; // number of threads to use for generation
# uint32_t n_threads_batch; // number of threads to use for batch processing # uint32_t n_threads_batch; // number of threads to use for batch processing
@ -588,6 +591,7 @@ class llama_context_params(ctypes.Structure):
seed (int): RNG seed, -1 for random seed (int): RNG seed, -1 for random
n_ctx (int): text context, 0 = from model n_ctx (int): text context, 0 = from model
n_batch (int): prompt processing maximum batch size n_batch (int): prompt processing maximum batch size
n_parallel (int): number of parallel sequences (i.e. distinct states for recurrent models)
n_threads (int): number of threads to use for generation n_threads (int): number of threads to use for generation
n_threads_batch (int): number of threads to use for batch processing n_threads_batch (int): number of threads to use for batch processing
rope_scaling_type (int): RoPE scaling type, from `enum llama_rope_scaling_type` rope_scaling_type (int): RoPE scaling type, from `enum llama_rope_scaling_type`
@ -615,6 +619,7 @@ class llama_context_params(ctypes.Structure):
("seed", ctypes.c_uint32), ("seed", ctypes.c_uint32),
("n_ctx", ctypes.c_uint32), ("n_ctx", ctypes.c_uint32),
("n_batch", ctypes.c_uint32), ("n_batch", ctypes.c_uint32),
("n_parallel", ctypes.c_uint32),
("n_threads", ctypes.c_uint32), ("n_threads", ctypes.c_uint32),
("n_threads_batch", ctypes.c_uint32), ("n_threads_batch", ctypes.c_uint32),
("rope_scaling_type", ctypes.c_int), ("rope_scaling_type", ctypes.c_int),
@ -1322,7 +1327,7 @@ def llama_kv_cache_clear(ctx: llama_context_p, /):
# // seq_id < 0 : match any sequence # // seq_id < 0 : match any sequence
# // p0 < 0 : [0, p1] # // p0 < 0 : [0, p1]
# // p1 < 0 : [p0, inf) # // p1 < 0 : [p0, inf)
# LLAMA_API void llama_kv_cache_seq_rm( # LLAMA_API bool llama_kv_cache_seq_rm(
# struct llama_context * ctx, # struct llama_context * ctx,
# llama_seq_id seq_id, # llama_seq_id seq_id,
# llama_pos p0, # llama_pos p0,
@ -1335,7 +1340,7 @@ def llama_kv_cache_clear(ctx: llama_context_p, /):
llama_pos, llama_pos,
llama_pos, llama_pos,
], ],
None, ctypes.c_bool,
) )
def llama_kv_cache_seq_rm( def llama_kv_cache_seq_rm(
ctx: llama_context_p, ctx: llama_context_p,
@ -1343,7 +1348,7 @@ def llama_kv_cache_seq_rm(
p0: Union[llama_pos, int], p0: Union[llama_pos, int],
p1: Union[llama_pos, int], p1: Union[llama_pos, int],
/, /,
): ) -> bool:
"""Removes all tokens that belong to the specified sequence and have positions in [p0, p1) """Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
seq_id < 0 : match any sequence seq_id < 0 : match any sequence
p0 < 0 : [0, p1] p0 < 0 : [0, p1]
@ -1754,7 +1759,10 @@ def llama_get_logits(ctx: llama_context_p, /) -> CtypesArray[ctypes.c_float]:
The logits for the last token are stored in the last row The logits for the last token are stored in the last row
Logits for which llama_batch.logits[i] == 0 are undefined Logits for which llama_batch.logits[i] == 0 are undefined
Rows: n_tokens provided with llama_batch Rows: n_tokens provided with llama_batch
Cols: n_vocab""" Cols: n_vocab
Returns:
Pointer to the logits buffer of shape (n_tokens, n_vocab)"""
... ...

2
vendor/llama.cpp vendored

@ -1 +1 @@
Subproject commit 8ced9f7e3225adb8501e9821ed1bbd92e3a5c7ae Subproject commit c2101a2e909ac7c08976d414e64e96c90ee5fa9e