diff --git a/llama_cpp/llama_cpp.py b/llama_cpp/llama_cpp.py index 92b9676..0176e49 100644 --- a/llama_cpp/llama_cpp.py +++ b/llama_cpp/llama_cpp.py @@ -429,10 +429,12 @@ class llama_batch(ctypes.Structure): The provided arrays (i.e. token, embd, pos, etc.) must have size of n_tokens Attributes: + n_tokens (int): number of tokens token (ctypes.Array[llama_token]): the token ids of the input (used when embd is NULL) embd (ctypes.Array[ctypes.ctypes.c_float]): token embeddings (i.e. float vector of size n_embd) (used when token is NULL) pos (ctypes.Array[ctypes.Array[llama_pos]]): the positions of the respective token in the sequence seq_id (ctypes.Array[ctypes.Array[llama_seq_id]]): the sequence to which the respective token belongs + logits (ctypes.Array[ctypes.ctypes.c_int8]): if zero, the logits for the respective token will not be output """ _fields_ = [ @@ -547,6 +549,7 @@ class llama_model_params(ctypes.Structure): # uint32_t seed; // RNG seed, -1 for random # uint32_t n_ctx; // text context, 0 = from model # uint32_t n_batch; // prompt processing maximum batch size +# uint32_t n_parallel; // number of parallel sequences (i.e. distinct states for recurrent models) # uint32_t n_threads; // number of threads to use for generation # uint32_t n_threads_batch; // number of threads to use for batch processing @@ -588,6 +591,7 @@ class llama_context_params(ctypes.Structure): seed (int): RNG seed, -1 for random n_ctx (int): text context, 0 = from model n_batch (int): prompt processing maximum batch size + n_parallel (int): number of parallel sequences (i.e. distinct states for recurrent models) n_threads (int): number of threads to use for generation n_threads_batch (int): number of threads to use for batch processing rope_scaling_type (int): RoPE scaling type, from `enum llama_rope_scaling_type` @@ -615,6 +619,7 @@ class llama_context_params(ctypes.Structure): ("seed", ctypes.c_uint32), ("n_ctx", ctypes.c_uint32), ("n_batch", ctypes.c_uint32), + ("n_parallel", ctypes.c_uint32), ("n_threads", ctypes.c_uint32), ("n_threads_batch", ctypes.c_uint32), ("rope_scaling_type", ctypes.c_int), @@ -1322,7 +1327,7 @@ def llama_kv_cache_clear(ctx: llama_context_p, /): # // seq_id < 0 : match any sequence # // p0 < 0 : [0, p1] # // p1 < 0 : [p0, inf) -# LLAMA_API void llama_kv_cache_seq_rm( +# LLAMA_API bool llama_kv_cache_seq_rm( # struct llama_context * ctx, # llama_seq_id seq_id, # llama_pos p0, @@ -1335,7 +1340,7 @@ def llama_kv_cache_clear(ctx: llama_context_p, /): llama_pos, llama_pos, ], - None, + ctypes.c_bool, ) def llama_kv_cache_seq_rm( ctx: llama_context_p, @@ -1343,7 +1348,7 @@ def llama_kv_cache_seq_rm( p0: Union[llama_pos, int], p1: Union[llama_pos, int], /, -): +) -> bool: """Removes all tokens that belong to the specified sequence and have positions in [p0, p1) seq_id < 0 : match any sequence p0 < 0 : [0, p1] @@ -1754,7 +1759,10 @@ def llama_get_logits(ctx: llama_context_p, /) -> CtypesArray[ctypes.c_float]: The logits for the last token are stored in the last row Logits for which llama_batch.logits[i] == 0 are undefined Rows: n_tokens provided with llama_batch - Cols: n_vocab""" + Cols: n_vocab + + Returns: + Pointer to the logits buffer of shape (n_tokens, n_vocab)""" ... diff --git a/vendor/llama.cpp b/vendor/llama.cpp index 8ced9f7..c2101a2 160000 --- a/vendor/llama.cpp +++ b/vendor/llama.cpp @@ -1 +1 @@ -Subproject commit 8ced9f7e3225adb8501e9821ed1bbd92e3a5c7ae +Subproject commit c2101a2e909ac7c08976d414e64e96c90ee5fa9e