fix: Pass raise_exception and add_generation_prompt to jinja2 chat template
This commit is contained in:
parent
411494706a
commit
078cca0361
1 changed files with 9 additions and 8 deletions
|
@ -185,16 +185,17 @@ class Jinja2ChatFormatter(ChatFormatter):
|
||||||
messages: List[llama_types.ChatCompletionRequestMessage],
|
messages: List[llama_types.ChatCompletionRequestMessage],
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> ChatFormatterResponse:
|
) -> ChatFormatterResponse:
|
||||||
if self.add_generation_prompt:
|
def raise_exception(message: str):
|
||||||
messages = [
|
raise ValueError(message)
|
||||||
*messages,
|
|
||||||
llama_types.ChatCompletionRequestAssistantMessage(
|
|
||||||
role="assistant", content=""
|
|
||||||
),
|
|
||||||
]
|
|
||||||
prompt = self._environment.render(
|
prompt = self._environment.render(
|
||||||
messages=messages, eos_token=self.eos_token, bos_token=self.bos_token
|
messages=messages,
|
||||||
|
eos_token=self.eos_token,
|
||||||
|
bos_token=self.bos_token,
|
||||||
|
raise_exception=raise_exception,
|
||||||
|
add_generation_prompt=self.add_generation_prompt
|
||||||
)
|
)
|
||||||
|
|
||||||
return ChatFormatterResponse(prompt=prompt, stop=[self.eos_token])
|
return ChatFormatterResponse(prompt=prompt, stop=[self.eos_token])
|
||||||
|
|
||||||
def to_chat_handler(self) -> LlamaChatCompletionHandler:
|
def to_chat_handler(self) -> LlamaChatCompletionHandler:
|
||||||
|
|
Loading…
Reference in a new issue