From 56071c956a947355ae63fe878448fe03df3b7586 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Tue, 9 Apr 2024 09:53:49 -0400 Subject: [PATCH] feat: Update llama.cpp --- llama_cpp/llama_cpp.py | 237 +++++++++++++++++++++++++++++++++++++++-- vendor/llama.cpp | 2 +- 2 files changed, 231 insertions(+), 8 deletions(-) diff --git a/llama_cpp/llama_cpp.py b/llama_cpp/llama_cpp.py index accc02c..8793085 100644 --- a/llama_cpp/llama_cpp.py +++ b/llama_cpp/llama_cpp.py @@ -237,11 +237,18 @@ LLAMA_FILE_MAGIC_GGLA = 0x67676C61 # define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn' LLAMA_FILE_MAGIC_GGSN = 0x6767736E +#define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq' +LLAMA_FILE_MAGIC_GGSQ = 0x67677371 + # define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN LLAMA_SESSION_MAGIC = LLAMA_FILE_MAGIC_GGSN # define LLAMA_SESSION_VERSION 5 LLAMA_SESSION_VERSION = 5 +#define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ +LLAMA_STATE_SEQ_MAGIC = LLAMA_FILE_MAGIC_GGSQ +#define LLAMA_STATE_SEQ_VERSION 1 +LLAMA_STATE_SEQ_VERSION = 1 # struct llama_model; llama_model_p = NewType("llama_model_p", int) @@ -1467,6 +1474,7 @@ def llama_kv_cache_clear(ctx: llama_context_p, /): # // Removes all tokens that belong to the specified sequence and have positions in [p0, p1) +# // Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails # // seq_id < 0 : match any sequence # // p0 < 0 : [0, p1] # // p1 < 0 : [p0, inf) @@ -1493,6 +1501,9 @@ def llama_kv_cache_seq_rm( /, ) -> bool: """Removes all tokens that belong to the specified sequence and have positions in [p0, p1) + + Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails + seq_id < 0 : match any sequence p0 < 0 : [0, p1] p1 < 0 : [p0, inf)""" @@ -1652,7 +1663,16 @@ def llama_kv_cache_update(ctx: llama_context_p, /): # Returns the maximum size in bytes of the state (rng, logits, embedding # and kv_cache) - will often be smaller after compacting tokens -# LLAMA_API size_t llama_get_state_size(const struct llama_context * ctx); +# LLAMA_API size_t llama_state_get_size(const struct llama_context * ctx); +@ctypes_function("llama_state_get_size", [llama_context_p_ctypes], ctypes.c_size_t) +def llama_state_get_size(ctx: llama_context_p, /) -> int: + """Returns the maximum size in bytes of the state (rng, logits, embedding + and kv_cache) - will often be smaller after compacting tokens""" + ... + + +# LLAMA_API DEPRECATED(size_t llama_get_state_size(const struct llama_context * ctx), +# "use llama_state_get_size instead"); @ctypes_function("llama_get_state_size", [llama_context_p_ctypes], ctypes.c_size_t) def llama_get_state_size(ctx: llama_context_p, /) -> int: """Returns the maximum size in bytes of the state (rng, logits, embedding @@ -1663,9 +1683,30 @@ def llama_get_state_size(ctx: llama_context_p, /) -> int: # Copies the state to the specified destination address. # Destination needs to have allocated enough memory. # Returns the number of bytes copied -# LLAMA_API size_t llama_copy_state_data( +# LLAMA_API size_t llama_state_get_data( # struct llama_context * ctx, # uint8_t * dst); +@ctypes_function( + "llama_state_get_data", + [ + llama_context_p_ctypes, + ctypes.POINTER(ctypes.c_uint8), + ], + ctypes.c_size_t, +) +def llama_state_get_data( + ctx: llama_context_p, dst: CtypesArray[ctypes.c_uint8], / +) -> int: + """Copies the state to the specified destination address. + Destination needs to have allocated enough memory. + Returns the number of bytes copied""" + ... + + +# LLAMA_API DEPRECATED(size_t llama_copy_state_data( +# struct llama_context * ctx, +# uint8_t * dst), +# "use llama_state_get_data instead"); @ctypes_function( "llama_copy_state_data", [ @@ -1685,9 +1726,26 @@ def llama_copy_state_data( # // Set the state reading from the specified address # // Returns the number of bytes read -# LLAMA_API size_t llama_set_state_data( +# LLAMA_API size_t llama_state_set_data( # struct llama_context * ctx, # const uint8_t * src); +@ctypes_function( + "llama_state_set_data", + [llama_context_p_ctypes, ctypes.POINTER(ctypes.c_uint8)], + ctypes.c_size_t, +) +def llama_state_set_data( + ctx: llama_context_p, src: CtypesArray[ctypes.c_uint8], / +) -> int: + """Set the state reading from the specified address + Returns the number of bytes read""" + ... + + +# LLAMA_API DEPRECATED(size_t llama_set_state_data( +# struct llama_context * ctx, +# const uint8_t * src), +# "use llama_state_set_data instead"); @ctypes_function( "llama_set_state_data", [llama_context_p_ctypes, ctypes.POINTER(ctypes.c_uint8)], @@ -1701,12 +1759,40 @@ def llama_set_state_data( # Save/load session file -# LLAMA_API bool llama_load_session_file( +# LLAMA_API bool llama_state_load_file( # struct llama_context * ctx, # const char * path_session, # llama_token * tokens_out, # size_t n_token_capacity, # size_t * n_token_count_out); +@ctypes_function( + "llama_state_load_file", + [ + llama_context_p_ctypes, + ctypes.c_char_p, + llama_token_p, + ctypes.c_size_t, + ctypes.POINTER(ctypes.c_size_t), + ], + ctypes.c_bool, +) +def llama_state_load_file( + ctx: llama_context_p, + path_session: bytes, + tokens_out: CtypesArray[llama_token], + n_token_capacity: Union[ctypes.c_size_t, int], + n_token_count_out: CtypesPointerOrRef[ctypes.c_size_t], + /, +) -> bool: ... + + +# LLAMA_API DEPRECATED(bool llama_load_session_file( +# struct llama_context * ctx, +# const char * path_session, +# llama_token * tokens_out, +# size_t n_token_capacity, +# size_t * n_token_count_out), +# "use llama_state_load_file instead"); @ctypes_function( "llama_load_session_file", [ @@ -1728,11 +1814,36 @@ def llama_load_session_file( ) -> int: ... -# LLAMA_API bool llama_save_session_file( +# LLAMA_API bool llama_state_save_file( # struct llama_context * ctx, # const char * path_session, # const llama_token * tokens, # size_t n_token_count); +@ctypes_function( + "llama_state_save_file", + [ + llama_context_p_ctypes, + ctypes.c_char_p, + llama_token_p, + ctypes.c_size_t, + ], + ctypes.c_bool, +) +def llama_state_save_file( + ctx: llama_context_p, + path_session: bytes, + tokens: CtypesArray[llama_token], + n_token_count: Union[ctypes.c_size_t, int], + /, +) -> bool: ... + + +# LLAMA_API DEPRECATED(bool llama_save_session_file( +# struct llama_context * ctx, +# const char * path_session, +# const llama_token * tokens, +# size_t n_token_count), +# "use llama_state_save_file instead"); @ctypes_function( "llama_save_session_file", [ @@ -1752,6 +1863,116 @@ def llama_save_session_file( ) -> int: ... +# // Get the exact size needed to copy the KV cache of a single sequence +# LLAMA_API size_t llama_state_seq_get_size( +# struct llama_context * ctx, +# llama_seq_id seq_id); +@ctypes_function( + "llama_state_seq_get_size", + [llama_context_p_ctypes, llama_seq_id], + ctypes.c_size_t, +) +def llama_state_seq_get_size(ctx: llama_context_p, seq_id: llama_seq_id, /) -> int: + """Get the exact size needed to copy the KV cache of a single sequence""" + ... + + +# // Copy the KV cache of a single sequence into the specified buffer +# LLAMA_API size_t llama_state_seq_get_data( +# struct llama_context * ctx, +# uint8_t * dst, +# llama_seq_id seq_id); +@ctypes_function( + "llama_state_seq_get_data", + [llama_context_p_ctypes, ctypes.POINTER(ctypes.c_uint8), llama_seq_id], + ctypes.c_size_t, +) +def llama_state_seq_get_data( + ctx: llama_context_p, dst: CtypesArray[ctypes.c_uint8], seq_id: llama_seq_id, / +) -> int: + """Copy the KV cache of a single sequence into the specified buffer""" + ... + + +# // Copy the sequence data (originally copied with `llama_state_seq_get_data`) into the specified sequence +# // Returns: +# // - Positive: Ok +# // - Zero: Failed to load +# LLAMA_API size_t llama_state_seq_set_data( +# struct llama_context * ctx, +# const uint8_t * src, +# llama_seq_id dest_seq_id); +@ctypes_function( + "llama_state_seq_set_data", + [llama_context_p_ctypes, ctypes.POINTER(ctypes.c_uint8), llama_seq_id], + ctypes.c_size_t, +) +def llama_state_seq_set_data( + ctx: llama_context_p, src: CtypesArray[ctypes.c_uint8], dest_seq_id: llama_seq_id, / +) -> int: + """Copy the sequence data (originally copied with `llama_state_seq_get_data`) into the specified sequence""" + ... + + +# LLAMA_API size_t llama_state_seq_save_file( +# struct llama_context * ctx, +# const char * filepath, +# llama_seq_id seq_id, +# const llama_token * tokens, +# size_t n_token_count); +@ctypes_function( + "llama_state_seq_save_file", + [ + llama_context_p_ctypes, + ctypes.c_char_p, + llama_seq_id, + llama_token_p, + ctypes.c_size_t, + ], + ctypes.c_size_t, +) +def llama_state_seq_save_file( + ctx: llama_context_p, + filepath: bytes, + seq_id: llama_seq_id, + tokens: CtypesArray[llama_token], + n_token_count: Union[ctypes.c_size_t, int], + /, +) -> int: + ... + + +# LLAMA_API size_t llama_state_seq_load_file( +# struct llama_context * ctx, +# const char * filepath, +# llama_seq_id dest_seq_id, +# llama_token * tokens_out, +# size_t n_token_capacity, +# size_t * n_token_count_out); +@ctypes_function( + "llama_state_seq_load_file", + [ + llama_context_p_ctypes, + ctypes.c_char_p, + llama_seq_id, + llama_token_p, + ctypes.c_size_t, + ctypes.POINTER(ctypes.c_size_t), + ], + ctypes.c_size_t, +) +def llama_state_seq_load_file( + ctx: llama_context_p, + filepath: bytes, + dest_seq_id: llama_seq_id, + tokens_out: CtypesArray[llama_token], + n_token_capacity: Union[ctypes.c_size_t, int], + n_token_count_out: CtypesPointerOrRef[ctypes.c_size_t], + /, +) -> int: + ... + + # // # // Decoding # // @@ -1930,8 +2151,9 @@ def llama_get_logits(ctx: llama_context_p, /) -> CtypesArray[ctypes.c_float]: ... -# // Logits for the ith token. Equivalent to: +# // Logits for the ith token. For positive indices, Equivalent to: # // llama_get_logits(ctx) + ctx->output_ids[i]*n_vocab +# // Negative indicies can be used to access logits in reverse order, -1 is the last logit. # // returns NULL for invalid ids. # LLAMA_API float * llama_get_logits_ith(struct llama_context * ctx, int32_t i); @ctypes_function( @@ -1963,8 +2185,9 @@ def llama_get_embeddings(ctx: llama_context_p, /) -> CtypesArray[ctypes.c_float] ... -# // Get the embeddings for the ith token. Equivalent to: +# // Get the embeddings for the ith token. For positive indices, Equivalent to: # // llama_get_embeddings(ctx) + ctx->output_ids[i]*n_embd +# // Negative indicies can be used to access embeddings in reverse order, -1 is the last embedding. # // shape: [n_embd] (1-dimensional) # // returns NULL for invalid ids. # LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i); diff --git a/vendor/llama.cpp b/vendor/llama.cpp index 75cd4c7..400d5d7 160000 --- a/vendor/llama.cpp +++ b/vendor/llama.cpp @@ -1 +1 @@ -Subproject commit 75cd4c77292034ecec587ecb401366f57338f7c0 +Subproject commit 400d5d722d7edf7de0cf24a18c42b183c65047d2