feat: Add Jinja2ChatFormatter

This commit is contained in:
Andrei Betlen 2024-01-19 15:04:42 -05:00
parent 5a34c57e54
commit be09318c26

View file

@ -121,14 +121,21 @@ def register_chat_completion_handler(name: str):
@dataclasses.dataclass
class ChatFormatterResponse:
"""Dataclass that stores completion parameters for a given chat format and
create_chat_completion request.
prompt contains the formatted prompt generated from the chat format and messages.
stop contains the stop token or list of stop tokens to use for the chat format."""
prompt: str
stop: Optional[Union[str, List[str]]] = None
class ChatFormatter(Protocol):
"""Base Protocol for a chat formatter. A chat formatter is a function that
takes a list of messages and returns a formatted prompt. It can also return
a stop token or list of stop tokens to use for the completion."""
takes a list of messages and returns a chat format response which can be used
to generate a completion. The response can also include a stop token or list
of stop tokens to use for the completion."""
def __call__(
self,
@ -139,131 +146,43 @@ class ChatFormatter(Protocol):
...
### Utility functions for formatting chat prompts ###
class Jinja2ChatFormatter(ChatFormatter):
def __init__(
self,
template: str,
eos_token: str,
bos_token: str,
):
"""A chat formatter that uses jinja2 templates to format the prompt."""
self.template = template
self.eos_token = eos_token
self.bos_token = bos_token
self._environment = jinja2.Environment(
loader=jinja2.BaseLoader(),
trim_blocks=True,
lstrip_blocks=True,
).from_string(self.template)
def _get_system_message(
messages: List[llama_types.ChatCompletionRequestMessage],
) -> str:
"""Get the first system message."""
for message in messages:
if message["role"] == "system":
return message["content"] or ""
return ""
def __call__(
self,
*,
messages: List[llama_types.ChatCompletionRequestMessage],
**kwargs: Any,
) -> ChatFormatterResponse:
messages = [
*messages,
llama_types.ChatCompletionRequestAssistantMessage(
role="assistant", content=""
),
]
prompt = self._environment.render(
messages=messages, eos_token=self.eos_token, bos_token=self.bos_token
)
return ChatFormatterResponse(prompt=prompt, stop=[self.eos_token])
def _map_roles(
messages: List[llama_types.ChatCompletionRequestMessage],
role_map: Dict[str, str],
) -> List[Tuple[str, Optional[str]]]:
"""Map the message roles."""
output: List[Tuple[str, Optional[str]]] = []
for message in messages:
role = message["role"]
if role in role_map:
content: str | None = (
message["content"] if isinstance(message["content"], str) else None
)
output.append((role_map[role], content))
return output
def _format_llama2(
system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str, sep2: str
) -> str:
"""Format the prompt with the llama2 style."""
seps = [sep, sep2]
ret = system_message + sep
for i, (role, message) in enumerate(messages):
if system_message and i == 0:
m = message or ""
ret += m + seps[i % 2]
elif message:
ret += role + message + " " + seps[i % 2]
else:
ret += role + " "
return ret
def _format_add_colon_single(
system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str
) -> str:
"""Format the prompt with the add-colon-single style."""
ret = system_message + sep
for role, message in messages:
if message:
ret += role + ": " + message + sep
else:
ret += role + ":"
return ret
def _format_add_colon_two(
system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str, sep2: str
) -> str:
"""Format the prompt with the add-colon-two style."""
seps = [sep, sep2]
ret = system_message + seps[0]
for i, (role, message) in enumerate(messages):
if message:
ret += role + ": " + message + seps[i % 2]
else:
ret += role + ":"
return ret
def _format_no_colon_single(
system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str
) -> str:
"""Format the prompt with the no-colon-single style."""
ret = system_message
for role, message in messages:
if message:
ret += role + message + sep
else:
ret += role
return ret
def _format_add_colon_space_single(
system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str
) -> str:
"""Format the prompt with the add-colon-space-single style."""
ret = system_message + sep
for role, message in messages:
if message:
ret += role + ": " + message + sep
else:
ret += role + ": " # must be end with a space
return ret
def _format_chatml(
system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str
) -> str:
"""Format the prompt with the chatml style."""
ret = "" if system_message == "" else system_message + sep + "\n"
for role, message in messages:
if message:
ret += role + "\n" + message + sep + "\n"
else:
ret += role + "\n"
return ret
def _format_chatglm3(
system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str
) -> str:
"""Format the prompt with the chatglm3 style."""
ret = ""
if system_message:
ret += system_message
for role, message in messages:
if message:
ret += role + "\n" + " " + message
else:
ret += role
return ret
def to_chat_handler(self) -> LlamaChatCompletionHandler:
return chat_formatter_to_chat_completion_handler(self)
def _convert_text_completion_to_chat(
@ -426,16 +345,6 @@ def chat_formatter_to_chat_completion_handler(
return chat_completion_handler
def register_chat_format(name: str):
def decorator(f: ChatFormatter):
chat_completion_handler = chat_formatter_to_chat_completion_handler(f)
LlamaChatCompletionHandlerRegistry().register_chat_completion_handler(
name, chat_completion_handler
)
return f
return decorator
def hf_autotokenizer_to_chat_formatter(
pretrained_model_name_or_path: Union[str, os.PathLike[str]]
) -> ChatFormatter:
@ -466,7 +375,9 @@ def hf_autotokenizer_to_chat_completion_handler(
return chat_formatter_to_chat_completion_handler(chat_formatter)
def hf_tokenizer_config_to_chat_formatter(tokenizer_config: Dict[str, Any]) -> ChatFormatter:
def hf_tokenizer_config_to_chat_formatter(
tokenizer_config: Dict[str, Any]
) -> ChatFormatter:
assert isinstance(tokenizer_config, dict)
assert "chat_template" in tokenizer_config
@ -504,6 +415,7 @@ def hf_tokenizer_config_to_chat_formatter(tokenizer_config: Dict[str, Any]) -> C
eos_token=eos_token,
)
return ChatFormatterResponse(prompt=prompt, stop=eos_token)
return format_autotokenizer
@ -514,6 +426,147 @@ def hf_tokenizer_config_to_chat_completion_handler(
return chat_formatter_to_chat_completion_handler(chat_formatter)
### Utility functions for formatting chat prompts ###
def _get_system_message(
messages: List[llama_types.ChatCompletionRequestMessage],
) -> str:
"""Get the first system message."""
for message in messages:
if message["role"] == "system":
return message["content"] or ""
return ""
def _map_roles(
messages: List[llama_types.ChatCompletionRequestMessage],
role_map: Dict[str, str],
) -> List[Tuple[str, Optional[str]]]:
"""Map the message roles."""
output: List[Tuple[str, Optional[str]]] = []
for message in messages:
role = message["role"]
if role in role_map:
content: str | None = (
message["content"] if isinstance(message["content"], str) else None
)
output.append((role_map[role], content))
return output
def _format_llama2(
system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str, sep2: str
) -> str:
"""Format the prompt with the llama2 style."""
seps = [sep, sep2]
ret = system_message + sep
for i, (role, message) in enumerate(messages):
if system_message and i == 0:
m = message or ""
ret += m + seps[i % 2]
elif message:
ret += role + message + " " + seps[i % 2]
else:
ret += role + " "
return ret
def _format_add_colon_single(
system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str
) -> str:
"""Format the prompt with the add-colon-single style."""
ret = system_message + sep
for role, message in messages:
if message:
ret += role + ": " + message + sep
else:
ret += role + ":"
return ret
def _format_add_colon_two(
system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str, sep2: str
) -> str:
"""Format the prompt with the add-colon-two style."""
seps = [sep, sep2]
ret = system_message + seps[0]
for i, (role, message) in enumerate(messages):
if message:
ret += role + ": " + message + seps[i % 2]
else:
ret += role + ":"
return ret
def _format_no_colon_single(
system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str
) -> str:
"""Format the prompt with the no-colon-single style."""
ret = system_message
for role, message in messages:
if message:
ret += role + message + sep
else:
ret += role
return ret
def _format_add_colon_space_single(
system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str
) -> str:
"""Format the prompt with the add-colon-space-single style."""
ret = system_message + sep
for role, message in messages:
if message:
ret += role + ": " + message + sep
else:
ret += role + ": " # must be end with a space
return ret
def _format_chatml(
system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str
) -> str:
"""Format the prompt with the chatml style."""
ret = "" if system_message == "" else system_message + sep + "\n"
for role, message in messages:
if message:
ret += role + "\n" + message + sep + "\n"
else:
ret += role + "\n"
return ret
def _format_chatglm3(
system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str
) -> str:
"""Format the prompt with the chatglm3 style."""
ret = ""
if system_message:
ret += system_message
for role, message in messages:
if message:
ret += role + "\n" + " " + message
else:
ret += role
return ret
### Chat Formats ###
def register_chat_format(name: str):
def decorator(f: ChatFormatter):
chat_completion_handler = chat_formatter_to_chat_completion_handler(f)
LlamaChatCompletionHandlerRegistry().register_chat_completion_handler(
name, chat_completion_handler
)
return f
return decorator
# see https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/tokenization_llama.py
# system prompt is "embedded" in the first message
@register_chat_format("llama-2")