From 8c2b24d5aafffffadf37c3067ca12a45b59d95c4 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Tue, 30 Apr 2024 09:27:55 -0400 Subject: [PATCH] feat: Update llama.cpp --- llama_cpp/llama_cpp.py | 10 +++++++--- vendor/llama.cpp | 2 +- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/llama_cpp/llama_cpp.py b/llama_cpp/llama_cpp.py index 9c8f778..46aa516 100644 --- a/llama_cpp/llama_cpp.py +++ b/llama_cpp/llama_cpp.py @@ -242,8 +242,8 @@ LLAMA_FILE_MAGIC_GGSQ = 0x67677371 # define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN LLAMA_SESSION_MAGIC = LLAMA_FILE_MAGIC_GGSN -# define LLAMA_SESSION_VERSION 5 -LLAMA_SESSION_VERSION = 5 +# define LLAMA_SESSION_VERSION 6 +LLAMA_SESSION_VERSION = 6 # define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ LLAMA_STATE_SEQ_MAGIC = LLAMA_FILE_MAGIC_GGSQ @@ -730,6 +730,7 @@ class llama_model_params(ctypes.Structure): # bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead) # bool embeddings; // if true, extract embeddings (together with logits) # bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU +# bool flash_attn; // whether to use flash attention # // Abort callback @@ -766,6 +767,7 @@ class llama_context_params(ctypes.Structure): logits_all (bool): the llama_eval() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead) embeddings (bool): if true, extract embeddings (together with logits) offload_kqv (bool): whether to offload the KQV ops (including the KV cache) to GPU + flash_attn (bool): whether to use flash attention abort_callback (ggml_abort_callback): abort callback if it returns true, execution of llama_decode() will be aborted abort_callback_data (ctypes.ctypes.c_void_p): data for abort_callback """ @@ -795,6 +797,7 @@ class llama_context_params(ctypes.Structure): logits_all: bool embeddings: bool offload_kqv: bool + flash_attn: bool abort_callback: Callable[[ctypes.c_void_p], bool] abort_callback_data: ctypes.c_void_p @@ -823,6 +826,7 @@ class llama_context_params(ctypes.Structure): ("logits_all", ctypes.c_bool), ("embeddings", ctypes.c_bool), ("offload_kqv", ctypes.c_bool), + ("flash_attn", ctypes.c_bool), ("abort_callback", ggml_abort_callback), ("abort_callback_data", ctypes.c_void_p), ] @@ -1615,7 +1619,7 @@ def llama_get_kv_cache_used_cells(ctx: llama_context_p, /) -> int: ... -# // Clear the KV cache +# // Clear the KV cache - both cell info is erased and KV data is zeroed # LLAMA_API void llama_kv_cache_clear( # struct llama_context * ctx); @ctypes_function("llama_kv_cache_clear", [llama_context_p_ctypes], None) diff --git a/vendor/llama.cpp b/vendor/llama.cpp index 8843a98..77e15be 160000 --- a/vendor/llama.cpp +++ b/vendor/llama.cpp @@ -1 +1 @@ -Subproject commit 8843a98c2ba97a25e93319a104f9ddfaf83ce4c4 +Subproject commit 77e15bec6217a39be59b9cc83d6b9afb6b0d8167