fix: tokenization of special characters: (#850)
It should behave like llama.cpp, where most out of the box usages treat special characters accordingly
This commit is contained in:
parent
952e4cc3ce
commit
4d4e0f11e2
4 changed files with 13 additions and 4 deletions
|
@ -856,7 +856,7 @@ class Llama:
|
||||||
data: List[EmbeddingData] = []
|
data: List[EmbeddingData] = []
|
||||||
total_tokens = 0
|
total_tokens = 0
|
||||||
for index, input in enumerate(inputs):
|
for index, input in enumerate(inputs):
|
||||||
tokens = self.tokenize(input.encode("utf-8"))
|
tokens = self.tokenize(input.encode("utf-8"), special=True)
|
||||||
self.reset()
|
self.reset()
|
||||||
self.eval(tokens)
|
self.eval(tokens)
|
||||||
n_tokens = len(tokens)
|
n_tokens = len(tokens)
|
||||||
|
@ -928,7 +928,7 @@ class Llama:
|
||||||
completion_tokens: List[int] = []
|
completion_tokens: List[int] = []
|
||||||
# Add blank space to start of prompt to match OG llama tokenizer
|
# Add blank space to start of prompt to match OG llama tokenizer
|
||||||
prompt_tokens: List[int] = (
|
prompt_tokens: List[int] = (
|
||||||
self.tokenize(prompt.encode("utf-8"))
|
self.tokenize(prompt.encode("utf-8"), special=True)
|
||||||
if prompt != ""
|
if prompt != ""
|
||||||
else [self.token_bos()]
|
else [self.token_bos()]
|
||||||
)
|
)
|
||||||
|
@ -1826,7 +1826,7 @@ class LlamaTokenizer:
|
||||||
|
|
||||||
def encode(self, text: str, add_bos: bool = True) -> List[int]:
|
def encode(self, text: str, add_bos: bool = True) -> List[int]:
|
||||||
return self.llama.tokenize(
|
return self.llama.tokenize(
|
||||||
text.encode("utf-8", errors="ignore"), add_bos=add_bos
|
text.encode("utf-8", errors="ignore"), add_bos=add_bos, special=True
|
||||||
)
|
)
|
||||||
|
|
||||||
def decode(self, tokens: List[int]) -> str:
|
def decode(self, tokens: List[int]) -> str:
|
||||||
|
|
|
@ -594,7 +594,7 @@ def make_logit_bias_processor(
|
||||||
elif logit_bias_type == "tokens":
|
elif logit_bias_type == "tokens":
|
||||||
for token, score in logit_bias.items():
|
for token, score in logit_bias.items():
|
||||||
token = token.encode("utf-8")
|
token = token.encode("utf-8")
|
||||||
for input_id in llama.tokenize(token, add_bos=False):
|
for input_id in llama.tokenize(token, add_bos=False, special=True):
|
||||||
to_bias[input_id] = score
|
to_bias[input_id] = score
|
||||||
|
|
||||||
def logit_bias_processor(
|
def logit_bias_processor(
|
||||||
|
|
0
test.py
Normal file
0
test.py
Normal file
|
@ -25,6 +25,15 @@ def test_llama_cpp_tokenization():
|
||||||
detokenized = llama.detokenize(tokens)
|
detokenized = llama.detokenize(tokens)
|
||||||
assert detokenized != text
|
assert detokenized != text
|
||||||
|
|
||||||
|
text = b"Hello World</s>"
|
||||||
|
tokens = llama.tokenize(text)
|
||||||
|
assert tokens[-1] != llama.token_eos()
|
||||||
|
assert tokens == [1, 15043, 2787, 829, 29879, 29958]
|
||||||
|
|
||||||
|
tokens = llama.tokenize(text, special=True)
|
||||||
|
assert tokens[-1] == llama.token_eos()
|
||||||
|
assert tokens == [1, 10994, 2787, 2]
|
||||||
|
|
||||||
|
|
||||||
def test_llama_patch(monkeypatch):
|
def test_llama_patch(monkeypatch):
|
||||||
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True)
|
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True)
|
||||||
|
|
Loading…
Add table
Reference in a new issue