tests: Improve llama.cpp mock

This commit is contained in:
Andrei Betlen 2023-11-20 23:23:18 -05:00
parent 63fe1370ed
commit 3dc21b2557

View file

@ -37,77 +37,106 @@ def test_llama_cpp_tokenization():
assert tokens[-1] == llama.token_eos() assert tokens[-1] == llama.token_eos()
assert tokens == [1, 15043, 2787, 2] assert tokens == [1, 15043, 2787, 2]
text = b""
tokens = llama.tokenize(text, add_bos=True, special=True)
assert tokens[-1] != llama.token_eos()
assert tokens == [llama.token_bos()]
assert text == llama.detokenize(tokens)
def test_llama_patch(monkeypatch):
@pytest.fixture
def mock_llama(monkeypatch):
def setup_mock(llama: llama_cpp.Llama, output_text: str):
llama.reset()
n_vocab = llama.n_vocab()
output_tokens = llama.tokenize(
output_text.encode("utf-8"), add_bos=True, special=True
)
n = 0
last_n_tokens = 0
def mock_decode(ctx: llama_cpp.llama_context_p, batch: llama_cpp.llama_batch):
nonlocal n
nonlocal last_n_tokens
# Test some basic invariants of this mocking technique
assert ctx == llama._ctx.ctx
assert llama.n_tokens == n
assert batch.n_tokens > 0
n += batch.n_tokens
last_n_tokens = batch.n_tokens
return 0
def mock_get_logits(*args, **kwargs):
nonlocal last_n_tokens
size = n_vocab * last_n_tokens
return (llama_cpp.c_float * size)()
def mock_sample(*args, **kwargs):
nonlocal n
if n < len(output_tokens):
return output_tokens[n]
else:
return llama.token_eos()
monkeypatch.setattr("llama_cpp.llama_cpp.llama_decode", mock_decode)
monkeypatch.setattr("llama_cpp.llama_cpp.llama_get_logits", mock_get_logits)
monkeypatch.setattr("llama_cpp.llama_cpp.llama_sample_token", mock_sample)
return setup_mock
def test_llama_patch(mock_llama):
n_ctx = 128 n_ctx = 128
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True, n_ctx=n_ctx) llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True, n_ctx=n_ctx)
n_vocab = llama_cpp.llama_n_vocab(llama._model.model) n_vocab = llama_cpp.llama_n_vocab(llama._model.model)
assert n_vocab == 32000 assert n_vocab == 32000
## Set up mock function
def mock_decode(*args, **kwargs):
return 0
def mock_get_logits(*args, **kwargs):
size = n_vocab * n_ctx
return (llama_cpp.c_float * size)()
monkeypatch.setattr("llama_cpp.llama_cpp.llama_decode", mock_decode)
monkeypatch.setattr("llama_cpp.llama_cpp.llama_get_logits", mock_get_logits)
text = "The quick brown fox" text = "The quick brown fox"
text_tokens = llama.tokenize(text.encode("utf-8"), add_bos=True, special=True)
output_text = " jumps over the lazy dog." output_text = " jumps over the lazy dog."
all_text_tokens = llama.tokenize((text + output_text).encode("utf-8"), add_bos=True, special=True) all_text = text + output_text
output_tokens = all_text_tokens[len(text_tokens):]
token_eos = llama.token_eos()
n = 0
def mock_sample(*args, **kwargs):
nonlocal n
if n < len(output_tokens):
n += 1
return output_tokens[n - 1]
else:
return token_eos
monkeypatch.setattr("llama_cpp.llama_cpp.llama_sample_token", mock_sample)
## Test basic completion from bos until eos
mock_llama(llama, all_text)
completion = llama.create_completion("", max_tokens=36)
assert completion["choices"][0]["text"] == all_text
assert completion["choices"][0]["finish_reason"] == "stop"
## Test basic completion until eos ## Test basic completion until eos
n = 0 # reset mock_llama(llama, all_text)
completion = llama.create_completion(text, max_tokens=20) completion = llama.create_completion(text, max_tokens=20)
assert completion["choices"][0]["text"] == output_text assert completion["choices"][0]["text"] == output_text
assert completion["choices"][0]["finish_reason"] == "stop" assert completion["choices"][0]["finish_reason"] == "stop"
## Test streaming completion until eos ## Test streaming completion until eos
n = 0 # reset mock_llama(llama, all_text)
chunks = list(llama.create_completion(text, max_tokens=20, stream=True)) chunks = list(llama.create_completion(text, max_tokens=20, stream=True))
assert "".join(chunk["choices"][0]["text"] for chunk in chunks) == output_text assert "".join(chunk["choices"][0]["text"] for chunk in chunks) == output_text
assert chunks[-1]["choices"][0]["finish_reason"] == "stop" assert chunks[-1]["choices"][0]["finish_reason"] == "stop"
## Test basic completion until stop sequence ## Test basic completion until stop sequence
n = 0 # reset mock_llama(llama, all_text)
completion = llama.create_completion(text, max_tokens=20, stop=["lazy"]) completion = llama.create_completion(text, max_tokens=20, stop=["lazy"])
assert completion["choices"][0]["text"] == " jumps over the " assert completion["choices"][0]["text"] == " jumps over the "
assert completion["choices"][0]["finish_reason"] == "stop" assert completion["choices"][0]["finish_reason"] == "stop"
## Test streaming completion until stop sequence ## Test streaming completion until stop sequence
n = 0 # reset mock_llama(llama, all_text)
chunks = list(llama.create_completion(text, max_tokens=20, stream=True, stop=["lazy"])) chunks = list(
llama.create_completion(text, max_tokens=20, stream=True, stop=["lazy"])
)
assert ( assert (
"".join(chunk["choices"][0]["text"] for chunk in chunks) == " jumps over the " "".join(chunk["choices"][0]["text"] for chunk in chunks) == " jumps over the "
) )
assert chunks[-1]["choices"][0]["finish_reason"] == "stop" assert chunks[-1]["choices"][0]["finish_reason"] == "stop"
## Test basic completion until length ## Test basic completion until length
n = 0 # reset mock_llama(llama, all_text)
completion = llama.create_completion(text, max_tokens=2) completion = llama.create_completion(text, max_tokens=2)
assert completion["choices"][0]["text"] == " jumps" assert completion["choices"][0]["text"] == " jumps"
assert completion["choices"][0]["finish_reason"] == "length" assert completion["choices"][0]["finish_reason"] == "length"
## Test streaming completion until length ## Test streaming completion until length
n = 0 # reset mock_llama(llama, all_text)
chunks = list(llama.create_completion(text, max_tokens=2, stream=True)) chunks = list(llama.create_completion(text, max_tokens=2, stream=True))
assert "".join(chunk["choices"][0]["text"] for chunk in chunks) == " jumps" assert "".join(chunk["choices"][0]["text"] for chunk in chunks) == " jumps"
assert chunks[-1]["choices"][0]["finish_reason"] == "length" assert chunks[-1]["choices"][0]["finish_reason"] == "length"
@ -131,44 +160,55 @@ def test_llama_pickle():
assert llama.detokenize(llama.tokenize(text)) == text assert llama.detokenize(llama.tokenize(text)) == text
def test_utf8(monkeypatch): def test_utf8(mock_llama, monkeypatch):
n_ctx = 512 llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True, logits_all=True)
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True, n_ctx=n_ctx, logits_all=True) n_ctx = llama.n_ctx()
n_vocab = llama.n_vocab() n_vocab = llama.n_vocab()
output_text = "😀"
output_tokens = llama.tokenize(
output_text.encode("utf-8"), add_bos=True, special=True
)
token_eos = llama.token_eos()
n = 0
def reset():
nonlocal n
llama.reset()
n = 0
## Set up mock function ## Set up mock function
def mock_decode(*args, **kwargs): def mock_decode(ctx: llama_cpp.llama_context_p, batch: llama_cpp.llama_batch):
nonlocal n
assert batch.n_tokens > 0
assert llama.n_tokens == n
n += batch.n_tokens
return 0 return 0
def mock_get_logits(*args, **kwargs): def mock_get_logits(*args, **kwargs):
size = n_vocab * n_ctx size = n_vocab * n_ctx
return (llama_cpp.c_float * size)() return (llama_cpp.c_float * size)()
monkeypatch.setattr("llama_cpp.llama_cpp.llama_decode", mock_decode)
monkeypatch.setattr("llama_cpp.llama_cpp.llama_get_logits", mock_get_logits)
output_text = "😀"
output_tokens = llama.tokenize(output_text.encode("utf-8"))
token_eos = llama.token_eos()
n = 0
def mock_sample(*args, **kwargs): def mock_sample(*args, **kwargs):
nonlocal n nonlocal n
if n < len(output_tokens): if n <= len(output_tokens):
n += 1
return output_tokens[n - 1] return output_tokens[n - 1]
else: else:
return token_eos return token_eos
monkeypatch.setattr("llama_cpp.llama_cpp.llama_decode", mock_decode)
monkeypatch.setattr("llama_cpp.llama_cpp.llama_get_logits", mock_get_logits)
monkeypatch.setattr("llama_cpp.llama_cpp.llama_sample_token", mock_sample) monkeypatch.setattr("llama_cpp.llama_cpp.llama_sample_token", mock_sample)
## Test basic completion with utf8 multibyte ## Test basic completion with utf8 multibyte
n = 0 # reset # mock_llama(llama, output_text)
reset()
completion = llama.create_completion("", max_tokens=4) completion = llama.create_completion("", max_tokens=4)
assert completion["choices"][0]["text"] == output_text assert completion["choices"][0]["text"] == output_text
## Test basic completion with incomplete utf8 multibyte ## Test basic completion with incomplete utf8 multibyte
n = 0 # reset # mock_llama(llama, output_text)
reset()
completion = llama.create_completion("", max_tokens=1) completion = llama.create_completion("", max_tokens=1)
assert completion["choices"][0]["text"] == "" assert completion["choices"][0]["text"] == ""
@ -196,5 +236,6 @@ def test_llama_server():
], ],
} }
def test_llama_cpp_version(): def test_llama_cpp_version():
assert llama_cpp.__version__ assert llama_cpp.__version__