diff --git a/llama_cpp/llama_cpp.py b/llama_cpp/llama_cpp.py index 3dbe570..724126e 100644 --- a/llama_cpp/llama_cpp.py +++ b/llama_cpp/llama_cpp.py @@ -3,7 +3,6 @@ import ctypes from ctypes import ( c_int, c_float, - c_double, c_char_p, c_void_p, c_bool, @@ -40,7 +39,7 @@ class llama_token_data(Structure): llama_token_data_p = POINTER(llama_token_data) -llama_progress_callback = ctypes.CFUNCTYPE(None, c_double, c_void_p) +llama_progress_callback = ctypes.CFUNCTYPE(None, c_float, c_void_p) class llama_context_params(Structure): @@ -48,6 +47,7 @@ class llama_context_params(Structure): ("n_ctx", c_int), # text context ("n_parts", c_int), # -1 for default ("seed", c_int), # RNG seed, 0 for random + ("f16_kv", c_bool), # use fp16 for KV cache ( "logits_all", @@ -56,6 +56,7 @@ class llama_context_params(Structure): ("vocab_only", c_bool), # only load the vocabulary, no weights ("use_mlock", c_bool), # force system to keep model in RAM ("embedding", c_bool), # embedding mode only + # called with a progress value between 0 and 1, pass NULL to disable ("progress_callback", llama_progress_callback), # context pointer passed to the progress callback @@ -70,8 +71,7 @@ llama_context_params_p = POINTER(llama_context_params) def llama_context_default_params() -> llama_context_params: - params = _lib.llama_context_default_params() - return params + return _lib.llama_context_default_params() _lib.llama_context_default_params.argtypes = [] @@ -229,9 +229,9 @@ def llama_sample_top_p_top_k( last_n_tokens_data: llama_token_p, last_n_tokens_size: c_int, top_k: c_int, - top_p: c_double, - temp: c_double, - repeat_penalty: c_double, + top_p: c_float, + temp: c_float, + repeat_penalty: c_float, ) -> llama_token: return _lib.llama_sample_top_p_top_k( ctx, last_n_tokens_data, last_n_tokens_size, top_k, top_p, temp, repeat_penalty @@ -243,9 +243,9 @@ _lib.llama_sample_top_p_top_k.argtypes = [ llama_token_p, c_int, c_int, - c_double, - c_double, - c_double, + c_float, + c_float, + c_float, ] _lib.llama_sample_top_p_top_k.restype = llama_token diff --git a/vendor/llama.cpp b/vendor/llama.cpp index 7e53955..5a5f8b1 160000 --- a/vendor/llama.cpp +++ b/vendor/llama.cpp @@ -1 +1 @@ -Subproject commit 7e5395575a3360598f2565c73c8a2ec0c0abbdb8 +Subproject commit 5a5f8b1501fbb34367225544010ddfc306d6d2fe