From 952228407ebd68ef621ad747e3561c821d1c02d3 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Mon, 26 Jun 2023 08:50:38 -0400 Subject: [PATCH] Update llama.cpp --- llama_cpp/llama.py | 9 ++++-- llama_cpp/llama_cpp.py | 66 +++++++++++++++++++++++++++++++++++++++++- vendor/llama.cpp | 2 +- 3 files changed, 72 insertions(+), 5 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 3465cd4..3319cde 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -282,15 +282,18 @@ class Llama: if not os.path.exists(model_path): raise ValueError(f"Model path does not exist: {model_path}") - self.ctx = llama_cpp.llama_init_from_file( + self.model = llama_cpp.llama_load_model_from_file( self.model_path.encode("utf-8"), self.params ) + assert self.model is not None + + self.ctx = llama_cpp.llama_new_context_with_model(self.model, self.params) assert self.ctx is not None if self.lora_path: - if llama_cpp.llama_apply_lora_from_file( - self.ctx, + if llama_cpp.llama_model_apply_lora_from_file( + self.model, llama_cpp.c_char_p(self.lora_path.encode("utf-8")), llama_cpp.c_char_p(self.lora_base.encode("utf-8")) if self.lora_base is not None diff --git a/llama_cpp/llama_cpp.py b/llama_cpp/llama_cpp.py index a516829..23643e2 100644 --- a/llama_cpp/llama_cpp.py +++ b/llama_cpp/llama_cpp.py @@ -15,7 +15,7 @@ from ctypes import ( c_size_t, ) import pathlib -from typing import List +from typing import List, Union # Load the library @@ -105,6 +105,9 @@ LLAMA_FILE_MAGIC_UNVERSIONED = LLAMA_FILE_MAGIC_GGML LLAMA_SESSION_MAGIC = LLAMA_FILE_MAGIC_GGSN LLAMA_SESSION_VERSION = c_int(1) +# struct llama_model; +llama_model_p = c_void_p + # struct llama_context; llama_context_p = c_void_p @@ -161,6 +164,7 @@ llama_progress_callback = ctypes.CFUNCTYPE(None, c_float, c_void_p) # // context pointer passed to the progress callback # void * progress_callback_user_data; + # // Keep the booleans together to avoid misalignment during copy-by-value. # bool low_vram; // if true, reduce VRAM usage at the cost of performance # bool f16_kv; // use fp16 for KV cache @@ -296,6 +300,41 @@ _lib.llama_init_backend.argtypes = [] _lib.llama_init_backend.restype = None +# LLAMA_API struct llama_model * llama_load_model_from_file( +# const char * path_model, +# struct llama_context_params params); +def llama_load_model_from_file( + path_model: bytes, params: llama_context_params +) -> llama_model_p: + return _lib.llama_load_model_from_file(path_model, params) + + +_lib.llama_load_model_from_file.argtypes = [c_char_p, llama_context_params] +_lib.llama_load_model_from_file.restype = llama_model_p + + +# LLAMA_API void llama_free_model(struct llama_model * model); +def llama_free_model(model: llama_model_p): + return _lib.llama_free_model(model) + + +_lib.llama_free_model.argtypes = [llama_model_p] +_lib.llama_free_model.restype = None + + +# LLAMA_API struct llama_context * llama_new_context_with_model( +# struct llama_model * model, +# struct llama_context_params params); +def llama_new_context_with_model( + model: llama_model_p, params: llama_context_params +) -> llama_context_p: + return _lib.llama_new_context_with_model(model, params) + + +_lib.llama_new_context_with_model.argtypes = [llama_model_p, llama_context_params] +_lib.llama_new_context_with_model.restype = llama_context_p + + # LLAMA_API int64_t llama_time_us(); def llama_time_us() -> int: return _lib.llama_time_us() @@ -376,6 +415,31 @@ _lib.llama_apply_lora_from_file.argtypes = [llama_context_p, c_char_p, c_char_p, _lib.llama_apply_lora_from_file.restype = c_int +# LLAMA_API int llama_model_apply_lora_from_file( +# const struct llama_model * model, +# const char * path_lora, +# const char * path_base_model, +# int n_threads); +def llama_model_apply_lora_from_file( + model: llama_model_p, + path_lora: Union[c_char_p, bytes], + path_base_model: Union[c_char_p, bytes], + n_threads: c_int, +) -> int: + return _lib.llama_model_apply_lora_from_file( + model, path_lora, path_base_model, n_threads + ) + + +_lib.llama_model_apply_lora_from_file.argtypes = [ + llama_model_p, + c_char_p, + c_char_p, + c_int, +] +_lib.llama_model_apply_lora_from_file.restype = c_int + + # Returns the number of tokens in the KV cache # LLAMA_API int llama_get_kv_cache_token_count(const struct llama_context * ctx); def llama_get_kv_cache_token_count(ctx: llama_context_p) -> int: diff --git a/vendor/llama.cpp b/vendor/llama.cpp index 2322ec2..447ccbe 160000 --- a/vendor/llama.cpp +++ b/vendor/llama.cpp @@ -1 +1 @@ -Subproject commit 2322ec223a21625dfe9bd73ee677444a98a24ac9 +Subproject commit 447ccbe8c39332fcdd0d98a041b6e2ff6f06219d