fix: Avoid duplicate special tokens in chat formats (#1439)
* Templates sometimes have BOS in them, remove duplicate * tokenize chat format prompts before completion This is to ensure that we don't duplicate any special tokens. Hopefully I amended the existing formats correctly? * updated comment * corrected a few * add some missing internals * proper bos/eos detection * just let tokenizer do the job * typo-- * align test with new response * changed to a warning * move to another PR * Use python warnings module --------- Co-authored-by: Andrei Betlen <abetlen@gmail.com>
This commit is contained in:
parent
951e39caf9
commit
027f7bc678
4 changed files with 25 additions and 10 deletions
|
@ -142,6 +142,14 @@ class _LlamaModel:
|
||||||
assert self.model is not None
|
assert self.model is not None
|
||||||
return llama_cpp.llama_token_eos(self.model)
|
return llama_cpp.llama_token_eos(self.model)
|
||||||
|
|
||||||
|
def token_cls(self) -> int:
|
||||||
|
assert self.model is not None
|
||||||
|
return llama_cpp.llama_token_cls(self.model)
|
||||||
|
|
||||||
|
def token_sep(self) -> int:
|
||||||
|
assert self.model is not None
|
||||||
|
return llama_cpp.llama_token_sep(self.model)
|
||||||
|
|
||||||
def token_nl(self) -> int:
|
def token_nl(self) -> int:
|
||||||
assert self.model is not None
|
assert self.model is not None
|
||||||
return llama_cpp.llama_token_nl(self.model)
|
return llama_cpp.llama_token_nl(self.model)
|
||||||
|
|
|
@ -8,6 +8,7 @@ import json
|
||||||
import ctypes
|
import ctypes
|
||||||
import typing
|
import typing
|
||||||
import fnmatch
|
import fnmatch
|
||||||
|
import warnings
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
|
|
||||||
from typing import (
|
from typing import (
|
||||||
|
@ -1019,6 +1020,12 @@ class Llama:
|
||||||
)
|
)
|
||||||
model_name: str = model if model is not None else self.model_path
|
model_name: str = model if model is not None else self.model_path
|
||||||
|
|
||||||
|
if prompt_tokens[:2] == [self.token_bos()] * 2:
|
||||||
|
warnings.warn(
|
||||||
|
f'Detected duplicate leading "{self._model.token_get_text(self.token_bos())}" in prompt, this will likely reduce response quality, consider removing it...',
|
||||||
|
RuntimeWarning,
|
||||||
|
)
|
||||||
|
|
||||||
# NOTE: This likely doesn't work correctly for the first token in the prompt
|
# NOTE: This likely doesn't work correctly for the first token in the prompt
|
||||||
# because of the extra space added to the start of the prompt_tokens
|
# because of the extra space added to the start of the prompt_tokens
|
||||||
if logit_bias is not None:
|
if logit_bias is not None:
|
||||||
|
|
|
@ -160,6 +160,7 @@ class ChatFormatterResponse:
|
||||||
prompt: str
|
prompt: str
|
||||||
stop: Optional[Union[str, List[str]]] = None
|
stop: Optional[Union[str, List[str]]] = None
|
||||||
stopping_criteria: Optional[llama.StoppingCriteriaList] = None
|
stopping_criteria: Optional[llama.StoppingCriteriaList] = None
|
||||||
|
added_special: bool = False
|
||||||
|
|
||||||
|
|
||||||
class ChatFormatter(Protocol):
|
class ChatFormatter(Protocol):
|
||||||
|
@ -232,7 +233,7 @@ class Jinja2ChatFormatter(ChatFormatter):
|
||||||
return tokens[-1] in self.stop_token_ids
|
return tokens[-1] in self.stop_token_ids
|
||||||
stopping_criteria = llama.StoppingCriteriaList([stop_on_last_token])
|
stopping_criteria = llama.StoppingCriteriaList([stop_on_last_token])
|
||||||
|
|
||||||
return ChatFormatterResponse(prompt=prompt, stop=[self.eos_token], stopping_criteria=stopping_criteria)
|
return ChatFormatterResponse(prompt=prompt, stop=[self.eos_token], stopping_criteria=stopping_criteria, added_special=True)
|
||||||
|
|
||||||
def to_chat_handler(self) -> LlamaChatCompletionHandler:
|
def to_chat_handler(self) -> LlamaChatCompletionHandler:
|
||||||
return chat_formatter_to_chat_completion_handler(self)
|
return chat_formatter_to_chat_completion_handler(self)
|
||||||
|
@ -548,7 +549,7 @@ def chat_formatter_to_chat_completion_handler(
|
||||||
tools=tools,
|
tools=tools,
|
||||||
tool_choice=tool_choice,
|
tool_choice=tool_choice,
|
||||||
)
|
)
|
||||||
prompt = result.prompt
|
prompt = llama.tokenize(result.prompt.encode("utf-8"), add_bos=not result.added_special, special=True)
|
||||||
if result.stop is not None:
|
if result.stop is not None:
|
||||||
stop = [] if stop is None else [stop] if isinstance(stop, str) else stop
|
stop = [] if stop is None else [stop] if isinstance(stop, str) else stop
|
||||||
rstop = result.stop if isinstance(result.stop, list) else [result.stop]
|
rstop = result.stop if isinstance(result.stop, list) else [result.stop]
|
||||||
|
@ -655,7 +656,7 @@ def hf_autotokenizer_to_chat_formatter(
|
||||||
prompt: str = tokenizer.apply_chat_template(messages, tokenize=False) # type: ignore
|
prompt: str = tokenizer.apply_chat_template(messages, tokenize=False) # type: ignore
|
||||||
assert isinstance(prompt, str)
|
assert isinstance(prompt, str)
|
||||||
# Return formatted prompt and eos token by default
|
# Return formatted prompt and eos token by default
|
||||||
return ChatFormatterResponse(prompt=prompt, stop=tokenizer.eos_token)
|
return ChatFormatterResponse(prompt=prompt, stop=tokenizer.eos_token, added_special=True)
|
||||||
|
|
||||||
return format_autotokenizer
|
return format_autotokenizer
|
||||||
|
|
||||||
|
@ -708,7 +709,7 @@ def hf_tokenizer_config_to_chat_formatter(
|
||||||
bos_token=bos_token,
|
bos_token=bos_token,
|
||||||
eos_token=eos_token,
|
eos_token=eos_token,
|
||||||
)
|
)
|
||||||
return ChatFormatterResponse(prompt=prompt, stop=[eos_token, bos_token])
|
return ChatFormatterResponse(prompt=prompt, stop=[eos_token, bos_token], added_special=True)
|
||||||
|
|
||||||
return format_tokenizer_config
|
return format_tokenizer_config
|
||||||
|
|
||||||
|
@ -918,7 +919,7 @@ def format_llama2(
|
||||||
messages: List[llama_types.ChatCompletionRequestMessage],
|
messages: List[llama_types.ChatCompletionRequestMessage],
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> ChatFormatterResponse:
|
) -> ChatFormatterResponse:
|
||||||
_system_template = "<s>[INST] <<SYS>>\n{system_message}\n<</SYS>>"
|
_system_template = "[INST] <<SYS>>\n{system_message}\n<</SYS>>"
|
||||||
_roles = dict(user="<s>[INST]", assistant="[/INST]")
|
_roles = dict(user="<s>[INST]", assistant="[/INST]")
|
||||||
_messages = _map_roles(messages, _roles)
|
_messages = _map_roles(messages, _roles)
|
||||||
system_message = _get_system_message(messages)
|
system_message = _get_system_message(messages)
|
||||||
|
@ -940,11 +941,10 @@ def format_llama3(
|
||||||
user="<|start_header_id|>user<|end_header_id|>\n\n",
|
user="<|start_header_id|>user<|end_header_id|>\n\n",
|
||||||
assistant="<|start_header_id|>assistant<|end_header_id|>\n\n",
|
assistant="<|start_header_id|>assistant<|end_header_id|>\n\n",
|
||||||
)
|
)
|
||||||
_begin_token = "<|begin_of_text|>"
|
|
||||||
_sep = "<|eot_id|>"
|
_sep = "<|eot_id|>"
|
||||||
_messages = _map_roles(messages, _roles)
|
_messages = _map_roles(messages, _roles)
|
||||||
_messages.append((_roles["assistant"], None))
|
_messages.append((_roles["assistant"], None))
|
||||||
_prompt = _format_no_colon_single(_begin_token, _messages, _sep)
|
_prompt = _format_no_colon_single("", _messages, _sep)
|
||||||
return ChatFormatterResponse(prompt=_prompt, stop=_sep)
|
return ChatFormatterResponse(prompt=_prompt, stop=_sep)
|
||||||
|
|
||||||
|
|
||||||
|
@ -1229,10 +1229,9 @@ def format_mistral_instruct(
|
||||||
messages: List[llama_types.ChatCompletionRequestMessage],
|
messages: List[llama_types.ChatCompletionRequestMessage],
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> ChatFormatterResponse:
|
) -> ChatFormatterResponse:
|
||||||
bos = "<s>"
|
|
||||||
eos = "</s>"
|
eos = "</s>"
|
||||||
stop = eos
|
stop = eos
|
||||||
prompt = bos
|
prompt = ""
|
||||||
for message in messages:
|
for message in messages:
|
||||||
if (
|
if (
|
||||||
message["role"] == "user"
|
message["role"] == "user"
|
||||||
|
|
|
@ -21,12 +21,13 @@ def test_mistral_instruct():
|
||||||
response = llama_chat_format.format_mistral_instruct(
|
response = llama_chat_format.format_mistral_instruct(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
)
|
)
|
||||||
|
prompt = ("" if response.added_special else "<s>") + response.prompt
|
||||||
reference = chat_formatter.render(
|
reference = chat_formatter.render(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
bos_token="<s>",
|
bos_token="<s>",
|
||||||
eos_token="</s>",
|
eos_token="</s>",
|
||||||
)
|
)
|
||||||
assert response.prompt == reference
|
assert prompt == reference
|
||||||
|
|
||||||
|
|
||||||
mistral_7b_tokenizer_config = """{
|
mistral_7b_tokenizer_config = """{
|
||||||
|
|
Loading…
Reference in a new issue