Update llama_chat_format.py (#869)

* Update llama_chat_format.py

properly formal llama2 with first-message prompt embedded

* Update llama_chat_format.py
This commit is contained in:
earonesty 2023-11-05 17:00:13 -05:00 committed by GitHub
parent f0b30ef7dc
commit 3580e2c5df
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -73,13 +73,16 @@ def _map_roles(
def _format_llama2(
system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str
system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str, sep2: str
) -> str:
"""Format the prompt with the llama2 style."""
seps = [sep, sep2]
ret = system_message + sep
for role, message in messages:
if message:
ret += role + message + " "
for i, (role, message) in enumerate(messages):
if system_message and i == 0:
ret += message + seps[i % 2]
elif message:
ret += role + message + " " + seps[i % 2]
else:
ret += role + " "
return ret
@ -324,19 +327,20 @@ def get_chat_format(name: str):
)
# see https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/tokenization_llama.py
# system prompt is "embedded" in the first message
@register_chat_format("llama-2")
def format_llama2(
messages: List[llama_types.ChatCompletionRequestMessage],
**kwargs: Any,
) -> ChatFormatterResponse:
_system_template = "[INST] <<SYS>>\n{system_message}\n<</SYS>>\n\n"
_roles = dict(user="[INST]", assistant="[/INST]")
_sep = "\n\n"
system_message = _get_system_message(messages)
system_message = _system_template.format(system_message=system_message)
_system_template = "<s>[INST] <<SYS>>\n{system_message}\n<</SYS>>"
_roles = dict(user="<s>[INST]", assistant="[/INST]")
_messages = _map_roles(messages, _roles)
_messages.append((_roles["assistant"], None))
_prompt = _format_llama2(system_message, _messages, _sep)
system_message = _get_system_message(messages)
if system_message:
system_message = _system_template.format(system_message=system_message)
_prompt = _format_llama2(system_message, _messages, " ", "</s>") + "[/INST]"
return ChatFormatterResponse(prompt=_prompt)