feat: Add add_generation_prompt option for jinja2chatformatter.

This commit is contained in:
Andrei Betlen 2024-01-21 18:37:24 -05:00
parent ac2e96d4b4
commit 7f3209b1eb

View file

@ -152,11 +152,13 @@ class Jinja2ChatFormatter(ChatFormatter):
template: str,
eos_token: str,
bos_token: str,
add_generation_prompt: bool = True,
):
"""A chat formatter that uses jinja2 templates to format the prompt."""
self.template = template
self.eos_token = eos_token
self.bos_token = bos_token
self.add_generation_prompt = add_generation_prompt
self._environment = jinja2.Environment(
loader=jinja2.BaseLoader(),
@ -170,12 +172,13 @@ class Jinja2ChatFormatter(ChatFormatter):
messages: List[llama_types.ChatCompletionRequestMessage],
**kwargs: Any,
) -> ChatFormatterResponse:
messages = [
*messages,
llama_types.ChatCompletionRequestAssistantMessage(
role="assistant", content=""
),
]
if self.add_generation_prompt:
messages = [
*messages,
llama_types.ChatCompletionRequestAssistantMessage(
role="assistant", content=""
),
]
prompt = self._environment.render(
messages=messages, eos_token=self.eos_token, bos_token=self.bos_token
)