diff --git a/llama_cpp/llama_cpp.py b/llama_cpp/llama_cpp.py index eea26ac..c9d79b9 100644 --- a/llama_cpp/llama_cpp.py +++ b/llama_cpp/llama_cpp.py @@ -159,11 +159,13 @@ llama_progress_callback = ctypes.CFUNCTYPE(None, c_float, c_void_p) # struct llama_context_params { -# uint32_t seed; // RNG seed, -1 for random -# int32_t n_ctx; // text context -# int32_t n_batch; // prompt processing batch size -# int32_t n_gpu_layers; // number of layers to store in VRAM -# int32_t main_gpu; // the GPU that is used for scratch and small tensors +# uint32_t seed; // RNG seed, -1 for random +# int32_t n_ctx; // text context +# int32_t n_batch; // prompt processing batch size +# int32_t n_gqa; // grouped-query attention (TEMP - will be moved to model hparams) +# int32_t n_gpu_layers; // number of layers to store in VRAM +# int32_t main_gpu; // the GPU that is used for scratch and small tensors +# # const float * tensor_split; // how to split layers across multiple GPUs (size: LLAMA_MAX_DEVICES) # // ref: https://github.com/ggerganov/llama.cpp/pull/2054 @@ -190,6 +192,7 @@ class llama_context_params(Structure): ("seed", c_uint32), ("n_ctx", c_int32), ("n_batch", c_int32), + ("n_gqa", c_int32), ("n_gpu_layers", c_int32), ("main_gpu", c_int32), ("tensor_split", POINTER(c_float)), @@ -265,6 +268,57 @@ class llama_model_quantize_params(Structure): ] +# // grammar types +# struct llama_grammar; +llama_grammar_p = c_void_p + +# // grammar element type +# enum llama_gretype { +# // end of rule definition +# LLAMA_GRETYPE_END = 0, + +# // start of alternate definition for rule +# LLAMA_GRETYPE_ALT = 1, + +# // non-terminal element: reference to rule +# LLAMA_GRETYPE_RULE_REF = 2, + +# // terminal element: character (code point) +# LLAMA_GRETYPE_CHAR = 3, + +# // inverse char(s) ([^a], [^a-b] [^abc]) +# LLAMA_GRETYPE_CHAR_NOT = 4, + +# // modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to +# // be an inclusive range ([a-z]) +# LLAMA_GRETYPE_CHAR_RNG_UPPER = 5, + +# // modifies a preceding LLAMA_GRETYPE_CHAR or +# // LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA]) +# LLAMA_GRETYPE_CHAR_ALT = 6, +# }; +LLAMA_GRETYPE_END = c_int(0) +LLAMA_GRETYPE_ALT = c_int(1) +LLAMA_GRETYPE_RULE_REF = c_int(2) +LLAMA_GRETYPE_CHAR = c_int(3) +LLAMA_GRETYPE_CHAR_NOT = c_int(4) +LLAMA_GRETYPE_CHAR_RNG_UPPER = c_int(5) +LLAMA_GRETYPE_CHAR_ALT = c_int(6) + + +# typedef struct llama_grammar_element { +# enum llama_gretype type; +# uint32_t value; // Unicode code point or rule ID +# } llama_grammar_element; +class llama_grammar_element(Structure): + _fields_ = [ + ("type", c_int), + ("value", c_uint32), + ] + + +llama_grammar_element_p = POINTER(llama_grammar_element) + # // performance timing information # struct llama_timings { # double t_start_ms; @@ -871,6 +925,37 @@ _lib.llama_token_nl.argtypes = [] _lib.llama_token_nl.restype = llama_token +# // Grammar +# // +# LLAMA_API struct llama_grammar * llama_grammar_init( +# const llama_grammar_element ** rules, +# size_t n_rules, +# size_t start_rule_index); +def llama_grammar_init( + rules, # type: Array[llama_grammar_element_p] # type: ignore + n_rules: c_size_t, + start_rule_index: c_size_t, +) -> llama_grammar_p: + return _lib.llama_grammar_init(rules, n_rules, start_rule_index) + + +_lib.llama_grammar_init.argtypes = [ + POINTER(llama_grammar_element_p), + c_size_t, + c_size_t, +] +_lib.llama_grammar_init.restype = llama_grammar_p + + +# LLAMA_API void llama_grammar_free(struct llama_grammar * grammar); +def llama_grammar_free(grammar: llama_grammar_p): + return _lib.llama_grammar_free(grammar) + + +_lib.llama_grammar_free.argtypes = [llama_grammar_p] +_lib.llama_grammar_free.restype = None + + # Sampling functions diff --git a/vendor/llama.cpp b/vendor/llama.cpp index d924522..84e09a7 160000 --- a/vendor/llama.cpp +++ b/vendor/llama.cpp @@ -1 +1 @@ -Subproject commit d924522a46c5ef097af4a88087d91673e8e87e4d +Subproject commit 84e09a7d8bc4ab6d658b5cd81295ac0add60be78