From 4d4e0f11e292984320e59f1dc1ecf9a0cc437a75 Mon Sep 17 00:00:00 2001 From: Antoine Lizee Date: Thu, 2 Nov 2023 01:29:06 +0000 Subject: [PATCH] fix: tokenization of special characters: (#850) It should behave like llama.cpp, where most out of the box usages treat special characters accordingly --- llama_cpp/llama.py | 6 +++--- llama_cpp/server/app.py | 2 +- test.py | 0 tests/test_llama.py | 9 +++++++++ 4 files changed, 13 insertions(+), 4 deletions(-) create mode 100644 test.py diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index bc747cf..c9ea90f 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -856,7 +856,7 @@ class Llama: data: List[EmbeddingData] = [] total_tokens = 0 for index, input in enumerate(inputs): - tokens = self.tokenize(input.encode("utf-8")) + tokens = self.tokenize(input.encode("utf-8"), special=True) self.reset() self.eval(tokens) n_tokens = len(tokens) @@ -928,7 +928,7 @@ class Llama: completion_tokens: List[int] = [] # Add blank space to start of prompt to match OG llama tokenizer prompt_tokens: List[int] = ( - self.tokenize(prompt.encode("utf-8")) + self.tokenize(prompt.encode("utf-8"), special=True) if prompt != "" else [self.token_bos()] ) @@ -1826,7 +1826,7 @@ class LlamaTokenizer: def encode(self, text: str, add_bos: bool = True) -> List[int]: return self.llama.tokenize( - text.encode("utf-8", errors="ignore"), add_bos=add_bos + text.encode("utf-8", errors="ignore"), add_bos=add_bos, special=True ) def decode(self, tokens: List[int]) -> str: diff --git a/llama_cpp/server/app.py b/llama_cpp/server/app.py index 930ad5d..f8d8c76 100644 --- a/llama_cpp/server/app.py +++ b/llama_cpp/server/app.py @@ -594,7 +594,7 @@ def make_logit_bias_processor( elif logit_bias_type == "tokens": for token, score in logit_bias.items(): token = token.encode("utf-8") - for input_id in llama.tokenize(token, add_bos=False): + for input_id in llama.tokenize(token, add_bos=False, special=True): to_bias[input_id] = score def logit_bias_processor( diff --git a/test.py b/test.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_llama.py b/tests/test_llama.py index 76291fb..330b69b 100644 --- a/tests/test_llama.py +++ b/tests/test_llama.py @@ -25,6 +25,15 @@ def test_llama_cpp_tokenization(): detokenized = llama.detokenize(tokens) assert detokenized != text + text = b"Hello World" + tokens = llama.tokenize(text) + assert tokens[-1] != llama.token_eos() + assert tokens == [1, 15043, 2787, 829, 29879, 29958] + + tokens = llama.tokenize(text, special=True) + assert tokens[-1] == llama.token_eos() + assert tokens == [1, 10994, 2787, 2] + def test_llama_patch(monkeypatch): llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True)