From 251a8a2cadb4c0df4671062144d168a7874086a2 Mon Sep 17 00:00:00 2001 From: Alvaro Bartolome Date: Fri, 23 Feb 2024 18:40:52 +0900 Subject: [PATCH] feat: Add Google's Gemma formatting via `chat_format="gemma"` (#1210) * Add Google's Gemma formatting via `chat_format="gemma"` * Replace `raise ValueError` with `logger.debug` Co-authored-by: Andrei --------- Co-authored-by: Andrei --- llama_cpp/llama_chat_format.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index 8dd0ddf..16bccb9 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -14,6 +14,7 @@ import llama_cpp.llama as llama import llama_cpp.llama_types as llama_types import llama_cpp.llama_grammar as llama_grammar +from ._logger import logger from ._utils import suppress_stdout_stderr, Singleton ### Common Chat Templates and Special Tokens ### @@ -993,6 +994,26 @@ def format_saiga( return ChatFormatterResponse(prompt=_prompt.strip()) +# Chat format for Google's Gemma models, see more details and available models: +# https://huggingface.co/collections/google/gemma-release-65d5efbccdbb8c4202ec078b +@register_chat_format("gemma") +def format_gemma( + messages: List[llama_types.ChatCompletionRequestMessage], + **kwargs: Any, +) -> ChatFormatterResponse: + system_message = _get_system_message(messages) + if system_message is not None and system_message != "": + logger.debug( + "`role='system'` messages are not allowed on Google's Gemma models." + ) + _roles = dict(user="user\n", assistant="model\n") + _sep = "\n" + _messages = _map_roles(messages, _roles) + _messages.append((_roles["assistant"], None)) + _prompt = _format_no_colon_single(system_message="", messages=_messages, sep=_sep) + return ChatFormatterResponse(prompt=_prompt, stop=_sep) + + # Tricky chat formats that require custom chat handlers