From be09318c26add8674ce494ae7cc480cce72a4146 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Fri, 19 Jan 2024 15:04:42 -0500 Subject: [PATCH] feat: Add Jinja2ChatFormatter --- llama_cpp/llama_chat_format.py | 323 +++++++++++++++++++-------------- 1 file changed, 188 insertions(+), 135 deletions(-) diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index 3d18d90..4e1b174 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -121,14 +121,21 @@ def register_chat_completion_handler(name: str): @dataclasses.dataclass class ChatFormatterResponse: + """Dataclass that stores completion parameters for a given chat format and + create_chat_completion request. + + prompt contains the formatted prompt generated from the chat format and messages. + stop contains the stop token or list of stop tokens to use for the chat format.""" + prompt: str stop: Optional[Union[str, List[str]]] = None class ChatFormatter(Protocol): """Base Protocol for a chat formatter. A chat formatter is a function that - takes a list of messages and returns a formatted prompt. It can also return - a stop token or list of stop tokens to use for the completion.""" + takes a list of messages and returns a chat format response which can be used + to generate a completion. The response can also include a stop token or list + of stop tokens to use for the completion.""" def __call__( self, @@ -139,131 +146,43 @@ class ChatFormatter(Protocol): ... -### Utility functions for formatting chat prompts ### +class Jinja2ChatFormatter(ChatFormatter): + def __init__( + self, + template: str, + eos_token: str, + bos_token: str, + ): + """A chat formatter that uses jinja2 templates to format the prompt.""" + self.template = template + self.eos_token = eos_token + self.bos_token = bos_token + self._environment = jinja2.Environment( + loader=jinja2.BaseLoader(), + trim_blocks=True, + lstrip_blocks=True, + ).from_string(self.template) -def _get_system_message( - messages: List[llama_types.ChatCompletionRequestMessage], -) -> str: - """Get the first system message.""" - for message in messages: - if message["role"] == "system": - return message["content"] or "" - return "" + def __call__( + self, + *, + messages: List[llama_types.ChatCompletionRequestMessage], + **kwargs: Any, + ) -> ChatFormatterResponse: + messages = [ + *messages, + llama_types.ChatCompletionRequestAssistantMessage( + role="assistant", content="" + ), + ] + prompt = self._environment.render( + messages=messages, eos_token=self.eos_token, bos_token=self.bos_token + ) + return ChatFormatterResponse(prompt=prompt, stop=[self.eos_token]) - -def _map_roles( - messages: List[llama_types.ChatCompletionRequestMessage], - role_map: Dict[str, str], -) -> List[Tuple[str, Optional[str]]]: - """Map the message roles.""" - output: List[Tuple[str, Optional[str]]] = [] - for message in messages: - role = message["role"] - if role in role_map: - content: str | None = ( - message["content"] if isinstance(message["content"], str) else None - ) - output.append((role_map[role], content)) - return output - - -def _format_llama2( - system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str, sep2: str -) -> str: - """Format the prompt with the llama2 style.""" - seps = [sep, sep2] - ret = system_message + sep - for i, (role, message) in enumerate(messages): - if system_message and i == 0: - m = message or "" - ret += m + seps[i % 2] - elif message: - ret += role + message + " " + seps[i % 2] - else: - ret += role + " " - return ret - - -def _format_add_colon_single( - system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str -) -> str: - """Format the prompt with the add-colon-single style.""" - ret = system_message + sep - for role, message in messages: - if message: - ret += role + ": " + message + sep - else: - ret += role + ":" - return ret - - -def _format_add_colon_two( - system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str, sep2: str -) -> str: - """Format the prompt with the add-colon-two style.""" - seps = [sep, sep2] - ret = system_message + seps[0] - for i, (role, message) in enumerate(messages): - if message: - ret += role + ": " + message + seps[i % 2] - else: - ret += role + ":" - return ret - - -def _format_no_colon_single( - system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str -) -> str: - """Format the prompt with the no-colon-single style.""" - ret = system_message - for role, message in messages: - if message: - ret += role + message + sep - else: - ret += role - return ret - - -def _format_add_colon_space_single( - system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str -) -> str: - """Format the prompt with the add-colon-space-single style.""" - ret = system_message + sep - for role, message in messages: - if message: - ret += role + ": " + message + sep - else: - ret += role + ": " # must be end with a space - return ret - - -def _format_chatml( - system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str -) -> str: - """Format the prompt with the chatml style.""" - ret = "" if system_message == "" else system_message + sep + "\n" - for role, message in messages: - if message: - ret += role + "\n" + message + sep + "\n" - else: - ret += role + "\n" - return ret - - -def _format_chatglm3( - system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str -) -> str: - """Format the prompt with the chatglm3 style.""" - ret = "" - if system_message: - ret += system_message - for role, message in messages: - if message: - ret += role + "\n" + " " + message - else: - ret += role - return ret + def to_chat_handler(self) -> LlamaChatCompletionHandler: + return chat_formatter_to_chat_completion_handler(self) def _convert_text_completion_to_chat( @@ -426,16 +345,6 @@ def chat_formatter_to_chat_completion_handler( return chat_completion_handler -def register_chat_format(name: str): - def decorator(f: ChatFormatter): - chat_completion_handler = chat_formatter_to_chat_completion_handler(f) - LlamaChatCompletionHandlerRegistry().register_chat_completion_handler( - name, chat_completion_handler - ) - return f - return decorator - - def hf_autotokenizer_to_chat_formatter( pretrained_model_name_or_path: Union[str, os.PathLike[str]] ) -> ChatFormatter: @@ -466,7 +375,9 @@ def hf_autotokenizer_to_chat_completion_handler( return chat_formatter_to_chat_completion_handler(chat_formatter) -def hf_tokenizer_config_to_chat_formatter(tokenizer_config: Dict[str, Any]) -> ChatFormatter: +def hf_tokenizer_config_to_chat_formatter( + tokenizer_config: Dict[str, Any] +) -> ChatFormatter: assert isinstance(tokenizer_config, dict) assert "chat_template" in tokenizer_config @@ -504,6 +415,7 @@ def hf_tokenizer_config_to_chat_formatter(tokenizer_config: Dict[str, Any]) -> C eos_token=eos_token, ) return ChatFormatterResponse(prompt=prompt, stop=eos_token) + return format_autotokenizer @@ -514,6 +426,147 @@ def hf_tokenizer_config_to_chat_completion_handler( return chat_formatter_to_chat_completion_handler(chat_formatter) +### Utility functions for formatting chat prompts ### + + +def _get_system_message( + messages: List[llama_types.ChatCompletionRequestMessage], +) -> str: + """Get the first system message.""" + for message in messages: + if message["role"] == "system": + return message["content"] or "" + return "" + + +def _map_roles( + messages: List[llama_types.ChatCompletionRequestMessage], + role_map: Dict[str, str], +) -> List[Tuple[str, Optional[str]]]: + """Map the message roles.""" + output: List[Tuple[str, Optional[str]]] = [] + for message in messages: + role = message["role"] + if role in role_map: + content: str | None = ( + message["content"] if isinstance(message["content"], str) else None + ) + output.append((role_map[role], content)) + return output + + +def _format_llama2( + system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str, sep2: str +) -> str: + """Format the prompt with the llama2 style.""" + seps = [sep, sep2] + ret = system_message + sep + for i, (role, message) in enumerate(messages): + if system_message and i == 0: + m = message or "" + ret += m + seps[i % 2] + elif message: + ret += role + message + " " + seps[i % 2] + else: + ret += role + " " + return ret + + +def _format_add_colon_single( + system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str +) -> str: + """Format the prompt with the add-colon-single style.""" + ret = system_message + sep + for role, message in messages: + if message: + ret += role + ": " + message + sep + else: + ret += role + ":" + return ret + + +def _format_add_colon_two( + system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str, sep2: str +) -> str: + """Format the prompt with the add-colon-two style.""" + seps = [sep, sep2] + ret = system_message + seps[0] + for i, (role, message) in enumerate(messages): + if message: + ret += role + ": " + message + seps[i % 2] + else: + ret += role + ":" + return ret + + +def _format_no_colon_single( + system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str +) -> str: + """Format the prompt with the no-colon-single style.""" + ret = system_message + for role, message in messages: + if message: + ret += role + message + sep + else: + ret += role + return ret + + +def _format_add_colon_space_single( + system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str +) -> str: + """Format the prompt with the add-colon-space-single style.""" + ret = system_message + sep + for role, message in messages: + if message: + ret += role + ": " + message + sep + else: + ret += role + ": " # must be end with a space + return ret + + +def _format_chatml( + system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str +) -> str: + """Format the prompt with the chatml style.""" + ret = "" if system_message == "" else system_message + sep + "\n" + for role, message in messages: + if message: + ret += role + "\n" + message + sep + "\n" + else: + ret += role + "\n" + return ret + + +def _format_chatglm3( + system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str +) -> str: + """Format the prompt with the chatglm3 style.""" + ret = "" + if system_message: + ret += system_message + for role, message in messages: + if message: + ret += role + "\n" + " " + message + else: + ret += role + return ret + + +### Chat Formats ### + + +def register_chat_format(name: str): + def decorator(f: ChatFormatter): + chat_completion_handler = chat_formatter_to_chat_completion_handler(f) + LlamaChatCompletionHandlerRegistry().register_chat_completion_handler( + name, chat_completion_handler + ) + return f + + return decorator + + # see https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/tokenization_llama.py # system prompt is "embedded" in the first message @register_chat_format("llama-2")