feat: Update llama.cpp

This commit is contained in:
Andrei Betlen 2024-04-27 23:41:54 -04:00
parent c07db99e5b
commit c9b85bf098
2 changed files with 69 additions and 85 deletions

View file

@ -552,19 +552,25 @@ class llama_batch(ctypes.Structure):
# LLAMA_KV_OVERRIDE_TYPE_INT, # LLAMA_KV_OVERRIDE_TYPE_INT,
# LLAMA_KV_OVERRIDE_TYPE_FLOAT, # LLAMA_KV_OVERRIDE_TYPE_FLOAT,
# LLAMA_KV_OVERRIDE_TYPE_BOOL, # LLAMA_KV_OVERRIDE_TYPE_BOOL,
# LLAMA_KV_OVERRIDE_TYPE_STR,
# }; # };
LLAMA_KV_OVERRIDE_TYPE_INT = 0 LLAMA_KV_OVERRIDE_TYPE_INT = 0
LLAMA_KV_OVERRIDE_TYPE_FLOAT = 1 LLAMA_KV_OVERRIDE_TYPE_FLOAT = 1
LLAMA_KV_OVERRIDE_TYPE_BOOL = 2 LLAMA_KV_OVERRIDE_TYPE_BOOL = 2
LLAMA_KV_OVERRIDE_TYPE_STR = 3
# struct llama_model_kv_override { # struct llama_model_kv_override {
# char key[128];
# enum llama_model_kv_override_type tag; # enum llama_model_kv_override_type tag;
# char key[128];
# union { # union {
# int64_t int_value; # int64_t val_i64;
# double float_value; # double val_f64;
# bool bool_value; # bool val_bool;
# char val_str[128];
# }; # };
# }; # };
class llama_model_kv_override_value(ctypes.Union): class llama_model_kv_override_value(ctypes.Union):
@ -572,16 +578,28 @@ class llama_model_kv_override_value(ctypes.Union):
("int_value", ctypes.c_int64), ("int_value", ctypes.c_int64),
("float_value", ctypes.c_double), ("float_value", ctypes.c_double),
("bool_value", ctypes.c_bool), ("bool_value", ctypes.c_bool),
("str_value", ctypes.c_char * 128),
] ]
if TYPE_CHECKING:
int_value: int
float_value: float
bool_value: bool
str_value: bytes
class llama_model_kv_override(ctypes.Structure): class llama_model_kv_override(ctypes.Structure):
_fields_ = [ _fields_ = [
("key", ctypes.c_char * 128),
("tag", ctypes.c_int), ("tag", ctypes.c_int),
("key", ctypes.c_char * 128),
("value", llama_model_kv_override_value), ("value", llama_model_kv_override_value),
] ]
if TYPE_CHECKING:
tag: int
key: bytes
value: Union[int, float, bool, bytes]
# struct llama_model_params { # struct llama_model_params {
# int32_t n_gpu_layers; // number of layers to store in VRAM # int32_t n_gpu_layers; // number of layers to store in VRAM
@ -612,6 +630,7 @@ class llama_model_kv_override(ctypes.Structure):
# bool vocab_only; // only load the vocabulary, no weights # bool vocab_only; // only load the vocabulary, no weights
# bool use_mmap; // use mmap if possible # bool use_mmap; // use mmap if possible
# bool use_mlock; // force system to keep model in RAM # bool use_mlock; // force system to keep model in RAM
# bool check_tensors; // validate model tensor data
# }; # };
class llama_model_params(ctypes.Structure): class llama_model_params(ctypes.Structure):
"""Parameters for llama_model """Parameters for llama_model
@ -626,7 +645,8 @@ class llama_model_params(ctypes.Structure):
kv_overrides (ctypes.Array[llama_model_kv_override]): override key-value pairs of the model meta data kv_overrides (ctypes.Array[llama_model_kv_override]): override key-value pairs of the model meta data
vocab_only (bool): only load the vocabulary, no weights vocab_only (bool): only load the vocabulary, no weights
use_mmap (bool): use mmap if possible use_mmap (bool): use mmap if possible
use_mlock (bool): force system to keep model in RAM""" use_mlock (bool): force system to keep model in RAM
check_tensors (bool): validate model tensor data"""
if TYPE_CHECKING: if TYPE_CHECKING:
n_gpu_layers: int n_gpu_layers: int
@ -639,6 +659,7 @@ class llama_model_params(ctypes.Structure):
vocab_only: bool vocab_only: bool
use_mmap: bool use_mmap: bool
use_mlock: bool use_mlock: bool
check_tensors: bool
_fields_ = [ _fields_ = [
("n_gpu_layers", ctypes.c_int32), ("n_gpu_layers", ctypes.c_int32),
@ -651,6 +672,7 @@ class llama_model_params(ctypes.Structure):
("vocab_only", ctypes.c_bool), ("vocab_only", ctypes.c_bool),
("use_mmap", ctypes.c_bool), ("use_mmap", ctypes.c_bool),
("use_mlock", ctypes.c_bool), ("use_mlock", ctypes.c_bool),
("check_tensors", ctypes.c_bool),
] ]
@ -1041,8 +1063,7 @@ GGML_NUMA_STRATEGY_COUNT = 5
[ctypes.c_int], [ctypes.c_int],
None, None,
) )
def llama_numa_init(numa: int, /): def llama_numa_init(numa: int, /): ...
...
# // Call once at the end of the program - currently only used for MPI # // Call once at the end of the program - currently only used for MPI
@ -1067,8 +1088,7 @@ def llama_backend_free():
) )
def llama_load_model_from_file( def llama_load_model_from_file(
path_model: bytes, params: llama_model_params, / path_model: bytes, params: llama_model_params, /
) -> Optional[llama_model_p]: ) -> Optional[llama_model_p]: ...
...
# LLAMA_API void llama_free_model(struct llama_model * model); # LLAMA_API void llama_free_model(struct llama_model * model);
@ -1077,8 +1097,7 @@ def llama_load_model_from_file(
[llama_model_p_ctypes], [llama_model_p_ctypes],
None, None,
) )
def llama_free_model(model: llama_model_p, /): def llama_free_model(model: llama_model_p, /): ...
...
# LLAMA_API struct llama_context * llama_new_context_with_model( # LLAMA_API struct llama_context * llama_new_context_with_model(
@ -1091,8 +1110,7 @@ def llama_free_model(model: llama_model_p, /):
) )
def llama_new_context_with_model( def llama_new_context_with_model(
model: llama_model_p, params: llama_context_params, / model: llama_model_p, params: llama_context_params, /
) -> Optional[llama_context_p]: ) -> Optional[llama_context_p]: ...
...
# // Frees all allocated memory # // Frees all allocated memory
@ -1113,104 +1131,87 @@ def llama_free(ctx: llama_context_p, /):
[], [],
ctypes.c_int64, ctypes.c_int64,
) )
def llama_time_us() -> int: def llama_time_us() -> int: ...
...
# LLAMA_API size_t llama_max_devices(void); # LLAMA_API size_t llama_max_devices(void);
@ctypes_function("llama_max_devices", [], ctypes.c_size_t) @ctypes_function("llama_max_devices", [], ctypes.c_size_t)
def llama_max_devices() -> int: def llama_max_devices() -> int: ...
...
# LLAMA_API bool llama_supports_mmap (void); # LLAMA_API bool llama_supports_mmap (void);
@ctypes_function("llama_supports_mmap", [], ctypes.c_bool) @ctypes_function("llama_supports_mmap", [], ctypes.c_bool)
def llama_supports_mmap() -> bool: def llama_supports_mmap() -> bool: ...
...
# LLAMA_API bool llama_supports_mlock (void); # LLAMA_API bool llama_supports_mlock (void);
@ctypes_function("llama_supports_mlock", [], ctypes.c_bool) @ctypes_function("llama_supports_mlock", [], ctypes.c_bool)
def llama_supports_mlock() -> bool: def llama_supports_mlock() -> bool: ...
...
# LLAMA_API bool llama_supports_gpu_offload(void); # LLAMA_API bool llama_supports_gpu_offload(void);
@ctypes_function("llama_supports_gpu_offload", [], ctypes.c_bool) @ctypes_function("llama_supports_gpu_offload", [], ctypes.c_bool)
def llama_supports_gpu_offload() -> bool: def llama_supports_gpu_offload() -> bool: ...
...
# LLAMA_API const struct llama_model * llama_get_model(const struct llama_context * ctx); # LLAMA_API const struct llama_model * llama_get_model(const struct llama_context * ctx);
@ctypes_function("llama_get_model", [llama_context_p_ctypes], llama_model_p_ctypes) @ctypes_function("llama_get_model", [llama_context_p_ctypes], llama_model_p_ctypes)
def llama_get_model(ctx: llama_context_p, /) -> Optional[llama_model_p]: def llama_get_model(ctx: llama_context_p, /) -> Optional[llama_model_p]: ...
...
# LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx); # LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx);
@ctypes_function("llama_n_ctx", [llama_context_p_ctypes], ctypes.c_uint32) @ctypes_function("llama_n_ctx", [llama_context_p_ctypes], ctypes.c_uint32)
def llama_n_ctx(ctx: llama_context_p, /) -> int: def llama_n_ctx(ctx: llama_context_p, /) -> int: ...
...
# LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx); # LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx);
@ctypes_function("llama_n_batch", [llama_context_p_ctypes], ctypes.c_uint32) @ctypes_function("llama_n_batch", [llama_context_p_ctypes], ctypes.c_uint32)
def llama_n_batch(ctx: llama_context_p, /) -> int: def llama_n_batch(ctx: llama_context_p, /) -> int: ...
...
# LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx); # LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx);
@ctypes_function("llama_n_ubatch", [llama_context_p_ctypes], ctypes.c_uint32) @ctypes_function("llama_n_ubatch", [llama_context_p_ctypes], ctypes.c_uint32)
def llama_n_ubatch(ctx: llama_context_p, /) -> int: def llama_n_ubatch(ctx: llama_context_p, /) -> int: ...
...
# LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx); # LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx);
@ctypes_function("llama_n_seq_max", [llama_context_p_ctypes], ctypes.c_uint32) @ctypes_function("llama_n_seq_max", [llama_context_p_ctypes], ctypes.c_uint32)
def llama_n_seq_max(ctx: llama_context_p, /) -> int: def llama_n_seq_max(ctx: llama_context_p, /) -> int: ...
...
# LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx); # LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx);
@ctypes_function("llama_pooling_type", [llama_context_p_ctypes], ctypes.c_int) @ctypes_function("llama_pooling_type", [llama_context_p_ctypes], ctypes.c_int)
def llama_pooling_type(ctx: llama_context_p, /) -> int: def llama_pooling_type(ctx: llama_context_p, /) -> int: ...
...
# LLAMA_API enum llama_vocab_type llama_vocab_type (const struct llama_model * model); # LLAMA_API enum llama_vocab_type llama_vocab_type (const struct llama_model * model);
@ctypes_function("llama_vocab_type", [llama_model_p_ctypes], ctypes.c_int) @ctypes_function("llama_vocab_type", [llama_model_p_ctypes], ctypes.c_int)
def llama_vocab_type(model: llama_model_p, /) -> int: def llama_vocab_type(model: llama_model_p, /) -> int: ...
...
# LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model); # LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model);
@ctypes_function("llama_rope_type", [llama_model_p_ctypes], ctypes.c_int) @ctypes_function("llama_rope_type", [llama_model_p_ctypes], ctypes.c_int)
def llama_rope_type(model: llama_model_p, /) -> int: def llama_rope_type(model: llama_model_p, /) -> int: ...
...
# LLAMA_API int32_t llama_n_vocab (const struct llama_model * model); # LLAMA_API int32_t llama_n_vocab (const struct llama_model * model);
@ctypes_function("llama_n_vocab", [llama_model_p_ctypes], ctypes.c_int32) @ctypes_function("llama_n_vocab", [llama_model_p_ctypes], ctypes.c_int32)
def llama_n_vocab(model: llama_model_p, /) -> int: def llama_n_vocab(model: llama_model_p, /) -> int: ...
...
# LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model); # LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model);
@ctypes_function("llama_n_ctx_train", [llama_model_p_ctypes], ctypes.c_int32) @ctypes_function("llama_n_ctx_train", [llama_model_p_ctypes], ctypes.c_int32)
def llama_n_ctx_train(model: llama_model_p, /) -> int: def llama_n_ctx_train(model: llama_model_p, /) -> int: ...
...
# LLAMA_API int32_t llama_n_embd (const struct llama_model * model); # LLAMA_API int32_t llama_n_embd (const struct llama_model * model);
@ctypes_function("llama_n_embd", [llama_model_p_ctypes], ctypes.c_int32) @ctypes_function("llama_n_embd", [llama_model_p_ctypes], ctypes.c_int32)
def llama_n_embd(model: llama_model_p, /) -> int: def llama_n_embd(model: llama_model_p, /) -> int: ...
...
# LLAMA_API int32_t llama_n_layer (const struct llama_model * model); # LLAMA_API int32_t llama_n_layer (const struct llama_model * model);
@ctypes_function("llama_n_layer", [llama_model_p_ctypes], ctypes.c_int32) @ctypes_function("llama_n_layer", [llama_model_p_ctypes], ctypes.c_int32)
def llama_n_layer(model: llama_model_p, /) -> int: def llama_n_layer(model: llama_model_p, /) -> int: ...
...
# // Get the model's RoPE frequency scaling factor # // Get the model's RoPE frequency scaling factor
@ -1912,8 +1913,7 @@ def llama_state_load_file(
n_token_capacity: Union[ctypes.c_size_t, int], n_token_capacity: Union[ctypes.c_size_t, int],
n_token_count_out: CtypesPointerOrRef[ctypes.c_size_t], n_token_count_out: CtypesPointerOrRef[ctypes.c_size_t],
/, /,
) -> bool: ) -> bool: ...
...
# LLAMA_API DEPRECATED(bool llama_load_session_file( # LLAMA_API DEPRECATED(bool llama_load_session_file(
@ -1941,8 +1941,7 @@ def llama_load_session_file(
n_token_capacity: Union[ctypes.c_size_t, int], n_token_capacity: Union[ctypes.c_size_t, int],
n_token_count_out: CtypesPointerOrRef[ctypes.c_size_t], n_token_count_out: CtypesPointerOrRef[ctypes.c_size_t],
/, /,
) -> int: ) -> int: ...
...
# LLAMA_API bool llama_state_save_file( # LLAMA_API bool llama_state_save_file(
@ -1966,8 +1965,7 @@ def llama_state_save_file(
tokens: CtypesArray[llama_token], tokens: CtypesArray[llama_token],
n_token_count: Union[ctypes.c_size_t, int], n_token_count: Union[ctypes.c_size_t, int],
/, /,
) -> bool: ) -> bool: ...
...
# LLAMA_API DEPRECATED(bool llama_save_session_file( # LLAMA_API DEPRECATED(bool llama_save_session_file(
@ -1992,8 +1990,7 @@ def llama_save_session_file(
tokens: CtypesArray[llama_token], tokens: CtypesArray[llama_token],
n_token_count: Union[ctypes.c_size_t, int], n_token_count: Union[ctypes.c_size_t, int],
/, /,
) -> int: ) -> int: ...
...
# // Get the exact size needed to copy the KV cache of a single sequence # // Get the exact size needed to copy the KV cache of a single sequence
@ -2071,8 +2068,7 @@ def llama_state_seq_save_file(
tokens: CtypesArray[llama_token], tokens: CtypesArray[llama_token],
n_token_count: Union[ctypes.c_size_t, int], n_token_count: Union[ctypes.c_size_t, int],
/, /,
) -> int: ) -> int: ...
...
# LLAMA_API size_t llama_state_seq_load_file( # LLAMA_API size_t llama_state_seq_load_file(
@ -2102,8 +2098,7 @@ def llama_state_seq_load_file(
n_token_capacity: Union[ctypes.c_size_t, int], n_token_capacity: Union[ctypes.c_size_t, int],
n_token_count_out: CtypesPointerOrRef[ctypes.c_size_t], n_token_count_out: CtypesPointerOrRef[ctypes.c_size_t],
/, /,
) -> int: ) -> int: ...
...
# // # //
@ -2366,8 +2361,7 @@ def llama_get_embeddings_seq(
) )
def llama_token_get_text( def llama_token_get_text(
model: llama_model_p, token: Union[llama_token, int], / model: llama_model_p, token: Union[llama_token, int], /
) -> bytes: ) -> bytes: ...
...
# LLAMA_API float llama_token_get_score(const struct llama_model * model, llama_token token); # LLAMA_API float llama_token_get_score(const struct llama_model * model, llama_token token);
@ -2376,8 +2370,7 @@ def llama_token_get_text(
) )
def llama_token_get_score( def llama_token_get_score(
model: llama_model_p, token: Union[llama_token, int], / model: llama_model_p, token: Union[llama_token, int], /
) -> float: ) -> float: ...
...
# LLAMA_API enum llama_token_type llama_token_get_type(const struct llama_model * model, llama_token token); # LLAMA_API enum llama_token_type llama_token_get_type(const struct llama_model * model, llama_token token);
@ -2386,8 +2379,7 @@ def llama_token_get_score(
) )
def llama_token_get_type( def llama_token_get_type(
model: llama_model_p, token: Union[llama_token, int], / model: llama_model_p, token: Union[llama_token, int], /
) -> int: ) -> int: ...
...
# // Check if the token is supposed to end generation (end-of-generation, eg. EOS, EOT, etc.) # // Check if the token is supposed to end generation (end-of-generation, eg. EOS, EOT, etc.)
@ -2395,9 +2387,7 @@ def llama_token_get_type(
@ctypes_function( @ctypes_function(
"llama_token_is_eog", [llama_model_p_ctypes, llama_token], ctypes.c_bool "llama_token_is_eog", [llama_model_p_ctypes, llama_token], ctypes.c_bool
) )
def llama_token_is_eog( def llama_token_is_eog(model: llama_model_p, token: Union[llama_token, int], /) -> bool:
model: llama_model_p, token: Union[llama_token, int], /
) -> bool:
"""Check if the token is supposed to end generation (end-of-generation, eg. EOS, EOT, etc.)""" """Check if the token is supposed to end generation (end-of-generation, eg. EOS, EOT, etc.)"""
... ...
@ -2466,20 +2456,17 @@ def llama_token_prefix(model: llama_model_p) -> int:
# LLAMA_API llama_token llama_token_middle(const struct llama_model * model); // Beginning of infill middle # LLAMA_API llama_token llama_token_middle(const struct llama_model * model); // Beginning of infill middle
@ctypes_function("llama_token_middle", [llama_model_p_ctypes], llama_token) @ctypes_function("llama_token_middle", [llama_model_p_ctypes], llama_token)
def llama_token_middle(model: llama_model_p, /) -> int: def llama_token_middle(model: llama_model_p, /) -> int: ...
...
# LLAMA_API llama_token llama_token_suffix(const struct llama_model * model); // Beginning of infill suffix # LLAMA_API llama_token llama_token_suffix(const struct llama_model * model); // Beginning of infill suffix
@ctypes_function("llama_token_suffix", [llama_model_p_ctypes], llama_token) @ctypes_function("llama_token_suffix", [llama_model_p_ctypes], llama_token)
def llama_token_suffix(model: llama_model_p, /) -> int: def llama_token_suffix(model: llama_model_p, /) -> int: ...
...
# LLAMA_API llama_token llama_token_eot (const struct llama_model * model); // End of infill middle # LLAMA_API llama_token llama_token_eot (const struct llama_model * model); // End of infill middle
@ctypes_function("llama_token_eot", [llama_model_p_ctypes], llama_token) @ctypes_function("llama_token_eot", [llama_model_p_ctypes], llama_token)
def llama_token_eot(model: llama_model_p, /) -> int: def llama_token_eot(model: llama_model_p, /) -> int: ...
...
# // # //
@ -2620,8 +2607,7 @@ def llama_chat_apply_template(
chat: CtypesArray[llama_chat_message], chat: CtypesArray[llama_chat_message],
n_msg: int, n_msg: int,
/, /,
) -> int: ) -> int: ...
...
# // # //
@ -3234,8 +3220,7 @@ def llama_beam_search(
n_past: Union[ctypes.c_int, int], n_past: Union[ctypes.c_int, int],
n_predict: Union[ctypes.c_int, int], n_predict: Union[ctypes.c_int, int],
/, /,
): ): ...
...
# /// @details Build a split GGUF final path for this chunk. # /// @details Build a split GGUF final path for this chunk.
@ -3354,5 +3339,4 @@ def llama_log_set(
[ctypes.c_void_p, llama_context_p_ctypes], [ctypes.c_void_p, llama_context_p_ctypes],
None, None,
) )
def llama_dump_timing_info_yaml(stream: ctypes.c_void_p, ctx: llama_context_p, /): def llama_dump_timing_info_yaml(stream: ctypes.c_void_p, ctx: llama_context_p, /): ...
...

2
vendor/llama.cpp vendored

@ -1 +1 @@
Subproject commit 46e12c4692a37bdd31a0432fc5153d7d22bc7f72 Subproject commit 4dba7e8114d84241c842b986e008af8b88d1a019