feat: Add add_generation_prompt option for jinja2chatformatter.

This commit is contained in:
Andrei Betlen 2024-01-21 18:37:24 -05:00
parent ac2e96d4b4
commit 7f3209b1eb

View file

@ -152,11 +152,13 @@ class Jinja2ChatFormatter(ChatFormatter):
template: str, template: str,
eos_token: str, eos_token: str,
bos_token: str, bos_token: str,
add_generation_prompt: bool = True,
): ):
"""A chat formatter that uses jinja2 templates to format the prompt.""" """A chat formatter that uses jinja2 templates to format the prompt."""
self.template = template self.template = template
self.eos_token = eos_token self.eos_token = eos_token
self.bos_token = bos_token self.bos_token = bos_token
self.add_generation_prompt = add_generation_prompt
self._environment = jinja2.Environment( self._environment = jinja2.Environment(
loader=jinja2.BaseLoader(), loader=jinja2.BaseLoader(),
@ -170,12 +172,13 @@ class Jinja2ChatFormatter(ChatFormatter):
messages: List[llama_types.ChatCompletionRequestMessage], messages: List[llama_types.ChatCompletionRequestMessage],
**kwargs: Any, **kwargs: Any,
) -> ChatFormatterResponse: ) -> ChatFormatterResponse:
messages = [ if self.add_generation_prompt:
*messages, messages = [
llama_types.ChatCompletionRequestAssistantMessage( *messages,
role="assistant", content="" llama_types.ChatCompletionRequestAssistantMessage(
), role="assistant", content=""
] ),
]
prompt = self._environment.render( prompt = self._environment.render(
messages=messages, eos_token=self.eos_token, bos_token=self.bos_token messages=messages, eos_token=self.eos_token, bos_token=self.bos_token
) )