feat: Add add_generation_prompt option for jinja2chatformatter.

This commit is contained in:
Andrei Betlen 2024-01-21 18:37:24 -05:00
parent ac2e96d4b4
commit 7f3209b1eb

View file

@ -152,11 +152,13 @@ class Jinja2ChatFormatter(ChatFormatter):
template: str, template: str,
eos_token: str, eos_token: str,
bos_token: str, bos_token: str,
add_generation_prompt: bool = True,
): ):
"""A chat formatter that uses jinja2 templates to format the prompt.""" """A chat formatter that uses jinja2 templates to format the prompt."""
self.template = template self.template = template
self.eos_token = eos_token self.eos_token = eos_token
self.bos_token = bos_token self.bos_token = bos_token
self.add_generation_prompt = add_generation_prompt
self._environment = jinja2.Environment( self._environment = jinja2.Environment(
loader=jinja2.BaseLoader(), loader=jinja2.BaseLoader(),
@ -170,6 +172,7 @@ class Jinja2ChatFormatter(ChatFormatter):
messages: List[llama_types.ChatCompletionRequestMessage], messages: List[llama_types.ChatCompletionRequestMessage],
**kwargs: Any, **kwargs: Any,
) -> ChatFormatterResponse: ) -> ChatFormatterResponse:
if self.add_generation_prompt:
messages = [ messages = [
*messages, *messages,
llama_types.ChatCompletionRequestAssistantMessage( llama_types.ChatCompletionRequestAssistantMessage(