fix: tokenization of special characters: (#850)
It should behave like llama.cpp, where most out of the box usages treat special characters accordingly
This commit is contained in:
parent
952e4cc3ce
commit
4d4e0f11e2
4 changed files with 13 additions and 4 deletions
|
@ -856,7 +856,7 @@ class Llama:
|
|||
data: List[EmbeddingData] = []
|
||||
total_tokens = 0
|
||||
for index, input in enumerate(inputs):
|
||||
tokens = self.tokenize(input.encode("utf-8"))
|
||||
tokens = self.tokenize(input.encode("utf-8"), special=True)
|
||||
self.reset()
|
||||
self.eval(tokens)
|
||||
n_tokens = len(tokens)
|
||||
|
@ -928,7 +928,7 @@ class Llama:
|
|||
completion_tokens: List[int] = []
|
||||
# Add blank space to start of prompt to match OG llama tokenizer
|
||||
prompt_tokens: List[int] = (
|
||||
self.tokenize(prompt.encode("utf-8"))
|
||||
self.tokenize(prompt.encode("utf-8"), special=True)
|
||||
if prompt != ""
|
||||
else [self.token_bos()]
|
||||
)
|
||||
|
@ -1826,7 +1826,7 @@ class LlamaTokenizer:
|
|||
|
||||
def encode(self, text: str, add_bos: bool = True) -> List[int]:
|
||||
return self.llama.tokenize(
|
||||
text.encode("utf-8", errors="ignore"), add_bos=add_bos
|
||||
text.encode("utf-8", errors="ignore"), add_bos=add_bos, special=True
|
||||
)
|
||||
|
||||
def decode(self, tokens: List[int]) -> str:
|
||||
|
|
|
@ -594,7 +594,7 @@ def make_logit_bias_processor(
|
|||
elif logit_bias_type == "tokens":
|
||||
for token, score in logit_bias.items():
|
||||
token = token.encode("utf-8")
|
||||
for input_id in llama.tokenize(token, add_bos=False):
|
||||
for input_id in llama.tokenize(token, add_bos=False, special=True):
|
||||
to_bias[input_id] = score
|
||||
|
||||
def logit_bias_processor(
|
||||
|
|
0
test.py
Normal file
0
test.py
Normal file
|
@ -25,6 +25,15 @@ def test_llama_cpp_tokenization():
|
|||
detokenized = llama.detokenize(tokens)
|
||||
assert detokenized != text
|
||||
|
||||
text = b"Hello World</s>"
|
||||
tokens = llama.tokenize(text)
|
||||
assert tokens[-1] != llama.token_eos()
|
||||
assert tokens == [1, 15043, 2787, 829, 29879, 29958]
|
||||
|
||||
tokens = llama.tokenize(text, special=True)
|
||||
assert tokens[-1] == llama.token_eos()
|
||||
assert tokens == [1, 10994, 2787, 2]
|
||||
|
||||
|
||||
def test_llama_patch(monkeypatch):
|
||||
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True)
|
||||
|
|
Loading…
Reference in a new issue