From 3580e2c5df79414cacc993ead38457c5317ad5cf Mon Sep 17 00:00:00 2001 From: earonesty Date: Sun, 5 Nov 2023 17:00:13 -0500 Subject: [PATCH] Update llama_chat_format.py (#869) * Update llama_chat_format.py properly formal llama2 with first-message prompt embedded * Update llama_chat_format.py --- llama_cpp/llama_chat_format.py | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index f92793d..36382a4 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -73,13 +73,16 @@ def _map_roles( def _format_llama2( - system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str + system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str, sep2: str ) -> str: """Format the prompt with the llama2 style.""" + seps = [sep, sep2] ret = system_message + sep - for role, message in messages: - if message: - ret += role + message + " " + for i, (role, message) in enumerate(messages): + if system_message and i == 0: + ret += message + seps[i % 2] + elif message: + ret += role + message + " " + seps[i % 2] else: ret += role + " " return ret @@ -324,19 +327,20 @@ def get_chat_format(name: str): ) +# see https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/tokenization_llama.py +# system prompt is "embedded" in the first message @register_chat_format("llama-2") def format_llama2( messages: List[llama_types.ChatCompletionRequestMessage], **kwargs: Any, ) -> ChatFormatterResponse: - _system_template = "[INST] <>\n{system_message}\n<>\n\n" - _roles = dict(user="[INST]", assistant="[/INST]") - _sep = "\n\n" - system_message = _get_system_message(messages) - system_message = _system_template.format(system_message=system_message) + _system_template = "[INST] <>\n{system_message}\n<>" + _roles = dict(user="[INST]", assistant="[/INST]") _messages = _map_roles(messages, _roles) - _messages.append((_roles["assistant"], None)) - _prompt = _format_llama2(system_message, _messages, _sep) + system_message = _get_system_message(messages) + if system_message: + system_message = _system_template.format(system_message=system_message) + _prompt = _format_llama2(system_message, _messages, " ", "") + "[/INST]" return ChatFormatterResponse(prompt=_prompt)