fix: Avoid duplicate special tokens in chat formats (#1439)

* Templates sometimes have BOS in them, remove duplicate

* tokenize chat format prompts before completion

This is to ensure that we don't duplicate any special tokens.

Hopefully I amended the existing formats correctly?

* updated comment

* corrected a few

* add some missing internals

* proper bos/eos detection

* just let tokenizer do the job

* typo--

* align test with new response

* changed to a warning

* move to another PR

* Use python warnings module

---------

Co-authored-by: Andrei Betlen <abetlen@gmail.com>
This commit is contained in:
Sigbjørn Skjæret 2024-06-04 16:15:41 +02:00 committed by GitHub
parent 951e39caf9
commit 027f7bc678
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 25 additions and 10 deletions

View file

@ -142,6 +142,14 @@ class _LlamaModel:
assert self.model is not None assert self.model is not None
return llama_cpp.llama_token_eos(self.model) return llama_cpp.llama_token_eos(self.model)
def token_cls(self) -> int:
assert self.model is not None
return llama_cpp.llama_token_cls(self.model)
def token_sep(self) -> int:
assert self.model is not None
return llama_cpp.llama_token_sep(self.model)
def token_nl(self) -> int: def token_nl(self) -> int:
assert self.model is not None assert self.model is not None
return llama_cpp.llama_token_nl(self.model) return llama_cpp.llama_token_nl(self.model)

View file

@ -8,6 +8,7 @@ import json
import ctypes import ctypes
import typing import typing
import fnmatch import fnmatch
import warnings
import multiprocessing import multiprocessing
from typing import ( from typing import (
@ -1019,6 +1020,12 @@ class Llama:
) )
model_name: str = model if model is not None else self.model_path model_name: str = model if model is not None else self.model_path
if prompt_tokens[:2] == [self.token_bos()] * 2:
warnings.warn(
f'Detected duplicate leading "{self._model.token_get_text(self.token_bos())}" in prompt, this will likely reduce response quality, consider removing it...',
RuntimeWarning,
)
# NOTE: This likely doesn't work correctly for the first token in the prompt # NOTE: This likely doesn't work correctly for the first token in the prompt
# because of the extra space added to the start of the prompt_tokens # because of the extra space added to the start of the prompt_tokens
if logit_bias is not None: if logit_bias is not None:

View file

@ -160,6 +160,7 @@ class ChatFormatterResponse:
prompt: str prompt: str
stop: Optional[Union[str, List[str]]] = None stop: Optional[Union[str, List[str]]] = None
stopping_criteria: Optional[llama.StoppingCriteriaList] = None stopping_criteria: Optional[llama.StoppingCriteriaList] = None
added_special: bool = False
class ChatFormatter(Protocol): class ChatFormatter(Protocol):
@ -232,7 +233,7 @@ class Jinja2ChatFormatter(ChatFormatter):
return tokens[-1] in self.stop_token_ids return tokens[-1] in self.stop_token_ids
stopping_criteria = llama.StoppingCriteriaList([stop_on_last_token]) stopping_criteria = llama.StoppingCriteriaList([stop_on_last_token])
return ChatFormatterResponse(prompt=prompt, stop=[self.eos_token], stopping_criteria=stopping_criteria) return ChatFormatterResponse(prompt=prompt, stop=[self.eos_token], stopping_criteria=stopping_criteria, added_special=True)
def to_chat_handler(self) -> LlamaChatCompletionHandler: def to_chat_handler(self) -> LlamaChatCompletionHandler:
return chat_formatter_to_chat_completion_handler(self) return chat_formatter_to_chat_completion_handler(self)
@ -548,7 +549,7 @@ def chat_formatter_to_chat_completion_handler(
tools=tools, tools=tools,
tool_choice=tool_choice, tool_choice=tool_choice,
) )
prompt = result.prompt prompt = llama.tokenize(result.prompt.encode("utf-8"), add_bos=not result.added_special, special=True)
if result.stop is not None: if result.stop is not None:
stop = [] if stop is None else [stop] if isinstance(stop, str) else stop stop = [] if stop is None else [stop] if isinstance(stop, str) else stop
rstop = result.stop if isinstance(result.stop, list) else [result.stop] rstop = result.stop if isinstance(result.stop, list) else [result.stop]
@ -655,7 +656,7 @@ def hf_autotokenizer_to_chat_formatter(
prompt: str = tokenizer.apply_chat_template(messages, tokenize=False) # type: ignore prompt: str = tokenizer.apply_chat_template(messages, tokenize=False) # type: ignore
assert isinstance(prompt, str) assert isinstance(prompt, str)
# Return formatted prompt and eos token by default # Return formatted prompt and eos token by default
return ChatFormatterResponse(prompt=prompt, stop=tokenizer.eos_token) return ChatFormatterResponse(prompt=prompt, stop=tokenizer.eos_token, added_special=True)
return format_autotokenizer return format_autotokenizer
@ -708,7 +709,7 @@ def hf_tokenizer_config_to_chat_formatter(
bos_token=bos_token, bos_token=bos_token,
eos_token=eos_token, eos_token=eos_token,
) )
return ChatFormatterResponse(prompt=prompt, stop=[eos_token, bos_token]) return ChatFormatterResponse(prompt=prompt, stop=[eos_token, bos_token], added_special=True)
return format_tokenizer_config return format_tokenizer_config
@ -918,7 +919,7 @@ def format_llama2(
messages: List[llama_types.ChatCompletionRequestMessage], messages: List[llama_types.ChatCompletionRequestMessage],
**kwargs: Any, **kwargs: Any,
) -> ChatFormatterResponse: ) -> ChatFormatterResponse:
_system_template = "<s>[INST] <<SYS>>\n{system_message}\n<</SYS>>" _system_template = "[INST] <<SYS>>\n{system_message}\n<</SYS>>"
_roles = dict(user="<s>[INST]", assistant="[/INST]") _roles = dict(user="<s>[INST]", assistant="[/INST]")
_messages = _map_roles(messages, _roles) _messages = _map_roles(messages, _roles)
system_message = _get_system_message(messages) system_message = _get_system_message(messages)
@ -940,11 +941,10 @@ def format_llama3(
user="<|start_header_id|>user<|end_header_id|>\n\n", user="<|start_header_id|>user<|end_header_id|>\n\n",
assistant="<|start_header_id|>assistant<|end_header_id|>\n\n", assistant="<|start_header_id|>assistant<|end_header_id|>\n\n",
) )
_begin_token = "<|begin_of_text|>"
_sep = "<|eot_id|>" _sep = "<|eot_id|>"
_messages = _map_roles(messages, _roles) _messages = _map_roles(messages, _roles)
_messages.append((_roles["assistant"], None)) _messages.append((_roles["assistant"], None))
_prompt = _format_no_colon_single(_begin_token, _messages, _sep) _prompt = _format_no_colon_single("", _messages, _sep)
return ChatFormatterResponse(prompt=_prompt, stop=_sep) return ChatFormatterResponse(prompt=_prompt, stop=_sep)
@ -1229,10 +1229,9 @@ def format_mistral_instruct(
messages: List[llama_types.ChatCompletionRequestMessage], messages: List[llama_types.ChatCompletionRequestMessage],
**kwargs: Any, **kwargs: Any,
) -> ChatFormatterResponse: ) -> ChatFormatterResponse:
bos = "<s>"
eos = "</s>" eos = "</s>"
stop = eos stop = eos
prompt = bos prompt = ""
for message in messages: for message in messages:
if ( if (
message["role"] == "user" message["role"] == "user"

View file

@ -21,12 +21,13 @@ def test_mistral_instruct():
response = llama_chat_format.format_mistral_instruct( response = llama_chat_format.format_mistral_instruct(
messages=messages, messages=messages,
) )
prompt = ("" if response.added_special else "<s>") + response.prompt
reference = chat_formatter.render( reference = chat_formatter.render(
messages=messages, messages=messages,
bos_token="<s>", bos_token="<s>",
eos_token="</s>", eos_token="</s>",
) )
assert response.prompt == reference assert prompt == reference
mistral_7b_tokenizer_config = """{ mistral_7b_tokenizer_config = """{