Remove excessive errors="ignore" and add utf8 test

This commit is contained in:
Mug 2023-04-29 12:19:22 +02:00
parent b7d14efc8b
commit 18a0c10032
2 changed files with 39 additions and 5 deletions

View file

@ -358,7 +358,7 @@ class Llama:
if self.verbose:
llama_cpp.llama_reset_timings(self.ctx)
tokens = self.tokenize(input.encode("utf-8", errors="ignore"))
tokens = self.tokenize(input.encode("utf-8"))
self.reset()
self.eval(tokens)
n_tokens = len(tokens)
@ -416,7 +416,7 @@ class Llama:
completion_tokens: List[llama_cpp.llama_token] = []
# Add blank space to start of prompt to match OG llama tokenizer
prompt_tokens: List[llama_cpp.llama_token] = self.tokenize(
b" " + prompt.encode("utf-8", errors="ignore")
b" " + prompt.encode("utf-8")
)
text: bytes = b""
returned_characters: int = 0
@ -431,7 +431,7 @@ class Llama:
)
if stop != []:
stop_sequences = [s.encode("utf-8", errors="ignore") for s in stop]
stop_sequences = [s.encode("utf-8") for s in stop]
else:
stop_sequences = []

View file

@ -24,7 +24,7 @@ def test_llama_patch(monkeypatch):
monkeypatch.setattr("llama_cpp.llama_cpp.llama_eval", mock_eval)
output_text = " jumps over the lazy dog."
output_tokens = llama.tokenize(output_text.encode("utf-8", errors="ignore"))
output_tokens = llama.tokenize(output_text.encode("utf-8"))
token_eos = llama.token_eos()
n = 0
@ -93,4 +93,38 @@ def test_llama_pickle():
text = b"Hello World"
assert llama.detokenize(llama.tokenize(text)) == text
assert llama.detokenize(llama.tokenize(text)) == text
def test_utf8(monkeypatch):
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True)
## Set up mock function
def mock_eval(*args, **kwargs):
return 0
monkeypatch.setattr("llama_cpp.llama_cpp.llama_eval", mock_eval)
output_text = "😀"
output_tokens = llama.tokenize(output_text.encode("utf-8"))
token_eos = llama.token_eos()
n = 0
def mock_sample(*args, **kwargs):
nonlocal n
if n < len(output_tokens):
n += 1
return output_tokens[n - 1]
else:
return token_eos
monkeypatch.setattr("llama_cpp.llama_cpp.llama_sample_top_p_top_k", mock_sample)
## Test basic completion with utf8 multibyte
n = 0 # reset
completion = llama.create_completion("", max_tokens=4)
assert completion["choices"][0]["text"] == output_text
## Test basic completion with incomplete utf8 multibyte
n = 0 # reset
completion = llama.create_completion("", max_tokens=1)
assert completion["choices"][0]["text"] == ""