diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index 8dd0ddf..16bccb9 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -14,6 +14,7 @@ import llama_cpp.llama as llama import llama_cpp.llama_types as llama_types import llama_cpp.llama_grammar as llama_grammar +from ._logger import logger from ._utils import suppress_stdout_stderr, Singleton ### Common Chat Templates and Special Tokens ### @@ -993,6 +994,26 @@ def format_saiga( return ChatFormatterResponse(prompt=_prompt.strip()) +# Chat format for Google's Gemma models, see more details and available models: +# https://huggingface.co/collections/google/gemma-release-65d5efbccdbb8c4202ec078b +@register_chat_format("gemma") +def format_gemma( + messages: List[llama_types.ChatCompletionRequestMessage], + **kwargs: Any, +) -> ChatFormatterResponse: + system_message = _get_system_message(messages) + if system_message is not None and system_message != "": + logger.debug( + "`role='system'` messages are not allowed on Google's Gemma models." + ) + _roles = dict(user="user\n", assistant="model\n") + _sep = "\n" + _messages = _map_roles(messages, _roles) + _messages.append((_roles["assistant"], None)) + _prompt = _format_no_colon_single(system_message="", messages=_messages, sep=_sep) + return ChatFormatterResponse(prompt=_prompt, stop=_sep) + + # Tricky chat formats that require custom chat handlers