fix: Pass raise_exception and add_generation_prompt to jinja2 chat template
This commit is contained in:
parent
411494706a
commit
078cca0361
1 changed files with 9 additions and 8 deletions
|
@ -185,16 +185,17 @@ class Jinja2ChatFormatter(ChatFormatter):
|
|||
messages: List[llama_types.ChatCompletionRequestMessage],
|
||||
**kwargs: Any,
|
||||
) -> ChatFormatterResponse:
|
||||
if self.add_generation_prompt:
|
||||
messages = [
|
||||
*messages,
|
||||
llama_types.ChatCompletionRequestAssistantMessage(
|
||||
role="assistant", content=""
|
||||
),
|
||||
]
|
||||
def raise_exception(message: str):
|
||||
raise ValueError(message)
|
||||
|
||||
prompt = self._environment.render(
|
||||
messages=messages, eos_token=self.eos_token, bos_token=self.bos_token
|
||||
messages=messages,
|
||||
eos_token=self.eos_token,
|
||||
bos_token=self.bos_token,
|
||||
raise_exception=raise_exception,
|
||||
add_generation_prompt=self.add_generation_prompt
|
||||
)
|
||||
|
||||
return ChatFormatterResponse(prompt=prompt, stop=[self.eos_token])
|
||||
|
||||
def to_chat_handler(self) -> LlamaChatCompletionHandler:
|
||||
|
|
Loading…
Reference in a new issue