From 0e70984fb69d621c191913bf870b7d9201bcc3d5 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Sat, 2 Mar 2024 22:20:04 -0500 Subject: [PATCH] feat: Update llama.cpp --- llama_cpp/llama_cpp.py | 36 ++++++++++++++++++++++++++++++++++-- vendor/llama.cpp | 2 +- 2 files changed, 35 insertions(+), 3 deletions(-) diff --git a/llama_cpp/llama_cpp.py b/llama_cpp/llama_cpp.py index 1593256..88ba41c 100644 --- a/llama_cpp/llama_cpp.py +++ b/llama_cpp/llama_cpp.py @@ -148,6 +148,12 @@ ggml_backend_sched_eval_callback = ctypes.CFUNCTYPE( ctypes.c_bool, ctypes.c_void_p, ctypes.c_bool, ctypes.c_void_p ) +# // Abort callback +# // If not NULL, called before ggml computation +# // If it returns true, the computation is aborted +# typedef bool (*ggml_abort_callback)(void * data); +ggml_abort_callback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_void_p) + # llama.h bindings _lib.llama_max_devices.argtypes = [] @@ -560,10 +566,16 @@ class llama_model_params(ctypes.Structure): # enum ggml_type type_v; // data type for V cache # // Keep the booleans together to avoid misalignment during copy-by-value. -# bool logits_all; // the llama_eval() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead) +# bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead) # bool embedding; // embedding mode only # bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU # bool do_pooling; // whether to pool (sum) embedding results by sequence id (ignored if no pooling layer) + +# // Abort callback +# // if it returns true, execution of llama_decode() will be aborted +# // currently works only with CPU execution +# ggml_abort_callback abort_callback; +# void * abort_callback_data; # }; class llama_context_params(ctypes.Structure): """Parameters for llama_context @@ -591,6 +603,8 @@ class llama_context_params(ctypes.Structure): embedding (bool): embedding mode only offload_kqv (bool): whether to offload the KQV ops (including the KV cache) to GPU do_pooling (bool): whether to pool (sum) embedding results by sequence id (ignored if no pooling layer) + abort_callback (ggml_abort_callback): abort callback if it returns true, execution of llama_decode() will be aborted + abort_callback_data (ctypes.ctypes.c_void_p): data for abort_callback """ _fields_ = [ @@ -616,6 +630,8 @@ class llama_context_params(ctypes.Structure): ("embedding", ctypes.c_bool), ("offload_kqv", ctypes.c_bool), ("do_pooling", ctypes.c_bool), + ("abort_callback", ggml_abort_callback), + ("abort_callback_data", ctypes.c_void_p), ] @@ -1703,8 +1719,24 @@ def llama_set_n_threads( """ ... +# // Set abort callback +# LLAMA_API void llama_set_abort_callback(struct llama_context * ctx, ggml_abort_callback abort_callback, void * abort_callback_data); +@ctypes_function( + "llama_set_abort_callback", + [llama_context_p_ctypes, ggml_abort_callback, ctypes.c_void_p], + None, +) +def llama_set_abort_callback( + ctx: llama_context_p, + abort_callback: Callable[[ctypes.c_void_p], None], + abort_callback_data: ctypes.c_void_p, + /, +): + """Set abort callback""" + ... -# // Token logits obtained from the last call to llama_eval() + +# // Token logits obtained from the last call to llama_decode() # // The logits for the last token are stored in the last row # // Logits for which llama_batch.logits[i] == 0 are undefined # // Rows: n_tokens provided with llama_batch diff --git a/vendor/llama.cpp b/vendor/llama.cpp index c2224f0..9731134 160000 --- a/vendor/llama.cpp +++ b/vendor/llama.cpp @@ -1 +1 @@ -Subproject commit c2224f003bf9cf558b1a3c57033563e11a4de9a5 +Subproject commit 9731134296af3a6839cd682e51d9c2109a871de5