diff --git a/llama_cpp/llama_cpp.py b/llama_cpp/llama_cpp.py index 2a1d3f0..e946adb 100644 --- a/llama_cpp/llama_cpp.py +++ b/llama_cpp/llama_cpp.py @@ -664,6 +664,18 @@ class llama_timings(Structure): ] +# // used in chat template +# typedef struct llama_chat_message { +# const char * role; +# const char * content; +# } llama_chat_message; +class llama_chat_message(Structure): + _fields_ = [ + ("role", c_char_p), + ("content", c_char_p), + ] + + # // Helpers for getting default parameters # LLAMA_API struct llama_model_params llama_model_default_params(void); def llama_model_default_params() -> llama_model_params: @@ -1956,6 +1968,47 @@ _lib.llama_token_to_piece.argtypes = [llama_model_p, llama_token, c_char_p, c_in _lib.llama_token_to_piece.restype = c_int32 +# /// Apply chat template. Inspired by hf apply_chat_template() on python. +# /// Both "model" and "custom_template" are optional, but at least one is required. "custom_template" has higher precedence than "model" +# /// NOTE: This function only support some known jinja templates. It is not a jinja parser. +# /// @param tmpl A Jinja template to use for this chat. If this is nullptr, the model’s default chat template will be used instead. +# /// @param chat Pointer to a list of multiple llama_chat_message +# /// @param n_msg Number of llama_chat_message in this chat +# /// @param add_ass Whether to end the prompt with the token(s) that indicate the start of an assistant message. +# /// @param buf A buffer to hold the output formatted prompt. The recommended alloc size is 2 * (total number of characters of all messages) +# /// @param length The size of the allocated buffer +# /// @return The total number of bytes of the formatted prompt. If is it larger than the size of buffer, you may need to re-alloc it and then re-apply the template. +# LLAMA_API int32_t llama_chat_apply_template( +# const struct llama_model * model, +# const char * tmpl, +# const struct llama_chat_message * chat, +# size_t n_msg, +# bool add_ass, +# char * buf, +# int32_t length); +def llama_chat_apply_template( + model: llama_model_p, + tmpl: bytes, + chat: "ctypes._Pointer[llama_chat_message]", + n_msg: int, +) -> int: + return _lib.llama_chat_apply_template( + model, + tmpl, + chat, + n_msg + ) + +_lib.llama_chat_apply_template.argtypes = [ + ctypes.c_void_p, + ctypes.c_char_p, + ctypes.POINTER(llama_chat_message), + ctypes.c_size_t +] +_lib.llama_chat_apply_template.restype = ctypes.c_int32 + + + # // # // Grammar # // diff --git a/vendor/llama.cpp b/vendor/llama.cpp index a0c2dad..f53119c 160000 --- a/vendor/llama.cpp +++ b/vendor/llama.cpp @@ -1 +1 @@ -Subproject commit a0c2dad9d43456c677e205c6240a5f8afb0121ac +Subproject commit f53119cec4f073b6d214195ecbe1fad3abdf2b34