diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index 02bdbcf..6c274aa 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -379,7 +379,8 @@ def hf_autotokenizer_to_chat_completion_handler( def hf_tokenizer_config_to_chat_formatter( - tokenizer_config: Dict[str, Any] + tokenizer_config: Dict[str, Any], + add_generation_prompt: bool = True, ) -> ChatFormatter: assert isinstance(tokenizer_config, dict) @@ -401,31 +402,34 @@ def hf_tokenizer_config_to_chat_formatter( lstrip_blocks=True, ).from_string(chat_template) - def format_autotokenizer( + def format_tokenizer_config( messages: List[llama_types.ChatCompletionRequestMessage], **kwargs: Any, ) -> ChatFormatterResponse: # TODO: veryify this is correct # Add a blank assistant message to the end of the messages to prompt the model to generate a response - prompt = env.render( - messages=[ + if add_generation_prompt: + messages = [ *messages, llama_types.ChatCompletionRequestAssistantMessage( role="assistant", content="" ), - ], + ] + prompt = env.render( + messages=messages, bos_token=bos_token, eos_token=eos_token, ) - return ChatFormatterResponse(prompt=prompt, stop=eos_token) + return ChatFormatterResponse(prompt=prompt, stop=[eos_token, bos_token]) - return format_autotokenizer + return format_tokenizer_config def hf_tokenizer_config_to_chat_completion_handler( tokenizer_config: Dict[str, Any], + add_generation_prompt: bool = True, ) -> LlamaChatCompletionHandler: - chat_formatter = hf_tokenizer_config_to_chat_formatter(tokenizer_config) + chat_formatter = hf_tokenizer_config_to_chat_formatter(tokenizer_config, add_generation_prompt=add_generation_prompt) return chat_formatter_to_chat_completion_handler(chat_formatter)