diff --git a/examples/low_level_api/low_level_api_chat_cpp.py b/examples/low_level_api/low_level_api_chat_cpp.py index d64ee8f..6fced65 100644 --- a/examples/low_level_api/low_level_api_chat_cpp.py +++ b/examples/low_level_api/low_level_api_chat_cpp.py @@ -201,7 +201,7 @@ n_keep = {self.params.n_keep} # tokenize a prompt def _tokenize(self, prompt, bos=True): _arr = (llama_cpp.llama_token * (len(prompt) + 1))() - _n = llama_cpp.llama_tokenize(self.ctx, prompt.encode("utf8"), _arr, len(_arr), bos) + _n = llama_cpp.llama_tokenize(self.ctx, prompt.encode("utf8", errors="ignore"), _arr, len(_arr), bos) return _arr[:_n] def set_color(self, c): @@ -342,7 +342,7 @@ n_keep = {self.params.n_keep} # return past text def past(self): for id in self.last_n_tokens[-self.n_past:]: - yield llama_cpp.llama_token_to_str(self.ctx, id).decode("utf-8") + yield llama_cpp.llama_token_to_str(self.ctx, id).decode("utf-8", errors="ignore") # write input def input(self, prompt: str): @@ -356,7 +356,7 @@ n_keep = {self.params.n_keep} def output(self): self.remaining_tokens = self.params.n_predict for id in self.generate(): - yield llama_cpp.llama_token_to_str(self.ctx, id).decode("utf-8") + yield llama_cpp.llama_token_to_str(self.ctx, id).decode("utf-8", errors="ignore") # read user input def read_input(self): diff --git a/examples/low_level_api/low_level_api_llama_cpp.py b/examples/low_level_api/low_level_api_llama_cpp.py index b048c0a..4fb5a03 100644 --- a/examples/low_level_api/low_level_api_llama_cpp.py +++ b/examples/low_level_api/low_level_api_llama_cpp.py @@ -70,7 +70,7 @@ while remaining_tokens > 0: if not input_noecho: for id in embd: print( - llama_cpp.llama_token_to_str(ctx, id).decode("utf-8"), + llama_cpp.llama_token_to_str(ctx, id).decode("utf-8", errors="ignore"), end="", flush=True, ) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index df9a491..4e3c3aa 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -446,6 +446,7 @@ class Llama: self.load_state(self.cache[prompt_tokens]) finish_reason = "length" + multibyte_fix = 0 for token in self.generate( prompt_tokens, top_k=top_k, @@ -467,6 +468,20 @@ class Llama: completion_tokens.append(token) all_text = self.detokenize(completion_tokens) + + # Contains multi-byte UTF8 + for k,char in enumerate(all_text[-3:]): + k = 3 - k + for num,pattern in [(2, 192), (3, 224), (4, 240)]: + # Bitwise AND check + if (num > k and pattern & char == pattern): + multibyte_fix = num - k + + # Stop incomplete bytes from passing + if (multibyte_fix > 0): + multibyte_fix -= 1 + continue + any_stop = [s for s in stop_sequences if s in all_text] if len(any_stop) > 0: first_stop = any_stop[0] @@ -495,7 +510,7 @@ class Llama: "model": self.model_path, "choices": [ { - "text": text[start:].decode("utf-8"), + "text": text[start:].decode("utf-8", errors="ignore"), "index": 0, "logprobs": None, "finish_reason": None, @@ -516,7 +531,7 @@ class Llama: "model": self.model_path, "choices": [ { - "text": text[returned_characters:].decode("utf-8"), + "text": text[returned_characters:].decode("utf-8", errors="ignore"), "index": 0, "logprobs": None, "finish_reason": finish_reason, @@ -525,7 +540,7 @@ class Llama: } return - text_str = text.decode("utf-8") + text_str = text.decode("utf-8", errors="ignore") if echo: text_str = prompt + text_str @@ -543,7 +558,7 @@ class Llama: all_tokens = prompt_tokens + completion_tokens all_token_strs = [ - self.detokenize([token]).decode("utf-8") for token in all_tokens + self.detokenize([token]).decode("utf-8", errors="ignore") for token in all_tokens ] all_logprobs = [ [Llama.logit_to_logprob(logit) for logit in row] @@ -562,7 +577,7 @@ class Llama: ) token_logprobs.append(sorted_logprobs[int(token)][0]) top_logprob = { - self.detokenize([llama_cpp.llama_token(i)]).decode("utf-8"): logprob + self.detokenize([llama_cpp.llama_token(i)]).decode("utf-8", errors="ignore"): logprob for logprob, i in sorted_logprobs[:logprobs] } top_logprob.update({token_str: sorted_logprobs[int(token)][0]}) diff --git a/tests/test_llama.py b/tests/test_llama.py index 6a50256..4727d90 100644 --- a/tests/test_llama.py +++ b/tests/test_llama.py @@ -93,4 +93,38 @@ def test_llama_pickle(): text = b"Hello World" - assert llama.detokenize(llama.tokenize(text)) == text \ No newline at end of file + assert llama.detokenize(llama.tokenize(text)) == text + +def test_utf8(monkeypatch): + llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True) + + ## Set up mock function + def mock_eval(*args, **kwargs): + return 0 + + monkeypatch.setattr("llama_cpp.llama_cpp.llama_eval", mock_eval) + + output_text = "😀" + output_tokens = llama.tokenize(output_text.encode("utf-8")) + token_eos = llama.token_eos() + n = 0 + + def mock_sample(*args, **kwargs): + nonlocal n + if n < len(output_tokens): + n += 1 + return output_tokens[n - 1] + else: + return token_eos + + monkeypatch.setattr("llama_cpp.llama_cpp.llama_sample_top_p_top_k", mock_sample) + + ## Test basic completion with utf8 multibyte + n = 0 # reset + completion = llama.create_completion("", max_tokens=4) + assert completion["choices"][0]["text"] == output_text + + ## Test basic completion with incomplete utf8 multibyte + n = 0 # reset + completion = llama.create_completion("", max_tokens=1) + assert completion["choices"][0]["text"] == ""