Detect multi-byte responses and wait

This commit is contained in:
Mug 2023-04-28 12:50:30 +02:00
parent 5f81400fcb
commit c39547a986
2 changed files with 14 additions and 2 deletions

View file

@ -96,7 +96,7 @@ specified) expect poor results""", file=sys.stderr)
print(file=sys.stderr)
print(f"system_info: n_threads = {self.params.n_threads} / {cpu_count()} \
| {llama_cpp.llama_print_system_info().decode('utf8', errors='ignore')}", file=sys.stderr)
| {llama_cpp.llama_print_system_info().decode('utf8')}", file=sys.stderr)
# determine the required inference memory per token:
if (self.params.mem_test):

View file

@ -159,7 +159,7 @@ class Llama:
)
if self.verbose:
print(llama_cpp.llama_print_system_info().decode("utf-8", errors="ignore"), file=sys.stderr)
print(llama_cpp.llama_print_system_info().decode("utf-8"), file=sys.stderr)
def tokenize(self, text: bytes) -> List[llama_cpp.llama_token]:
"""Tokenize a string.
@ -446,6 +446,7 @@ class Llama:
self.load_state(self.cache[prompt_tokens])
finish_reason = "length"
multibyte_fix = 0
for token in self.generate(
prompt_tokens,
top_k=top_k,
@ -458,6 +459,12 @@ class Llama:
finish_reason = "stop"
break
# Contains multi-byte UTF8
for num,pattern in [(2, 192), (3, 224), (4, 240)]:
# Bitwise AND check
if (pattern & token == pattern):
multibyte_fix = num
if self.cache and len(completion_tokens) == 0:
if prompt_tokens not in self.cache:
if self.verbose:
@ -466,6 +473,11 @@ class Llama:
completion_tokens.append(token)
# Stop incomplete bytes from passing
if (multibyte_fix > 0):
multibyte_fix -= 1
continue
all_text = self.detokenize(completion_tokens)
any_stop = [s for s in stop_sequences if s in all_text]
if len(any_stop) > 0: