fix: Pass raise_exception and add_generation_prompt to jinja2 chat template

This commit is contained in:
Andrei Betlen 2024-01-31 08:42:21 -05:00
parent 411494706a
commit 078cca0361

View file

@ -185,16 +185,17 @@ class Jinja2ChatFormatter(ChatFormatter):
messages: List[llama_types.ChatCompletionRequestMessage],
**kwargs: Any,
) -> ChatFormatterResponse:
if self.add_generation_prompt:
messages = [
*messages,
llama_types.ChatCompletionRequestAssistantMessage(
role="assistant", content=""
),
]
def raise_exception(message: str):
raise ValueError(message)
prompt = self._environment.render(
messages=messages, eos_token=self.eos_token, bos_token=self.bos_token
messages=messages,
eos_token=self.eos_token,
bos_token=self.bos_token,
raise_exception=raise_exception,
add_generation_prompt=self.add_generation_prompt
)
return ChatFormatterResponse(prompt=prompt, stop=[self.eos_token])
def to_chat_handler(self) -> LlamaChatCompletionHandler: