Update llama.cpp

This commit is contained in:
Andrei Betlen 2023-07-24 13:04:34 -04:00
parent 231123ee1e
commit 985d559971
2 changed files with 91 additions and 6 deletions

View file

@ -159,11 +159,13 @@ llama_progress_callback = ctypes.CFUNCTYPE(None, c_float, c_void_p)
# struct llama_context_params { # struct llama_context_params {
# uint32_t seed; // RNG seed, -1 for random # uint32_t seed; // RNG seed, -1 for random
# int32_t n_ctx; // text context # int32_t n_ctx; // text context
# int32_t n_batch; // prompt processing batch size # int32_t n_batch; // prompt processing batch size
# int32_t n_gpu_layers; // number of layers to store in VRAM # int32_t n_gqa; // grouped-query attention (TEMP - will be moved to model hparams)
# int32_t main_gpu; // the GPU that is used for scratch and small tensors # int32_t n_gpu_layers; // number of layers to store in VRAM
# int32_t main_gpu; // the GPU that is used for scratch and small tensors
#
# const float * tensor_split; // how to split layers across multiple GPUs (size: LLAMA_MAX_DEVICES) # const float * tensor_split; // how to split layers across multiple GPUs (size: LLAMA_MAX_DEVICES)
# // ref: https://github.com/ggerganov/llama.cpp/pull/2054 # // ref: https://github.com/ggerganov/llama.cpp/pull/2054
@ -190,6 +192,7 @@ class llama_context_params(Structure):
("seed", c_uint32), ("seed", c_uint32),
("n_ctx", c_int32), ("n_ctx", c_int32),
("n_batch", c_int32), ("n_batch", c_int32),
("n_gqa", c_int32),
("n_gpu_layers", c_int32), ("n_gpu_layers", c_int32),
("main_gpu", c_int32), ("main_gpu", c_int32),
("tensor_split", POINTER(c_float)), ("tensor_split", POINTER(c_float)),
@ -265,6 +268,57 @@ class llama_model_quantize_params(Structure):
] ]
# // grammar types
# struct llama_grammar;
llama_grammar_p = c_void_p
# // grammar element type
# enum llama_gretype {
# // end of rule definition
# LLAMA_GRETYPE_END = 0,
# // start of alternate definition for rule
# LLAMA_GRETYPE_ALT = 1,
# // non-terminal element: reference to rule
# LLAMA_GRETYPE_RULE_REF = 2,
# // terminal element: character (code point)
# LLAMA_GRETYPE_CHAR = 3,
# // inverse char(s) ([^a], [^a-b] [^abc])
# LLAMA_GRETYPE_CHAR_NOT = 4,
# // modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to
# // be an inclusive range ([a-z])
# LLAMA_GRETYPE_CHAR_RNG_UPPER = 5,
# // modifies a preceding LLAMA_GRETYPE_CHAR or
# // LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA])
# LLAMA_GRETYPE_CHAR_ALT = 6,
# };
LLAMA_GRETYPE_END = c_int(0)
LLAMA_GRETYPE_ALT = c_int(1)
LLAMA_GRETYPE_RULE_REF = c_int(2)
LLAMA_GRETYPE_CHAR = c_int(3)
LLAMA_GRETYPE_CHAR_NOT = c_int(4)
LLAMA_GRETYPE_CHAR_RNG_UPPER = c_int(5)
LLAMA_GRETYPE_CHAR_ALT = c_int(6)
# typedef struct llama_grammar_element {
# enum llama_gretype type;
# uint32_t value; // Unicode code point or rule ID
# } llama_grammar_element;
class llama_grammar_element(Structure):
_fields_ = [
("type", c_int),
("value", c_uint32),
]
llama_grammar_element_p = POINTER(llama_grammar_element)
# // performance timing information # // performance timing information
# struct llama_timings { # struct llama_timings {
# double t_start_ms; # double t_start_ms;
@ -871,6 +925,37 @@ _lib.llama_token_nl.argtypes = []
_lib.llama_token_nl.restype = llama_token _lib.llama_token_nl.restype = llama_token
# // Grammar
# //
# LLAMA_API struct llama_grammar * llama_grammar_init(
# const llama_grammar_element ** rules,
# size_t n_rules,
# size_t start_rule_index);
def llama_grammar_init(
rules, # type: Array[llama_grammar_element_p] # type: ignore
n_rules: c_size_t,
start_rule_index: c_size_t,
) -> llama_grammar_p:
return _lib.llama_grammar_init(rules, n_rules, start_rule_index)
_lib.llama_grammar_init.argtypes = [
POINTER(llama_grammar_element_p),
c_size_t,
c_size_t,
]
_lib.llama_grammar_init.restype = llama_grammar_p
# LLAMA_API void llama_grammar_free(struct llama_grammar * grammar);
def llama_grammar_free(grammar: llama_grammar_p):
return _lib.llama_grammar_free(grammar)
_lib.llama_grammar_free.argtypes = [llama_grammar_p]
_lib.llama_grammar_free.restype = None
# Sampling functions # Sampling functions

2
vendor/llama.cpp vendored

@ -1 +1 @@
Subproject commit d924522a46c5ef097af4a88087d91673e8e87e4d Subproject commit 84e09a7d8bc4ab6d658b5cd81295ac0add60be78