diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py
index f92793d..36382a4 100644
--- a/llama_cpp/llama_chat_format.py
+++ b/llama_cpp/llama_chat_format.py
@@ -73,13 +73,16 @@ def _map_roles(
def _format_llama2(
- system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str
+ system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str, sep2: str
) -> str:
"""Format the prompt with the llama2 style."""
+ seps = [sep, sep2]
ret = system_message + sep
- for role, message in messages:
- if message:
- ret += role + message + " "
+ for i, (role, message) in enumerate(messages):
+ if system_message and i == 0:
+ ret += message + seps[i % 2]
+ elif message:
+ ret += role + message + " " + seps[i % 2]
else:
ret += role + " "
return ret
@@ -324,19 +327,20 @@ def get_chat_format(name: str):
)
+# see https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/tokenization_llama.py
+# system prompt is "embedded" in the first message
@register_chat_format("llama-2")
def format_llama2(
messages: List[llama_types.ChatCompletionRequestMessage],
**kwargs: Any,
) -> ChatFormatterResponse:
- _system_template = "[INST] <>\n{system_message}\n<>\n\n"
- _roles = dict(user="[INST]", assistant="[/INST]")
- _sep = "\n\n"
- system_message = _get_system_message(messages)
- system_message = _system_template.format(system_message=system_message)
+ _system_template = "[INST] <>\n{system_message}\n<>"
+ _roles = dict(user="[INST]", assistant="[/INST]")
_messages = _map_roles(messages, _roles)
- _messages.append((_roles["assistant"], None))
- _prompt = _format_llama2(system_message, _messages, _sep)
+ system_message = _get_system_message(messages)
+ if system_message:
+ system_message = _system_template.format(system_message=system_message)
+ _prompt = _format_llama2(system_message, _messages, " ", "") + "[/INST]"
return ChatFormatterResponse(prompt=_prompt)