feat: Add add_generation_prompt option for jinja2chatformatter.
This commit is contained in:
parent
ac2e96d4b4
commit
7f3209b1eb
1 changed files with 9 additions and 6 deletions
|
@ -152,11 +152,13 @@ class Jinja2ChatFormatter(ChatFormatter):
|
||||||
template: str,
|
template: str,
|
||||||
eos_token: str,
|
eos_token: str,
|
||||||
bos_token: str,
|
bos_token: str,
|
||||||
|
add_generation_prompt: bool = True,
|
||||||
):
|
):
|
||||||
"""A chat formatter that uses jinja2 templates to format the prompt."""
|
"""A chat formatter that uses jinja2 templates to format the prompt."""
|
||||||
self.template = template
|
self.template = template
|
||||||
self.eos_token = eos_token
|
self.eos_token = eos_token
|
||||||
self.bos_token = bos_token
|
self.bos_token = bos_token
|
||||||
|
self.add_generation_prompt = add_generation_prompt
|
||||||
|
|
||||||
self._environment = jinja2.Environment(
|
self._environment = jinja2.Environment(
|
||||||
loader=jinja2.BaseLoader(),
|
loader=jinja2.BaseLoader(),
|
||||||
|
@ -170,12 +172,13 @@ class Jinja2ChatFormatter(ChatFormatter):
|
||||||
messages: List[llama_types.ChatCompletionRequestMessage],
|
messages: List[llama_types.ChatCompletionRequestMessage],
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> ChatFormatterResponse:
|
) -> ChatFormatterResponse:
|
||||||
messages = [
|
if self.add_generation_prompt:
|
||||||
*messages,
|
messages = [
|
||||||
llama_types.ChatCompletionRequestAssistantMessage(
|
*messages,
|
||||||
role="assistant", content=""
|
llama_types.ChatCompletionRequestAssistantMessage(
|
||||||
),
|
role="assistant", content=""
|
||||||
]
|
),
|
||||||
|
]
|
||||||
prompt = self._environment.render(
|
prompt = self._environment.render(
|
||||||
messages=messages, eos_token=self.eos_token, bos_token=self.bos_token
|
messages=messages, eos_token=self.eos_token, bos_token=self.bos_token
|
||||||
)
|
)
|
||||||
|
|
Loading…
Add table
Reference in a new issue