Detect multi-byte responses and wait
This commit is contained in:
parent
5f81400fcb
commit
c39547a986
2 changed files with 14 additions and 2 deletions
|
@ -96,7 +96,7 @@ specified) expect poor results""", file=sys.stderr)
|
|||
|
||||
print(file=sys.stderr)
|
||||
print(f"system_info: n_threads = {self.params.n_threads} / {cpu_count()} \
|
||||
| {llama_cpp.llama_print_system_info().decode('utf8', errors='ignore')}", file=sys.stderr)
|
||||
| {llama_cpp.llama_print_system_info().decode('utf8')}", file=sys.stderr)
|
||||
|
||||
# determine the required inference memory per token:
|
||||
if (self.params.mem_test):
|
||||
|
|
|
@ -159,7 +159,7 @@ class Llama:
|
|||
)
|
||||
|
||||
if self.verbose:
|
||||
print(llama_cpp.llama_print_system_info().decode("utf-8", errors="ignore"), file=sys.stderr)
|
||||
print(llama_cpp.llama_print_system_info().decode("utf-8"), file=sys.stderr)
|
||||
|
||||
def tokenize(self, text: bytes) -> List[llama_cpp.llama_token]:
|
||||
"""Tokenize a string.
|
||||
|
@ -446,6 +446,7 @@ class Llama:
|
|||
self.load_state(self.cache[prompt_tokens])
|
||||
|
||||
finish_reason = "length"
|
||||
multibyte_fix = 0
|
||||
for token in self.generate(
|
||||
prompt_tokens,
|
||||
top_k=top_k,
|
||||
|
@ -458,6 +459,12 @@ class Llama:
|
|||
finish_reason = "stop"
|
||||
break
|
||||
|
||||
# Contains multi-byte UTF8
|
||||
for num,pattern in [(2, 192), (3, 224), (4, 240)]:
|
||||
# Bitwise AND check
|
||||
if (pattern & token == pattern):
|
||||
multibyte_fix = num
|
||||
|
||||
if self.cache and len(completion_tokens) == 0:
|
||||
if prompt_tokens not in self.cache:
|
||||
if self.verbose:
|
||||
|
@ -466,6 +473,11 @@ class Llama:
|
|||
|
||||
completion_tokens.append(token)
|
||||
|
||||
# Stop incomplete bytes from passing
|
||||
if (multibyte_fix > 0):
|
||||
multibyte_fix -= 1
|
||||
continue
|
||||
|
||||
all_text = self.detokenize(completion_tokens)
|
||||
any_stop = [s for s in stop_sequences if s in all_text]
|
||||
if len(any_stop) > 0:
|
||||
|
|
Loading…
Add table
Reference in a new issue