Update llama_chat_format.py (#869)
* Update llama_chat_format.py properly formal llama2 with first-message prompt embedded * Update llama_chat_format.py
This commit is contained in:
parent
f0b30ef7dc
commit
3580e2c5df
1 changed files with 15 additions and 11 deletions
|
@ -73,13 +73,16 @@ def _map_roles(
|
|||
|
||||
|
||||
def _format_llama2(
|
||||
system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str
|
||||
system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str, sep2: str
|
||||
) -> str:
|
||||
"""Format the prompt with the llama2 style."""
|
||||
seps = [sep, sep2]
|
||||
ret = system_message + sep
|
||||
for role, message in messages:
|
||||
if message:
|
||||
ret += role + message + " "
|
||||
for i, (role, message) in enumerate(messages):
|
||||
if system_message and i == 0:
|
||||
ret += message + seps[i % 2]
|
||||
elif message:
|
||||
ret += role + message + " " + seps[i % 2]
|
||||
else:
|
||||
ret += role + " "
|
||||
return ret
|
||||
|
@ -324,19 +327,20 @@ def get_chat_format(name: str):
|
|||
)
|
||||
|
||||
|
||||
# see https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/tokenization_llama.py
|
||||
# system prompt is "embedded" in the first message
|
||||
@register_chat_format("llama-2")
|
||||
def format_llama2(
|
||||
messages: List[llama_types.ChatCompletionRequestMessage],
|
||||
**kwargs: Any,
|
||||
) -> ChatFormatterResponse:
|
||||
_system_template = "[INST] <<SYS>>\n{system_message}\n<</SYS>>\n\n"
|
||||
_roles = dict(user="[INST]", assistant="[/INST]")
|
||||
_sep = "\n\n"
|
||||
system_message = _get_system_message(messages)
|
||||
system_message = _system_template.format(system_message=system_message)
|
||||
_system_template = "<s>[INST] <<SYS>>\n{system_message}\n<</SYS>>"
|
||||
_roles = dict(user="<s>[INST]", assistant="[/INST]")
|
||||
_messages = _map_roles(messages, _roles)
|
||||
_messages.append((_roles["assistant"], None))
|
||||
_prompt = _format_llama2(system_message, _messages, _sep)
|
||||
system_message = _get_system_message(messages)
|
||||
if system_message:
|
||||
system_message = _system_template.format(system_message=system_message)
|
||||
_prompt = _format_llama2(system_message, _messages, " ", "</s>") + "[/INST]"
|
||||
return ChatFormatterResponse(prompt=_prompt)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in a new issue