feat: Add add_generation_prompt option for jinja2chatformatter.
This commit is contained in:
parent
ac2e96d4b4
commit
7f3209b1eb
1 changed files with 9 additions and 6 deletions
|
@ -152,11 +152,13 @@ class Jinja2ChatFormatter(ChatFormatter):
|
|||
template: str,
|
||||
eos_token: str,
|
||||
bos_token: str,
|
||||
add_generation_prompt: bool = True,
|
||||
):
|
||||
"""A chat formatter that uses jinja2 templates to format the prompt."""
|
||||
self.template = template
|
||||
self.eos_token = eos_token
|
||||
self.bos_token = bos_token
|
||||
self.add_generation_prompt = add_generation_prompt
|
||||
|
||||
self._environment = jinja2.Environment(
|
||||
loader=jinja2.BaseLoader(),
|
||||
|
@ -170,12 +172,13 @@ class Jinja2ChatFormatter(ChatFormatter):
|
|||
messages: List[llama_types.ChatCompletionRequestMessage],
|
||||
**kwargs: Any,
|
||||
) -> ChatFormatterResponse:
|
||||
messages = [
|
||||
*messages,
|
||||
llama_types.ChatCompletionRequestAssistantMessage(
|
||||
role="assistant", content=""
|
||||
),
|
||||
]
|
||||
if self.add_generation_prompt:
|
||||
messages = [
|
||||
*messages,
|
||||
llama_types.ChatCompletionRequestAssistantMessage(
|
||||
role="assistant", content=""
|
||||
),
|
||||
]
|
||||
prompt = self._environment.render(
|
||||
messages=messages, eos_token=self.eos_token, bos_token=self.bos_token
|
||||
)
|
||||
|
|
Loading…
Reference in a new issue