fix: Pass raise_exception and add_generation_prompt to jinja2 chat template

This commit is contained in:
Andrei Betlen 2024-01-31 08:42:21 -05:00
parent 411494706a
commit 078cca0361

View file

@ -185,16 +185,17 @@ class Jinja2ChatFormatter(ChatFormatter):
messages: List[llama_types.ChatCompletionRequestMessage], messages: List[llama_types.ChatCompletionRequestMessage],
**kwargs: Any, **kwargs: Any,
) -> ChatFormatterResponse: ) -> ChatFormatterResponse:
if self.add_generation_prompt: def raise_exception(message: str):
messages = [ raise ValueError(message)
*messages,
llama_types.ChatCompletionRequestAssistantMessage(
role="assistant", content=""
),
]
prompt = self._environment.render( prompt = self._environment.render(
messages=messages, eos_token=self.eos_token, bos_token=self.bos_token messages=messages,
eos_token=self.eos_token,
bos_token=self.bos_token,
raise_exception=raise_exception,
add_generation_prompt=self.add_generation_prompt
) )
return ChatFormatterResponse(prompt=prompt, stop=[self.eos_token]) return ChatFormatterResponse(prompt=prompt, stop=[self.eos_token])
def to_chat_handler(self) -> LlamaChatCompletionHandler: def to_chat_handler(self) -> LlamaChatCompletionHandler: