diff --git a/llama_cpp/llama_cpp.py b/llama_cpp/llama_cpp.py index 62ddbf4..1731878 100644 --- a/llama_cpp/llama_cpp.py +++ b/llama_cpp/llama_cpp.py @@ -568,13 +568,33 @@ _lib.llama_model_n_embd.restype = c_int # // Get a string describing the model type -# LLAMA_API int llama_model_type(const struct llama_model * model, char * buf, size_t buf_size); -def llama_model_type(model: llama_model_p, buf: bytes, buf_size: c_size_t) -> int: - return _lib.llama_model_type(model, buf, buf_size) +# LLAMA_API int llama_model_desc(const struct llama_model * model, char * buf, size_t buf_size); +def llama_model_desc(model: llama_model_p, buf: bytes, buf_size: c_size_t) -> int: + return _lib.llama_model_desc(model, buf, buf_size) -_lib.llama_model_type.argtypes = [llama_model_p, c_char_p, c_size_t] -_lib.llama_model_type.restype = c_int +_lib.llama_model_desc.argtypes = [llama_model_p, c_char_p, c_size_t] +_lib.llama_model_desc.restype = c_int + + +# // Returns the total size of all the tensors in the model in bytes +# LLAMA_API uint64_t llama_model_size(const struct llama_model * model); +def llama_model_size(model: llama_model_p) -> int: + return _lib.llama_model_size(model) + + +_lib.llama_model_size.argtypes = [llama_model_p] +_lib.llama_model_size.restype = ctypes.c_uint64 + + +# // Returns the total number of parameters in the model +# LLAMA_API uint64_t llama_model_n_params(const struct llama_model * model); +def llama_model_n_params(model: llama_model_p) -> int: + return _lib.llama_model_n_params(model) + + +_lib.llama_model_n_params.argtypes = [llama_model_p] +_lib.llama_model_n_params.restype = ctypes.c_uint64 # // Returns 0 on success @@ -1029,6 +1049,74 @@ def llama_grammar_free(grammar: llama_grammar_p): _lib.llama_grammar_free.argtypes = [llama_grammar_p] _lib.llama_grammar_free.restype = None +# // +# // Beam search +# // + + +# struct llama_beam_view { +# const llama_token * tokens; +# size_t n_tokens; +# float p; // Cumulative beam probability (renormalized relative to all beams) +# bool eob; // Callback should set this to true when a beam is at end-of-beam. +# }; +class llama_beam_view(ctypes.Structure): + _fields_ = [ + ("tokens", llama_token_p), + ("n_tokens", c_size_t), + ("p", c_float), + ("eob", c_bool), + ] + + +# // Passed to beam_search_callback function. +# // Whenever 0 < common_prefix_length, this number of tokens should be copied from any of the beams +# // (e.g. beams[0]) as they will be removed (shifted) from all beams in all subsequent callbacks. +# // These pointers are valid only during the synchronous callback, so should not be saved. +# struct llama_beams_state { +# struct llama_beam_view * beam_views; +# size_t n_beams; // Number of elements in beam_views[]. +# size_t common_prefix_length; // Current max length of prefix tokens shared by all beams. +# bool last_call; // True iff this is the last callback invocation. +# }; +class llama_beams_state(ctypes.Structure): + _fields_ = [ + ("beam_views", POINTER(llama_beam_view)), + ("n_beams", c_size_t), + ("common_prefix_length", c_size_t), + ("last_call", c_bool), + ] + + +# // Type of pointer to the beam_search_callback function. +# // void* callback_data is any custom data passed to llama_beam_search, that is subsequently +# // passed back to beam_search_callback. This avoids having to use global variables in the callback. +# typedef void (*llama_beam_search_callback_fn_t)(void * callback_data, llama_beams_state); +llama_beam_search_callback_fn_t = ctypes.CFUNCTYPE(None, c_void_p, llama_beams_state) + + +# /// @details Deterministically returns entire sentence constructed by a beam search. +# /// @param ctx Pointer to the llama_context. +# /// @param callback Invoked for each iteration of the beam_search loop, passing in beams_state. +# /// @param callback_data A pointer that is simply passed back to callback. +# /// @param n_beams Number of beams to use. +# /// @param n_past Number of tokens already evaluated. +# /// @param n_predict Maximum number of tokens to predict. EOS may occur earlier. +# /// @param n_threads Number of threads as passed to llama_eval(). +# LLAMA_API void llama_beam_search(struct llama_context * ctx, llama_beam_search_callback_fn_t callback, void * callback_data, size_t n_beams, int n_past, int n_predict, int n_threads); +def llama_beam_search( + ctx: llama_context_p, + callback: "ctypes._CFuncPtr[None, c_void_p, llama_beams_state]", # type: ignore + callback_data: c_void_p, + n_beams: c_size_t, + n_past: c_int, + n_predict: c_int, + n_threads: c_int, +): + return _lib.llama_beam_search( + ctx, callback, callback_data, n_beams, n_past, n_predict, n_threads + ) + # // # // Sampling functions diff --git a/vendor/llama.cpp b/vendor/llama.cpp index 2e5f70a..232caf3 160000 --- a/vendor/llama.cpp +++ b/vendor/llama.cpp @@ -1 +1 @@ -Subproject commit 2e5f70a25fc4576e9ed78603fe493eb7702c37a3 +Subproject commit 232caf3c1581a6cb023571780ff41dc2d66d1ca0