feat: Add Google's Gemma formatting via chat_format="gemma" (#1210)

* Add Google's Gemma formatting via `chat_format="gemma"`

* Replace `raise ValueError` with `logger.debug`

Co-authored-by: Andrei <abetlen@gmail.com>

---------

Co-authored-by: Andrei <abetlen@gmail.com>
This commit is contained in:
Alvaro Bartolome 2024-02-23 18:40:52 +09:00 committed by GitHub
parent eebb102df7
commit 251a8a2cad
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -14,6 +14,7 @@ import llama_cpp.llama as llama
import llama_cpp.llama_types as llama_types import llama_cpp.llama_types as llama_types
import llama_cpp.llama_grammar as llama_grammar import llama_cpp.llama_grammar as llama_grammar
from ._logger import logger
from ._utils import suppress_stdout_stderr, Singleton from ._utils import suppress_stdout_stderr, Singleton
### Common Chat Templates and Special Tokens ### ### Common Chat Templates and Special Tokens ###
@ -993,6 +994,26 @@ def format_saiga(
return ChatFormatterResponse(prompt=_prompt.strip()) return ChatFormatterResponse(prompt=_prompt.strip())
# Chat format for Google's Gemma models, see more details and available models:
# https://huggingface.co/collections/google/gemma-release-65d5efbccdbb8c4202ec078b
@register_chat_format("gemma")
def format_gemma(
messages: List[llama_types.ChatCompletionRequestMessage],
**kwargs: Any,
) -> ChatFormatterResponse:
system_message = _get_system_message(messages)
if system_message is not None and system_message != "":
logger.debug(
"`role='system'` messages are not allowed on Google's Gemma models."
)
_roles = dict(user="<start_of_turn>user\n", assistant="<start_of_turn>model\n")
_sep = "<end_of_turn>\n"
_messages = _map_roles(messages, _roles)
_messages.append((_roles["assistant"], None))
_prompt = _format_no_colon_single(system_message="", messages=_messages, sep=_sep)
return ChatFormatterResponse(prompt=_prompt, stop=_sep)
# Tricky chat formats that require custom chat handlers # Tricky chat formats that require custom chat handlers