feat: Add Jinja2ChatFormatter
This commit is contained in:
parent
5a34c57e54
commit
be09318c26
1 changed files with 188 additions and 135 deletions
|
@ -121,14 +121,21 @@ def register_chat_completion_handler(name: str):
|
|||
|
||||
@dataclasses.dataclass
|
||||
class ChatFormatterResponse:
|
||||
"""Dataclass that stores completion parameters for a given chat format and
|
||||
create_chat_completion request.
|
||||
|
||||
prompt contains the formatted prompt generated from the chat format and messages.
|
||||
stop contains the stop token or list of stop tokens to use for the chat format."""
|
||||
|
||||
prompt: str
|
||||
stop: Optional[Union[str, List[str]]] = None
|
||||
|
||||
|
||||
class ChatFormatter(Protocol):
|
||||
"""Base Protocol for a chat formatter. A chat formatter is a function that
|
||||
takes a list of messages and returns a formatted prompt. It can also return
|
||||
a stop token or list of stop tokens to use for the completion."""
|
||||
takes a list of messages and returns a chat format response which can be used
|
||||
to generate a completion. The response can also include a stop token or list
|
||||
of stop tokens to use for the completion."""
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
|
@ -139,131 +146,43 @@ class ChatFormatter(Protocol):
|
|||
...
|
||||
|
||||
|
||||
### Utility functions for formatting chat prompts ###
|
||||
class Jinja2ChatFormatter(ChatFormatter):
|
||||
def __init__(
|
||||
self,
|
||||
template: str,
|
||||
eos_token: str,
|
||||
bos_token: str,
|
||||
):
|
||||
"""A chat formatter that uses jinja2 templates to format the prompt."""
|
||||
self.template = template
|
||||
self.eos_token = eos_token
|
||||
self.bos_token = bos_token
|
||||
|
||||
self._environment = jinja2.Environment(
|
||||
loader=jinja2.BaseLoader(),
|
||||
trim_blocks=True,
|
||||
lstrip_blocks=True,
|
||||
).from_string(self.template)
|
||||
|
||||
def _get_system_message(
|
||||
messages: List[llama_types.ChatCompletionRequestMessage],
|
||||
) -> str:
|
||||
"""Get the first system message."""
|
||||
for message in messages:
|
||||
if message["role"] == "system":
|
||||
return message["content"] or ""
|
||||
return ""
|
||||
def __call__(
|
||||
self,
|
||||
*,
|
||||
messages: List[llama_types.ChatCompletionRequestMessage],
|
||||
**kwargs: Any,
|
||||
) -> ChatFormatterResponse:
|
||||
messages = [
|
||||
*messages,
|
||||
llama_types.ChatCompletionRequestAssistantMessage(
|
||||
role="assistant", content=""
|
||||
),
|
||||
]
|
||||
prompt = self._environment.render(
|
||||
messages=messages, eos_token=self.eos_token, bos_token=self.bos_token
|
||||
)
|
||||
return ChatFormatterResponse(prompt=prompt, stop=[self.eos_token])
|
||||
|
||||
|
||||
def _map_roles(
|
||||
messages: List[llama_types.ChatCompletionRequestMessage],
|
||||
role_map: Dict[str, str],
|
||||
) -> List[Tuple[str, Optional[str]]]:
|
||||
"""Map the message roles."""
|
||||
output: List[Tuple[str, Optional[str]]] = []
|
||||
for message in messages:
|
||||
role = message["role"]
|
||||
if role in role_map:
|
||||
content: str | None = (
|
||||
message["content"] if isinstance(message["content"], str) else None
|
||||
)
|
||||
output.append((role_map[role], content))
|
||||
return output
|
||||
|
||||
|
||||
def _format_llama2(
|
||||
system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str, sep2: str
|
||||
) -> str:
|
||||
"""Format the prompt with the llama2 style."""
|
||||
seps = [sep, sep2]
|
||||
ret = system_message + sep
|
||||
for i, (role, message) in enumerate(messages):
|
||||
if system_message and i == 0:
|
||||
m = message or ""
|
||||
ret += m + seps[i % 2]
|
||||
elif message:
|
||||
ret += role + message + " " + seps[i % 2]
|
||||
else:
|
||||
ret += role + " "
|
||||
return ret
|
||||
|
||||
|
||||
def _format_add_colon_single(
|
||||
system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str
|
||||
) -> str:
|
||||
"""Format the prompt with the add-colon-single style."""
|
||||
ret = system_message + sep
|
||||
for role, message in messages:
|
||||
if message:
|
||||
ret += role + ": " + message + sep
|
||||
else:
|
||||
ret += role + ":"
|
||||
return ret
|
||||
|
||||
|
||||
def _format_add_colon_two(
|
||||
system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str, sep2: str
|
||||
) -> str:
|
||||
"""Format the prompt with the add-colon-two style."""
|
||||
seps = [sep, sep2]
|
||||
ret = system_message + seps[0]
|
||||
for i, (role, message) in enumerate(messages):
|
||||
if message:
|
||||
ret += role + ": " + message + seps[i % 2]
|
||||
else:
|
||||
ret += role + ":"
|
||||
return ret
|
||||
|
||||
|
||||
def _format_no_colon_single(
|
||||
system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str
|
||||
) -> str:
|
||||
"""Format the prompt with the no-colon-single style."""
|
||||
ret = system_message
|
||||
for role, message in messages:
|
||||
if message:
|
||||
ret += role + message + sep
|
||||
else:
|
||||
ret += role
|
||||
return ret
|
||||
|
||||
|
||||
def _format_add_colon_space_single(
|
||||
system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str
|
||||
) -> str:
|
||||
"""Format the prompt with the add-colon-space-single style."""
|
||||
ret = system_message + sep
|
||||
for role, message in messages:
|
||||
if message:
|
||||
ret += role + ": " + message + sep
|
||||
else:
|
||||
ret += role + ": " # must be end with a space
|
||||
return ret
|
||||
|
||||
|
||||
def _format_chatml(
|
||||
system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str
|
||||
) -> str:
|
||||
"""Format the prompt with the chatml style."""
|
||||
ret = "" if system_message == "" else system_message + sep + "\n"
|
||||
for role, message in messages:
|
||||
if message:
|
||||
ret += role + "\n" + message + sep + "\n"
|
||||
else:
|
||||
ret += role + "\n"
|
||||
return ret
|
||||
|
||||
|
||||
def _format_chatglm3(
|
||||
system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str
|
||||
) -> str:
|
||||
"""Format the prompt with the chatglm3 style."""
|
||||
ret = ""
|
||||
if system_message:
|
||||
ret += system_message
|
||||
for role, message in messages:
|
||||
if message:
|
||||
ret += role + "\n" + " " + message
|
||||
else:
|
||||
ret += role
|
||||
return ret
|
||||
def to_chat_handler(self) -> LlamaChatCompletionHandler:
|
||||
return chat_formatter_to_chat_completion_handler(self)
|
||||
|
||||
|
||||
def _convert_text_completion_to_chat(
|
||||
|
@ -426,16 +345,6 @@ def chat_formatter_to_chat_completion_handler(
|
|||
return chat_completion_handler
|
||||
|
||||
|
||||
def register_chat_format(name: str):
|
||||
def decorator(f: ChatFormatter):
|
||||
chat_completion_handler = chat_formatter_to_chat_completion_handler(f)
|
||||
LlamaChatCompletionHandlerRegistry().register_chat_completion_handler(
|
||||
name, chat_completion_handler
|
||||
)
|
||||
return f
|
||||
return decorator
|
||||
|
||||
|
||||
def hf_autotokenizer_to_chat_formatter(
|
||||
pretrained_model_name_or_path: Union[str, os.PathLike[str]]
|
||||
) -> ChatFormatter:
|
||||
|
@ -466,7 +375,9 @@ def hf_autotokenizer_to_chat_completion_handler(
|
|||
return chat_formatter_to_chat_completion_handler(chat_formatter)
|
||||
|
||||
|
||||
def hf_tokenizer_config_to_chat_formatter(tokenizer_config: Dict[str, Any]) -> ChatFormatter:
|
||||
def hf_tokenizer_config_to_chat_formatter(
|
||||
tokenizer_config: Dict[str, Any]
|
||||
) -> ChatFormatter:
|
||||
assert isinstance(tokenizer_config, dict)
|
||||
|
||||
assert "chat_template" in tokenizer_config
|
||||
|
@ -504,6 +415,7 @@ def hf_tokenizer_config_to_chat_formatter(tokenizer_config: Dict[str, Any]) -> C
|
|||
eos_token=eos_token,
|
||||
)
|
||||
return ChatFormatterResponse(prompt=prompt, stop=eos_token)
|
||||
|
||||
return format_autotokenizer
|
||||
|
||||
|
||||
|
@ -514,6 +426,147 @@ def hf_tokenizer_config_to_chat_completion_handler(
|
|||
return chat_formatter_to_chat_completion_handler(chat_formatter)
|
||||
|
||||
|
||||
### Utility functions for formatting chat prompts ###
|
||||
|
||||
|
||||
def _get_system_message(
|
||||
messages: List[llama_types.ChatCompletionRequestMessage],
|
||||
) -> str:
|
||||
"""Get the first system message."""
|
||||
for message in messages:
|
||||
if message["role"] == "system":
|
||||
return message["content"] or ""
|
||||
return ""
|
||||
|
||||
|
||||
def _map_roles(
|
||||
messages: List[llama_types.ChatCompletionRequestMessage],
|
||||
role_map: Dict[str, str],
|
||||
) -> List[Tuple[str, Optional[str]]]:
|
||||
"""Map the message roles."""
|
||||
output: List[Tuple[str, Optional[str]]] = []
|
||||
for message in messages:
|
||||
role = message["role"]
|
||||
if role in role_map:
|
||||
content: str | None = (
|
||||
message["content"] if isinstance(message["content"], str) else None
|
||||
)
|
||||
output.append((role_map[role], content))
|
||||
return output
|
||||
|
||||
|
||||
def _format_llama2(
|
||||
system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str, sep2: str
|
||||
) -> str:
|
||||
"""Format the prompt with the llama2 style."""
|
||||
seps = [sep, sep2]
|
||||
ret = system_message + sep
|
||||
for i, (role, message) in enumerate(messages):
|
||||
if system_message and i == 0:
|
||||
m = message or ""
|
||||
ret += m + seps[i % 2]
|
||||
elif message:
|
||||
ret += role + message + " " + seps[i % 2]
|
||||
else:
|
||||
ret += role + " "
|
||||
return ret
|
||||
|
||||
|
||||
def _format_add_colon_single(
|
||||
system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str
|
||||
) -> str:
|
||||
"""Format the prompt with the add-colon-single style."""
|
||||
ret = system_message + sep
|
||||
for role, message in messages:
|
||||
if message:
|
||||
ret += role + ": " + message + sep
|
||||
else:
|
||||
ret += role + ":"
|
||||
return ret
|
||||
|
||||
|
||||
def _format_add_colon_two(
|
||||
system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str, sep2: str
|
||||
) -> str:
|
||||
"""Format the prompt with the add-colon-two style."""
|
||||
seps = [sep, sep2]
|
||||
ret = system_message + seps[0]
|
||||
for i, (role, message) in enumerate(messages):
|
||||
if message:
|
||||
ret += role + ": " + message + seps[i % 2]
|
||||
else:
|
||||
ret += role + ":"
|
||||
return ret
|
||||
|
||||
|
||||
def _format_no_colon_single(
|
||||
system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str
|
||||
) -> str:
|
||||
"""Format the prompt with the no-colon-single style."""
|
||||
ret = system_message
|
||||
for role, message in messages:
|
||||
if message:
|
||||
ret += role + message + sep
|
||||
else:
|
||||
ret += role
|
||||
return ret
|
||||
|
||||
|
||||
def _format_add_colon_space_single(
|
||||
system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str
|
||||
) -> str:
|
||||
"""Format the prompt with the add-colon-space-single style."""
|
||||
ret = system_message + sep
|
||||
for role, message in messages:
|
||||
if message:
|
||||
ret += role + ": " + message + sep
|
||||
else:
|
||||
ret += role + ": " # must be end with a space
|
||||
return ret
|
||||
|
||||
|
||||
def _format_chatml(
|
||||
system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str
|
||||
) -> str:
|
||||
"""Format the prompt with the chatml style."""
|
||||
ret = "" if system_message == "" else system_message + sep + "\n"
|
||||
for role, message in messages:
|
||||
if message:
|
||||
ret += role + "\n" + message + sep + "\n"
|
||||
else:
|
||||
ret += role + "\n"
|
||||
return ret
|
||||
|
||||
|
||||
def _format_chatglm3(
|
||||
system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str
|
||||
) -> str:
|
||||
"""Format the prompt with the chatglm3 style."""
|
||||
ret = ""
|
||||
if system_message:
|
||||
ret += system_message
|
||||
for role, message in messages:
|
||||
if message:
|
||||
ret += role + "\n" + " " + message
|
||||
else:
|
||||
ret += role
|
||||
return ret
|
||||
|
||||
|
||||
### Chat Formats ###
|
||||
|
||||
|
||||
def register_chat_format(name: str):
|
||||
def decorator(f: ChatFormatter):
|
||||
chat_completion_handler = chat_formatter_to_chat_completion_handler(f)
|
||||
LlamaChatCompletionHandlerRegistry().register_chat_completion_handler(
|
||||
name, chat_completion_handler
|
||||
)
|
||||
return f
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
# see https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/tokenization_llama.py
|
||||
# system prompt is "embedded" in the first message
|
||||
@register_chat_format("llama-2")
|
||||
|
|
Loading…
Reference in a new issue