From 56071c956a947355ae63fe878448fe03df3b7586 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Tue, 9 Apr 2024 09:53:49 -0400 Subject: [PATCH 01/29] feat: Update llama.cpp --- llama_cpp/llama_cpp.py | 237 +++++++++++++++++++++++++++++++++++++++-- vendor/llama.cpp | 2 +- 2 files changed, 231 insertions(+), 8 deletions(-) diff --git a/llama_cpp/llama_cpp.py b/llama_cpp/llama_cpp.py index accc02c..8793085 100644 --- a/llama_cpp/llama_cpp.py +++ b/llama_cpp/llama_cpp.py @@ -237,11 +237,18 @@ LLAMA_FILE_MAGIC_GGLA = 0x67676C61 # define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn' LLAMA_FILE_MAGIC_GGSN = 0x6767736E +#define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq' +LLAMA_FILE_MAGIC_GGSQ = 0x67677371 + # define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN LLAMA_SESSION_MAGIC = LLAMA_FILE_MAGIC_GGSN # define LLAMA_SESSION_VERSION 5 LLAMA_SESSION_VERSION = 5 +#define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ +LLAMA_STATE_SEQ_MAGIC = LLAMA_FILE_MAGIC_GGSQ +#define LLAMA_STATE_SEQ_VERSION 1 +LLAMA_STATE_SEQ_VERSION = 1 # struct llama_model; llama_model_p = NewType("llama_model_p", int) @@ -1467,6 +1474,7 @@ def llama_kv_cache_clear(ctx: llama_context_p, /): # // Removes all tokens that belong to the specified sequence and have positions in [p0, p1) +# // Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails # // seq_id < 0 : match any sequence # // p0 < 0 : [0, p1] # // p1 < 0 : [p0, inf) @@ -1493,6 +1501,9 @@ def llama_kv_cache_seq_rm( /, ) -> bool: """Removes all tokens that belong to the specified sequence and have positions in [p0, p1) + + Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails + seq_id < 0 : match any sequence p0 < 0 : [0, p1] p1 < 0 : [p0, inf)""" @@ -1652,7 +1663,16 @@ def llama_kv_cache_update(ctx: llama_context_p, /): # Returns the maximum size in bytes of the state (rng, logits, embedding # and kv_cache) - will often be smaller after compacting tokens -# LLAMA_API size_t llama_get_state_size(const struct llama_context * ctx); +# LLAMA_API size_t llama_state_get_size(const struct llama_context * ctx); +@ctypes_function("llama_state_get_size", [llama_context_p_ctypes], ctypes.c_size_t) +def llama_state_get_size(ctx: llama_context_p, /) -> int: + """Returns the maximum size in bytes of the state (rng, logits, embedding + and kv_cache) - will often be smaller after compacting tokens""" + ... + + +# LLAMA_API DEPRECATED(size_t llama_get_state_size(const struct llama_context * ctx), +# "use llama_state_get_size instead"); @ctypes_function("llama_get_state_size", [llama_context_p_ctypes], ctypes.c_size_t) def llama_get_state_size(ctx: llama_context_p, /) -> int: """Returns the maximum size in bytes of the state (rng, logits, embedding @@ -1663,9 +1683,30 @@ def llama_get_state_size(ctx: llama_context_p, /) -> int: # Copies the state to the specified destination address. # Destination needs to have allocated enough memory. # Returns the number of bytes copied -# LLAMA_API size_t llama_copy_state_data( +# LLAMA_API size_t llama_state_get_data( # struct llama_context * ctx, # uint8_t * dst); +@ctypes_function( + "llama_state_get_data", + [ + llama_context_p_ctypes, + ctypes.POINTER(ctypes.c_uint8), + ], + ctypes.c_size_t, +) +def llama_state_get_data( + ctx: llama_context_p, dst: CtypesArray[ctypes.c_uint8], / +) -> int: + """Copies the state to the specified destination address. + Destination needs to have allocated enough memory. + Returns the number of bytes copied""" + ... + + +# LLAMA_API DEPRECATED(size_t llama_copy_state_data( +# struct llama_context * ctx, +# uint8_t * dst), +# "use llama_state_get_data instead"); @ctypes_function( "llama_copy_state_data", [ @@ -1685,9 +1726,26 @@ def llama_copy_state_data( # // Set the state reading from the specified address # // Returns the number of bytes read -# LLAMA_API size_t llama_set_state_data( +# LLAMA_API size_t llama_state_set_data( # struct llama_context * ctx, # const uint8_t * src); +@ctypes_function( + "llama_state_set_data", + [llama_context_p_ctypes, ctypes.POINTER(ctypes.c_uint8)], + ctypes.c_size_t, +) +def llama_state_set_data( + ctx: llama_context_p, src: CtypesArray[ctypes.c_uint8], / +) -> int: + """Set the state reading from the specified address + Returns the number of bytes read""" + ... + + +# LLAMA_API DEPRECATED(size_t llama_set_state_data( +# struct llama_context * ctx, +# const uint8_t * src), +# "use llama_state_set_data instead"); @ctypes_function( "llama_set_state_data", [llama_context_p_ctypes, ctypes.POINTER(ctypes.c_uint8)], @@ -1701,12 +1759,40 @@ def llama_set_state_data( # Save/load session file -# LLAMA_API bool llama_load_session_file( +# LLAMA_API bool llama_state_load_file( # struct llama_context * ctx, # const char * path_session, # llama_token * tokens_out, # size_t n_token_capacity, # size_t * n_token_count_out); +@ctypes_function( + "llama_state_load_file", + [ + llama_context_p_ctypes, + ctypes.c_char_p, + llama_token_p, + ctypes.c_size_t, + ctypes.POINTER(ctypes.c_size_t), + ], + ctypes.c_bool, +) +def llama_state_load_file( + ctx: llama_context_p, + path_session: bytes, + tokens_out: CtypesArray[llama_token], + n_token_capacity: Union[ctypes.c_size_t, int], + n_token_count_out: CtypesPointerOrRef[ctypes.c_size_t], + /, +) -> bool: ... + + +# LLAMA_API DEPRECATED(bool llama_load_session_file( +# struct llama_context * ctx, +# const char * path_session, +# llama_token * tokens_out, +# size_t n_token_capacity, +# size_t * n_token_count_out), +# "use llama_state_load_file instead"); @ctypes_function( "llama_load_session_file", [ @@ -1728,11 +1814,36 @@ def llama_load_session_file( ) -> int: ... -# LLAMA_API bool llama_save_session_file( +# LLAMA_API bool llama_state_save_file( # struct llama_context * ctx, # const char * path_session, # const llama_token * tokens, # size_t n_token_count); +@ctypes_function( + "llama_state_save_file", + [ + llama_context_p_ctypes, + ctypes.c_char_p, + llama_token_p, + ctypes.c_size_t, + ], + ctypes.c_bool, +) +def llama_state_save_file( + ctx: llama_context_p, + path_session: bytes, + tokens: CtypesArray[llama_token], + n_token_count: Union[ctypes.c_size_t, int], + /, +) -> bool: ... + + +# LLAMA_API DEPRECATED(bool llama_save_session_file( +# struct llama_context * ctx, +# const char * path_session, +# const llama_token * tokens, +# size_t n_token_count), +# "use llama_state_save_file instead"); @ctypes_function( "llama_save_session_file", [ @@ -1752,6 +1863,116 @@ def llama_save_session_file( ) -> int: ... +# // Get the exact size needed to copy the KV cache of a single sequence +# LLAMA_API size_t llama_state_seq_get_size( +# struct llama_context * ctx, +# llama_seq_id seq_id); +@ctypes_function( + "llama_state_seq_get_size", + [llama_context_p_ctypes, llama_seq_id], + ctypes.c_size_t, +) +def llama_state_seq_get_size(ctx: llama_context_p, seq_id: llama_seq_id, /) -> int: + """Get the exact size needed to copy the KV cache of a single sequence""" + ... + + +# // Copy the KV cache of a single sequence into the specified buffer +# LLAMA_API size_t llama_state_seq_get_data( +# struct llama_context * ctx, +# uint8_t * dst, +# llama_seq_id seq_id); +@ctypes_function( + "llama_state_seq_get_data", + [llama_context_p_ctypes, ctypes.POINTER(ctypes.c_uint8), llama_seq_id], + ctypes.c_size_t, +) +def llama_state_seq_get_data( + ctx: llama_context_p, dst: CtypesArray[ctypes.c_uint8], seq_id: llama_seq_id, / +) -> int: + """Copy the KV cache of a single sequence into the specified buffer""" + ... + + +# // Copy the sequence data (originally copied with `llama_state_seq_get_data`) into the specified sequence +# // Returns: +# // - Positive: Ok +# // - Zero: Failed to load +# LLAMA_API size_t llama_state_seq_set_data( +# struct llama_context * ctx, +# const uint8_t * src, +# llama_seq_id dest_seq_id); +@ctypes_function( + "llama_state_seq_set_data", + [llama_context_p_ctypes, ctypes.POINTER(ctypes.c_uint8), llama_seq_id], + ctypes.c_size_t, +) +def llama_state_seq_set_data( + ctx: llama_context_p, src: CtypesArray[ctypes.c_uint8], dest_seq_id: llama_seq_id, / +) -> int: + """Copy the sequence data (originally copied with `llama_state_seq_get_data`) into the specified sequence""" + ... + + +# LLAMA_API size_t llama_state_seq_save_file( +# struct llama_context * ctx, +# const char * filepath, +# llama_seq_id seq_id, +# const llama_token * tokens, +# size_t n_token_count); +@ctypes_function( + "llama_state_seq_save_file", + [ + llama_context_p_ctypes, + ctypes.c_char_p, + llama_seq_id, + llama_token_p, + ctypes.c_size_t, + ], + ctypes.c_size_t, +) +def llama_state_seq_save_file( + ctx: llama_context_p, + filepath: bytes, + seq_id: llama_seq_id, + tokens: CtypesArray[llama_token], + n_token_count: Union[ctypes.c_size_t, int], + /, +) -> int: + ... + + +# LLAMA_API size_t llama_state_seq_load_file( +# struct llama_context * ctx, +# const char * filepath, +# llama_seq_id dest_seq_id, +# llama_token * tokens_out, +# size_t n_token_capacity, +# size_t * n_token_count_out); +@ctypes_function( + "llama_state_seq_load_file", + [ + llama_context_p_ctypes, + ctypes.c_char_p, + llama_seq_id, + llama_token_p, + ctypes.c_size_t, + ctypes.POINTER(ctypes.c_size_t), + ], + ctypes.c_size_t, +) +def llama_state_seq_load_file( + ctx: llama_context_p, + filepath: bytes, + dest_seq_id: llama_seq_id, + tokens_out: CtypesArray[llama_token], + n_token_capacity: Union[ctypes.c_size_t, int], + n_token_count_out: CtypesPointerOrRef[ctypes.c_size_t], + /, +) -> int: + ... + + # // # // Decoding # // @@ -1930,8 +2151,9 @@ def llama_get_logits(ctx: llama_context_p, /) -> CtypesArray[ctypes.c_float]: ... -# // Logits for the ith token. Equivalent to: +# // Logits for the ith token. For positive indices, Equivalent to: # // llama_get_logits(ctx) + ctx->output_ids[i]*n_vocab +# // Negative indicies can be used to access logits in reverse order, -1 is the last logit. # // returns NULL for invalid ids. # LLAMA_API float * llama_get_logits_ith(struct llama_context * ctx, int32_t i); @ctypes_function( @@ -1963,8 +2185,9 @@ def llama_get_embeddings(ctx: llama_context_p, /) -> CtypesArray[ctypes.c_float] ... -# // Get the embeddings for the ith token. Equivalent to: +# // Get the embeddings for the ith token. For positive indices, Equivalent to: # // llama_get_embeddings(ctx) + ctx->output_ids[i]*n_embd +# // Negative indicies can be used to access embeddings in reverse order, -1 is the last embedding. # // shape: [n_embd] (1-dimensional) # // returns NULL for invalid ids. # LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i); diff --git a/vendor/llama.cpp b/vendor/llama.cpp index 75cd4c7..400d5d7 160000 --- a/vendor/llama.cpp +++ b/vendor/llama.cpp @@ -1 +1 @@ -Subproject commit 75cd4c77292034ecec587ecb401366f57338f7c0 +Subproject commit 400d5d722d7edf7de0cf24a18c42b183c65047d2 From 889d0e8981641b603259e9a9ee40ce3b3ee8db51 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Wed, 10 Apr 2024 02:25:58 -0400 Subject: [PATCH 02/29] feat: Update llama.cpp --- llama_cpp/llama_cpp.py | 31 ++++++++++++++++++++++--------- vendor/llama.cpp | 2 +- 2 files changed, 23 insertions(+), 10 deletions(-) diff --git a/llama_cpp/llama_cpp.py b/llama_cpp/llama_cpp.py index 8793085..99ae7de 100644 --- a/llama_cpp/llama_cpp.py +++ b/llama_cpp/llama_cpp.py @@ -2271,6 +2271,20 @@ def llama_token_eos(model: llama_model_p, /) -> int: ... +# LLAMA_API llama_token llama_token_cls(const struct llama_model * model); // classification +@ctypes_function("llama_token_cls", [llama_model_p_ctypes], llama_token) +def llama_token_cls(model: llama_model_p, /) -> int: + """classification""" + ... + + +# LLAMA_API llama_token llama_token_sep(const struct llama_model * model); // sentence separator +@ctypes_function("llama_token_sep", [llama_model_p_ctypes], llama_token) +def llama_token_sep(model: llama_model_p, /) -> int: + """sentence separator""" + ... + + # LLAMA_API llama_token llama_token_nl (const struct llama_model * model); // next-line @ctypes_function("llama_token_nl", [llama_model_p_ctypes], llama_token) def llama_token_nl(model: llama_model_p, /) -> int: @@ -2326,16 +2340,16 @@ def llama_token_eot(model: llama_model_p, /) -> int: ... # /// @param tokens The tokens pointer must be large enough to hold the resulting tokens. # /// @return Returns the number of tokens on success, no more than n_tokens_max # /// @return Returns a negative number on failure - the number of tokens that would have been returned -# /// @param special Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext. -# /// Does not insert a leading space. +# /// @param parse_special Allow tokenizing special and/or control tokens which otherwise are not exposed and treated +# /// as plaintext. Does not insert a leading space. # LLAMA_API int32_t llama_tokenize( # const struct llama_model * model, # const char * text, # int32_t text_len, # llama_token * tokens, # int32_t n_tokens_max, -# bool add_bos, -# bool special); +# bool add_special, +# bool parse_special); @ctypes_function( "llama_tokenize", [ @@ -2355,8 +2369,8 @@ def llama_tokenize( text_len: Union[ctypes.c_int, int], tokens: CtypesArray[llama_token], n_tokens_max: Union[ctypes.c_int, int], - add_bos: Union[ctypes.c_bool, bool], - special: Union[ctypes.c_bool, bool], + add_special: Union[ctypes.c_bool, bool], + parse_special: Union[ctypes.c_bool, bool], /, ) -> int: """Convert the provided text into tokens. @@ -2367,9 +2381,8 @@ def llama_tokenize( text_len: The length of the text. tokens: The tokens pointer must be large enough to hold the resulting tokens. n_max_tokens: The maximum number of tokens to return. - add_bos: Whether to add a beginning-of-sentence token. - special: Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext. - Does not insert a leading space. + add_special: Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext. Does not insert a leading space. + parse_special: Allow parsing special tokens. Returns: Returns the number of tokens on success, no more than n_tokens_max diff --git a/vendor/llama.cpp b/vendor/llama.cpp index 400d5d7..ba5e134 160000 --- a/vendor/llama.cpp +++ b/vendor/llama.cpp @@ -1 +1 @@ -Subproject commit 400d5d722d7edf7de0cf24a18c42b183c65047d2 +Subproject commit ba5e134e073ec6837078c874aba44a702944a676 From 1347e1d050fc5a9a32ffe0bb3e22858da28003bd Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Wed, 10 Apr 2024 02:40:41 -0400 Subject: [PATCH 03/29] feat: Add typechecking for ctypes structure attributes --- llama_cpp/llama_cpp.py | 216 ++++++++++++++++++++++++++++++++++------- 1 file changed, 180 insertions(+), 36 deletions(-) diff --git a/llama_cpp/llama_cpp.py b/llama_cpp/llama_cpp.py index 99ae7de..2450d11 100644 --- a/llama_cpp/llama_cpp.py +++ b/llama_cpp/llama_cpp.py @@ -237,7 +237,7 @@ LLAMA_FILE_MAGIC_GGLA = 0x67676C61 # define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn' LLAMA_FILE_MAGIC_GGSN = 0x6767736E -#define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq' +# define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq' LLAMA_FILE_MAGIC_GGSQ = 0x67677371 # define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN @@ -245,9 +245,9 @@ LLAMA_SESSION_MAGIC = LLAMA_FILE_MAGIC_GGSN # define LLAMA_SESSION_VERSION 5 LLAMA_SESSION_VERSION = 5 -#define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ +# define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ LLAMA_STATE_SEQ_MAGIC = LLAMA_FILE_MAGIC_GGSQ -#define LLAMA_STATE_SEQ_VERSION 1 +# define LLAMA_STATE_SEQ_VERSION 1 LLAMA_STATE_SEQ_VERSION = 1 # struct llama_model; @@ -431,6 +431,11 @@ class llama_token_data(ctypes.Structure): logit (float): log-odds of the token p (float): probability of the token""" + if TYPE_CHECKING: + id: llama_token + logit: float + p: float + _fields_ = [ ("id", llama_token), ("logit", ctypes.c_float), @@ -454,6 +459,11 @@ class llama_token_data_array(ctypes.Structure): size (int): size of the array sorted (bool): whether the array is sorted""" + if TYPE_CHECKING: + data: CtypesArray[llama_token_data] + size: int + sorted: bool + _fields_ = [ ("data", llama_token_data_p), ("size", ctypes.c_size_t), @@ -515,6 +525,15 @@ class llama_batch(ctypes.Structure): logits (ctypes.Array[ctypes.ctypes.c_int8]): if zero, the logits for the respective token will not be output """ + if TYPE_CHECKING: + n_tokens: int + token: CtypesArray[llama_token] + embd: CtypesArray[ctypes.c_float] + pos: CtypesArray[CtypesArray[llama_pos]] + n_seq_id: CtypesArray[ctypes.c_int] + seq_id: CtypesArray[CtypesArray[llama_seq_id]] + logits: CtypesArray[ctypes.c_int8] + _fields_ = [ ("n_tokens", ctypes.c_int32), ("token", ctypes.POINTER(llama_token)), @@ -609,6 +628,18 @@ class llama_model_params(ctypes.Structure): use_mmap (bool): use mmap if possible use_mlock (bool): force system to keep model in RAM""" + if TYPE_CHECKING: + n_gpu_layers: int + split_mode: int + main_gpu: int + tensor_split: CtypesArray[ctypes.c_float] + progress_callback: Callable[[float, ctypes.c_void_p], bool] + progress_callback_user_data: ctypes.c_void_p + kv_overrides: CtypesArray[llama_model_kv_override] + vocab_only: bool + use_mmap: bool + use_mlock: bool + _fields_ = [ ("n_gpu_layers", ctypes.c_int32), ("split_mode", ctypes.c_int), @@ -696,6 +727,34 @@ class llama_context_params(ctypes.Structure): abort_callback_data (ctypes.ctypes.c_void_p): data for abort_callback """ + if TYPE_CHECKING: + seed: int + n_ctx: int + n_batch: int + n_ubatch: int + n_seq_max: int + n_threads: int + n_threads_batch: int + rope_scaling_type: int + pooling_type: int + rope_freq_base: float + rope_freq_scale: float + yarn_ext_factor: float + yarn_attn_factor: float + yarn_beta_fast: float + yarn_beta_slow: float + yarn_orig_ctx: int + defrag_thold: float + cb_eval: Callable[[ctypes.c_void_p, bool], bool] + cb_eval_user_data: ctypes.c_void_p + type_k: int + type_v: int + logits_all: bool + embeddings: bool + offload_kqv: bool + abort_callback: Callable[[ctypes.c_void_p], bool] + abort_callback_data: ctypes.c_void_p + _fields_ = [ ("seed", ctypes.c_uint32), ("n_ctx", ctypes.c_uint32), @@ -771,6 +830,18 @@ class llama_model_quantize_params(ctypes.Structure): kv_overrides (ctypes.c_void_p): pointer to vector containing overrides """ + if TYPE_CHECKING: + nthread: int + ftype: int + output_tensor_type: int + token_embedding_type: int + allow_requantize: bool + quantize_output_tensor: bool + only_copy: bool + pure: bool + imatrix: ctypes.c_void_p + kv_overrides: ctypes.c_void_p + _fields_ = [ ("nthread", ctypes.c_int32), ("ftype", ctypes.c_int), @@ -828,6 +899,10 @@ LLAMA_GRETYPE_CHAR_ALT = 6 # uint32_t value; // Unicode code point or rule ID # } llama_grammar_element; class llama_grammar_element(ctypes.Structure): + if TYPE_CHECKING: + type: int + value: int + _fields_ = [ ("type", ctypes.c_int), ("value", ctypes.c_uint32), @@ -851,6 +926,17 @@ llama_grammar_element_p = ctypes.POINTER(llama_grammar_element) # int32_t n_eval; # }; class llama_timings(ctypes.Structure): + if TYPE_CHECKING: + t_start_ms: float + t_end_ms: float + t_load_ms: float + t_sample_ms: float + t_p_eval_ms: float + t_eval_ms: float + n_sample: int + n_p_eval: int + n_eval: int + _fields_ = [ ("t_start_ms", ctypes.c_double), ("t_end_ms", ctypes.c_double), @@ -951,7 +1037,8 @@ GGML_NUMA_STRATEGY_COUNT = 5 [ctypes.c_int], None, ) -def llama_numa_init(numa: int, /): ... +def llama_numa_init(numa: int, /): + ... # // Call once at the end of the program - currently only used for MPI @@ -976,7 +1063,8 @@ def llama_backend_free(): ) def llama_load_model_from_file( path_model: bytes, params: llama_model_params, / -) -> Optional[llama_model_p]: ... +) -> Optional[llama_model_p]: + ... # LLAMA_API void llama_free_model(struct llama_model * model); @@ -985,7 +1073,8 @@ def llama_load_model_from_file( [llama_model_p_ctypes], None, ) -def llama_free_model(model: llama_model_p, /): ... +def llama_free_model(model: llama_model_p, /): + ... # LLAMA_API struct llama_context * llama_new_context_with_model( @@ -998,7 +1087,8 @@ def llama_free_model(model: llama_model_p, /): ... ) def llama_new_context_with_model( model: llama_model_p, params: llama_context_params, / -) -> Optional[llama_context_p]: ... +) -> Optional[llama_context_p]: + ... # // Frees all allocated memory @@ -1019,82 +1109,98 @@ def llama_free(ctx: llama_context_p, /): [], ctypes.c_int64, ) -def llama_time_us() -> int: ... +def llama_time_us() -> int: + ... # LLAMA_API size_t llama_max_devices(void); @ctypes_function("llama_max_devices", [], ctypes.c_size_t) -def llama_max_devices() -> int: ... +def llama_max_devices() -> int: + ... # LLAMA_API bool llama_supports_mmap (void); @ctypes_function("llama_supports_mmap", [], ctypes.c_bool) -def llama_supports_mmap() -> bool: ... +def llama_supports_mmap() -> bool: + ... # LLAMA_API bool llama_supports_mlock (void); @ctypes_function("llama_supports_mlock", [], ctypes.c_bool) -def llama_supports_mlock() -> bool: ... +def llama_supports_mlock() -> bool: + ... # LLAMA_API bool llama_supports_gpu_offload(void); @ctypes_function("llama_supports_gpu_offload", [], ctypes.c_bool) -def llama_supports_gpu_offload() -> bool: ... +def llama_supports_gpu_offload() -> bool: + ... # LLAMA_API const struct llama_model * llama_get_model(const struct llama_context * ctx); @ctypes_function("llama_get_model", [llama_context_p_ctypes], llama_model_p_ctypes) -def llama_get_model(ctx: llama_context_p, /) -> Optional[llama_model_p]: ... +def llama_get_model(ctx: llama_context_p, /) -> Optional[llama_model_p]: + ... # LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx); @ctypes_function("llama_n_ctx", [llama_context_p_ctypes], ctypes.c_uint32) -def llama_n_ctx(ctx: llama_context_p, /) -> int: ... +def llama_n_ctx(ctx: llama_context_p, /) -> int: + ... # LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx); @ctypes_function("llama_n_batch", [llama_context_p_ctypes], ctypes.c_uint32) -def llama_n_batch(ctx: llama_context_p, /) -> int: ... +def llama_n_batch(ctx: llama_context_p, /) -> int: + ... # LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx); @ctypes_function("llama_n_ubatch", [llama_context_p_ctypes], ctypes.c_uint32) -def llama_n_ubatch(ctx: llama_context_p, /) -> int: ... +def llama_n_ubatch(ctx: llama_context_p, /) -> int: + ... # LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx); @ctypes_function("llama_n_seq_max", [llama_context_p_ctypes], ctypes.c_uint32) -def llama_n_seq_max(ctx: llama_context_p, /) -> int: ... +def llama_n_seq_max(ctx: llama_context_p, /) -> int: + ... # LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_model * model); @ctypes_function("llama_vocab_type", [llama_model_p_ctypes], ctypes.c_int) -def llama_vocab_type(model: llama_model_p, /) -> int: ... +def llama_vocab_type(model: llama_model_p, /) -> int: + ... # LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model); @ctypes_function("llama_rope_type", [llama_model_p_ctypes], ctypes.c_int) -def llama_rope_type(model: llama_model_p, /) -> int: ... +def llama_rope_type(model: llama_model_p, /) -> int: + ... # LLAMA_API int32_t llama_n_vocab (const struct llama_model * model); @ctypes_function("llama_n_vocab", [llama_model_p_ctypes], ctypes.c_int32) -def llama_n_vocab(model: llama_model_p, /) -> int: ... +def llama_n_vocab(model: llama_model_p, /) -> int: + ... # LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model); @ctypes_function("llama_n_ctx_train", [llama_model_p_ctypes], ctypes.c_int32) -def llama_n_ctx_train(model: llama_model_p, /) -> int: ... +def llama_n_ctx_train(model: llama_model_p, /) -> int: + ... # LLAMA_API int32_t llama_n_embd (const struct llama_model * model); @ctypes_function("llama_n_embd", [llama_model_p_ctypes], ctypes.c_int32) -def llama_n_embd(model: llama_model_p, /) -> int: ... +def llama_n_embd(model: llama_model_p, /) -> int: + ... # LLAMA_API int32_t llama_n_layer (const struct llama_model * model); @ctypes_function("llama_n_layer", [llama_model_p_ctypes], ctypes.c_int32) -def llama_n_layer(model: llama_model_p, /) -> int: ... +def llama_n_layer(model: llama_model_p, /) -> int: + ... # // Get the model's RoPE frequency scaling factor @@ -1358,6 +1464,9 @@ class llama_kv_cache_view_cell(ctypes.Structure): pos (llama_pos): The position for this cell. Takes KV cache shifts into account. May be negative if the cell is not populated.""" + if TYPE_CHECKING: + pos: llama_pos + _fields_ = [("pos", llama_pos)] @@ -1394,6 +1503,16 @@ class llama_kv_cache_view_cell(ctypes.Structure): # llama_seq_id * cells_sequences; # }; class llama_kv_cache_view(ctypes.Structure): + if TYPE_CHECKING: + n_cells: int + n_max_seq: int + token_count: int + used_cells: int + max_contiguous: int + max_contiguous_idx: int + cells: CtypesArray[llama_kv_cache_view_cell] + cells_sequences: CtypesArray[llama_seq_id] + _fields_ = [ ("n_cells", ctypes.c_int32), ("n_max_seq", ctypes.c_int32), @@ -1783,7 +1902,8 @@ def llama_state_load_file( n_token_capacity: Union[ctypes.c_size_t, int], n_token_count_out: CtypesPointerOrRef[ctypes.c_size_t], /, -) -> bool: ... +) -> bool: + ... # LLAMA_API DEPRECATED(bool llama_load_session_file( @@ -1811,7 +1931,8 @@ def llama_load_session_file( n_token_capacity: Union[ctypes.c_size_t, int], n_token_count_out: CtypesPointerOrRef[ctypes.c_size_t], /, -) -> int: ... +) -> int: + ... # LLAMA_API bool llama_state_save_file( @@ -1835,7 +1956,8 @@ def llama_state_save_file( tokens: CtypesArray[llama_token], n_token_count: Union[ctypes.c_size_t, int], /, -) -> bool: ... +) -> bool: + ... # LLAMA_API DEPRECATED(bool llama_save_session_file( @@ -1860,7 +1982,8 @@ def llama_save_session_file( tokens: CtypesArray[llama_token], n_token_count: Union[ctypes.c_size_t, int], /, -) -> int: ... +) -> int: + ... # // Get the exact size needed to copy the KV cache of a single sequence @@ -2233,7 +2356,8 @@ def llama_get_embeddings_seq( ) def llama_token_get_text( model: llama_model_p, token: Union[llama_token, int], / -) -> bytes: ... +) -> bytes: + ... # LLAMA_API float llama_token_get_score(const struct llama_model * model, llama_token token); @@ -2242,7 +2366,8 @@ def llama_token_get_text( ) def llama_token_get_score( model: llama_model_p, token: Union[llama_token, int], / -) -> float: ... +) -> float: + ... # LLAMA_API enum llama_token_type llama_token_get_type(const struct llama_model * model, llama_token token); @@ -2251,7 +2376,8 @@ def llama_token_get_score( ) def llama_token_get_type( model: llama_model_p, token: Union[llama_token, int], / -) -> int: ... +) -> int: + ... # // Special tokens @@ -2318,17 +2444,20 @@ def llama_token_prefix(model: llama_model_p) -> int: # LLAMA_API llama_token llama_token_middle(const struct llama_model * model); // Beginning of infill middle @ctypes_function("llama_token_middle", [llama_model_p_ctypes], llama_token) -def llama_token_middle(model: llama_model_p, /) -> int: ... +def llama_token_middle(model: llama_model_p, /) -> int: + ... # LLAMA_API llama_token llama_token_suffix(const struct llama_model * model); // Beginning of infill suffix @ctypes_function("llama_token_suffix", [llama_model_p_ctypes], llama_token) -def llama_token_suffix(model: llama_model_p, /) -> int: ... +def llama_token_suffix(model: llama_model_p, /) -> int: + ... # LLAMA_API llama_token llama_token_eot (const struct llama_model * model); // End of infill middle @ctypes_function("llama_token_eot", [llama_model_p_ctypes], llama_token) -def llama_token_eot(model: llama_model_p, /) -> int: ... +def llama_token_eot(model: llama_model_p, /) -> int: + ... # // @@ -2459,7 +2588,8 @@ def llama_chat_apply_template( chat: CtypesArray[llama_chat_message], n_msg: int, /, -) -> int: ... +) -> int: + ... # // @@ -2989,6 +3119,12 @@ def llama_grammar_accept_token( # bool eob; // Callback should set this to true when a beam is at end-of-beam. # }; class llama_beam_view(ctypes.Structure): + if TYPE_CHECKING: + tokens: CtypesArray[llama_token] + n_tokens: int + p: float + eob: bool + _fields_ = [ ("tokens", llama_token_p), ("n_tokens", ctypes.c_size_t), @@ -3008,6 +3144,12 @@ class llama_beam_view(ctypes.Structure): # bool last_call; // True iff this is the last callback invocation. # }; class llama_beams_state(ctypes.Structure): + if TYPE_CHECKING: + beam_views: CtypesArray[llama_beam_view] + n_beams: int + common_prefix_length: int + last_call: bool + _fields_ = [ ("beam_views", ctypes.POINTER(llama_beam_view)), ("n_beams", ctypes.c_size_t), @@ -3060,7 +3202,8 @@ def llama_beam_search( n_past: Union[ctypes.c_int, int], n_predict: Union[ctypes.c_int, int], /, -): ... +): + ... # /// @details Build a split GGUF final path for this chunk. @@ -3179,4 +3322,5 @@ def llama_log_set( [ctypes.c_void_p, llama_context_p_ctypes], None, ) -def llama_dump_timing_info_yaml(stream: ctypes.c_void_p, ctx: llama_context_p, /): ... +def llama_dump_timing_info_yaml(stream: ctypes.c_void_p, ctx: llama_context_p, /): + ... From 060bfa64d529ade2af9b1f4e207a3937bbc4138f Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Wed, 10 Apr 2024 02:47:01 -0400 Subject: [PATCH 04/29] feat: Add support for yaml based configs --- examples/batch-processing/server.py | 30 +++++++++++++++++++++++++++++ llama_cpp/server/__main__.py | 11 ++++++++++- llama_cpp/server/app.py | 10 +++++++++- pyproject.toml | 1 + 4 files changed, 50 insertions(+), 2 deletions(-) create mode 100644 examples/batch-processing/server.py diff --git a/examples/batch-processing/server.py b/examples/batch-processing/server.py new file mode 100644 index 0000000..d353669 --- /dev/null +++ b/examples/batch-processing/server.py @@ -0,0 +1,30 @@ +"""llama-cpp-python server from scratch in a single file. +""" + +# import llama_cpp + +# path = b"../../models/Qwen1.5-0.5B-Chat-GGUF/qwen1_5-0_5b-chat-q8_0.gguf" + +# model_params = llama_cpp.llama_model_default_params() +# model = llama_cpp.llama_load_model_from_file(path, model_params) + +# if model is None: +# raise RuntimeError(f"Failed to load model from file: {path}") + + +# ctx_params = llama_cpp.llama_context_default_params() +# ctx = llama_cpp.llama_new_context_with_model(model, ctx_params) + +# if ctx is None: +# raise RuntimeError("Failed to create context") + + +from fastapi import FastAPI + +app = FastAPI() + +import openai.types.chat as types + +@app.post("/v1/chat/completions") +def create_chat_completions(): + return {"message": "Hello World"} diff --git a/llama_cpp/server/__main__.py b/llama_cpp/server/__main__.py index fadfc5f..a6f1f4e 100644 --- a/llama_cpp/server/__main__.py +++ b/llama_cpp/server/__main__.py @@ -59,7 +59,16 @@ def main(): if not os.path.exists(config_file): raise ValueError(f"Config file {config_file} not found!") with open(config_file, "rb") as f: - config_file_settings = ConfigFileSettings.model_validate_json(f.read()) + # Check if yaml file + if config_file.endswith(".yaml") or config_file.endswith(".yml"): + import yaml + import json + + config_file_settings = ConfigFileSettings.model_validate_json( + json.dumps(yaml.safe_load(f)) + ) + else: + config_file_settings = ConfigFileSettings.model_validate_json(f.read()) server_settings = ServerSettings.model_validate(config_file_settings) model_settings = config_file_settings.models else: diff --git a/llama_cpp/server/app.py b/llama_cpp/server/app.py index 815ed3c..8211323 100644 --- a/llama_cpp/server/app.py +++ b/llama_cpp/server/app.py @@ -97,7 +97,15 @@ def create_app( if not os.path.exists(config_file): raise ValueError(f"Config file {config_file} not found!") with open(config_file, "rb") as f: - config_file_settings = ConfigFileSettings.model_validate_json(f.read()) + # Check if yaml file + if config_file.endswith(".yaml") or config_file.endswith(".yml"): + import yaml + + config_file_settings = ConfigFileSettings.model_validate_json( + json.dumps(yaml.safe_load(f)) + ) + else: + config_file_settings = ConfigFileSettings.model_validate_json(f.read()) server_settings = ServerSettings.model_validate(config_file_settings) model_settings = config_file_settings.models diff --git a/pyproject.toml b/pyproject.toml index 2f3d3ce..e2bbb4b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,7 @@ server = [ "pydantic-settings>=2.0.1", "sse-starlette>=1.6.1", "starlette-context>=0.3.6,<0.4", + "PyYAML>=5.1", ] test = [ "pytest>=7.4.0", From bb65b4d76411112c6fb0bf759efd746f99ef3c6b Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Wed, 10 Apr 2024 03:41:55 -0400 Subject: [PATCH 05/29] fix: pass correct type to chat handlers for chat completion logprobs --- llama_cpp/llama.py | 3 ++- llama_cpp/llama_chat_format.py | 24 ++++++++++++++++-------- 2 files changed, 18 insertions(+), 9 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index e07d57a..466dc22 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -1664,7 +1664,8 @@ class Llama: top_k=top_k, min_p=min_p, typical_p=typical_p, - logprobs=top_logprobs if logprobs else None, + logprobs=logprobs, + top_logprobs=top_logprobs, stream=stream, stop=stop, seed=seed, diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index 705202e..519d2f5 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -77,6 +77,8 @@ class LlamaChatCompletionHandler(Protocol): mirostat_eta: float = 0.1, logits_processor: Optional[llama.LogitsProcessorList] = None, grammar: Optional[llama.LlamaGrammar] = None, + logprobs: Optional[bool] = None, + top_logprobs: Optional[int] = None, **kwargs, # type: ignore ) -> Union[ llama_types.CreateChatCompletionResponse, @@ -338,7 +340,7 @@ def _convert_completion_to_chat_function( } ], }, - "logprobs": None, + "logprobs": completion["choices"][0]["logprobs"], "finish_reason": "tool_calls", } ], @@ -391,7 +393,7 @@ def _convert_completion_to_chat_function( { "index": 0, "finish_reason": None, - "logprobs": None, + "logprobs": chunk["choices"][0]["logprobs"], "delta": { "role": None, "content": None, @@ -426,7 +428,7 @@ def _convert_completion_to_chat_function( { "index": 0, "finish_reason": None, - "logprobs": None, + "logprobs": chunk["choices"][0]["logprobs"], "delta": { "role": None, "content": None, @@ -491,7 +493,6 @@ def chat_formatter_to_chat_completion_handler( temperature: float = 0.2, top_p: float = 0.95, top_k: int = 40, - logprobs: int = 0, min_p: float = 0.05, typical_p: float = 1.0, stream: bool = False, @@ -512,6 +513,8 @@ def chat_formatter_to_chat_completion_handler( logits_processor: Optional[llama.LogitsProcessorList] = None, grammar: Optional[llama.LlamaGrammar] = None, logit_bias: Optional[Dict[str, float]] = None, + logprobs: Optional[bool] = None, + top_logprobs: Optional[int] = None, **kwargs, # type: ignore ) -> Union[ llama_types.CreateChatCompletionResponse, @@ -581,7 +584,7 @@ def chat_formatter_to_chat_completion_handler( top_k=top_k, min_p=min_p, typical_p=typical_p, - logprobs=logprobs, + logprobs=top_logprobs if logprobs else None, stream=stream, stop=stop, seed=seed, @@ -1628,7 +1631,7 @@ def functionary_chat_handler( } ], }, - "logprobs": None, + "logprobs": completion["choices"][0]["logprobs"], "finish_reason": "tool_calls", } ], @@ -2085,7 +2088,7 @@ def functionary_v1_v2_chat_handler( choices=[ { "index": 0, - "logprobs": None, + "logprobs": completion["choices"][0]["logprobs"], "message": { "role": "assistant", "content": None if content == "" else content, @@ -2311,11 +2314,14 @@ def chatml_function_calling( model: Optional[str] = None, logits_processor: Optional[llama.LogitsProcessorList] = None, grammar: Optional[llama.LlamaGrammar] = None, + logprobs: Optional[bool] = None, + top_logprobs: Optional[int] = None, **kwargs, # type: ignore ) -> Union[ llama_types.CreateChatCompletionResponse, Iterator[llama_types.CreateChatCompletionStreamResponse], ]: + print(logprobs) function_calling_template = ( "{% for message in messages %}" "<|im_start|>{{ message.role }}\n" @@ -2437,6 +2443,7 @@ def chatml_function_calling( model=model, logits_processor=logits_processor, grammar=grammar, + logprobs=top_logprobs if logprobs else None, ), stream=stream, ) @@ -2549,6 +2556,7 @@ def chatml_function_calling( typical_p=typical_p, stream=stream, stop=["<|im_end|>"], + logprobs=top_logprobs if logprobs else None, max_tokens=None, presence_penalty=presence_penalty, frequency_penalty=frequency_penalty, @@ -2660,7 +2668,7 @@ def chatml_function_calling( { "finish_reason": "tool_calls", "index": 0, - "logprobs": None, + "logprobs": completion["choices"][0]["logprobs"], "message": { "role": "assistant", "content": None, From ef29235d453ecf552f015f6ec04270b233bb22c9 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Wed, 10 Apr 2024 03:44:46 -0400 Subject: [PATCH 06/29] chore: Bump version --- CHANGELOG.md | 7 +++++++ llama_cpp/__init__.py | 2 +- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index dcb7a81..c67498e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [0.2.61] + +- feat: Update llama.cpp to ggerganov/llama.cpp@ba5e134e073ec6837078c874aba44a702944a676 +- fix: pass correct type to chat handlers for chat completion logprobs by @abetlen in bb65b4d76411112c6fb0bf759efd746f99ef3c6b +- feat: Add support for yaml based server configs by @abetlen in 060bfa64d529ade2af9b1f4e207a3937bbc4138f +- feat: Add typechecking for ctypes structure attributes by @abetlen in 1347e1d050fc5a9a32ffe0bb3e22858da28003bd + ## [0.2.60] - feat: Update llama.cpp to ggerganov/llama.cpp@75cd4c77292034ecec587ecb401366f57338f7c0 diff --git a/llama_cpp/__init__.py b/llama_cpp/__init__.py index d5db993..2382db9 100644 --- a/llama_cpp/__init__.py +++ b/llama_cpp/__init__.py @@ -1,4 +1,4 @@ from .llama_cpp import * from .llama import * -__version__ = "0.2.60" \ No newline at end of file +__version__ = "0.2.61" \ No newline at end of file From 2e9ffd28fd99958a58fccad6db4614d8c6b555a5 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Fri, 12 Apr 2024 21:09:12 -0400 Subject: [PATCH 07/29] feat: Update llama.cpp --- vendor/llama.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vendor/llama.cpp b/vendor/llama.cpp index ba5e134..ab9a324 160000 --- a/vendor/llama.cpp +++ b/vendor/llama.cpp @@ -1 +1 @@ -Subproject commit ba5e134e073ec6837078c874aba44a702944a676 +Subproject commit ab9a3240a9da941fdef5cd4a25f2b97c2f5a67aa From 90dceaba8a9938ca4d7e7d2cb0997e613bad7040 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Sun, 14 Apr 2024 11:35:57 -0400 Subject: [PATCH 08/29] feat: Update llama.cpp --- vendor/llama.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vendor/llama.cpp b/vendor/llama.cpp index ab9a324..f184dd9 160000 --- a/vendor/llama.cpp +++ b/vendor/llama.cpp @@ -1 +1 @@ -Subproject commit ab9a3240a9da941fdef5cd4a25f2b97c2f5a67aa +Subproject commit f184dd920852d6d372b754f871ee06cfe6f977ad From a420f9608bbd3b76e8bfbb6cdcf4d3fa69efe5c0 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Sun, 14 Apr 2024 19:14:09 -0400 Subject: [PATCH 09/29] feat: Update llama.cpp --- vendor/llama.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vendor/llama.cpp b/vendor/llama.cpp index f184dd9..1958f7e 160000 --- a/vendor/llama.cpp +++ b/vendor/llama.cpp @@ -1 +1 @@ -Subproject commit f184dd920852d6d372b754f871ee06cfe6f977ad +Subproject commit 1958f7e06ca2d2e3ab5698cc67513ba359144d8e From c96b2daebf7b5dbe3cfd9194b249ab579a28d633 Mon Sep 17 00:00:00 2001 From: ddh0 <40664579+ddh0@users.noreply.github.com> Date: Wed, 17 Apr 2024 09:04:33 -0500 Subject: [PATCH 10/29] feat: Use all available CPUs for batch processing (#1345) --- llama_cpp/llama.py | 4 +--- llama_cpp/server/settings.py | 2 +- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 466dc22..dfac9bb 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -262,9 +262,7 @@ class Llama: self.n_batch = min(n_ctx, n_batch) # ??? self.n_threads = n_threads or max(multiprocessing.cpu_count() // 2, 1) - self.n_threads_batch = n_threads_batch or max( - multiprocessing.cpu_count() // 2, 1 - ) + self.n_threads_batch = n_threads_batch or multiprocessing.cpu_count() # Context Params self.context_params = llama_cpp.llama_context_default_params() diff --git a/llama_cpp/server/settings.py b/llama_cpp/server/settings.py index 9ebdd0d..811c6ca 100644 --- a/llama_cpp/server/settings.py +++ b/llama_cpp/server/settings.py @@ -70,7 +70,7 @@ class ModelSettings(BaseSettings): description="The number of threads to use.", ) n_threads_batch: int = Field( - default=max(multiprocessing.cpu_count() // 2, 1), + default=max(multiprocessing.cpu_count(), 1), ge=0, description="The number of threads to use when batch processing.", ) From 9842cbf99d048a121cff03f01340592ad4b7fc1e Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Wed, 17 Apr 2024 10:06:15 -0400 Subject: [PATCH 11/29] feat: Update llama.cpp --- vendor/llama.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vendor/llama.cpp b/vendor/llama.cpp index 1958f7e..8dd1ec8 160000 --- a/vendor/llama.cpp +++ b/vendor/llama.cpp @@ -1 +1 @@ -Subproject commit 1958f7e06ca2d2e3ab5698cc67513ba359144d8e +Subproject commit 8dd1ec8b3ffbfa2d26e82e672cea89f5eeb2f141 From 4924455decd79273c8c695a8ff796306ac0df30d Mon Sep 17 00:00:00 2001 From: tc-wolf <50339167+tc-wolf@users.noreply.github.com> Date: Wed, 17 Apr 2024 09:06:50 -0500 Subject: [PATCH 12/29] feat: Make saved state more compact on-disk (#1296) * State load/save changes - Only store up to `n_tokens` logits instead of full `(n_ctx, n_vocab)` sized array. - Difference between ~350MB and ~1500MB for example prompt with ~300 tokens (makes sense lol) - Auto-formatting changes * Back out formatting changes --- llama_cpp/llama.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index dfac9bb..5a0111b 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -18,6 +18,7 @@ from typing import ( Iterator, Deque, Callable, + Dict, ) from collections import deque from pathlib import Path @@ -1791,7 +1792,7 @@ class Llama: file=sys.stderr, ) return LlamaState( - scores=self.scores.copy(), + scores=self._scores.copy(), input_ids=self.input_ids.copy(), n_tokens=self.n_tokens, llama_state=bytes(llama_state_compact), @@ -1800,7 +1801,9 @@ class Llama: def load_state(self, state: LlamaState) -> None: assert self._ctx.ctx is not None - self.scores = state.scores.copy() + # Only filling in up to `n_tokens` and then zero-ing out the rest + self.scores[: state.n_tokens, :] = state.scores.copy() + self.scores[state.n_tokens :, :] = 0.0 self.input_ids = state.input_ids.copy() self.n_tokens = state.n_tokens state_size = state.llama_state_size @@ -1951,7 +1954,6 @@ class Llama: local_dir_use_symlinks=local_dir_use_symlinks, cache_dir=cache_dir, local_files_only=True, - ) else: model_path = os.path.join(local_dir, filename) From b73c73c0c67f559fd9c7d620ad3d4e24d5c4bc4c Mon Sep 17 00:00:00 2001 From: khimaros Date: Wed, 17 Apr 2024 14:08:19 +0000 Subject: [PATCH 13/29] feat: add `disable_ping_events` flag (#1257) for backward compatibility, this is false by default it can be set to true to disable EventSource pings which are not supported by some OpenAI clients. fixes https://github.com/abetlen/llama-cpp-python/issues/1256 --- llama_cpp/server/app.py | 12 ++++++++++++ llama_cpp/server/settings.py | 4 ++++ 2 files changed, 16 insertions(+) diff --git a/llama_cpp/server/app.py b/llama_cpp/server/app.py index 8211323..b6ed9b1 100644 --- a/llama_cpp/server/app.py +++ b/llama_cpp/server/app.py @@ -87,6 +87,13 @@ def get_llama_proxy(): llama_outer_lock.release() +_ping_message_factory = None + +def set_ping_message_factory(factory): + global _ping_message_factory + _ping_message_factory = factory + + def create_app( settings: Settings | None = None, server_settings: ServerSettings | None = None, @@ -138,6 +145,9 @@ def create_app( assert model_settings is not None set_llama_proxy(model_settings=model_settings) + if server_settings.disable_ping_events: + set_ping_message_factory(lambda: bytes()) + return app @@ -302,6 +312,7 @@ async def create_completion( iterator=iterator(), ), sep="\n", + ping_message_factory=_ping_message_factory, ) else: return iterator_or_completion @@ -470,6 +481,7 @@ async def create_chat_completion( iterator=iterator(), ), sep="\n", + ping_message_factory=_ping_message_factory, ) else: return iterator_or_completion diff --git a/llama_cpp/server/settings.py b/llama_cpp/server/settings.py index 811c6ca..934aecd 100644 --- a/llama_cpp/server/settings.py +++ b/llama_cpp/server/settings.py @@ -195,6 +195,10 @@ class ServerSettings(BaseSettings): default=True, description="Whether to interrupt requests when a new request is received.", ) + disable_ping_events: bool = Field( + default=False, + description="Disable EventSource pings (may be needed for some clients).", + ) class Settings(ServerSettings, ModelSettings): From 610a592f708f2a5a8b0e1d6f0900f6337b9beb39 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lucca=20Zen=C3=B3bio?= Date: Wed, 17 Apr 2024 11:10:21 -0300 Subject: [PATCH 14/29] feat: Update json to grammar (#1350) * feat: improve function calling * feat:grammar --- llama_cpp/llama_chat_format.py | 2 +- llama_cpp/llama_grammar.py | 612 +++++++++++++++++++++++++++------ 2 files changed, 514 insertions(+), 100 deletions(-) diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index 519d2f5..eb98cbf 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -2709,4 +2709,4 @@ def chatml_function_calling( }, } - raise ValueError("Automatic streaming tool choice is not supported") + raise ValueError("Automatic streaming tool choice is not supported") \ No newline at end of file diff --git a/llama_cpp/llama_grammar.py b/llama_cpp/llama_grammar.py index 9cc48a9..8c0f8aa 100644 --- a/llama_cpp/llama_grammar.py +++ b/llama_cpp/llama_grammar.py @@ -5,11 +5,12 @@ from pathlib import Path import sys from ctypes import * # type: ignore from enum import Enum -from itertools import islice +from itertools import islice, groupby from typing import ( Any, Callable, Dict, + Set, Generic, List, Optional, @@ -1391,139 +1392,552 @@ from typing import List, Optional # whitespace. Also maybe improves generation quality? SPACE_RULE = '" "?' -PRIMITIVE_RULES = { - "boolean": '("true" | "false") space', - "number": '("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? space', - "integer": '("-"? ([0-9] | [1-9] [0-9]*)) space', - "string": r""" "\"" ( - [^"\\] | - "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) - )* "\"" space """, - "null": '"null" space', -} INVALID_RULE_CHARS_RE = re.compile(r"[^a-zA-Z0-9-]+") GRAMMAR_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"]') GRAMMAR_LITERAL_ESCAPES = {"\r": "\\r", "\n": "\\n", '"': '\\"'} +# whitespace is constrained to a single space char to prevent model "running away" in +# whitespace. Also maybe improves generation quality? +SPACE_RULE = '" "?' + + +def _build_repetition(item_rule, min_items, max_items, separator_rule=None, item_rule_is_literal=False): + if not separator_rule: + if min_items == 0 and max_items == 1: + return f'{item_rule}?' + elif min_items == 1 and max_items is None: + return f'{item_rule}+' + + result = '' + + if min_items > 0: + if item_rule_is_literal and separator_rule is None: + result = '"' + (item_rule[1:-1] * min_items) + '"' + else: + result = (f' {separator_rule} ' if separator_rule else ' ').join([item_rule] * min_items) + + def opt_repetitions(up_to_n, prefix_with_sep=False): + ''' + - n=4, no sep: '(a (a (a (a)?)?)?)?' + - n=4, sep=',', prefix: '("," a ("," a ("," a ("," a)?)?)?)?' + - n=4, sep=',', no prefix: '(a ("," a ("," a ("," a)?)?)?)?' + ''' + + content = f'{separator_rule} {item_rule}' if prefix_with_sep and separator_rule else item_rule + if up_to_n == 0: + return '' + elif up_to_n == 1: + return f'({content})?' + elif separator_rule and not prefix_with_sep: + return f'({content} {opt_repetitions(up_to_n - 1, prefix_with_sep=True)})?' + else: + return (f'({content} ' * up_to_n).rstrip() + (')?' * up_to_n) + + if min_items > 0 and max_items != min_items: + result += ' ' + + if max_items is not None: + result += opt_repetitions(max_items - min_items, prefix_with_sep=min_items > 0) + else: + item_operator = f'({separator_rule + " " if separator_rule else ""}{item_rule})' + + if min_items == 0 and separator_rule: + result = f'({item_rule} {item_operator}*)?' + else: + result += f'{item_operator}*' + + return result + + + +class BuiltinRule: + def __init__(self, content: str, deps: list = None): + self.content = content + self.deps = deps or [] + +_up_to_15_digits = _build_repetition('[0-9]', 0, 15) + +PRIMITIVE_RULES = { + 'boolean' : BuiltinRule('("true" | "false") space', []), + 'decimal-part' : BuiltinRule('[0-9] ' + _up_to_15_digits, []), + 'integral-part': BuiltinRule('[0-9] | [1-9] ' + _up_to_15_digits, []), + 'number' : BuiltinRule('("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space', ['integral-part', 'decimal-part']), + 'integer' : BuiltinRule('("-"? integral-part) space', ['integral-part']), + 'value' : BuiltinRule('object | array | string | number | boolean | null', ['object', 'array', 'string', 'number', 'boolean', 'null']), + 'object' : BuiltinRule('"{" space ( string ":" space value ("," space string ":" space value)* )? "}" space', ['string', 'value']), + 'array' : BuiltinRule('"[" space ( value ("," space value)* )? "]" space', ['value']), + 'uuid' : BuiltinRule(r'"\"" ' + ' "-" '.join('[0-9a-fA-F]' * n for n in [8, 4, 4, 4, 12]) + r' "\"" space', []), + 'char' : BuiltinRule(r'[^"\\] | "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])', []), + 'string' : BuiltinRule(r'"\"" char* "\"" space', ['char']), + 'null' : BuiltinRule('"null" space', []), +} + +# TODO: support "uri", "email" string formats +STRING_FORMAT_RULES = { + 'date' : BuiltinRule('[0-9] [0-9] [0-9] [0-9] "-" ( "0" [1-9] | "1" [0-2] ) "-" ( \"0\" [1-9] | [1-2] [0-9] | "3" [0-1] )', []), + 'time' : BuiltinRule('([01] [0-9] | "2" [0-3]) ":" [0-5] [0-9] ":" [0-5] [0-9] ( "." [0-9] [0-9] [0-9] )? ( "Z" | ( "+" | "-" ) ( [01] [0-9] | "2" [0-3] ) ":" [0-5] [0-9] )', []), + 'date-time' : BuiltinRule('date "T" time', ['date', 'time']), + 'date-string' : BuiltinRule('"\\"" date "\\"" space', ['date']), + 'time-string' : BuiltinRule('"\\"" time "\\"" space', ['time']), + 'date-time-string': BuiltinRule('"\\"" date-time "\\"" space', ['date-time']), +} + +DOTALL = '[\\U00000000-\\U0010FFFF]' +DOT = '[^\\x0A\\x0D]' + +RESERVED_NAMES = set(["root", "dot", *PRIMITIVE_RULES.keys(), *STRING_FORMAT_RULES.keys()]) + + +NON_LITERAL_SET = set('|.()[]{}*+?') +ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = set('[]()|{}*+?') + + + class SchemaConverter: def __init__(self, prop_order): self._prop_order = prop_order self._rules = {"space": SPACE_RULE} self._defs: Dict[str, Any] = {} + self._refs = {} + self._refs_being_resolved = set() - def _format_literal(self, literal: str): - escaped: str = GRAMMAR_LITERAL_ESCAPE_RE.sub( - lambda m: GRAMMAR_LITERAL_ESCAPES.get(m.group(0)), json.dumps(literal) + def _format_literal(self, literal): + escaped = GRAMMAR_LITERAL_ESCAPE_RE.sub( + lambda m: GRAMMAR_LITERAL_ESCAPES.get(m.group(0)), literal ) return f'"{escaped}"' - def _add_rule(self, name: str, rule: str): - esc_name = INVALID_RULE_CHARS_RE.sub("-", name) + def not_literal(self, literal: str, dotall: bool = True, maybe_escaped_underscores = False) -> str: + ''' + not_literal('a') -> '[^a]' + not_literal('abc') -> '([^a] | "a" ([^b] | "b" ([^c])?)?)?' + ''' + assert len(literal) > 0, 'Empty literal not supported' + def recurse(i: int): + c = literal[i] + if maybe_escaped_underscores and c == '_': + yield f'[^{c}\\\\]' + yield ' | ' + yield f'"\\\\"? "{c}"' + else: + yield f'[^{c}]' + if i < len(literal) - 1: + yield ' | ' + yield self._format_literal(c) + yield ' (' + yield from recurse(i + 1) + yield ')?' + + return ''.join(('(', *recurse(0), ')')) + + def _add_rule(self, name, rule): + esc_name = INVALID_RULE_CHARS_RE.sub('-', name) if esc_name not in self._rules or self._rules[esc_name] == rule: key = esc_name else: i = 0 - while f"{esc_name}{i}" in self._rules: + while f'{esc_name}{i}' in self._rules and self._rules[f'{esc_name}{i}'] != rule: i += 1 - key = f"{esc_name}{i}" + key = f'{esc_name}{i}' self._rules[key] = rule return key - def visit(self, schema: Dict[str, Any], name: str) -> str: - rule_name = name or "root" + def resolve_refs(self, schema: dict, url: str): + ''' + Resolves all $ref fields in the given schema, fetching any remote schemas, + replacing $ref with absolute reference URL and populating self._refs with the + respective referenced (sub)schema dictionaries. + ''' + def visit(n: dict): + if isinstance(n, list): + return [visit(x) for x in n] + elif isinstance(n, dict): + ref = n.get('$ref') + if ref is not None and ref not in self._refs: + if ref.startswith('https://'): + assert self._allow_fetch, 'Fetching remote schemas is not allowed (use --allow-fetch for force)' + import requests - if "$defs" in schema: - # add defs to self._defs for later inlining - for def_name, def_schema in schema["$defs"].items(): - self._defs[def_name] = def_schema + frag_split = ref.split('#') + base_url = frag_split[0] - if "oneOf" in schema or "anyOf" in schema: - rule = " | ".join( - ( - self.visit(alt_schema, f'{name}{"-" if name else ""}{i}') - for i, alt_schema in enumerate( - schema.get("oneOf") or schema["anyOf"] - ) - ) + target = self._refs.get(base_url) + if target is None: + target = self.resolve_refs(requests.get(ref).json(), base_url) + self._refs[base_url] = target + + if len(frag_split) == 1 or frag_split[-1] == '': + return target + elif ref.startswith('#/'): + target = schema + ref = f'{url}{ref}' + n['$ref'] = ref + else: + raise ValueError(f'Unsupported ref {ref}') + + for sel in ref.split('#')[-1].split('/')[1:]: + assert target is not None and sel in target, f'Error resolving ref {ref}: {sel} not in {target}' + target = target[sel] + + self._refs[ref] = target + else: + for v in n.values(): + visit(v) + + return n + return visit(schema) + + def _generate_union_rule(self, name, alt_schemas): + return ' | '.join(( + self.visit(alt_schema, f'{name}{"-" if name else "alternative-"}{i}') + for i, alt_schema in enumerate(alt_schemas) + )) + + def _visit_pattern(self, pattern, name): + ''' + Transforms a regular expression pattern into a GBNF rule. + + Input: https://json-schema.org/understanding-json-schema/reference/regular_expressions + Output: https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md + + Unsupported features: negative/positive lookaheads, greedy/non-greedy modifiers. + + Mostly a 1:1 translation, except for {x} / {x,} / {x,y} quantifiers for which + we define sub-rules to keep the output lean. + ''' + + assert pattern.startswith('^') and pattern.endswith('$'), 'Pattern must start with "^" and end with "$"' + pattern = pattern[1:-1] + sub_rule_ids = {} + + i = 0 + length = len(pattern) + + def to_rule(s: Tuple[str, bool]) -> str: + (txt, is_literal) = s + return "\"" + txt + "\"" if is_literal else txt + + def transform() -> Tuple[str, bool]: + ''' + Parse a unit at index i (advancing it), and return its string representation + whether it's a literal. + ''' + nonlocal i + nonlocal pattern + nonlocal sub_rule_ids + + start = i + # For each component of this sequence, store its string representation and whether it's a literal. + # We only need a flat structure here to apply repetition operators to the last item, and + # to merge literals at the and (we're parsing grouped ( sequences ) recursively and don't treat '|' specially + # (GBNF's syntax is luckily very close to regular expressions!) + seq: list[Tuple[str, bool]] = [] + + def get_dot(): + if self._dotall: + rule = DOTALL + else: + # Accept any character... except \n and \r line break chars (\x0A and \xOD) + rule = DOT + return self._add_rule(f'dot', rule) + + def join_seq(): + nonlocal seq + ret = [] + for is_literal, g in groupby(seq, lambda x: x[1]): + if is_literal: + ret.append((''.join(x[0] for x in g), True)) + else: + ret.extend(g) + if len(ret) == 1: + return ret[0] + return (' '.join(to_rule(x) for x in seq), False) + + while i < length: + c = pattern[i] + if c == '.': + seq.append((get_dot(), False)) + i += 1 + elif c == '(': + i += 1 + if i < length: + assert pattern[i] != '?', f'Unsupported pattern syntax "{pattern[i]}" at index {i} of /{pattern}/' + seq.append((f'({to_rule(transform())})', False)) + elif c == ')': + i += 1 + assert start > 0 and pattern[start-1] == '(', f'Unbalanced parentheses; start = {start}, i = {i}, pattern = {pattern}' + return join_seq() + elif c == '[': + square_brackets = c + i += 1 + while i < length and pattern[i] != ']': + if pattern[i] == '\\': + square_brackets += pattern[i:i+2] + i += 2 + else: + square_brackets += pattern[i] + i += 1 + assert i < length, f'Unbalanced square brackets; start = {start}, i = {i}, pattern = {pattern}' + square_brackets += ']' + i += 1 + seq.append((square_brackets, False)) + elif c == '|': + seq.append(('|', False)) + i += 1 + elif c in ('*', '+', '?'): + seq[-1] = (to_rule(seq[-1]) + c, False) + i += 1 + elif c == '{': + curly_brackets = c + i += 1 + while i < length and pattern[i] != '}': + curly_brackets += pattern[i] + i += 1 + assert i < length, f'Unbalanced curly brackets; start = {start}, i = {i}, pattern = {pattern}' + curly_brackets += '}' + i += 1 + nums = [s.strip() for s in curly_brackets[1:-1].split(',')] + min_times = 0 + max_times = None + try: + if len(nums) == 1: + min_times = int(nums[0]) + max_times = min_times + else: + assert len(nums) == 2 + min_times = int(nums[0]) if nums[0] else 0 + max_times = int(nums[1]) if nums[1] else None + except ValueError: + raise ValueError(f'Invalid quantifier {curly_brackets} in /{pattern}/') + + (sub, sub_is_literal) = seq[-1] + + if not sub_is_literal: + id = sub_rule_ids.get(sub) + if id is None: + id = self._add_rule(f'{name}-{len(sub_rule_ids) + 1}', sub) + sub_rule_ids[sub] = id + sub = id + + seq[-1] = (_build_repetition(f'"{sub}"' if sub_is_literal else sub, min_times, max_times, item_rule_is_literal=sub_is_literal), False) + else: + literal = '' + while i < length: + if pattern[i] == '\\' and i < length - 1: + next = pattern[i + 1] + if next in ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS: + i += 1 + literal += pattern[i] + i += 1 + else: + literal += pattern[i:i+2] + i += 2 + elif pattern[i] == '"' and not self._raw_pattern: + literal += '\\"' + i += 1 + elif pattern[i] not in NON_LITERAL_SET and \ + (i == length - 1 or literal == '' or pattern[i+1] == '.' or pattern[i+1] not in NON_LITERAL_SET): + literal += pattern[i] + i += 1 + else: + break + if literal: + seq.append((literal, True)) + + return join_seq() + + return self._add_rule( + name, + to_rule(transform()) if self._raw_pattern \ + else "\"\\\"\" " + to_rule(transform()) + " \"\\\"\" space") + + + def _resolve_ref(self, ref): + ref_name = ref.split('/')[-1] + if ref_name not in self._rules and ref not in self._refs_being_resolved: + self._refs_being_resolved.add(ref) + resolved = self._refs[ref] + ref_name = self.visit(resolved, ref_name) + self._refs_being_resolved.remove(ref) + return ref_name + + def _generate_constant_rule(self, value): + return self._format_literal(json.dumps(value)) + + def visit(self, schema, name): + schema_type = schema.get('type') + schema_format = schema.get('format') + rule_name = name + '-' if name in RESERVED_NAMES else name or 'root' + + if (ref := schema.get('$ref')) is not None: + return self._add_rule(rule_name, self._resolve_ref(ref)) + + elif 'oneOf' in schema or 'anyOf' in schema: + return self._add_rule(rule_name, self._generate_union_rule(name, schema.get('oneOf') or schema['anyOf'])) + + elif isinstance(schema_type, list): + return self._add_rule(rule_name, self._generate_union_rule(name, [{'type': t} for t in schema_type])) + + elif 'const' in schema: + return self._add_rule(rule_name, self._generate_constant_rule(schema['const'])) + + elif 'enum' in schema: + rule = ' | '.join((self._generate_constant_rule(v) for v in schema['enum'])) + return self._add_rule(rule_name, rule) + + elif schema_type in (None, 'object') and \ + ('properties' in schema or \ + ('additionalProperties' in schema and schema['additionalProperties'] is not True)): + required = set(schema.get('required', [])) + properties = list(schema.get('properties', {}).items()) + return self._add_rule(rule_name, self._build_object_rule(properties, required, name, schema.get('additionalProperties'))) + + elif schema_type in (None, 'object') and 'allOf' in schema: + required = set() + properties = [] + hybrid_name = name + def add_component(comp_schema, is_required): + if (ref := comp_schema.get('$ref')) is not None: + comp_schema = self._refs[ref] + + if 'properties' in comp_schema: + for prop_name, prop_schema in comp_schema['properties'].items(): + properties.append((prop_name, prop_schema)) + if is_required: + required.add(prop_name) + + for t in schema['allOf']: + if 'anyOf' in t: + for tt in t['anyOf']: + add_component(tt, is_required=False) + else: + add_component(t, is_required=True) + + return self._add_rule(rule_name, self._build_object_rule(properties, required, hybrid_name, additional_properties=[])) + + elif schema_type in (None, 'array') and ('items' in schema or 'prefixItems' in schema): + items = schema.get('items') or schema['prefixItems'] + if isinstance(items, list): + return self._add_rule( + rule_name, + '"[" space ' + + ' "," space '.join( + self.visit(item, f'{name}{"-" if name else ""}tuple-{i}') + for i, item in enumerate(items)) + + ' "]" space') + else: + item_rule_name = self.visit(items, f'{name}{"-" if name else ""}item') + min_items = schema.get("minItems", 0) + max_items = schema.get("maxItems") + return self._add_rule(rule_name, '"[" space ' + _build_repetition(item_rule_name, min_items, max_items, separator_rule='"," space') + ' "]" space') + + elif schema_type in (None, 'string') and 'pattern' in schema: + return self._visit_pattern(schema['pattern'], rule_name) + + elif schema_type in (None, 'string') and re.match(r'^uuid[1-5]?$', schema_format or ''): + return self._add_primitive( + 'root' if rule_name == 'root' else schema_format, + PRIMITIVE_RULES['uuid'] ) - return self._add_rule(rule_name, rule) - elif "const" in schema: - return self._add_rule(rule_name, self._format_literal(schema["const"])) + elif schema_type in (None, 'string') and f'{schema_format}-string' in STRING_FORMAT_RULES: + prim_name = f'{schema_format}-string' + return self._add_rule(rule_name, self._add_primitive(prim_name, STRING_FORMAT_RULES[prim_name])) - elif "enum" in schema: - rule = " | ".join((self._format_literal(v) for v in schema["enum"])) - return self._add_rule(rule_name, rule) + elif schema_type == 'string' and ('minLength' in schema or 'maxLength' in schema): + char_rule = self._add_primitive('char', PRIMITIVE_RULES['char']) + min_len = schema.get('minLength', 0) + max_len = schema.get('maxLength') - elif "$ref" in schema: - ref = schema["$ref"] - assert ref.startswith("#/$defs/"), f"Unrecognized schema: {schema}" - # inline $defs - def_name = ref[len("#/$defs/") :] - def_schema = self._defs[def_name] - return self.visit(def_schema, f'{name}{"-" if name else ""}{def_name}') + return self._add_rule(rule_name, r'"\"" ' + _build_repetition(char_rule, min_len, max_len) + r' "\"" space') - - schema_type: Optional[str] = schema.get("type") # type: ignore - assert isinstance(schema_type, str), f"Unrecognized schema: {schema}" - - if schema_type == "object" and "properties" in schema: - # TODO: `required` keyword - if self._prop_order: - prop_order = self._prop_order - prop_pairs = sorted( - schema["properties"].items(), - # sort by position in prop_order (if specified) then by key - key=lambda kv: (prop_order.get(kv[0], len(prop_order)), kv[0]), - ) - else: - prop_pairs = schema["properties"].items() - - rule = '"{" space' - for i, (prop_name, prop_schema) in enumerate(prop_pairs): - prop_rule_name = self.visit( - prop_schema, f'{name}{"-" if name else ""}{prop_name}' - ) - if i > 0: - rule += ' "," space' - rule += rf' {self._format_literal(prop_name)} space ":" space {prop_rule_name}' - rule += ' "}" space' - - return self._add_rule(rule_name, rule) - - elif schema_type == "array" and "items" in schema: - # TODO `prefixItems` keyword - item_rule_name = self.visit( - schema["items"], f'{name}{"-" if name else ""}item' - ) - list_item_operator = f'("," space {item_rule_name})' - successive_items = "" - min_items = schema.get("minItems", 0) - if min_items > 0: - first_item = f"({item_rule_name})" - successive_items = list_item_operator * (min_items - 1) - min_items -= 1 - else: - first_item = f"({item_rule_name})?" - max_items = schema.get("maxItems") - if max_items is not None and max_items > min_items: - successive_items += (list_item_operator + "?") * (max_items - min_items - 1) - else: - successive_items += list_item_operator + "*" - rule = f'"[" space {first_item} {successive_items} "]" space' - return self._add_rule(rule_name, rule) + elif (schema_type == 'object') or (len(schema) == 0): + return self._add_rule(rule_name, self._add_primitive('object', PRIMITIVE_RULES['object'])) else: - assert schema_type in PRIMITIVE_RULES, f"Unrecognized schema: {schema}" - return self._add_rule( - "root" if rule_name == "root" else schema_type, - PRIMITIVE_RULES[schema_type], + assert schema_type in PRIMITIVE_RULES, f'Unrecognized schema: {schema}' + # TODO: support minimum, maximum, exclusiveMinimum, exclusiveMaximum at least for zero + return self._add_primitive('root' if rule_name == 'root' else schema_type, PRIMITIVE_RULES[schema_type]) + + def _add_primitive(self, name: str, rule: BuiltinRule): + n = self._add_rule(name, rule.content) + + for dep in rule.deps: + dep_rule = PRIMITIVE_RULES.get(dep) or STRING_FORMAT_RULES.get(dep) + assert dep_rule, f'Rule {dep} not known' + if dep not in self._rules: + self._add_primitive(dep, dep_rule) + return n + + def _build_object_rule(self, properties: List[Tuple[str, Any]], required: Set[str], name: str, additional_properties: Union[bool, Any]): + prop_order = self._prop_order + # sort by position in prop_order (if specified) then by original order + sorted_props = [kv[0] for _, kv in sorted(enumerate(properties), key=lambda ikv: (prop_order.get(ikv[1][0], len(prop_order)), ikv[0]))] + + prop_kv_rule_names = {} + for prop_name, prop_schema in properties: + prop_rule_name = self.visit(prop_schema, f'{name}{"-" if name else ""}{prop_name}') + prop_kv_rule_names[prop_name] = self._add_rule( + f'{name}{"-" if name else ""}{prop_name}-kv', + fr'{self._format_literal(json.dumps(prop_name))} space ":" space {prop_rule_name}' ) + required_props = [k for k in sorted_props if k in required] + optional_props = [k for k in sorted_props if k not in required] + + if additional_properties == True or isinstance(additional_properties, dict): + sub_name = f'{name}{"-" if name else ""}additional' + value_rule = self.visit({} if additional_properties == True else additional_properties, f'{sub_name}-value') + prop_kv_rule_names["*"] = self._add_rule( + f'{sub_name}-kv', + self._add_primitive('string', PRIMITIVE_RULES['string']) + f' ":" space {value_rule}' + ) + optional_props.append("*") + + rule = '"{" space ' + rule += ' "," space '.join(prop_kv_rule_names[k] for k in required_props) + + if optional_props: + rule += ' (' + if required_props: + rule += ' "," space ( ' + + def get_recursive_refs(ks, first_is_optional): + [k, *rest] = ks + kv_rule_name = prop_kv_rule_names[k] + if k == '*': + res = self._add_rule( + f'{name}{"-" if name else ""}additional-kvs', + f'{kv_rule_name} ( "," space ' + kv_rule_name + ' )*' + ) + elif first_is_optional: + res = f'( "," space {kv_rule_name} )?' + else: + res = kv_rule_name + if len(rest) > 0: + res += ' ' + self._add_rule( + f'{name}{"-" if name else ""}{k}-rest', + get_recursive_refs(rest, first_is_optional=True) + ) + return res + + rule += ' | '.join( + get_recursive_refs(optional_props[i:], first_is_optional=False) + for i in range(len(optional_props)) + ) + if required_props: + rule += ' )' + rule += ' )?' + + rule += ' "}" space' + + return rule def format_grammar(self): - return "\n".join((f"{name} ::= {rule}" for name, rule in self._rules.items())) + return '\n'.join( + f'{name} ::= {rule}' + for name, rule in sorted(self._rules.items(), key=lambda kv: kv[0]) + ) def json_schema_to_gbnf(schema: str, prop_order: Optional[List[str]] = None): From fa4bb0cf81594d3c5bfd0cdf883fc88f820cbc54 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Wed, 17 Apr 2024 16:18:16 -0400 Subject: [PATCH 15/29] Revert "feat: Update json to grammar (#1350)" This reverts commit 610a592f708f2a5a8b0e1d6f0900f6337b9beb39. --- llama_cpp/llama_chat_format.py | 2 +- llama_cpp/llama_grammar.py | 600 +++++---------------------------- 2 files changed, 94 insertions(+), 508 deletions(-) diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index eb98cbf..519d2f5 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -2709,4 +2709,4 @@ def chatml_function_calling( }, } - raise ValueError("Automatic streaming tool choice is not supported") \ No newline at end of file + raise ValueError("Automatic streaming tool choice is not supported") diff --git a/llama_cpp/llama_grammar.py b/llama_cpp/llama_grammar.py index 8c0f8aa..9cc48a9 100644 --- a/llama_cpp/llama_grammar.py +++ b/llama_cpp/llama_grammar.py @@ -5,12 +5,11 @@ from pathlib import Path import sys from ctypes import * # type: ignore from enum import Enum -from itertools import islice, groupby +from itertools import islice from typing import ( Any, Callable, Dict, - Set, Generic, List, Optional, @@ -1392,552 +1391,139 @@ from typing import List, Optional # whitespace. Also maybe improves generation quality? SPACE_RULE = '" "?' +PRIMITIVE_RULES = { + "boolean": '("true" | "false") space', + "number": '("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? space', + "integer": '("-"? ([0-9] | [1-9] [0-9]*)) space', + "string": r""" "\"" ( + [^"\\] | + "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) + )* "\"" space """, + "null": '"null" space', +} INVALID_RULE_CHARS_RE = re.compile(r"[^a-zA-Z0-9-]+") GRAMMAR_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"]') GRAMMAR_LITERAL_ESCAPES = {"\r": "\\r", "\n": "\\n", '"': '\\"'} -# whitespace is constrained to a single space char to prevent model "running away" in -# whitespace. Also maybe improves generation quality? -SPACE_RULE = '" "?' - - -def _build_repetition(item_rule, min_items, max_items, separator_rule=None, item_rule_is_literal=False): - if not separator_rule: - if min_items == 0 and max_items == 1: - return f'{item_rule}?' - elif min_items == 1 and max_items is None: - return f'{item_rule}+' - - result = '' - - if min_items > 0: - if item_rule_is_literal and separator_rule is None: - result = '"' + (item_rule[1:-1] * min_items) + '"' - else: - result = (f' {separator_rule} ' if separator_rule else ' ').join([item_rule] * min_items) - - def opt_repetitions(up_to_n, prefix_with_sep=False): - ''' - - n=4, no sep: '(a (a (a (a)?)?)?)?' - - n=4, sep=',', prefix: '("," a ("," a ("," a ("," a)?)?)?)?' - - n=4, sep=',', no prefix: '(a ("," a ("," a ("," a)?)?)?)?' - ''' - - content = f'{separator_rule} {item_rule}' if prefix_with_sep and separator_rule else item_rule - if up_to_n == 0: - return '' - elif up_to_n == 1: - return f'({content})?' - elif separator_rule and not prefix_with_sep: - return f'({content} {opt_repetitions(up_to_n - 1, prefix_with_sep=True)})?' - else: - return (f'({content} ' * up_to_n).rstrip() + (')?' * up_to_n) - - if min_items > 0 and max_items != min_items: - result += ' ' - - if max_items is not None: - result += opt_repetitions(max_items - min_items, prefix_with_sep=min_items > 0) - else: - item_operator = f'({separator_rule + " " if separator_rule else ""}{item_rule})' - - if min_items == 0 and separator_rule: - result = f'({item_rule} {item_operator}*)?' - else: - result += f'{item_operator}*' - - return result - - - -class BuiltinRule: - def __init__(self, content: str, deps: list = None): - self.content = content - self.deps = deps or [] - -_up_to_15_digits = _build_repetition('[0-9]', 0, 15) - -PRIMITIVE_RULES = { - 'boolean' : BuiltinRule('("true" | "false") space', []), - 'decimal-part' : BuiltinRule('[0-9] ' + _up_to_15_digits, []), - 'integral-part': BuiltinRule('[0-9] | [1-9] ' + _up_to_15_digits, []), - 'number' : BuiltinRule('("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space', ['integral-part', 'decimal-part']), - 'integer' : BuiltinRule('("-"? integral-part) space', ['integral-part']), - 'value' : BuiltinRule('object | array | string | number | boolean | null', ['object', 'array', 'string', 'number', 'boolean', 'null']), - 'object' : BuiltinRule('"{" space ( string ":" space value ("," space string ":" space value)* )? "}" space', ['string', 'value']), - 'array' : BuiltinRule('"[" space ( value ("," space value)* )? "]" space', ['value']), - 'uuid' : BuiltinRule(r'"\"" ' + ' "-" '.join('[0-9a-fA-F]' * n for n in [8, 4, 4, 4, 12]) + r' "\"" space', []), - 'char' : BuiltinRule(r'[^"\\] | "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])', []), - 'string' : BuiltinRule(r'"\"" char* "\"" space', ['char']), - 'null' : BuiltinRule('"null" space', []), -} - -# TODO: support "uri", "email" string formats -STRING_FORMAT_RULES = { - 'date' : BuiltinRule('[0-9] [0-9] [0-9] [0-9] "-" ( "0" [1-9] | "1" [0-2] ) "-" ( \"0\" [1-9] | [1-2] [0-9] | "3" [0-1] )', []), - 'time' : BuiltinRule('([01] [0-9] | "2" [0-3]) ":" [0-5] [0-9] ":" [0-5] [0-9] ( "." [0-9] [0-9] [0-9] )? ( "Z" | ( "+" | "-" ) ( [01] [0-9] | "2" [0-3] ) ":" [0-5] [0-9] )', []), - 'date-time' : BuiltinRule('date "T" time', ['date', 'time']), - 'date-string' : BuiltinRule('"\\"" date "\\"" space', ['date']), - 'time-string' : BuiltinRule('"\\"" time "\\"" space', ['time']), - 'date-time-string': BuiltinRule('"\\"" date-time "\\"" space', ['date-time']), -} - -DOTALL = '[\\U00000000-\\U0010FFFF]' -DOT = '[^\\x0A\\x0D]' - -RESERVED_NAMES = set(["root", "dot", *PRIMITIVE_RULES.keys(), *STRING_FORMAT_RULES.keys()]) - - -NON_LITERAL_SET = set('|.()[]{}*+?') -ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = set('[]()|{}*+?') - - - class SchemaConverter: def __init__(self, prop_order): self._prop_order = prop_order self._rules = {"space": SPACE_RULE} self._defs: Dict[str, Any] = {} - self._refs = {} - self._refs_being_resolved = set() - def _format_literal(self, literal): - escaped = GRAMMAR_LITERAL_ESCAPE_RE.sub( - lambda m: GRAMMAR_LITERAL_ESCAPES.get(m.group(0)), literal + def _format_literal(self, literal: str): + escaped: str = GRAMMAR_LITERAL_ESCAPE_RE.sub( + lambda m: GRAMMAR_LITERAL_ESCAPES.get(m.group(0)), json.dumps(literal) ) return f'"{escaped}"' - def not_literal(self, literal: str, dotall: bool = True, maybe_escaped_underscores = False) -> str: - ''' - not_literal('a') -> '[^a]' - not_literal('abc') -> '([^a] | "a" ([^b] | "b" ([^c])?)?)?' - ''' - assert len(literal) > 0, 'Empty literal not supported' - def recurse(i: int): - c = literal[i] - if maybe_escaped_underscores and c == '_': - yield f'[^{c}\\\\]' - yield ' | ' - yield f'"\\\\"? "{c}"' - else: - yield f'[^{c}]' - if i < len(literal) - 1: - yield ' | ' - yield self._format_literal(c) - yield ' (' - yield from recurse(i + 1) - yield ')?' - - return ''.join(('(', *recurse(0), ')')) - - def _add_rule(self, name, rule): - esc_name = INVALID_RULE_CHARS_RE.sub('-', name) + def _add_rule(self, name: str, rule: str): + esc_name = INVALID_RULE_CHARS_RE.sub("-", name) if esc_name not in self._rules or self._rules[esc_name] == rule: key = esc_name else: i = 0 - while f'{esc_name}{i}' in self._rules and self._rules[f'{esc_name}{i}'] != rule: + while f"{esc_name}{i}" in self._rules: i += 1 - key = f'{esc_name}{i}' + key = f"{esc_name}{i}" self._rules[key] = rule return key - def resolve_refs(self, schema: dict, url: str): - ''' - Resolves all $ref fields in the given schema, fetching any remote schemas, - replacing $ref with absolute reference URL and populating self._refs with the - respective referenced (sub)schema dictionaries. - ''' - def visit(n: dict): - if isinstance(n, list): - return [visit(x) for x in n] - elif isinstance(n, dict): - ref = n.get('$ref') - if ref is not None and ref not in self._refs: - if ref.startswith('https://'): - assert self._allow_fetch, 'Fetching remote schemas is not allowed (use --allow-fetch for force)' - import requests + def visit(self, schema: Dict[str, Any], name: str) -> str: + rule_name = name or "root" - frag_split = ref.split('#') - base_url = frag_split[0] + if "$defs" in schema: + # add defs to self._defs for later inlining + for def_name, def_schema in schema["$defs"].items(): + self._defs[def_name] = def_schema - target = self._refs.get(base_url) - if target is None: - target = self.resolve_refs(requests.get(ref).json(), base_url) - self._refs[base_url] = target - - if len(frag_split) == 1 or frag_split[-1] == '': - return target - elif ref.startswith('#/'): - target = schema - ref = f'{url}{ref}' - n['$ref'] = ref - else: - raise ValueError(f'Unsupported ref {ref}') - - for sel in ref.split('#')[-1].split('/')[1:]: - assert target is not None and sel in target, f'Error resolving ref {ref}: {sel} not in {target}' - target = target[sel] - - self._refs[ref] = target - else: - for v in n.values(): - visit(v) - - return n - return visit(schema) - - def _generate_union_rule(self, name, alt_schemas): - return ' | '.join(( - self.visit(alt_schema, f'{name}{"-" if name else "alternative-"}{i}') - for i, alt_schema in enumerate(alt_schemas) - )) - - def _visit_pattern(self, pattern, name): - ''' - Transforms a regular expression pattern into a GBNF rule. - - Input: https://json-schema.org/understanding-json-schema/reference/regular_expressions - Output: https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md - - Unsupported features: negative/positive lookaheads, greedy/non-greedy modifiers. - - Mostly a 1:1 translation, except for {x} / {x,} / {x,y} quantifiers for which - we define sub-rules to keep the output lean. - ''' - - assert pattern.startswith('^') and pattern.endswith('$'), 'Pattern must start with "^" and end with "$"' - pattern = pattern[1:-1] - sub_rule_ids = {} - - i = 0 - length = len(pattern) - - def to_rule(s: Tuple[str, bool]) -> str: - (txt, is_literal) = s - return "\"" + txt + "\"" if is_literal else txt - - def transform() -> Tuple[str, bool]: - ''' - Parse a unit at index i (advancing it), and return its string representation + whether it's a literal. - ''' - nonlocal i - nonlocal pattern - nonlocal sub_rule_ids - - start = i - # For each component of this sequence, store its string representation and whether it's a literal. - # We only need a flat structure here to apply repetition operators to the last item, and - # to merge literals at the and (we're parsing grouped ( sequences ) recursively and don't treat '|' specially - # (GBNF's syntax is luckily very close to regular expressions!) - seq: list[Tuple[str, bool]] = [] - - def get_dot(): - if self._dotall: - rule = DOTALL - else: - # Accept any character... except \n and \r line break chars (\x0A and \xOD) - rule = DOT - return self._add_rule(f'dot', rule) - - def join_seq(): - nonlocal seq - ret = [] - for is_literal, g in groupby(seq, lambda x: x[1]): - if is_literal: - ret.append((''.join(x[0] for x in g), True)) - else: - ret.extend(g) - if len(ret) == 1: - return ret[0] - return (' '.join(to_rule(x) for x in seq), False) - - while i < length: - c = pattern[i] - if c == '.': - seq.append((get_dot(), False)) - i += 1 - elif c == '(': - i += 1 - if i < length: - assert pattern[i] != '?', f'Unsupported pattern syntax "{pattern[i]}" at index {i} of /{pattern}/' - seq.append((f'({to_rule(transform())})', False)) - elif c == ')': - i += 1 - assert start > 0 and pattern[start-1] == '(', f'Unbalanced parentheses; start = {start}, i = {i}, pattern = {pattern}' - return join_seq() - elif c == '[': - square_brackets = c - i += 1 - while i < length and pattern[i] != ']': - if pattern[i] == '\\': - square_brackets += pattern[i:i+2] - i += 2 - else: - square_brackets += pattern[i] - i += 1 - assert i < length, f'Unbalanced square brackets; start = {start}, i = {i}, pattern = {pattern}' - square_brackets += ']' - i += 1 - seq.append((square_brackets, False)) - elif c == '|': - seq.append(('|', False)) - i += 1 - elif c in ('*', '+', '?'): - seq[-1] = (to_rule(seq[-1]) + c, False) - i += 1 - elif c == '{': - curly_brackets = c - i += 1 - while i < length and pattern[i] != '}': - curly_brackets += pattern[i] - i += 1 - assert i < length, f'Unbalanced curly brackets; start = {start}, i = {i}, pattern = {pattern}' - curly_brackets += '}' - i += 1 - nums = [s.strip() for s in curly_brackets[1:-1].split(',')] - min_times = 0 - max_times = None - try: - if len(nums) == 1: - min_times = int(nums[0]) - max_times = min_times - else: - assert len(nums) == 2 - min_times = int(nums[0]) if nums[0] else 0 - max_times = int(nums[1]) if nums[1] else None - except ValueError: - raise ValueError(f'Invalid quantifier {curly_brackets} in /{pattern}/') - - (sub, sub_is_literal) = seq[-1] - - if not sub_is_literal: - id = sub_rule_ids.get(sub) - if id is None: - id = self._add_rule(f'{name}-{len(sub_rule_ids) + 1}', sub) - sub_rule_ids[sub] = id - sub = id - - seq[-1] = (_build_repetition(f'"{sub}"' if sub_is_literal else sub, min_times, max_times, item_rule_is_literal=sub_is_literal), False) - else: - literal = '' - while i < length: - if pattern[i] == '\\' and i < length - 1: - next = pattern[i + 1] - if next in ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS: - i += 1 - literal += pattern[i] - i += 1 - else: - literal += pattern[i:i+2] - i += 2 - elif pattern[i] == '"' and not self._raw_pattern: - literal += '\\"' - i += 1 - elif pattern[i] not in NON_LITERAL_SET and \ - (i == length - 1 or literal == '' or pattern[i+1] == '.' or pattern[i+1] not in NON_LITERAL_SET): - literal += pattern[i] - i += 1 - else: - break - if literal: - seq.append((literal, True)) - - return join_seq() - - return self._add_rule( - name, - to_rule(transform()) if self._raw_pattern \ - else "\"\\\"\" " + to_rule(transform()) + " \"\\\"\" space") - - - def _resolve_ref(self, ref): - ref_name = ref.split('/')[-1] - if ref_name not in self._rules and ref not in self._refs_being_resolved: - self._refs_being_resolved.add(ref) - resolved = self._refs[ref] - ref_name = self.visit(resolved, ref_name) - self._refs_being_resolved.remove(ref) - return ref_name - - def _generate_constant_rule(self, value): - return self._format_literal(json.dumps(value)) - - def visit(self, schema, name): - schema_type = schema.get('type') - schema_format = schema.get('format') - rule_name = name + '-' if name in RESERVED_NAMES else name or 'root' - - if (ref := schema.get('$ref')) is not None: - return self._add_rule(rule_name, self._resolve_ref(ref)) - - elif 'oneOf' in schema or 'anyOf' in schema: - return self._add_rule(rule_name, self._generate_union_rule(name, schema.get('oneOf') or schema['anyOf'])) - - elif isinstance(schema_type, list): - return self._add_rule(rule_name, self._generate_union_rule(name, [{'type': t} for t in schema_type])) - - elif 'const' in schema: - return self._add_rule(rule_name, self._generate_constant_rule(schema['const'])) - - elif 'enum' in schema: - rule = ' | '.join((self._generate_constant_rule(v) for v in schema['enum'])) + if "oneOf" in schema or "anyOf" in schema: + rule = " | ".join( + ( + self.visit(alt_schema, f'{name}{"-" if name else ""}{i}') + for i, alt_schema in enumerate( + schema.get("oneOf") or schema["anyOf"] + ) + ) + ) return self._add_rule(rule_name, rule) - elif schema_type in (None, 'object') and \ - ('properties' in schema or \ - ('additionalProperties' in schema and schema['additionalProperties'] is not True)): - required = set(schema.get('required', [])) - properties = list(schema.get('properties', {}).items()) - return self._add_rule(rule_name, self._build_object_rule(properties, required, name, schema.get('additionalProperties'))) + elif "const" in schema: + return self._add_rule(rule_name, self._format_literal(schema["const"])) - elif schema_type in (None, 'object') and 'allOf' in schema: - required = set() - properties = [] - hybrid_name = name - def add_component(comp_schema, is_required): - if (ref := comp_schema.get('$ref')) is not None: - comp_schema = self._refs[ref] + elif "enum" in schema: + rule = " | ".join((self._format_literal(v) for v in schema["enum"])) + return self._add_rule(rule_name, rule) - if 'properties' in comp_schema: - for prop_name, prop_schema in comp_schema['properties'].items(): - properties.append((prop_name, prop_schema)) - if is_required: - required.add(prop_name) + elif "$ref" in schema: + ref = schema["$ref"] + assert ref.startswith("#/$defs/"), f"Unrecognized schema: {schema}" + # inline $defs + def_name = ref[len("#/$defs/") :] + def_schema = self._defs[def_name] + return self.visit(def_schema, f'{name}{"-" if name else ""}{def_name}') - for t in schema['allOf']: - if 'anyOf' in t: - for tt in t['anyOf']: - add_component(tt, is_required=False) - else: - add_component(t, is_required=True) - return self._add_rule(rule_name, self._build_object_rule(properties, required, hybrid_name, additional_properties=[])) + schema_type: Optional[str] = schema.get("type") # type: ignore + assert isinstance(schema_type, str), f"Unrecognized schema: {schema}" - elif schema_type in (None, 'array') and ('items' in schema or 'prefixItems' in schema): - items = schema.get('items') or schema['prefixItems'] - if isinstance(items, list): - return self._add_rule( - rule_name, - '"[" space ' + - ' "," space '.join( - self.visit(item, f'{name}{"-" if name else ""}tuple-{i}') - for i, item in enumerate(items)) + - ' "]" space') + if schema_type == "object" and "properties" in schema: + # TODO: `required` keyword + if self._prop_order: + prop_order = self._prop_order + prop_pairs = sorted( + schema["properties"].items(), + # sort by position in prop_order (if specified) then by key + key=lambda kv: (prop_order.get(kv[0], len(prop_order)), kv[0]), + ) else: - item_rule_name = self.visit(items, f'{name}{"-" if name else ""}item') - min_items = schema.get("minItems", 0) - max_items = schema.get("maxItems") - return self._add_rule(rule_name, '"[" space ' + _build_repetition(item_rule_name, min_items, max_items, separator_rule='"," space') + ' "]" space') + prop_pairs = schema["properties"].items() - elif schema_type in (None, 'string') and 'pattern' in schema: - return self._visit_pattern(schema['pattern'], rule_name) + rule = '"{" space' + for i, (prop_name, prop_schema) in enumerate(prop_pairs): + prop_rule_name = self.visit( + prop_schema, f'{name}{"-" if name else ""}{prop_name}' + ) + if i > 0: + rule += ' "," space' + rule += rf' {self._format_literal(prop_name)} space ":" space {prop_rule_name}' + rule += ' "}" space' - elif schema_type in (None, 'string') and re.match(r'^uuid[1-5]?$', schema_format or ''): - return self._add_primitive( - 'root' if rule_name == 'root' else schema_format, - PRIMITIVE_RULES['uuid'] + return self._add_rule(rule_name, rule) + + elif schema_type == "array" and "items" in schema: + # TODO `prefixItems` keyword + item_rule_name = self.visit( + schema["items"], f'{name}{"-" if name else ""}item' ) - - elif schema_type in (None, 'string') and f'{schema_format}-string' in STRING_FORMAT_RULES: - prim_name = f'{schema_format}-string' - return self._add_rule(rule_name, self._add_primitive(prim_name, STRING_FORMAT_RULES[prim_name])) - - elif schema_type == 'string' and ('minLength' in schema or 'maxLength' in schema): - char_rule = self._add_primitive('char', PRIMITIVE_RULES['char']) - min_len = schema.get('minLength', 0) - max_len = schema.get('maxLength') - - return self._add_rule(rule_name, r'"\"" ' + _build_repetition(char_rule, min_len, max_len) + r' "\"" space') - - elif (schema_type == 'object') or (len(schema) == 0): - return self._add_rule(rule_name, self._add_primitive('object', PRIMITIVE_RULES['object'])) + list_item_operator = f'("," space {item_rule_name})' + successive_items = "" + min_items = schema.get("minItems", 0) + if min_items > 0: + first_item = f"({item_rule_name})" + successive_items = list_item_operator * (min_items - 1) + min_items -= 1 + else: + first_item = f"({item_rule_name})?" + max_items = schema.get("maxItems") + if max_items is not None and max_items > min_items: + successive_items += (list_item_operator + "?") * (max_items - min_items - 1) + else: + successive_items += list_item_operator + "*" + rule = f'"[" space {first_item} {successive_items} "]" space' + return self._add_rule(rule_name, rule) else: - assert schema_type in PRIMITIVE_RULES, f'Unrecognized schema: {schema}' - # TODO: support minimum, maximum, exclusiveMinimum, exclusiveMaximum at least for zero - return self._add_primitive('root' if rule_name == 'root' else schema_type, PRIMITIVE_RULES[schema_type]) - - def _add_primitive(self, name: str, rule: BuiltinRule): - n = self._add_rule(name, rule.content) - - for dep in rule.deps: - dep_rule = PRIMITIVE_RULES.get(dep) or STRING_FORMAT_RULES.get(dep) - assert dep_rule, f'Rule {dep} not known' - if dep not in self._rules: - self._add_primitive(dep, dep_rule) - return n - - def _build_object_rule(self, properties: List[Tuple[str, Any]], required: Set[str], name: str, additional_properties: Union[bool, Any]): - prop_order = self._prop_order - # sort by position in prop_order (if specified) then by original order - sorted_props = [kv[0] for _, kv in sorted(enumerate(properties), key=lambda ikv: (prop_order.get(ikv[1][0], len(prop_order)), ikv[0]))] - - prop_kv_rule_names = {} - for prop_name, prop_schema in properties: - prop_rule_name = self.visit(prop_schema, f'{name}{"-" if name else ""}{prop_name}') - prop_kv_rule_names[prop_name] = self._add_rule( - f'{name}{"-" if name else ""}{prop_name}-kv', - fr'{self._format_literal(json.dumps(prop_name))} space ":" space {prop_rule_name}' + assert schema_type in PRIMITIVE_RULES, f"Unrecognized schema: {schema}" + return self._add_rule( + "root" if rule_name == "root" else schema_type, + PRIMITIVE_RULES[schema_type], ) - required_props = [k for k in sorted_props if k in required] - optional_props = [k for k in sorted_props if k not in required] - - if additional_properties == True or isinstance(additional_properties, dict): - sub_name = f'{name}{"-" if name else ""}additional' - value_rule = self.visit({} if additional_properties == True else additional_properties, f'{sub_name}-value') - prop_kv_rule_names["*"] = self._add_rule( - f'{sub_name}-kv', - self._add_primitive('string', PRIMITIVE_RULES['string']) + f' ":" space {value_rule}' - ) - optional_props.append("*") - - rule = '"{" space ' - rule += ' "," space '.join(prop_kv_rule_names[k] for k in required_props) - - if optional_props: - rule += ' (' - if required_props: - rule += ' "," space ( ' - - def get_recursive_refs(ks, first_is_optional): - [k, *rest] = ks - kv_rule_name = prop_kv_rule_names[k] - if k == '*': - res = self._add_rule( - f'{name}{"-" if name else ""}additional-kvs', - f'{kv_rule_name} ( "," space ' + kv_rule_name + ' )*' - ) - elif first_is_optional: - res = f'( "," space {kv_rule_name} )?' - else: - res = kv_rule_name - if len(rest) > 0: - res += ' ' + self._add_rule( - f'{name}{"-" if name else ""}{k}-rest', - get_recursive_refs(rest, first_is_optional=True) - ) - return res - - rule += ' | '.join( - get_recursive_refs(optional_props[i:], first_is_optional=False) - for i in range(len(optional_props)) - ) - if required_props: - rule += ' )' - rule += ' )?' - - rule += ' "}" space' - - return rule def format_grammar(self): - return '\n'.join( - f'{name} ::= {rule}' - for name, rule in sorted(self._rules.items(), key=lambda kv: kv[0]) - ) + return "\n".join((f"{name} ::= {rule}" for name, rule in self._rules.items())) def json_schema_to_gbnf(schema: str, prop_order: Optional[List[str]] = None): From 4f4266495549f691194abb357b792898fd12695e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lucca=20Zen=C3=B3bio?= Date: Thu, 18 Apr 2024 02:36:25 -0300 Subject: [PATCH 16/29] feat: update grammar schema converter to match llama.cpp (#1353) * feat: improve function calling * feat:grammar * fix * fix * fix --- llama_cpp/llama_chat_format.py | 2 +- llama_cpp/llama_grammar.py | 627 +++++++++++++++++++++++++++------ 2 files changed, 523 insertions(+), 106 deletions(-) diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index 519d2f5..eb98cbf 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -2709,4 +2709,4 @@ def chatml_function_calling( }, } - raise ValueError("Automatic streaming tool choice is not supported") + raise ValueError("Automatic streaming tool choice is not supported") \ No newline at end of file diff --git a/llama_cpp/llama_grammar.py b/llama_cpp/llama_grammar.py index 9cc48a9..6c7b57a 100644 --- a/llama_cpp/llama_grammar.py +++ b/llama_cpp/llama_grammar.py @@ -5,11 +5,12 @@ from pathlib import Path import sys from ctypes import * # type: ignore from enum import Enum -from itertools import islice +from itertools import islice, groupby from typing import ( Any, Callable, Dict, + Set, Generic, List, Optional, @@ -1391,145 +1392,561 @@ from typing import List, Optional # whitespace. Also maybe improves generation quality? SPACE_RULE = '" "?' -PRIMITIVE_RULES = { - "boolean": '("true" | "false") space', - "number": '("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? space', - "integer": '("-"? ([0-9] | [1-9] [0-9]*)) space', - "string": r""" "\"" ( - [^"\\] | - "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) - )* "\"" space """, - "null": '"null" space', -} INVALID_RULE_CHARS_RE = re.compile(r"[^a-zA-Z0-9-]+") GRAMMAR_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"]') GRAMMAR_LITERAL_ESCAPES = {"\r": "\\r", "\n": "\\n", '"': '\\"'} +# whitespace is constrained to a single space char to prevent model "running away" in +# whitespace. Also maybe improves generation quality? +SPACE_RULE = '" "?' + + +def _build_repetition(item_rule, min_items, max_items, separator_rule=None, item_rule_is_literal=False): + if not separator_rule: + if min_items == 0 and max_items == 1: + return f'{item_rule}?' + elif min_items == 1 and max_items is None: + return f'{item_rule}+' + + result = '' + + if min_items > 0: + if item_rule_is_literal and separator_rule is None: + result = '"' + (item_rule[1:-1] * min_items) + '"' + else: + result = (f' {separator_rule} ' if separator_rule else ' ').join([item_rule] * min_items) + + def opt_repetitions(up_to_n, prefix_with_sep=False): + ''' + - n=4, no sep: '(a (a (a (a)?)?)?)?' + - n=4, sep=',', prefix: '("," a ("," a ("," a ("," a)?)?)?)?' + - n=4, sep=',', no prefix: '(a ("," a ("," a ("," a)?)?)?)?' + ''' + + content = f'{separator_rule} {item_rule}' if prefix_with_sep and separator_rule else item_rule + if up_to_n == 0: + return '' + elif up_to_n == 1: + return f'({content})?' + elif separator_rule and not prefix_with_sep: + return f'({content} {opt_repetitions(up_to_n - 1, prefix_with_sep=True)})?' + else: + return (f'({content} ' * up_to_n).rstrip() + (')?' * up_to_n) + + if min_items > 0 and max_items != min_items: + result += ' ' + + if max_items is not None: + result += opt_repetitions(max_items - min_items, prefix_with_sep=min_items > 0) + else: + item_operator = f'({separator_rule + " " if separator_rule else ""}{item_rule})' + + if min_items == 0 and separator_rule: + result = f'({item_rule} {item_operator}*)?' + else: + result += f'{item_operator}*' + + return result + + + +class BuiltinRule: + def __init__(self, content: str, deps: list = None): + self.content = content + self.deps = deps or [] + +_up_to_15_digits = _build_repetition('[0-9]', 0, 15) + +PRIMITIVE_RULES = { + 'boolean' : BuiltinRule('("true" | "false") space', []), + 'decimal-part' : BuiltinRule('[0-9] ' + _up_to_15_digits, []), + 'integral-part': BuiltinRule('[0-9] | [1-9] ' + _up_to_15_digits, []), + 'number' : BuiltinRule('("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space', ['integral-part', 'decimal-part']), + 'integer' : BuiltinRule('("-"? integral-part) space', ['integral-part']), + 'value' : BuiltinRule('object | array | string | number | boolean | null', ['object', 'array', 'string', 'number', 'boolean', 'null']), + 'object' : BuiltinRule('"{" space ( string ":" space value ("," space string ":" space value)* )? "}" space', ['string', 'value']), + 'array' : BuiltinRule('"[" space ( value ("," space value)* )? "]" space', ['value']), + 'uuid' : BuiltinRule(r'"\"" ' + ' "-" '.join('[0-9a-fA-F]' * n for n in [8, 4, 4, 4, 12]) + r' "\"" space', []), + 'char' : BuiltinRule(r'[^"\\] | "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])', []), + 'string' : BuiltinRule(r'"\"" char* "\"" space', ['char']), + 'null' : BuiltinRule('"null" space', []), +} + +# TODO: support "uri", "email" string formats +STRING_FORMAT_RULES = { + 'date' : BuiltinRule('[0-9] [0-9] [0-9] [0-9] "-" ( "0" [1-9] | "1" [0-2] ) "-" ( \"0\" [1-9] | [1-2] [0-9] | "3" [0-1] )', []), + 'time' : BuiltinRule('([01] [0-9] | "2" [0-3]) ":" [0-5] [0-9] ":" [0-5] [0-9] ( "." [0-9] [0-9] [0-9] )? ( "Z" | ( "+" | "-" ) ( [01] [0-9] | "2" [0-3] ) ":" [0-5] [0-9] )', []), + 'date-time' : BuiltinRule('date "T" time', ['date', 'time']), + 'date-string' : BuiltinRule('"\\"" date "\\"" space', ['date']), + 'time-string' : BuiltinRule('"\\"" time "\\"" space', ['time']), + 'date-time-string': BuiltinRule('"\\"" date-time "\\"" space', ['date-time']), +} + +DOTALL = '[\\U00000000-\\U0010FFFF]' +DOT = '[^\\x0A\\x0D]' + +RESERVED_NAMES = set(["root", "dot", *PRIMITIVE_RULES.keys(), *STRING_FORMAT_RULES.keys()]) + + +NON_LITERAL_SET = set('|.()[]{}*+?') +ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = set('[]()|{}*+?') + + + class SchemaConverter: - def __init__(self, prop_order): + def __init__(self, *, prop_order, allow_fetch, dotall, raw_pattern): self._prop_order = prop_order - self._rules = {"space": SPACE_RULE} - self._defs: Dict[str, Any] = {} + self._allow_fetch = allow_fetch + self._dotall = dotall + self._raw_pattern = raw_pattern + self._rules = { + 'space': SPACE_RULE, + } + self._refs = {} + self._refs_being_resolved = set() - def _format_literal(self, literal: str): - escaped: str = GRAMMAR_LITERAL_ESCAPE_RE.sub( - lambda m: GRAMMAR_LITERAL_ESCAPES.get(m.group(0)), json.dumps(literal) + def _format_literal(self, literal): + escaped = GRAMMAR_LITERAL_ESCAPE_RE.sub( + lambda m: GRAMMAR_LITERAL_ESCAPES.get(m.group(0)), literal ) return f'"{escaped}"' - def _add_rule(self, name: str, rule: str): - esc_name = INVALID_RULE_CHARS_RE.sub("-", name) + def not_literal(self, literal: str, dotall: bool = True, maybe_escaped_underscores = False) -> str: + ''' + not_literal('a') -> '[^a]' + not_literal('abc') -> '([^a] | "a" ([^b] | "b" ([^c])?)?)?' + ''' + assert len(literal) > 0, 'Empty literal not supported' + def recurse(i: int): + c = literal[i] + if maybe_escaped_underscores and c == '_': + yield f'[^{c}\\\\]' + yield ' | ' + yield f'"\\\\"? "{c}"' + else: + yield f'[^{c}]' + if i < len(literal) - 1: + yield ' | ' + yield self._format_literal(c) + yield ' (' + yield from recurse(i + 1) + yield ')?' + + return ''.join(('(', *recurse(0), ')')) + + def _add_rule(self, name, rule): + esc_name = INVALID_RULE_CHARS_RE.sub('-', name) if esc_name not in self._rules or self._rules[esc_name] == rule: key = esc_name else: i = 0 - while f"{esc_name}{i}" in self._rules: + while f'{esc_name}{i}' in self._rules and self._rules[f'{esc_name}{i}'] != rule: i += 1 - key = f"{esc_name}{i}" + key = f'{esc_name}{i}' self._rules[key] = rule return key - def visit(self, schema: Dict[str, Any], name: str) -> str: - rule_name = name or "root" + def resolve_refs(self, schema: dict, url: str): + ''' + Resolves all $ref fields in the given schema, fetching any remote schemas, + replacing $ref with absolute reference URL and populating self._refs with the + respective referenced (sub)schema dictionaries. + ''' + def visit(n: dict): + if isinstance(n, list): + return [visit(x) for x in n] + elif isinstance(n, dict): + ref = n.get('$ref') + if ref is not None and ref not in self._refs: + if ref.startswith('https://'): + assert self._allow_fetch, 'Fetching remote schemas is not allowed (use --allow-fetch for force)' + import requests - if "$defs" in schema: - # add defs to self._defs for later inlining - for def_name, def_schema in schema["$defs"].items(): - self._defs[def_name] = def_schema + frag_split = ref.split('#') + base_url = frag_split[0] - if "oneOf" in schema or "anyOf" in schema: - rule = " | ".join( - ( - self.visit(alt_schema, f'{name}{"-" if name else ""}{i}') - for i, alt_schema in enumerate( - schema.get("oneOf") or schema["anyOf"] - ) - ) + target = self._refs.get(base_url) + if target is None: + target = self.resolve_refs(requests.get(ref).json(), base_url) + self._refs[base_url] = target + + if len(frag_split) == 1 or frag_split[-1] == '': + return target + elif ref.startswith('#/'): + target = schema + ref = f'{url}{ref}' + n['$ref'] = ref + else: + raise ValueError(f'Unsupported ref {ref}') + + for sel in ref.split('#')[-1].split('/')[1:]: + assert target is not None and sel in target, f'Error resolving ref {ref}: {sel} not in {target}' + target = target[sel] + + self._refs[ref] = target + else: + for v in n.values(): + visit(v) + + return n + return visit(schema) + + def _generate_union_rule(self, name, alt_schemas): + return ' | '.join(( + self.visit(alt_schema, f'{name}{"-" if name else "alternative-"}{i}') + for i, alt_schema in enumerate(alt_schemas) + )) + + def _visit_pattern(self, pattern, name): + ''' + Transforms a regular expression pattern into a GBNF rule. + + Input: https://json-schema.org/understanding-json-schema/reference/regular_expressions + Output: https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md + + Unsupported features: negative/positive lookaheads, greedy/non-greedy modifiers. + + Mostly a 1:1 translation, except for {x} / {x,} / {x,y} quantifiers for which + we define sub-rules to keep the output lean. + ''' + + assert pattern.startswith('^') and pattern.endswith('$'), 'Pattern must start with "^" and end with "$"' + pattern = pattern[1:-1] + sub_rule_ids = {} + + i = 0 + length = len(pattern) + + def to_rule(s: Tuple[str, bool]) -> str: + (txt, is_literal) = s + return "\"" + txt + "\"" if is_literal else txt + + def transform() -> Tuple[str, bool]: + ''' + Parse a unit at index i (advancing it), and return its string representation + whether it's a literal. + ''' + nonlocal i + nonlocal pattern + nonlocal sub_rule_ids + + start = i + # For each component of this sequence, store its string representation and whether it's a literal. + # We only need a flat structure here to apply repetition operators to the last item, and + # to merge literals at the and (we're parsing grouped ( sequences ) recursively and don't treat '|' specially + # (GBNF's syntax is luckily very close to regular expressions!) + seq: list[Tuple[str, bool]] = [] + + def get_dot(): + if self._dotall: + rule = DOTALL + else: + # Accept any character... except \n and \r line break chars (\x0A and \xOD) + rule = DOT + return self._add_rule(f'dot', rule) + + def join_seq(): + nonlocal seq + ret = [] + for is_literal, g in groupby(seq, lambda x: x[1]): + if is_literal: + ret.append((''.join(x[0] for x in g), True)) + else: + ret.extend(g) + if len(ret) == 1: + return ret[0] + return (' '.join(to_rule(x) for x in seq), False) + + while i < length: + c = pattern[i] + if c == '.': + seq.append((get_dot(), False)) + i += 1 + elif c == '(': + i += 1 + if i < length: + assert pattern[i] != '?', f'Unsupported pattern syntax "{pattern[i]}" at index {i} of /{pattern}/' + seq.append((f'({to_rule(transform())})', False)) + elif c == ')': + i += 1 + assert start > 0 and pattern[start-1] == '(', f'Unbalanced parentheses; start = {start}, i = {i}, pattern = {pattern}' + return join_seq() + elif c == '[': + square_brackets = c + i += 1 + while i < length and pattern[i] != ']': + if pattern[i] == '\\': + square_brackets += pattern[i:i+2] + i += 2 + else: + square_brackets += pattern[i] + i += 1 + assert i < length, f'Unbalanced square brackets; start = {start}, i = {i}, pattern = {pattern}' + square_brackets += ']' + i += 1 + seq.append((square_brackets, False)) + elif c == '|': + seq.append(('|', False)) + i += 1 + elif c in ('*', '+', '?'): + seq[-1] = (to_rule(seq[-1]) + c, False) + i += 1 + elif c == '{': + curly_brackets = c + i += 1 + while i < length and pattern[i] != '}': + curly_brackets += pattern[i] + i += 1 + assert i < length, f'Unbalanced curly brackets; start = {start}, i = {i}, pattern = {pattern}' + curly_brackets += '}' + i += 1 + nums = [s.strip() for s in curly_brackets[1:-1].split(',')] + min_times = 0 + max_times = None + try: + if len(nums) == 1: + min_times = int(nums[0]) + max_times = min_times + else: + assert len(nums) == 2 + min_times = int(nums[0]) if nums[0] else 0 + max_times = int(nums[1]) if nums[1] else None + except ValueError: + raise ValueError(f'Invalid quantifier {curly_brackets} in /{pattern}/') + + (sub, sub_is_literal) = seq[-1] + + if not sub_is_literal: + id = sub_rule_ids.get(sub) + if id is None: + id = self._add_rule(f'{name}-{len(sub_rule_ids) + 1}', sub) + sub_rule_ids[sub] = id + sub = id + + seq[-1] = (_build_repetition(f'"{sub}"' if sub_is_literal else sub, min_times, max_times, item_rule_is_literal=sub_is_literal), False) + else: + literal = '' + while i < length: + if pattern[i] == '\\' and i < length - 1: + next = pattern[i + 1] + if next in ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS: + i += 1 + literal += pattern[i] + i += 1 + else: + literal += pattern[i:i+2] + i += 2 + elif pattern[i] == '"' and not self._raw_pattern: + literal += '\\"' + i += 1 + elif pattern[i] not in NON_LITERAL_SET and \ + (i == length - 1 or literal == '' or pattern[i+1] == '.' or pattern[i+1] not in NON_LITERAL_SET): + literal += pattern[i] + i += 1 + else: + break + if literal: + seq.append((literal, True)) + + return join_seq() + + return self._add_rule( + name, + to_rule(transform()) if self._raw_pattern \ + else "\"\\\"\" " + to_rule(transform()) + " \"\\\"\" space") + + + def _resolve_ref(self, ref): + ref_name = ref.split('/')[-1] + if ref_name not in self._rules and ref not in self._refs_being_resolved: + self._refs_being_resolved.add(ref) + resolved = self._refs[ref] + ref_name = self.visit(resolved, ref_name) + self._refs_being_resolved.remove(ref) + return ref_name + + def _generate_constant_rule(self, value): + return self._format_literal(json.dumps(value)) + + def visit(self, schema, name): + schema_type = schema.get('type') + schema_format = schema.get('format') + rule_name = name + '-' if name in RESERVED_NAMES else name or 'root' + + if (ref := schema.get('$ref')) is not None: + return self._add_rule(rule_name, self._resolve_ref(ref)) + + elif 'oneOf' in schema or 'anyOf' in schema: + return self._add_rule(rule_name, self._generate_union_rule(name, schema.get('oneOf') or schema['anyOf'])) + + elif isinstance(schema_type, list): + return self._add_rule(rule_name, self._generate_union_rule(name, [{'type': t} for t in schema_type])) + + elif 'const' in schema: + return self._add_rule(rule_name, self._generate_constant_rule(schema['const'])) + + elif 'enum' in schema: + rule = ' | '.join((self._generate_constant_rule(v) for v in schema['enum'])) + return self._add_rule(rule_name, rule) + + elif schema_type in (None, 'object') and \ + ('properties' in schema or \ + ('additionalProperties' in schema and schema['additionalProperties'] is not True)): + required = set(schema.get('required', [])) + properties = list(schema.get('properties', {}).items()) + return self._add_rule(rule_name, self._build_object_rule(properties, required, name, schema.get('additionalProperties'))) + + elif schema_type in (None, 'object') and 'allOf' in schema: + required = set() + properties = [] + hybrid_name = name + def add_component(comp_schema, is_required): + if (ref := comp_schema.get('$ref')) is not None: + comp_schema = self._refs[ref] + + if 'properties' in comp_schema: + for prop_name, prop_schema in comp_schema['properties'].items(): + properties.append((prop_name, prop_schema)) + if is_required: + required.add(prop_name) + + for t in schema['allOf']: + if 'anyOf' in t: + for tt in t['anyOf']: + add_component(tt, is_required=False) + else: + add_component(t, is_required=True) + + return self._add_rule(rule_name, self._build_object_rule(properties, required, hybrid_name, additional_properties=[])) + + elif schema_type in (None, 'array') and ('items' in schema or 'prefixItems' in schema): + items = schema.get('items') or schema['prefixItems'] + if isinstance(items, list): + return self._add_rule( + rule_name, + '"[" space ' + + ' "," space '.join( + self.visit(item, f'{name}{"-" if name else ""}tuple-{i}') + for i, item in enumerate(items)) + + ' "]" space') + else: + item_rule_name = self.visit(items, f'{name}{"-" if name else ""}item') + min_items = schema.get("minItems", 0) + max_items = schema.get("maxItems") + return self._add_rule(rule_name, '"[" space ' + _build_repetition(item_rule_name, min_items, max_items, separator_rule='"," space') + ' "]" space') + + elif schema_type in (None, 'string') and 'pattern' in schema: + return self._visit_pattern(schema['pattern'], rule_name) + + elif schema_type in (None, 'string') and re.match(r'^uuid[1-5]?$', schema_format or ''): + return self._add_primitive( + 'root' if rule_name == 'root' else schema_format, + PRIMITIVE_RULES['uuid'] ) - return self._add_rule(rule_name, rule) - elif "const" in schema: - return self._add_rule(rule_name, self._format_literal(schema["const"])) + elif schema_type in (None, 'string') and f'{schema_format}-string' in STRING_FORMAT_RULES: + prim_name = f'{schema_format}-string' + return self._add_rule(rule_name, self._add_primitive(prim_name, STRING_FORMAT_RULES[prim_name])) - elif "enum" in schema: - rule = " | ".join((self._format_literal(v) for v in schema["enum"])) - return self._add_rule(rule_name, rule) + elif schema_type == 'string' and ('minLength' in schema or 'maxLength' in schema): + char_rule = self._add_primitive('char', PRIMITIVE_RULES['char']) + min_len = schema.get('minLength', 0) + max_len = schema.get('maxLength') - elif "$ref" in schema: - ref = schema["$ref"] - assert ref.startswith("#/$defs/"), f"Unrecognized schema: {schema}" - # inline $defs - def_name = ref[len("#/$defs/") :] - def_schema = self._defs[def_name] - return self.visit(def_schema, f'{name}{"-" if name else ""}{def_name}') + return self._add_rule(rule_name, r'"\"" ' + _build_repetition(char_rule, min_len, max_len) + r' "\"" space') - - schema_type: Optional[str] = schema.get("type") # type: ignore - assert isinstance(schema_type, str), f"Unrecognized schema: {schema}" - - if schema_type == "object" and "properties" in schema: - # TODO: `required` keyword - if self._prop_order: - prop_order = self._prop_order - prop_pairs = sorted( - schema["properties"].items(), - # sort by position in prop_order (if specified) then by key - key=lambda kv: (prop_order.get(kv[0], len(prop_order)), kv[0]), - ) - else: - prop_pairs = schema["properties"].items() - - rule = '"{" space' - for i, (prop_name, prop_schema) in enumerate(prop_pairs): - prop_rule_name = self.visit( - prop_schema, f'{name}{"-" if name else ""}{prop_name}' - ) - if i > 0: - rule += ' "," space' - rule += rf' {self._format_literal(prop_name)} space ":" space {prop_rule_name}' - rule += ' "}" space' - - return self._add_rule(rule_name, rule) - - elif schema_type == "array" and "items" in schema: - # TODO `prefixItems` keyword - item_rule_name = self.visit( - schema["items"], f'{name}{"-" if name else ""}item' - ) - list_item_operator = f'("," space {item_rule_name})' - successive_items = "" - min_items = schema.get("minItems", 0) - if min_items > 0: - first_item = f"({item_rule_name})" - successive_items = list_item_operator * (min_items - 1) - min_items -= 1 - else: - first_item = f"({item_rule_name})?" - max_items = schema.get("maxItems") - if max_items is not None and max_items > min_items: - successive_items += (list_item_operator + "?") * (max_items - min_items - 1) - else: - successive_items += list_item_operator + "*" - rule = f'"[" space {first_item} {successive_items} "]" space' - return self._add_rule(rule_name, rule) + elif (schema_type == 'object') or (len(schema) == 0): + return self._add_rule(rule_name, self._add_primitive('object', PRIMITIVE_RULES['object'])) else: - assert schema_type in PRIMITIVE_RULES, f"Unrecognized schema: {schema}" - return self._add_rule( - "root" if rule_name == "root" else schema_type, - PRIMITIVE_RULES[schema_type], + assert schema_type in PRIMITIVE_RULES, f'Unrecognized schema: {schema}' + # TODO: support minimum, maximum, exclusiveMinimum, exclusiveMaximum at least for zero + return self._add_primitive('root' if rule_name == 'root' else schema_type, PRIMITIVE_RULES[schema_type]) + + def _add_primitive(self, name: str, rule: BuiltinRule): + n = self._add_rule(name, rule.content) + + for dep in rule.deps: + dep_rule = PRIMITIVE_RULES.get(dep) or STRING_FORMAT_RULES.get(dep) + assert dep_rule, f'Rule {dep} not known' + if dep not in self._rules: + self._add_primitive(dep, dep_rule) + return n + + def _build_object_rule(self, properties: List[Tuple[str, Any]], required: Set[str], name: str, additional_properties: Union[bool, Any]): + prop_order = self._prop_order + # sort by position in prop_order (if specified) then by original order + sorted_props = [kv[0] for _, kv in sorted(enumerate(properties), key=lambda ikv: (prop_order.get(ikv[1][0], len(prop_order)), ikv[0]))] + + prop_kv_rule_names = {} + for prop_name, prop_schema in properties: + prop_rule_name = self.visit(prop_schema, f'{name}{"-" if name else ""}{prop_name}') + prop_kv_rule_names[prop_name] = self._add_rule( + f'{name}{"-" if name else ""}{prop_name}-kv', + fr'{self._format_literal(json.dumps(prop_name))} space ":" space {prop_rule_name}' ) + required_props = [k for k in sorted_props if k in required] + optional_props = [k for k in sorted_props if k not in required] + + if additional_properties == True or isinstance(additional_properties, dict): + sub_name = f'{name}{"-" if name else ""}additional' + value_rule = self.visit({} if additional_properties == True else additional_properties, f'{sub_name}-value') + prop_kv_rule_names["*"] = self._add_rule( + f'{sub_name}-kv', + self._add_primitive('string', PRIMITIVE_RULES['string']) + f' ":" space {value_rule}' + ) + optional_props.append("*") + + rule = '"{" space ' + rule += ' "," space '.join(prop_kv_rule_names[k] for k in required_props) + + if optional_props: + rule += ' (' + if required_props: + rule += ' "," space ( ' + + def get_recursive_refs(ks, first_is_optional): + [k, *rest] = ks + kv_rule_name = prop_kv_rule_names[k] + if k == '*': + res = self._add_rule( + f'{name}{"-" if name else ""}additional-kvs', + f'{kv_rule_name} ( "," space ' + kv_rule_name + ' )*' + ) + elif first_is_optional: + res = f'( "," space {kv_rule_name} )?' + else: + res = kv_rule_name + if len(rest) > 0: + res += ' ' + self._add_rule( + f'{name}{"-" if name else ""}{k}-rest', + get_recursive_refs(rest, first_is_optional=True) + ) + return res + + rule += ' | '.join( + get_recursive_refs(optional_props[i:], first_is_optional=False) + for i in range(len(optional_props)) + ) + if required_props: + rule += ' )' + rule += ' )?' + + rule += ' "}" space' + + return rule def format_grammar(self): - return "\n".join((f"{name} ::= {rule}" for name, rule in self._rules.items())) - - + return '\n'.join( + f'{name} ::= {rule}' + for name, rule in sorted(self._rules.items(), key=lambda kv: kv[0]) + ) def json_schema_to_gbnf(schema: str, prop_order: Optional[List[str]] = None): prop_order = prop_order or [] schema = json.loads(schema) prop_order = {name: idx for idx, name in enumerate(prop_order)} - converter = SchemaConverter(prop_order) + converter = SchemaConverter(prop_order=prop_order, allow_fetch=False, dotall=False, raw_pattern=False) + schema = converter.resolve_refs(schema, "stdin") converter.visit(schema, "") return converter.format_grammar() From a128c80500046f31d6411f75438d0caf66060ace Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Thu, 18 Apr 2024 01:39:45 -0400 Subject: [PATCH 17/29] feat: Update llama.cpp --- vendor/llama.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vendor/llama.cpp b/vendor/llama.cpp index 8dd1ec8..3b8f1ec 160000 --- a/vendor/llama.cpp +++ b/vendor/llama.cpp @@ -1 +1 @@ -Subproject commit 8dd1ec8b3ffbfa2d26e82e672cea89f5eeb2f141 +Subproject commit 3b8f1ec4b18770531d0b1d792f3edf08254e4f0c From 893a27a7364366cc195dcc50aadfa75d95bb9319 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Thu, 18 Apr 2024 01:43:39 -0400 Subject: [PATCH 18/29] chore: Bump version --- CHANGELOG.md | 8 ++++++++ llama_cpp/__init__.py | 2 +- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c67498e..cf1efee 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [0.2.62] + +- feat: Update llama.cpp to ggerganov/llama.cpp@3b8f1ec4b18770531d0b1d792f3edf08254e4f0c +- feat: update grammar schema converter to match llama.cpp by @themrzmaster in #1353 +- feat: add disable_ping_events flag by @khimaros in #1257 +- feat: Make saved state more compact on-disk by @tc-wolf in #1296 +- feat: Use all available CPUs for batch processing by @ddh0 in #1345 + ## [0.2.61] - feat: Update llama.cpp to ggerganov/llama.cpp@ba5e134e073ec6837078c874aba44a702944a676 diff --git a/llama_cpp/__init__.py b/llama_cpp/__init__.py index 2382db9..828390a 100644 --- a/llama_cpp/__init__.py +++ b/llama_cpp/__init__.py @@ -1,4 +1,4 @@ from .llama_cpp import * from .llama import * -__version__ = "0.2.61" \ No newline at end of file +__version__ = "0.2.62" \ No newline at end of file From d17c1887a364201d292e2959c45a7f144ff240b2 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Fri, 19 Apr 2024 23:58:16 -0400 Subject: [PATCH 19/29] feat: Update llama.cpp --- vendor/llama.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vendor/llama.cpp b/vendor/llama.cpp index 3b8f1ec..0e4802b 160000 --- a/vendor/llama.cpp +++ b/vendor/llama.cpp @@ -1 +1 @@ -Subproject commit 3b8f1ec4b18770531d0b1d792f3edf08254e4f0c +Subproject commit 0e4802b2ecbaab04b4f829fde4a3096ca19c84b5 From cc81afebf04d26ca1ac3cf72f23f18da6ab58588 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Sat, 20 Apr 2024 00:00:53 -0400 Subject: [PATCH 20/29] feat: Add stopping_criteria to ChatFormatter, allow stopping on arbitrary token ids, fixes llama3 instruct --- llama_cpp/llama.py | 5 ++++- llama_cpp/llama_chat_format.py | 22 +++++++++++++++++++++- 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 5a0111b..818be82 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -426,7 +426,10 @@ class Llama: print(f"Using chat bos_token: {bos_token}", file=sys.stderr) self.chat_handler = llama_chat_format.Jinja2ChatFormatter( - template=template, eos_token=eos_token, bos_token=bos_token + template=template, + eos_token=eos_token, + bos_token=bos_token, + stop_token_ids=[eos_token_id], ).to_chat_handler() if self.chat_format is None and self.chat_handler is None: diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index eb98cbf..189ccb0 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -10,6 +10,9 @@ from typing import Any, Dict, Iterator, List, Literal, Optional, Tuple, Union, P import jinja2 +import numpy as np +import numpy.typing as npt + import llama_cpp.llama as llama import llama_cpp.llama_types as llama_types import llama_cpp.llama_grammar as llama_grammar @@ -150,6 +153,7 @@ class ChatFormatterResponse: prompt: str stop: Optional[Union[str, List[str]]] = None + stopping_criteria: Optional[llama.StoppingCriteriaList] = None class ChatFormatter(Protocol): @@ -173,12 +177,14 @@ class Jinja2ChatFormatter(ChatFormatter): eos_token: str, bos_token: str, add_generation_prompt: bool = True, + stop_token_ids: Optional[List[int]] = None, ): """A chat formatter that uses jinja2 templates to format the prompt.""" self.template = template self.eos_token = eos_token self.bos_token = bos_token self.add_generation_prompt = add_generation_prompt + self.stop_token_ids = set(stop_token_ids) if stop_token_ids is not None else None self._environment = jinja2.Environment( loader=jinja2.BaseLoader(), @@ -211,7 +217,16 @@ class Jinja2ChatFormatter(ChatFormatter): tool_choice=tool_choice, ) - return ChatFormatterResponse(prompt=prompt, stop=[self.eos_token]) + stopping_criteria = None + if self.stop_token_ids is not None: + def stop_on_last_token( + tokens: npt.NDArray[np.intc], + logits: npt.NDArray[np.single] + ) -> bool: + return tokens[-1] in self.stop_token_ids + stopping_criteria = llama.StoppingCriteriaList([stop_on_last_token]) + + return ChatFormatterResponse(prompt=prompt, stop=[self.eos_token], stopping_criteria=stopping_criteria) def to_chat_handler(self) -> LlamaChatCompletionHandler: return chat_formatter_to_chat_completion_handler(self) @@ -533,6 +548,10 @@ def chat_formatter_to_chat_completion_handler( rstop = result.stop if isinstance(result.stop, list) else [result.stop] stop = stop + rstop + stopping_criteria = None + if result.stopping_criteria is not None: + stopping_criteria = result.stopping_criteria + if response_format is not None and response_format["type"] == "json_object": grammar = _grammar_for_response_format(response_format, verbose=llama.verbose) @@ -598,6 +617,7 @@ def chat_formatter_to_chat_completion_handler( mirostat_eta=mirostat_eta, model=model, logits_processor=logits_processor, + stopping_criteria=stopping_criteria, grammar=grammar, logit_bias=logit_bias, ) From 02812148635bf6337ffc7d1abb34093f4065df88 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Sat, 20 Apr 2024 00:09:37 -0400 Subject: [PATCH 21/29] chore: Bump version --- CHANGELOG.md | 5 +++++ llama_cpp/__init__.py | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index cf1efee..4c7e267 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [0.2.63] + +- feat: Update llama.cpp to ggerganov/llama.cpp@0e4802b2ecbaab04b4f829fde4a3096ca19c84b5 +- feat: Add stopping_criteria to ChatFormatter, allow stopping on arbitrary token ids, fixes llama3 instruct by @abetlen in cc81afebf04d26ca1ac3cf72f23f18da6ab58588 + ## [0.2.62] - feat: Update llama.cpp to ggerganov/llama.cpp@3b8f1ec4b18770531d0b1d792f3edf08254e4f0c diff --git a/llama_cpp/__init__.py b/llama_cpp/__init__.py index 828390a..f2c7aba 100644 --- a/llama_cpp/__init__.py +++ b/llama_cpp/__init__.py @@ -1,4 +1,4 @@ from .llama_cpp import * from .llama import * -__version__ = "0.2.62" \ No newline at end of file +__version__ = "0.2.63" \ No newline at end of file From 159cc4e5d924804c9776355011d5378dc8d5d9f4 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Sun, 21 Apr 2024 20:46:40 -0400 Subject: [PATCH 22/29] feat: Update llama.cpp --- llama_cpp/_internals.py | 14 +++++++------- llama_cpp/llama_cpp.py | 28 +++++++++++++++++++++++++--- vendor/llama.cpp | 2 +- 3 files changed, 33 insertions(+), 11 deletions(-) diff --git a/llama_cpp/_internals.py b/llama_cpp/_internals.py index 79f6543..ff2d657 100644 --- a/llama_cpp/_internals.py +++ b/llama_cpp/_internals.py @@ -181,20 +181,20 @@ class _LlamaModel: ) return list(tokens[:n_tokens]) - def token_to_piece(self, token: int) -> bytes: + def token_to_piece(self, token: int, special: bool = False) -> bytes: assert self.model is not None buf = ctypes.create_string_buffer(32) - llama_cpp.llama_token_to_piece(self.model, token, buf, 32) + llama_cpp.llama_token_to_piece(self.model, token, buf, 32, special) return bytes(buf) - def detokenize(self, tokens: List[int]) -> bytes: + def detokenize(self, tokens: List[int], special: bool = False) -> bytes: assert self.model is not None output = b"" size = 32 buffer = (ctypes.c_char * size)() for token in tokens: n = llama_cpp.llama_token_to_piece( - self.model, llama_cpp.llama_token(token), buffer, size + self.model, llama_cpp.llama_token(token), buffer, size, special ) assert n <= size output += bytes(buffer[:n]) @@ -597,13 +597,13 @@ def _tokenize(model: _LlamaModel, text: str, add_bos: bool, special: bool) -> li return list(result) -def _token_to_piece(model: _LlamaModel, token: int) -> str: +def _token_to_piece(model: _LlamaModel, token: int, special: bool = False) -> str: assert model.model is not None result = (ctypes.c_char * 8)(0) - n_tokens = llama_cpp.llama_token_to_piece(model.model, token, result, len(result)) + n_tokens = llama_cpp.llama_token_to_piece(model.model, token, result, len(result), special) if n_tokens < 0: result = (ctypes.c_char * -n_tokens)(0) - check = llama_cpp.llama_token_to_piece(model.model, token, result, len(result)) + check = llama_cpp.llama_token_to_piece(model.model, token, result, len(result), special) if check != -n_tokens: raise RuntimeError(f"Failed to get piece: token={token}") else: diff --git a/llama_cpp/llama_cpp.py b/llama_cpp/llama_cpp.py index 2450d11..c2b909e 100644 --- a/llama_cpp/llama_cpp.py +++ b/llama_cpp/llama_cpp.py @@ -2380,6 +2380,18 @@ def llama_token_get_type( ... +# // Check if the token is supposed to end generation (end-of-generation, eg. EOS, EOT, etc.) +# LLAMA_API bool llama_token_is_eog(const struct llama_model * model, llama_token token); +@ctypes_function( + "llama_token_is_eog", [llama_model_p_ctypes, llama_token], ctypes.c_bool +) +def llama_token_is_eog( + model: llama_model_p, token: Union[llama_token, int], / +) -> bool: + """Check if the token is supposed to end generation (end-of-generation, eg. EOS, EOT, etc.)""" + ... + + # // Special tokens @@ -2434,7 +2446,7 @@ def llama_add_eos_token(model: llama_model_p, /) -> int: ... -# // codellama infill tokens +# // Codellama infill tokens # LLAMA_API llama_token llama_token_prefix(const struct llama_model * model); // Beginning of infill prefix @ctypes_function("llama_token_prefix", [llama_model_p_ctypes], llama_token) def llama_token_prefix(model: llama_model_p) -> int: @@ -2524,11 +2536,13 @@ def llama_tokenize( # // Uses the vocabulary in the provided context. # // Does not write null terminator to the buffer. # // User code is responsible to remove the leading whitespace of the first non-BOS token when decoding multiple tokens. +# // @param special If true, special tokens are rendered in the output. # LLAMA_API int32_t llama_token_to_piece( # const struct llama_model * model, # llama_token token, # char * buf, -# int32_t length); +# int32_t length, +# bool special); @ctypes_function( "llama_token_to_piece", [ @@ -2536,6 +2550,7 @@ def llama_tokenize( llama_token, ctypes.c_char_p, ctypes.c_int32, + ctypes.c_bool, ], ctypes.c_int32, ) @@ -2544,13 +2559,20 @@ def llama_token_to_piece( token: Union[llama_token, int], buf: Union[ctypes.c_char_p, bytes, CtypesArray[ctypes.c_char]], length: Union[ctypes.c_int, int], + special: Union[ctypes.c_bool, bool], /, ) -> int: """Token Id -> Piece. Uses the vocabulary in the provided context. Does not write null terminator to the buffer. User code is responsible to remove the leading whitespace of the first non-BOS token when decoding multiple tokens. - """ + + Args: + model: The model to use for tokenization. + token: The token to convert. + buf: The buffer to write the token to. + length: The length of the buffer. + special: If true, special tokens are rendered in the output.""" ... diff --git a/vendor/llama.cpp b/vendor/llama.cpp index 3b8f1ec..5cf5e7d 160000 --- a/vendor/llama.cpp +++ b/vendor/llama.cpp @@ -1 +1 @@ -Subproject commit 3b8f1ec4b18770531d0b1d792f3edf08254e4f0c +Subproject commit 5cf5e7d490dfdd2e70bface2d35dfd14aa44b4fb From d40a250ef3cfaa8224d12c83776a2f1de96ae3d1 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Mon, 22 Apr 2024 00:35:47 -0400 Subject: [PATCH 23/29] feat: Use new llama_token_is_eog in create_completions --- llama_cpp/llama.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 818be82..0a576d4 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -1034,7 +1034,8 @@ class Llama: logits_processor=logits_processor, grammar=grammar, ): - if token == self._token_eos: + assert self._model.model is not None + if llama_cpp.llama_token_is_eog(self._model.model, token): text = self.detokenize(completion_tokens, prev_tokens=prompt_tokens) finish_reason = "stop" break From 617d536e1c4f9e84e551828897f5e2b8004539cb Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Tue, 23 Apr 2024 02:31:40 -0400 Subject: [PATCH 24/29] feat: Update llama.cpp --- vendor/llama.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vendor/llama.cpp b/vendor/llama.cpp index 5cf5e7d..4e96a81 160000 --- a/vendor/llama.cpp +++ b/vendor/llama.cpp @@ -1 +1 @@ -Subproject commit 5cf5e7d490dfdd2e70bface2d35dfd14aa44b4fb +Subproject commit 4e96a812b3ce7322a29a3008db2ed73d9087b176 From 8559e8ce88b7c7343004eeccb7333b806034b01c Mon Sep 17 00:00:00 2001 From: abk16 Date: Tue, 23 Apr 2024 06:33:29 +0000 Subject: [PATCH 25/29] feat: Add Llama-3 chat format (#1371) * feat: Add Llama-3 chat format * feat: Auto-detect Llama-3 chat format from gguf template * feat: Update llama.cpp to b2715 Includes proper Llama-3 <|eot_id|> token handling. --------- Co-authored-by: Andrei Betlen --- llama_cpp/llama_chat_format.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index 189ccb0..17b570a 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -35,6 +35,9 @@ MISTRAL_INSTRUCT_EOS_TOKEN = "" # Source: https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1/blob/main/tokenizer_config.json MIXTRAL_INSTRUCT_CHAT_TEMPLATE = "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}" +# Source: https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct/blob/main/tokenizer_config.json +LLAMA3_INSTRUCT_CHAT_TEMPLATE = "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}" + ### Chat Completion Handler ### @@ -729,6 +732,9 @@ def guess_chat_format_from_gguf_metadata(metadata: Dict[str, str]) -> Optional[s metadata["tokenizer.chat_template"] == MIXTRAL_INSTRUCT_CHAT_TEMPLATE): return "mistral-instruct" + if metadata["tokenizer.chat_template"] == LLAMA3_INSTRUCT_CHAT_TEMPLATE: + return "llama-3" + return None @@ -920,6 +926,26 @@ def format_llama2( return ChatFormatterResponse(prompt=_prompt) +# Chat format for Llama-3 models, see more details at: +# https://github.com/meta-llama/llama3/blob/main/llama/tokenizer.py#L202-L229 +@register_chat_format("llama-3") +def format_llama3( + messages: List[llama_types.ChatCompletionRequestMessage], + **kwargs: Any, +) -> ChatFormatterResponse: + _roles = dict( + system="<|start_header_id|>system<|end_header_id|>\n\n", + user="<|start_header_id|>user<|end_header_id|>\n\n", + assistant="<|start_header_id|>assistant<|end_header_id|>\n\n", + ) + _begin_token = "<|begin_of_text|>" + _sep = "<|eot_id|>" + _messages = _map_roles(messages, _roles) + _messages.append((_roles["assistant"], None)) + _prompt = _format_no_colon_single(_begin_token, _messages, _sep) + return ChatFormatterResponse(prompt=_prompt, stop=_sep) + + @register_chat_format("alpaca") def format_alpaca( messages: List[llama_types.ChatCompletionRequestMessage], From 507c1da0663e80d01a5a608355f523c0a791971d Mon Sep 17 00:00:00 2001 From: Geza Velkey Date: Tue, 23 Apr 2024 08:34:15 +0200 Subject: [PATCH 26/29] fix: Update scikit-build-core build dependency avoid bug in 0.9.1 (#1370) cmake [options] cmake [options] cmake [options] -S -B Specify a source directory to (re-)generate a build system for it in the current working directory. Specify an existing build directory to re-generate its build system. Run 'cmake --help' for more information. issue --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index e2bbb4b..8345cb1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,5 @@ [build-system] -requires = ["scikit-build-core[pyproject]>=0.5.1"] +requires = ["scikit-build-core[pyproject]>=0.9.2"] build-backend = "scikit_build_core.build" [project] From 53ebcc8bb5beb912bda4baa44ae27a1e3a7eabd0 Mon Sep 17 00:00:00 2001 From: Sean Bailey <34511443+sean-bailey@users.noreply.github.com> Date: Tue, 23 Apr 2024 02:35:38 -0400 Subject: [PATCH 27/29] feat(server): Provide ability to dynamically allocate all threads if desired using `-1` (#1364) --- llama_cpp/server/settings.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/llama_cpp/server/settings.py b/llama_cpp/server/settings.py index 934aecd..eab5a8a 100644 --- a/llama_cpp/server/settings.py +++ b/llama_cpp/server/settings.py @@ -3,7 +3,7 @@ from __future__ import annotations import multiprocessing from typing import Optional, List, Literal, Union -from pydantic import Field +from pydantic import Field, root_validator from pydantic_settings import BaseSettings import llama_cpp @@ -67,12 +67,12 @@ class ModelSettings(BaseSettings): n_threads: int = Field( default=max(multiprocessing.cpu_count() // 2, 1), ge=1, - description="The number of threads to use.", + description="The number of threads to use. Use -1 for max cpu threads", ) n_threads_batch: int = Field( default=max(multiprocessing.cpu_count(), 1), ge=0, - description="The number of threads to use when batch processing.", + description="The number of threads to use when batch processing. Use -1 for max cpu threads", ) rope_scaling_type: int = Field( default=llama_cpp.LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED @@ -173,6 +173,16 @@ class ModelSettings(BaseSettings): default=True, description="Whether to print debug information." ) + @root_validator(pre=True) # pre=True to ensure this runs before any other validation + def set_dynamic_defaults(cls, values): + # If n_threads or n_threads_batch is -1, set it to multiprocessing.cpu_count() + cpu_count = multiprocessing.cpu_count() + if values.get('n_threads', 0) == -1: + values['n_threads'] = cpu_count + if values.get('n_threads_batch', 0) == -1: + values['n_threads_batch'] = cpu_count + return values + class ServerSettings(BaseSettings): """Server settings used to configure the FastAPI and Uvicorn server.""" From 611781f5319719a3d05fefccbbf0cc321742a026 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Tue, 23 Apr 2024 02:48:09 -0400 Subject: [PATCH 28/29] ci: Build arm64 wheels. Closes #1342 --- .github/workflows/build-and-release.yaml | 31 +++++++++++++++++++++++- 1 file changed, 30 insertions(+), 1 deletion(-) diff --git a/.github/workflows/build-and-release.yaml b/.github/workflows/build-and-release.yaml index 76b5f7f..07742f1 100644 --- a/.github/workflows/build-and-release.yaml +++ b/.github/workflows/build-and-release.yaml @@ -41,6 +41,35 @@ jobs: with: path: ./wheelhouse/*.whl + build_arm64_wheels: + name: Build arm64 wheels + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + with: + submodules: "recursive" + + - name: Set up QEMU + uses: docker/setup-qemu-action@v3 + with: + platforms: linux/arm64 + + - name: Build wheels + uses: pypa/cibuildwheel@v2.16.5 + env: + CIBW_SKIP: "*musllinux* pp*" + CIBW_REPAIR_WHEEL_COMMAND: "" + CIBW_ARCHS: "aarch64" + CIBW_BUILD: "cp38-* cp39-* cp310-* cp311-* cp312-*" + with: + output-dir: wheelhouse/ + + - name: Upload wheels as artifacts + uses: actions/upload-artifact@v4 + with: + name: wheels-${{ matrix.version }} + path: wheelhouse/*.whl + build_sdist: name: Build source distribution runs-on: ubuntu-latest @@ -65,7 +94,7 @@ jobs: release: name: Release - needs: [build_wheels, build_sdist] + needs: [build_wheels, build_arm64_wheels, build_sdist] runs-on: ubuntu-latest steps: From c50d3300d2a09c98765be7f2c05b7e4fd0b4232e Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Tue, 23 Apr 2024 02:53:20 -0400 Subject: [PATCH 29/29] chore: Bump version --- CHANGELOG.md | 9 +++++++++ llama_cpp/__init__.py | 2 +- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4c7e267..25b8835 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [0.2.64] + +- feat: Update llama.cpp to ggerganov/llama.cpp@4e96a812b3ce7322a29a3008db2ed73d9087b176 +- feat: Add `llama-3` chat format by @andreabak in #1371 +- feat: Use new llama_token_is_eog in create_completions by @abetlen in d40a250ef3cfaa8224d12c83776a2f1de96ae3d1 +- feat(server): Provide ability to dynamically allocate all threads if desired using -1 by @sean-bailey in #1364 +- ci: Build arm64 wheels by @gaby in 611781f5319719a3d05fefccbbf0cc321742a026 +- fix: Update scikit-build-core build dependency avoid bug in 0.9.1 by @evelkey in #1370 + ## [0.2.63] - feat: Update llama.cpp to ggerganov/llama.cpp@0e4802b2ecbaab04b4f829fde4a3096ca19c84b5 diff --git a/llama_cpp/__init__.py b/llama_cpp/__init__.py index f2c7aba..f736458 100644 --- a/llama_cpp/__init__.py +++ b/llama_cpp/__init__.py @@ -1,4 +1,4 @@ from .llama_cpp import * from .llama import * -__version__ = "0.2.63" \ No newline at end of file +__version__ = "0.2.64" \ No newline at end of file