Merge pull request #118 from SagsMug/main
Fix UnicodeDecodeError permanently
This commit is contained in:
commit
755f9fa455
4 changed files with 59 additions and 10 deletions
|
@ -201,7 +201,7 @@ n_keep = {self.params.n_keep}
|
||||||
# tokenize a prompt
|
# tokenize a prompt
|
||||||
def _tokenize(self, prompt, bos=True):
|
def _tokenize(self, prompt, bos=True):
|
||||||
_arr = (llama_cpp.llama_token * (len(prompt) + 1))()
|
_arr = (llama_cpp.llama_token * (len(prompt) + 1))()
|
||||||
_n = llama_cpp.llama_tokenize(self.ctx, prompt.encode("utf8"), _arr, len(_arr), bos)
|
_n = llama_cpp.llama_tokenize(self.ctx, prompt.encode("utf8", errors="ignore"), _arr, len(_arr), bos)
|
||||||
return _arr[:_n]
|
return _arr[:_n]
|
||||||
|
|
||||||
def set_color(self, c):
|
def set_color(self, c):
|
||||||
|
@ -342,7 +342,7 @@ n_keep = {self.params.n_keep}
|
||||||
# return past text
|
# return past text
|
||||||
def past(self):
|
def past(self):
|
||||||
for id in self.last_n_tokens[-self.n_past:]:
|
for id in self.last_n_tokens[-self.n_past:]:
|
||||||
yield llama_cpp.llama_token_to_str(self.ctx, id).decode("utf-8")
|
yield llama_cpp.llama_token_to_str(self.ctx, id).decode("utf-8", errors="ignore")
|
||||||
|
|
||||||
# write input
|
# write input
|
||||||
def input(self, prompt: str):
|
def input(self, prompt: str):
|
||||||
|
@ -356,7 +356,7 @@ n_keep = {self.params.n_keep}
|
||||||
def output(self):
|
def output(self):
|
||||||
self.remaining_tokens = self.params.n_predict
|
self.remaining_tokens = self.params.n_predict
|
||||||
for id in self.generate():
|
for id in self.generate():
|
||||||
yield llama_cpp.llama_token_to_str(self.ctx, id).decode("utf-8")
|
yield llama_cpp.llama_token_to_str(self.ctx, id).decode("utf-8", errors="ignore")
|
||||||
|
|
||||||
# read user input
|
# read user input
|
||||||
def read_input(self):
|
def read_input(self):
|
||||||
|
|
|
@ -70,7 +70,7 @@ while remaining_tokens > 0:
|
||||||
if not input_noecho:
|
if not input_noecho:
|
||||||
for id in embd:
|
for id in embd:
|
||||||
print(
|
print(
|
||||||
llama_cpp.llama_token_to_str(ctx, id).decode("utf-8"),
|
llama_cpp.llama_token_to_str(ctx, id).decode("utf-8", errors="ignore"),
|
||||||
end="",
|
end="",
|
||||||
flush=True,
|
flush=True,
|
||||||
)
|
)
|
||||||
|
|
|
@ -446,6 +446,7 @@ class Llama:
|
||||||
self.load_state(self.cache[prompt_tokens])
|
self.load_state(self.cache[prompt_tokens])
|
||||||
|
|
||||||
finish_reason = "length"
|
finish_reason = "length"
|
||||||
|
multibyte_fix = 0
|
||||||
for token in self.generate(
|
for token in self.generate(
|
||||||
prompt_tokens,
|
prompt_tokens,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
|
@ -467,6 +468,20 @@ class Llama:
|
||||||
completion_tokens.append(token)
|
completion_tokens.append(token)
|
||||||
|
|
||||||
all_text = self.detokenize(completion_tokens)
|
all_text = self.detokenize(completion_tokens)
|
||||||
|
|
||||||
|
# Contains multi-byte UTF8
|
||||||
|
for k,char in enumerate(all_text[-3:]):
|
||||||
|
k = 3 - k
|
||||||
|
for num,pattern in [(2, 192), (3, 224), (4, 240)]:
|
||||||
|
# Bitwise AND check
|
||||||
|
if (num > k and pattern & char == pattern):
|
||||||
|
multibyte_fix = num - k
|
||||||
|
|
||||||
|
# Stop incomplete bytes from passing
|
||||||
|
if (multibyte_fix > 0):
|
||||||
|
multibyte_fix -= 1
|
||||||
|
continue
|
||||||
|
|
||||||
any_stop = [s for s in stop_sequences if s in all_text]
|
any_stop = [s for s in stop_sequences if s in all_text]
|
||||||
if len(any_stop) > 0:
|
if len(any_stop) > 0:
|
||||||
first_stop = any_stop[0]
|
first_stop = any_stop[0]
|
||||||
|
@ -495,7 +510,7 @@ class Llama:
|
||||||
"model": self.model_path,
|
"model": self.model_path,
|
||||||
"choices": [
|
"choices": [
|
||||||
{
|
{
|
||||||
"text": text[start:].decode("utf-8"),
|
"text": text[start:].decode("utf-8", errors="ignore"),
|
||||||
"index": 0,
|
"index": 0,
|
||||||
"logprobs": None,
|
"logprobs": None,
|
||||||
"finish_reason": None,
|
"finish_reason": None,
|
||||||
|
@ -516,7 +531,7 @@ class Llama:
|
||||||
"model": self.model_path,
|
"model": self.model_path,
|
||||||
"choices": [
|
"choices": [
|
||||||
{
|
{
|
||||||
"text": text[returned_characters:].decode("utf-8"),
|
"text": text[returned_characters:].decode("utf-8", errors="ignore"),
|
||||||
"index": 0,
|
"index": 0,
|
||||||
"logprobs": None,
|
"logprobs": None,
|
||||||
"finish_reason": finish_reason,
|
"finish_reason": finish_reason,
|
||||||
|
@ -525,7 +540,7 @@ class Llama:
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
|
|
||||||
text_str = text.decode("utf-8")
|
text_str = text.decode("utf-8", errors="ignore")
|
||||||
|
|
||||||
if echo:
|
if echo:
|
||||||
text_str = prompt + text_str
|
text_str = prompt + text_str
|
||||||
|
@ -543,7 +558,7 @@ class Llama:
|
||||||
|
|
||||||
all_tokens = prompt_tokens + completion_tokens
|
all_tokens = prompt_tokens + completion_tokens
|
||||||
all_token_strs = [
|
all_token_strs = [
|
||||||
self.detokenize([token]).decode("utf-8") for token in all_tokens
|
self.detokenize([token]).decode("utf-8", errors="ignore") for token in all_tokens
|
||||||
]
|
]
|
||||||
all_logprobs = [
|
all_logprobs = [
|
||||||
[Llama.logit_to_logprob(logit) for logit in row]
|
[Llama.logit_to_logprob(logit) for logit in row]
|
||||||
|
@ -562,7 +577,7 @@ class Llama:
|
||||||
)
|
)
|
||||||
token_logprobs.append(sorted_logprobs[int(token)][0])
|
token_logprobs.append(sorted_logprobs[int(token)][0])
|
||||||
top_logprob = {
|
top_logprob = {
|
||||||
self.detokenize([llama_cpp.llama_token(i)]).decode("utf-8"): logprob
|
self.detokenize([llama_cpp.llama_token(i)]).decode("utf-8", errors="ignore"): logprob
|
||||||
for logprob, i in sorted_logprobs[:logprobs]
|
for logprob, i in sorted_logprobs[:logprobs]
|
||||||
}
|
}
|
||||||
top_logprob.update({token_str: sorted_logprobs[int(token)][0]})
|
top_logprob.update({token_str: sorted_logprobs[int(token)][0]})
|
||||||
|
|
|
@ -93,4 +93,38 @@ def test_llama_pickle():
|
||||||
|
|
||||||
text = b"Hello World"
|
text = b"Hello World"
|
||||||
|
|
||||||
assert llama.detokenize(llama.tokenize(text)) == text
|
assert llama.detokenize(llama.tokenize(text)) == text
|
||||||
|
|
||||||
|
def test_utf8(monkeypatch):
|
||||||
|
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True)
|
||||||
|
|
||||||
|
## Set up mock function
|
||||||
|
def mock_eval(*args, **kwargs):
|
||||||
|
return 0
|
||||||
|
|
||||||
|
monkeypatch.setattr("llama_cpp.llama_cpp.llama_eval", mock_eval)
|
||||||
|
|
||||||
|
output_text = "😀"
|
||||||
|
output_tokens = llama.tokenize(output_text.encode("utf-8"))
|
||||||
|
token_eos = llama.token_eos()
|
||||||
|
n = 0
|
||||||
|
|
||||||
|
def mock_sample(*args, **kwargs):
|
||||||
|
nonlocal n
|
||||||
|
if n < len(output_tokens):
|
||||||
|
n += 1
|
||||||
|
return output_tokens[n - 1]
|
||||||
|
else:
|
||||||
|
return token_eos
|
||||||
|
|
||||||
|
monkeypatch.setattr("llama_cpp.llama_cpp.llama_sample_top_p_top_k", mock_sample)
|
||||||
|
|
||||||
|
## Test basic completion with utf8 multibyte
|
||||||
|
n = 0 # reset
|
||||||
|
completion = llama.create_completion("", max_tokens=4)
|
||||||
|
assert completion["choices"][0]["text"] == output_text
|
||||||
|
|
||||||
|
## Test basic completion with incomplete utf8 multibyte
|
||||||
|
n = 0 # reset
|
||||||
|
completion = llama.create_completion("", max_tokens=1)
|
||||||
|
assert completion["choices"][0]["text"] == ""
|
||||||
|
|
Loading…
Reference in a new issue