Use mock_llama for all tests

This commit is contained in:
Andrei Betlen 2023-11-21 18:13:19 -05:00
parent dbfaf53fe0
commit d7388f1ffb

View file

@ -160,55 +160,18 @@ def test_llama_pickle():
assert llama.detokenize(llama.tokenize(text)) == text
def test_utf8(mock_llama, monkeypatch):
def test_utf8(mock_llama):
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True, logits_all=True)
n_ctx = llama.n_ctx()
n_vocab = llama.n_vocab()
output_text = "😀"
output_tokens = llama.tokenize(
output_text.encode("utf-8"), add_bos=True, special=True
)
token_eos = llama.token_eos()
n = 0
def reset():
nonlocal n
llama.reset()
n = 0
## Set up mock function
def mock_decode(ctx: llama_cpp.llama_context_p, batch: llama_cpp.llama_batch):
nonlocal n
assert batch.n_tokens > 0
assert llama.n_tokens == n
n += batch.n_tokens
return 0
def mock_get_logits(*args, **kwargs):
size = n_vocab * n_ctx
return (llama_cpp.c_float * size)()
def mock_sample(*args, **kwargs):
nonlocal n
if n <= len(output_tokens):
return output_tokens[n - 1]
else:
return token_eos
monkeypatch.setattr("llama_cpp.llama_cpp.llama_decode", mock_decode)
monkeypatch.setattr("llama_cpp.llama_cpp.llama_get_logits", mock_get_logits)
monkeypatch.setattr("llama_cpp.llama_cpp.llama_sample_token", mock_sample)
## Test basic completion with utf8 multibyte
# mock_llama(llama, output_text)
reset()
mock_llama(llama, output_text)
completion = llama.create_completion("", max_tokens=4)
assert completion["choices"][0]["text"] == output_text
## Test basic completion with incomplete utf8 multibyte
# mock_llama(llama, output_text)
reset()
mock_llama(llama, output_text)
completion = llama.create_completion("", max_tokens=1)
assert completion["choices"][0]["text"] == ""