feat: Add Google's Gemma formatting via chat_format="gemma"
(#1210)
* Add Google's Gemma formatting via `chat_format="gemma"` * Replace `raise ValueError` with `logger.debug` Co-authored-by: Andrei <abetlen@gmail.com> --------- Co-authored-by: Andrei <abetlen@gmail.com>
This commit is contained in:
parent
eebb102df7
commit
251a8a2cad
1 changed files with 21 additions and 0 deletions
|
@ -14,6 +14,7 @@ import llama_cpp.llama as llama
|
|||
import llama_cpp.llama_types as llama_types
|
||||
import llama_cpp.llama_grammar as llama_grammar
|
||||
|
||||
from ._logger import logger
|
||||
from ._utils import suppress_stdout_stderr, Singleton
|
||||
|
||||
### Common Chat Templates and Special Tokens ###
|
||||
|
@ -993,6 +994,26 @@ def format_saiga(
|
|||
return ChatFormatterResponse(prompt=_prompt.strip())
|
||||
|
||||
|
||||
# Chat format for Google's Gemma models, see more details and available models:
|
||||
# https://huggingface.co/collections/google/gemma-release-65d5efbccdbb8c4202ec078b
|
||||
@register_chat_format("gemma")
|
||||
def format_gemma(
|
||||
messages: List[llama_types.ChatCompletionRequestMessage],
|
||||
**kwargs: Any,
|
||||
) -> ChatFormatterResponse:
|
||||
system_message = _get_system_message(messages)
|
||||
if system_message is not None and system_message != "":
|
||||
logger.debug(
|
||||
"`role='system'` messages are not allowed on Google's Gemma models."
|
||||
)
|
||||
_roles = dict(user="<start_of_turn>user\n", assistant="<start_of_turn>model\n")
|
||||
_sep = "<end_of_turn>\n"
|
||||
_messages = _map_roles(messages, _roles)
|
||||
_messages.append((_roles["assistant"], None))
|
||||
_prompt = _format_no_colon_single(system_message="", messages=_messages, sep=_sep)
|
||||
return ChatFormatterResponse(prompt=_prompt, stop=_sep)
|
||||
|
||||
|
||||
# Tricky chat formats that require custom chat handlers
|
||||
|
||||
|
||||
|
|
Loading…
Reference in a new issue