feat: Add add_generation_prompt option for jinja2chatformatter.
This commit is contained in:
parent
ac2e96d4b4
commit
7f3209b1eb
1 changed files with 9 additions and 6 deletions
|
@ -152,11 +152,13 @@ class Jinja2ChatFormatter(ChatFormatter):
|
||||||
template: str,
|
template: str,
|
||||||
eos_token: str,
|
eos_token: str,
|
||||||
bos_token: str,
|
bos_token: str,
|
||||||
|
add_generation_prompt: bool = True,
|
||||||
):
|
):
|
||||||
"""A chat formatter that uses jinja2 templates to format the prompt."""
|
"""A chat formatter that uses jinja2 templates to format the prompt."""
|
||||||
self.template = template
|
self.template = template
|
||||||
self.eos_token = eos_token
|
self.eos_token = eos_token
|
||||||
self.bos_token = bos_token
|
self.bos_token = bos_token
|
||||||
|
self.add_generation_prompt = add_generation_prompt
|
||||||
|
|
||||||
self._environment = jinja2.Environment(
|
self._environment = jinja2.Environment(
|
||||||
loader=jinja2.BaseLoader(),
|
loader=jinja2.BaseLoader(),
|
||||||
|
@ -170,6 +172,7 @@ class Jinja2ChatFormatter(ChatFormatter):
|
||||||
messages: List[llama_types.ChatCompletionRequestMessage],
|
messages: List[llama_types.ChatCompletionRequestMessage],
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> ChatFormatterResponse:
|
) -> ChatFormatterResponse:
|
||||||
|
if self.add_generation_prompt:
|
||||||
messages = [
|
messages = [
|
||||||
*messages,
|
*messages,
|
||||||
llama_types.ChatCompletionRequestAssistantMessage(
|
llama_types.ChatCompletionRequestAssistantMessage(
|
||||||
|
|
Loading…
Reference in a new issue