From c39547a986540d1152493db45ed461dde04f0ffa Mon Sep 17 00:00:00 2001 From: Mug <> Date: Fri, 28 Apr 2023 12:50:30 +0200 Subject: [PATCH] Detect multi-byte responses and wait --- examples/low_level_api/low_level_api_chat_cpp.py | 2 +- llama_cpp/llama.py | 14 +++++++++++++- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/examples/low_level_api/low_level_api_chat_cpp.py b/examples/low_level_api/low_level_api_chat_cpp.py index 90b2fcb..6fced65 100644 --- a/examples/low_level_api/low_level_api_chat_cpp.py +++ b/examples/low_level_api/low_level_api_chat_cpp.py @@ -96,7 +96,7 @@ specified) expect poor results""", file=sys.stderr) print(file=sys.stderr) print(f"system_info: n_threads = {self.params.n_threads} / {cpu_count()} \ -| {llama_cpp.llama_print_system_info().decode('utf8', errors='ignore')}", file=sys.stderr) +| {llama_cpp.llama_print_system_info().decode('utf8')}", file=sys.stderr) # determine the required inference memory per token: if (self.params.mem_test): diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 41e8c0a..630af18 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -159,7 +159,7 @@ class Llama: ) if self.verbose: - print(llama_cpp.llama_print_system_info().decode("utf-8", errors="ignore"), file=sys.stderr) + print(llama_cpp.llama_print_system_info().decode("utf-8"), file=sys.stderr) def tokenize(self, text: bytes) -> List[llama_cpp.llama_token]: """Tokenize a string. @@ -446,6 +446,7 @@ class Llama: self.load_state(self.cache[prompt_tokens]) finish_reason = "length" + multibyte_fix = 0 for token in self.generate( prompt_tokens, top_k=top_k, @@ -458,6 +459,12 @@ class Llama: finish_reason = "stop" break + # Contains multi-byte UTF8 + for num,pattern in [(2, 192), (3, 224), (4, 240)]: + # Bitwise AND check + if (pattern & token == pattern): + multibyte_fix = num + if self.cache and len(completion_tokens) == 0: if prompt_tokens not in self.cache: if self.verbose: @@ -466,6 +473,11 @@ class Llama: completion_tokens.append(token) + # Stop incomplete bytes from passing + if (multibyte_fix > 0): + multibyte_fix -= 1 + continue + all_text = self.detokenize(completion_tokens) any_stop = [s for s in stop_sequences if s in all_text] if len(any_stop) > 0: