fix: tokenization of special characters: (#850)

It should behave like llama.cpp, where most out of the box usages
treat special characters accordingly
This commit is contained in:
Antoine Lizee 2023-11-02 01:29:06 +00:00 committed by Andrei Betlen
parent 952e4cc3ce
commit 4d4e0f11e2
4 changed files with 13 additions and 4 deletions

View file

@ -856,7 +856,7 @@ class Llama:
data: List[EmbeddingData] = [] data: List[EmbeddingData] = []
total_tokens = 0 total_tokens = 0
for index, input in enumerate(inputs): for index, input in enumerate(inputs):
tokens = self.tokenize(input.encode("utf-8")) tokens = self.tokenize(input.encode("utf-8"), special=True)
self.reset() self.reset()
self.eval(tokens) self.eval(tokens)
n_tokens = len(tokens) n_tokens = len(tokens)
@ -928,7 +928,7 @@ class Llama:
completion_tokens: List[int] = [] completion_tokens: List[int] = []
# Add blank space to start of prompt to match OG llama tokenizer # Add blank space to start of prompt to match OG llama tokenizer
prompt_tokens: List[int] = ( prompt_tokens: List[int] = (
self.tokenize(prompt.encode("utf-8")) self.tokenize(prompt.encode("utf-8"), special=True)
if prompt != "" if prompt != ""
else [self.token_bos()] else [self.token_bos()]
) )
@ -1826,7 +1826,7 @@ class LlamaTokenizer:
def encode(self, text: str, add_bos: bool = True) -> List[int]: def encode(self, text: str, add_bos: bool = True) -> List[int]:
return self.llama.tokenize( return self.llama.tokenize(
text.encode("utf-8", errors="ignore"), add_bos=add_bos text.encode("utf-8", errors="ignore"), add_bos=add_bos, special=True
) )
def decode(self, tokens: List[int]) -> str: def decode(self, tokens: List[int]) -> str:

View file

@ -594,7 +594,7 @@ def make_logit_bias_processor(
elif logit_bias_type == "tokens": elif logit_bias_type == "tokens":
for token, score in logit_bias.items(): for token, score in logit_bias.items():
token = token.encode("utf-8") token = token.encode("utf-8")
for input_id in llama.tokenize(token, add_bos=False): for input_id in llama.tokenize(token, add_bos=False, special=True):
to_bias[input_id] = score to_bias[input_id] = score
def logit_bias_processor( def logit_bias_processor(

0
test.py Normal file
View file

View file

@ -25,6 +25,15 @@ def test_llama_cpp_tokenization():
detokenized = llama.detokenize(tokens) detokenized = llama.detokenize(tokens)
assert detokenized != text assert detokenized != text
text = b"Hello World</s>"
tokens = llama.tokenize(text)
assert tokens[-1] != llama.token_eos()
assert tokens == [1, 15043, 2787, 829, 29879, 29958]
tokens = llama.tokenize(text, special=True)
assert tokens[-1] == llama.token_eos()
assert tokens == [1, 10994, 2787, 2]
def test_llama_patch(monkeypatch): def test_llama_patch(monkeypatch):
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True) llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True)