fix: use both eos and bos tokens as stop sequences for hf-tokenizer-config chat format.

This commit is contained in:
Andrei Betlen 2024-01-22 08:32:48 -05:00
parent 2ce0b8aa2c
commit 5b982d0f8c

View file

@ -379,7 +379,8 @@ def hf_autotokenizer_to_chat_completion_handler(
def hf_tokenizer_config_to_chat_formatter(
tokenizer_config: Dict[str, Any]
tokenizer_config: Dict[str, Any],
add_generation_prompt: bool = True,
) -> ChatFormatter:
assert isinstance(tokenizer_config, dict)
@ -401,31 +402,34 @@ def hf_tokenizer_config_to_chat_formatter(
lstrip_blocks=True,
).from_string(chat_template)
def format_autotokenizer(
def format_tokenizer_config(
messages: List[llama_types.ChatCompletionRequestMessage],
**kwargs: Any,
) -> ChatFormatterResponse:
# TODO: veryify this is correct
# Add a blank assistant message to the end of the messages to prompt the model to generate a response
prompt = env.render(
if add_generation_prompt:
messages = [
*messages,
llama_types.ChatCompletionRequestAssistantMessage(
role="assistant", content=""
),
],
]
prompt = env.render(
messages=messages,
bos_token=bos_token,
eos_token=eos_token,
)
return ChatFormatterResponse(prompt=prompt, stop=eos_token)
return ChatFormatterResponse(prompt=prompt, stop=[eos_token, bos_token])
return format_autotokenizer
return format_tokenizer_config
def hf_tokenizer_config_to_chat_completion_handler(
tokenizer_config: Dict[str, Any],
add_generation_prompt: bool = True,
) -> LlamaChatCompletionHandler:
chat_formatter = hf_tokenizer_config_to_chat_formatter(tokenizer_config)
chat_formatter = hf_tokenizer_config_to_chat_formatter(tokenizer_config, add_generation_prompt=add_generation_prompt)
return chat_formatter_to_chat_completion_handler(chat_formatter)