Fix decode errors permanently

This commit is contained in:
Mug 2023-04-26 14:37:06 +02:00
parent 1b73a15e62
commit c4a8491d42
3 changed files with 13 additions and 10 deletions

View file

@ -96,7 +96,7 @@ specified) expect poor results""", file=sys.stderr)
print(file=sys.stderr) print(file=sys.stderr)
print(f"system_info: n_threads = {self.params.n_threads} / {cpu_count()} \ print(f"system_info: n_threads = {self.params.n_threads} / {cpu_count()} \
| {llama_cpp.llama_print_system_info().decode('utf8')}", file=sys.stderr) | {llama_cpp.llama_print_system_info().decode('utf8', errors='ignore')}", file=sys.stderr)
# determine the required inference memory per token: # determine the required inference memory per token:
if (self.params.mem_test): if (self.params.mem_test):
@ -342,7 +342,7 @@ n_keep = {self.params.n_keep}
# return past text # return past text
def past(self): def past(self):
for id in self.last_n_tokens[-self.n_past:]: for id in self.last_n_tokens[-self.n_past:]:
yield llama_cpp.llama_token_to_str(self.ctx, id).decode("utf-8") yield llama_cpp.llama_token_to_str(self.ctx, id).decode("utf-8", errors="ignore")
# write input # write input
def input(self, prompt: str): def input(self, prompt: str):
@ -356,7 +356,10 @@ n_keep = {self.params.n_keep}
def output(self): def output(self):
self.remaining_tokens = self.params.n_predict self.remaining_tokens = self.params.n_predict
for id in self.generate(): for id in self.generate():
yield llama_cpp.llama_token_to_str(self.ctx, id).decode("utf-8") try:
yield llama_cpp.llama_token_to_str(self.ctx, id).decode("utf-8", errors="ignore")
except UnicodeDecodeError:
pass
# read user input # read user input
def read_input(self): def read_input(self):

View file

@ -70,7 +70,7 @@ while remaining_tokens > 0:
if not input_noecho: if not input_noecho:
for id in embd: for id in embd:
print( print(
llama_cpp.llama_token_to_str(ctx, id).decode("utf-8"), llama_cpp.llama_token_to_str(ctx, id).decode("utf-8", errors="ignore"),
end="", end="",
flush=True, flush=True,
) )

View file

@ -109,7 +109,7 @@ class Llama:
) )
if self.verbose: if self.verbose:
print(llama_cpp.llama_print_system_info().decode("utf-8"), file=sys.stderr) print(llama_cpp.llama_print_system_info().decode("utf-8", errors="ignore"), file=sys.stderr)
def tokenize(self, text: bytes) -> List[llama_cpp.llama_token]: def tokenize(self, text: bytes) -> List[llama_cpp.llama_token]:
"""Tokenize a string. """Tokenize a string.
@ -460,7 +460,7 @@ class Llama:
"model": self.model_path, "model": self.model_path,
"choices": [ "choices": [
{ {
"text": text[start:].decode("utf-8"), "text": text[start:].decode("utf-8", errors="ignore"),
"index": 0, "index": 0,
"logprobs": None, "logprobs": None,
"finish_reason": None, "finish_reason": None,
@ -484,7 +484,7 @@ class Llama:
"model": self.model_path, "model": self.model_path,
"choices": [ "choices": [
{ {
"text": text[returned_characters:].decode("utf-8"), "text": text[returned_characters:].decode("utf-8", errors="ignore"),
"index": 0, "index": 0,
"logprobs": None, "logprobs": None,
"finish_reason": finish_reason, "finish_reason": finish_reason,
@ -496,7 +496,7 @@ class Llama:
### HACK ### HACK
self._completion_bytes.append(text) self._completion_bytes.append(text)
### ###
text_str = text.decode("utf-8") text_str = text.decode("utf-8", errors="ignore")
if echo: if echo:
text_str = prompt + text_str text_str = prompt + text_str
@ -514,7 +514,7 @@ class Llama:
all_tokens = prompt_tokens + completion_tokens all_tokens = prompt_tokens + completion_tokens
all_token_strs = [ all_token_strs = [
self.detokenize([token]).decode("utf-8") for token in all_tokens self.detokenize([token]).decode("utf-8", errors="ignore") for token in all_tokens
] ]
all_logprobs = [ all_logprobs = [
[Llama.logit_to_logprob(logit) for logit in row] [Llama.logit_to_logprob(logit) for logit in row]
@ -533,7 +533,7 @@ class Llama:
) )
token_logprobs.append(sorted_logprobs[int(token)][0]) token_logprobs.append(sorted_logprobs[int(token)][0])
top_logprob = { top_logprob = {
self.detokenize([llama_cpp.llama_token(i)]).decode("utf-8"): logprob self.detokenize([llama_cpp.llama_token(i)]).decode("utf-8", errors="ignore"): logprob
for logprob, i in sorted_logprobs[:logprobs] for logprob, i in sorted_logprobs[:logprobs]
} }
top_logprob.update({token_str: sorted_logprobs[int(token)][0]}) top_logprob.update({token_str: sorted_logprobs[int(token)][0]})