diff --git a/Makefile b/Makefile index 66d93f3..1be35cf 100644 --- a/Makefile +++ b/Makefile @@ -33,6 +33,9 @@ deploy.gh-docs: mkdocs build mkdocs gh-deploy +test: + python3 -m pytest + clean: - cd vendor/llama.cpp && make clean - cd vendor/llama.cpp && rm libllama.so diff --git a/llama_cpp/llama_cpp.py b/llama_cpp/llama_cpp.py index c68fb18..17c6319 100644 --- a/llama_cpp/llama_cpp.py +++ b/llama_cpp/llama_cpp.py @@ -2,6 +2,7 @@ import sys import os import ctypes from ctypes import ( + c_double, c_int, c_float, c_char_p, @@ -169,6 +170,7 @@ llama_progress_callback = ctypes.CFUNCTYPE(None, c_float, c_void_p) # // context pointer passed to the progress callback # void * progress_callback_user_data; + # // Keep the booleans together to avoid misalignment during copy-by-value. # bool low_vram; // if true, reduce VRAM usage at the cost of performance # bool f16_kv; // use fp16 for KV cache @@ -256,6 +258,34 @@ class llama_model_quantize_params(Structure): ] +# // performance timing information +# struct llama_timings { +# double t_start_ms; +# double t_end_ms; +# double t_load_ms; +# double t_sample_ms; +# double t_p_eval_ms; +# double t_eval_ms; + + +# int32_t n_sample; +# int32_t n_p_eval; +# int32_t n_eval; +# }; +class llama_timings(Structure): + _fields_ = [ + ("t_start_ms", c_double), + ("t_end_ms", c_double), + ("t_load_ms", c_double), + ("t_sample_ms", c_double), + ("t_p_eval_ms", c_double), + ("t_eval_ms", c_double), + ("n_sample", c_int32), + ("n_p_eval", c_int32), + ("n_eval", c_int32), + ] + + # LLAMA_API struct llama_context_params llama_context_default_params(); def llama_context_default_params() -> llama_context_params: return _lib.llama_context_default_params() @@ -991,6 +1021,15 @@ _lib.llama_sample_token.restype = llama_token # Performance information +# LLAMA_API struct llama_timings llama_get_timings(struct llama_context * ctx); +def llama_get_timings(ctx: llama_context_p) -> llama_timings: + return _lib.llama_get_timings(ctx) + + +_lib.llama_get_timings.argtypes = [llama_context_p] +_lib.llama_get_timings.restype = llama_timings + + # LLAMA_API void llama_print_timings(struct llama_context * ctx); def llama_print_timings(ctx: llama_context_p): _lib.llama_print_timings(ctx) diff --git a/vendor/llama.cpp b/vendor/llama.cpp index 7f0e9a7..dfd9fce 160000 --- a/vendor/llama.cpp +++ b/vendor/llama.cpp @@ -1 +1 @@ -Subproject commit 7f0e9a775ecc4c6ade271c217f63d6dc93e79eaa +Subproject commit dfd9fce6d65599bf33df43e616e85aa639bdae4c