From c4a8491d42b3b93330408afc3cc2af31ae2fecb1 Mon Sep 17 00:00:00 2001 From: Mug <> Date: Wed, 26 Apr 2023 14:37:06 +0200 Subject: [PATCH 1/8] Fix decode errors permanently --- examples/low_level_api/low_level_api_chat_cpp.py | 9 ++++++--- examples/low_level_api/low_level_api_llama_cpp.py | 2 +- llama_cpp/llama.py | 12 ++++++------ 3 files changed, 13 insertions(+), 10 deletions(-) diff --git a/examples/low_level_api/low_level_api_chat_cpp.py b/examples/low_level_api/low_level_api_chat_cpp.py index d64ee8f..4a7cfc1 100644 --- a/examples/low_level_api/low_level_api_chat_cpp.py +++ b/examples/low_level_api/low_level_api_chat_cpp.py @@ -96,7 +96,7 @@ specified) expect poor results""", file=sys.stderr) print(file=sys.stderr) print(f"system_info: n_threads = {self.params.n_threads} / {cpu_count()} \ -| {llama_cpp.llama_print_system_info().decode('utf8')}", file=sys.stderr) +| {llama_cpp.llama_print_system_info().decode('utf8', errors='ignore')}", file=sys.stderr) # determine the required inference memory per token: if (self.params.mem_test): @@ -342,7 +342,7 @@ n_keep = {self.params.n_keep} # return past text def past(self): for id in self.last_n_tokens[-self.n_past:]: - yield llama_cpp.llama_token_to_str(self.ctx, id).decode("utf-8") + yield llama_cpp.llama_token_to_str(self.ctx, id).decode("utf-8", errors="ignore") # write input def input(self, prompt: str): @@ -356,7 +356,10 @@ n_keep = {self.params.n_keep} def output(self): self.remaining_tokens = self.params.n_predict for id in self.generate(): - yield llama_cpp.llama_token_to_str(self.ctx, id).decode("utf-8") + try: + yield llama_cpp.llama_token_to_str(self.ctx, id).decode("utf-8", errors="ignore") + except UnicodeDecodeError: + pass # read user input def read_input(self): diff --git a/examples/low_level_api/low_level_api_llama_cpp.py b/examples/low_level_api/low_level_api_llama_cpp.py index b048c0a..4fb5a03 100644 --- a/examples/low_level_api/low_level_api_llama_cpp.py +++ b/examples/low_level_api/low_level_api_llama_cpp.py @@ -70,7 +70,7 @@ while remaining_tokens > 0: if not input_noecho: for id in embd: print( - llama_cpp.llama_token_to_str(ctx, id).decode("utf-8"), + llama_cpp.llama_token_to_str(ctx, id).decode("utf-8", errors="ignore"), end="", flush=True, ) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index edd2eef..a6e7ae3 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -109,7 +109,7 @@ class Llama: ) if self.verbose: - print(llama_cpp.llama_print_system_info().decode("utf-8"), file=sys.stderr) + print(llama_cpp.llama_print_system_info().decode("utf-8", errors="ignore"), file=sys.stderr) def tokenize(self, text: bytes) -> List[llama_cpp.llama_token]: """Tokenize a string. @@ -460,7 +460,7 @@ class Llama: "model": self.model_path, "choices": [ { - "text": text[start:].decode("utf-8"), + "text": text[start:].decode("utf-8", errors="ignore"), "index": 0, "logprobs": None, "finish_reason": None, @@ -484,7 +484,7 @@ class Llama: "model": self.model_path, "choices": [ { - "text": text[returned_characters:].decode("utf-8"), + "text": text[returned_characters:].decode("utf-8", errors="ignore"), "index": 0, "logprobs": None, "finish_reason": finish_reason, @@ -496,7 +496,7 @@ class Llama: ### HACK self._completion_bytes.append(text) ### - text_str = text.decode("utf-8") + text_str = text.decode("utf-8", errors="ignore") if echo: text_str = prompt + text_str @@ -514,7 +514,7 @@ class Llama: all_tokens = prompt_tokens + completion_tokens all_token_strs = [ - self.detokenize([token]).decode("utf-8") for token in all_tokens + self.detokenize([token]).decode("utf-8", errors="ignore") for token in all_tokens ] all_logprobs = [ [Llama.logit_to_logprob(logit) for logit in row] @@ -533,7 +533,7 @@ class Llama: ) token_logprobs.append(sorted_logprobs[int(token)][0]) top_logprob = { - self.detokenize([llama_cpp.llama_token(i)]).decode("utf-8"): logprob + self.detokenize([llama_cpp.llama_token(i)]).decode("utf-8", errors="ignore"): logprob for logprob, i in sorted_logprobs[:logprobs] } top_logprob.update({token_str: sorted_logprobs[int(token)][0]}) From 3c130f00ca65943fc4ac3db7d11cf9ca83cd5c2a Mon Sep 17 00:00:00 2001 From: Mug <> Date: Wed, 26 Apr 2023 14:38:53 +0200 Subject: [PATCH 2/8] Remove try catch from chat --- examples/low_level_api/low_level_api_chat_cpp.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/examples/low_level_api/low_level_api_chat_cpp.py b/examples/low_level_api/low_level_api_chat_cpp.py index 4a7cfc1..c383bf6 100644 --- a/examples/low_level_api/low_level_api_chat_cpp.py +++ b/examples/low_level_api/low_level_api_chat_cpp.py @@ -356,10 +356,7 @@ n_keep = {self.params.n_keep} def output(self): self.remaining_tokens = self.params.n_predict for id in self.generate(): - try: - yield llama_cpp.llama_token_to_str(self.ctx, id).decode("utf-8", errors="ignore") - except UnicodeDecodeError: - pass + yield llama_cpp.llama_token_to_str(self.ctx, id).decode("utf-8", errors="ignore") # read user input def read_input(self): From 5f81400fcb2898e9eb6b13f32dc066052d7efeef Mon Sep 17 00:00:00 2001 From: Mug <> Date: Wed, 26 Apr 2023 14:45:51 +0200 Subject: [PATCH 3/8] Also ignore errors on input prompts --- examples/low_level_api/low_level_api_chat_cpp.py | 2 +- llama_cpp/llama.py | 6 +++--- tests/test_llama.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/low_level_api/low_level_api_chat_cpp.py b/examples/low_level_api/low_level_api_chat_cpp.py index c383bf6..90b2fcb 100644 --- a/examples/low_level_api/low_level_api_chat_cpp.py +++ b/examples/low_level_api/low_level_api_chat_cpp.py @@ -201,7 +201,7 @@ n_keep = {self.params.n_keep} # tokenize a prompt def _tokenize(self, prompt, bos=True): _arr = (llama_cpp.llama_token * (len(prompt) + 1))() - _n = llama_cpp.llama_tokenize(self.ctx, prompt.encode("utf8"), _arr, len(_arr), bos) + _n = llama_cpp.llama_tokenize(self.ctx, prompt.encode("utf8", errors="ignore"), _arr, len(_arr), bos) return _arr[:_n] def set_color(self, c): diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index f442648..41e8c0a 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -358,7 +358,7 @@ class Llama: if self.verbose: llama_cpp.llama_reset_timings(self.ctx) - tokens = self.tokenize(input.encode("utf-8")) + tokens = self.tokenize(input.encode("utf-8", errors="ignore")) self.reset() self.eval(tokens) n_tokens = len(tokens) @@ -416,7 +416,7 @@ class Llama: completion_tokens: List[llama_cpp.llama_token] = [] # Add blank space to start of prompt to match OG llama tokenizer prompt_tokens: List[llama_cpp.llama_token] = self.tokenize( - b" " + prompt.encode("utf-8") + b" " + prompt.encode("utf-8", errors="ignore") ) text: bytes = b"" returned_characters: int = 0 @@ -431,7 +431,7 @@ class Llama: ) if stop != []: - stop_sequences = [s.encode("utf-8") for s in stop] + stop_sequences = [s.encode("utf-8", errors="ignore") for s in stop] else: stop_sequences = [] diff --git a/tests/test_llama.py b/tests/test_llama.py index 6a50256..4dab687 100644 --- a/tests/test_llama.py +++ b/tests/test_llama.py @@ -24,7 +24,7 @@ def test_llama_patch(monkeypatch): monkeypatch.setattr("llama_cpp.llama_cpp.llama_eval", mock_eval) output_text = " jumps over the lazy dog." - output_tokens = llama.tokenize(output_text.encode("utf-8")) + output_tokens = llama.tokenize(output_text.encode("utf-8", errors="ignore")) token_eos = llama.token_eos() n = 0 From c39547a986540d1152493db45ed461dde04f0ffa Mon Sep 17 00:00:00 2001 From: Mug <> Date: Fri, 28 Apr 2023 12:50:30 +0200 Subject: [PATCH 4/8] Detect multi-byte responses and wait --- examples/low_level_api/low_level_api_chat_cpp.py | 2 +- llama_cpp/llama.py | 14 +++++++++++++- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/examples/low_level_api/low_level_api_chat_cpp.py b/examples/low_level_api/low_level_api_chat_cpp.py index 90b2fcb..6fced65 100644 --- a/examples/low_level_api/low_level_api_chat_cpp.py +++ b/examples/low_level_api/low_level_api_chat_cpp.py @@ -96,7 +96,7 @@ specified) expect poor results""", file=sys.stderr) print(file=sys.stderr) print(f"system_info: n_threads = {self.params.n_threads} / {cpu_count()} \ -| {llama_cpp.llama_print_system_info().decode('utf8', errors='ignore')}", file=sys.stderr) +| {llama_cpp.llama_print_system_info().decode('utf8')}", file=sys.stderr) # determine the required inference memory per token: if (self.params.mem_test): diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 41e8c0a..630af18 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -159,7 +159,7 @@ class Llama: ) if self.verbose: - print(llama_cpp.llama_print_system_info().decode("utf-8", errors="ignore"), file=sys.stderr) + print(llama_cpp.llama_print_system_info().decode("utf-8"), file=sys.stderr) def tokenize(self, text: bytes) -> List[llama_cpp.llama_token]: """Tokenize a string. @@ -446,6 +446,7 @@ class Llama: self.load_state(self.cache[prompt_tokens]) finish_reason = "length" + multibyte_fix = 0 for token in self.generate( prompt_tokens, top_k=top_k, @@ -458,6 +459,12 @@ class Llama: finish_reason = "stop" break + # Contains multi-byte UTF8 + for num,pattern in [(2, 192), (3, 224), (4, 240)]: + # Bitwise AND check + if (pattern & token == pattern): + multibyte_fix = num + if self.cache and len(completion_tokens) == 0: if prompt_tokens not in self.cache: if self.verbose: @@ -466,6 +473,11 @@ class Llama: completion_tokens.append(token) + # Stop incomplete bytes from passing + if (multibyte_fix > 0): + multibyte_fix -= 1 + continue + all_text = self.detokenize(completion_tokens) any_stop = [s for s in stop_sequences if s in all_text] if len(any_stop) > 0: From 3a987470261b26f7a005b784863b282645326dc6 Mon Sep 17 00:00:00 2001 From: Mug <> Date: Fri, 28 Apr 2023 12:54:28 +0200 Subject: [PATCH 5/8] One day, i'll fix off by 1 errors permanently too --- llama_cpp/llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 630af18..5adeaf8 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -463,7 +463,7 @@ class Llama: for num,pattern in [(2, 192), (3, 224), (4, 240)]: # Bitwise AND check if (pattern & token == pattern): - multibyte_fix = num + multibyte_fix = num - 1 if self.cache and len(completion_tokens) == 0: if prompt_tokens not in self.cache: From eed61289b68903ad9ca01f85976e9ababbbb1291 Mon Sep 17 00:00:00 2001 From: Mug <> Date: Fri, 28 Apr 2023 13:16:18 +0200 Subject: [PATCH 6/8] Dont detect off tokens, detect off detokenized utf8 --- llama_cpp/llama.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 5adeaf8..92715b5 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -459,12 +459,6 @@ class Llama: finish_reason = "stop" break - # Contains multi-byte UTF8 - for num,pattern in [(2, 192), (3, 224), (4, 240)]: - # Bitwise AND check - if (pattern & token == pattern): - multibyte_fix = num - 1 - if self.cache and len(completion_tokens) == 0: if prompt_tokens not in self.cache: if self.verbose: @@ -473,12 +467,22 @@ class Llama: completion_tokens.append(token) + all_text = self.detokenize(completion_tokens) + + # Contains multi-byte UTF8 + for k,char in enumerate(all_text[-3:]): + k = 3 - k + char = int.from_bytes(char, "big") + for num,pattern in [(2, 192), (3, 224), (4, 240)]: + # Bitwise AND check + if (num > k and pattern & char == pattern): + multibyte_fix = num - k + # Stop incomplete bytes from passing if (multibyte_fix > 0): multibyte_fix -= 1 continue - all_text = self.detokenize(completion_tokens) any_stop = [s for s in stop_sequences if s in all_text] if len(any_stop) > 0: first_stop = any_stop[0] From b7d14efc8b7b62d97ed66694b0dca0e1e3b3b2f6 Mon Sep 17 00:00:00 2001 From: Mug <> Date: Fri, 28 Apr 2023 13:20:31 +0200 Subject: [PATCH 7/8] Python weirdness --- llama_cpp/llama.py | 1 - 1 file changed, 1 deletion(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 92715b5..fe540f9 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -472,7 +472,6 @@ class Llama: # Contains multi-byte UTF8 for k,char in enumerate(all_text[-3:]): k = 3 - k - char = int.from_bytes(char, "big") for num,pattern in [(2, 192), (3, 224), (4, 240)]: # Bitwise AND check if (num > k and pattern & char == pattern): From 18a0c10032ef793b67bb8ea9e4ca9e3aaa791595 Mon Sep 17 00:00:00 2001 From: Mug <> Date: Sat, 29 Apr 2023 12:19:22 +0200 Subject: [PATCH 8/8] Remove excessive errors="ignore" and add utf8 test --- llama_cpp/llama.py | 6 +++--- tests/test_llama.py | 38 ++++++++++++++++++++++++++++++++++++-- 2 files changed, 39 insertions(+), 5 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index fe540f9..4e3c3aa 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -358,7 +358,7 @@ class Llama: if self.verbose: llama_cpp.llama_reset_timings(self.ctx) - tokens = self.tokenize(input.encode("utf-8", errors="ignore")) + tokens = self.tokenize(input.encode("utf-8")) self.reset() self.eval(tokens) n_tokens = len(tokens) @@ -416,7 +416,7 @@ class Llama: completion_tokens: List[llama_cpp.llama_token] = [] # Add blank space to start of prompt to match OG llama tokenizer prompt_tokens: List[llama_cpp.llama_token] = self.tokenize( - b" " + prompt.encode("utf-8", errors="ignore") + b" " + prompt.encode("utf-8") ) text: bytes = b"" returned_characters: int = 0 @@ -431,7 +431,7 @@ class Llama: ) if stop != []: - stop_sequences = [s.encode("utf-8", errors="ignore") for s in stop] + stop_sequences = [s.encode("utf-8") for s in stop] else: stop_sequences = [] diff --git a/tests/test_llama.py b/tests/test_llama.py index 4dab687..4727d90 100644 --- a/tests/test_llama.py +++ b/tests/test_llama.py @@ -24,7 +24,7 @@ def test_llama_patch(monkeypatch): monkeypatch.setattr("llama_cpp.llama_cpp.llama_eval", mock_eval) output_text = " jumps over the lazy dog." - output_tokens = llama.tokenize(output_text.encode("utf-8", errors="ignore")) + output_tokens = llama.tokenize(output_text.encode("utf-8")) token_eos = llama.token_eos() n = 0 @@ -93,4 +93,38 @@ def test_llama_pickle(): text = b"Hello World" - assert llama.detokenize(llama.tokenize(text)) == text \ No newline at end of file + assert llama.detokenize(llama.tokenize(text)) == text + +def test_utf8(monkeypatch): + llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True) + + ## Set up mock function + def mock_eval(*args, **kwargs): + return 0 + + monkeypatch.setattr("llama_cpp.llama_cpp.llama_eval", mock_eval) + + output_text = "😀" + output_tokens = llama.tokenize(output_text.encode("utf-8")) + token_eos = llama.token_eos() + n = 0 + + def mock_sample(*args, **kwargs): + nonlocal n + if n < len(output_tokens): + n += 1 + return output_tokens[n - 1] + else: + return token_eos + + monkeypatch.setattr("llama_cpp.llama_cpp.llama_sample_top_p_top_k", mock_sample) + + ## Test basic completion with utf8 multibyte + n = 0 # reset + completion = llama.create_completion("", max_tokens=4) + assert completion["choices"][0]["text"] == output_text + + ## Test basic completion with incomplete utf8 multibyte + n = 0 # reset + completion = llama.create_completion("", max_tokens=1) + assert completion["choices"][0]["text"] == ""