fix: use both eos and bos tokens as stop sequences for hf-tokenizer-config chat format.

This commit is contained in:
Andrei Betlen 2024-01-22 08:32:48 -05:00
parent 2ce0b8aa2c
commit 5b982d0f8c

View file

@ -379,7 +379,8 @@ def hf_autotokenizer_to_chat_completion_handler(
def hf_tokenizer_config_to_chat_formatter( def hf_tokenizer_config_to_chat_formatter(
tokenizer_config: Dict[str, Any] tokenizer_config: Dict[str, Any],
add_generation_prompt: bool = True,
) -> ChatFormatter: ) -> ChatFormatter:
assert isinstance(tokenizer_config, dict) assert isinstance(tokenizer_config, dict)
@ -401,31 +402,34 @@ def hf_tokenizer_config_to_chat_formatter(
lstrip_blocks=True, lstrip_blocks=True,
).from_string(chat_template) ).from_string(chat_template)
def format_autotokenizer( def format_tokenizer_config(
messages: List[llama_types.ChatCompletionRequestMessage], messages: List[llama_types.ChatCompletionRequestMessage],
**kwargs: Any, **kwargs: Any,
) -> ChatFormatterResponse: ) -> ChatFormatterResponse:
# TODO: veryify this is correct # TODO: veryify this is correct
# Add a blank assistant message to the end of the messages to prompt the model to generate a response # Add a blank assistant message to the end of the messages to prompt the model to generate a response
prompt = env.render( if add_generation_prompt:
messages = [ messages = [
*messages, *messages,
llama_types.ChatCompletionRequestAssistantMessage( llama_types.ChatCompletionRequestAssistantMessage(
role="assistant", content="" role="assistant", content=""
), ),
], ]
prompt = env.render(
messages=messages,
bos_token=bos_token, bos_token=bos_token,
eos_token=eos_token, eos_token=eos_token,
) )
return ChatFormatterResponse(prompt=prompt, stop=eos_token) return ChatFormatterResponse(prompt=prompt, stop=[eos_token, bos_token])
return format_autotokenizer return format_tokenizer_config
def hf_tokenizer_config_to_chat_completion_handler( def hf_tokenizer_config_to_chat_completion_handler(
tokenizer_config: Dict[str, Any], tokenizer_config: Dict[str, Any],
add_generation_prompt: bool = True,
) -> LlamaChatCompletionHandler: ) -> LlamaChatCompletionHandler:
chat_formatter = hf_tokenizer_config_to_chat_formatter(tokenizer_config) chat_formatter = hf_tokenizer_config_to_chat_formatter(tokenizer_config, add_generation_prompt=add_generation_prompt)
return chat_formatter_to_chat_completion_handler(chat_formatter) return chat_formatter_to_chat_completion_handler(chat_formatter)