tests: Improve llama.cpp mock
This commit is contained in:
parent
63fe1370ed
commit
3dc21b2557
1 changed files with 91 additions and 50 deletions
|
@ -37,77 +37,106 @@ def test_llama_cpp_tokenization():
|
||||||
assert tokens[-1] == llama.token_eos()
|
assert tokens[-1] == llama.token_eos()
|
||||||
assert tokens == [1, 15043, 2787, 2]
|
assert tokens == [1, 15043, 2787, 2]
|
||||||
|
|
||||||
|
text = b""
|
||||||
|
tokens = llama.tokenize(text, add_bos=True, special=True)
|
||||||
|
assert tokens[-1] != llama.token_eos()
|
||||||
|
assert tokens == [llama.token_bos()]
|
||||||
|
assert text == llama.detokenize(tokens)
|
||||||
|
|
||||||
def test_llama_patch(monkeypatch):
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_llama(monkeypatch):
|
||||||
|
def setup_mock(llama: llama_cpp.Llama, output_text: str):
|
||||||
|
llama.reset()
|
||||||
|
n_vocab = llama.n_vocab()
|
||||||
|
output_tokens = llama.tokenize(
|
||||||
|
output_text.encode("utf-8"), add_bos=True, special=True
|
||||||
|
)
|
||||||
|
n = 0
|
||||||
|
last_n_tokens = 0
|
||||||
|
|
||||||
|
def mock_decode(ctx: llama_cpp.llama_context_p, batch: llama_cpp.llama_batch):
|
||||||
|
nonlocal n
|
||||||
|
nonlocal last_n_tokens
|
||||||
|
# Test some basic invariants of this mocking technique
|
||||||
|
assert ctx == llama._ctx.ctx
|
||||||
|
assert llama.n_tokens == n
|
||||||
|
assert batch.n_tokens > 0
|
||||||
|
n += batch.n_tokens
|
||||||
|
last_n_tokens = batch.n_tokens
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def mock_get_logits(*args, **kwargs):
|
||||||
|
nonlocal last_n_tokens
|
||||||
|
size = n_vocab * last_n_tokens
|
||||||
|
return (llama_cpp.c_float * size)()
|
||||||
|
|
||||||
|
def mock_sample(*args, **kwargs):
|
||||||
|
nonlocal n
|
||||||
|
if n < len(output_tokens):
|
||||||
|
return output_tokens[n]
|
||||||
|
else:
|
||||||
|
return llama.token_eos()
|
||||||
|
|
||||||
|
monkeypatch.setattr("llama_cpp.llama_cpp.llama_decode", mock_decode)
|
||||||
|
monkeypatch.setattr("llama_cpp.llama_cpp.llama_get_logits", mock_get_logits)
|
||||||
|
monkeypatch.setattr("llama_cpp.llama_cpp.llama_sample_token", mock_sample)
|
||||||
|
|
||||||
|
return setup_mock
|
||||||
|
|
||||||
|
|
||||||
|
def test_llama_patch(mock_llama):
|
||||||
n_ctx = 128
|
n_ctx = 128
|
||||||
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True, n_ctx=n_ctx)
|
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True, n_ctx=n_ctx)
|
||||||
n_vocab = llama_cpp.llama_n_vocab(llama._model.model)
|
n_vocab = llama_cpp.llama_n_vocab(llama._model.model)
|
||||||
assert n_vocab == 32000
|
assert n_vocab == 32000
|
||||||
|
|
||||||
## Set up mock function
|
|
||||||
def mock_decode(*args, **kwargs):
|
|
||||||
return 0
|
|
||||||
|
|
||||||
def mock_get_logits(*args, **kwargs):
|
|
||||||
size = n_vocab * n_ctx
|
|
||||||
return (llama_cpp.c_float * size)()
|
|
||||||
|
|
||||||
monkeypatch.setattr("llama_cpp.llama_cpp.llama_decode", mock_decode)
|
|
||||||
monkeypatch.setattr("llama_cpp.llama_cpp.llama_get_logits", mock_get_logits)
|
|
||||||
|
|
||||||
text = "The quick brown fox"
|
text = "The quick brown fox"
|
||||||
text_tokens = llama.tokenize(text.encode("utf-8"), add_bos=True, special=True)
|
|
||||||
output_text = " jumps over the lazy dog."
|
output_text = " jumps over the lazy dog."
|
||||||
all_text_tokens = llama.tokenize((text + output_text).encode("utf-8"), add_bos=True, special=True)
|
all_text = text + output_text
|
||||||
output_tokens = all_text_tokens[len(text_tokens):]
|
|
||||||
token_eos = llama.token_eos()
|
|
||||||
n = 0
|
|
||||||
|
|
||||||
def mock_sample(*args, **kwargs):
|
|
||||||
nonlocal n
|
|
||||||
if n < len(output_tokens):
|
|
||||||
n += 1
|
|
||||||
return output_tokens[n - 1]
|
|
||||||
else:
|
|
||||||
return token_eos
|
|
||||||
|
|
||||||
monkeypatch.setattr("llama_cpp.llama_cpp.llama_sample_token", mock_sample)
|
|
||||||
|
|
||||||
|
## Test basic completion from bos until eos
|
||||||
|
mock_llama(llama, all_text)
|
||||||
|
completion = llama.create_completion("", max_tokens=36)
|
||||||
|
assert completion["choices"][0]["text"] == all_text
|
||||||
|
assert completion["choices"][0]["finish_reason"] == "stop"
|
||||||
|
|
||||||
## Test basic completion until eos
|
## Test basic completion until eos
|
||||||
n = 0 # reset
|
mock_llama(llama, all_text)
|
||||||
completion = llama.create_completion(text, max_tokens=20)
|
completion = llama.create_completion(text, max_tokens=20)
|
||||||
assert completion["choices"][0]["text"] == output_text
|
assert completion["choices"][0]["text"] == output_text
|
||||||
assert completion["choices"][0]["finish_reason"] == "stop"
|
assert completion["choices"][0]["finish_reason"] == "stop"
|
||||||
|
|
||||||
## Test streaming completion until eos
|
## Test streaming completion until eos
|
||||||
n = 0 # reset
|
mock_llama(llama, all_text)
|
||||||
chunks = list(llama.create_completion(text, max_tokens=20, stream=True))
|
chunks = list(llama.create_completion(text, max_tokens=20, stream=True))
|
||||||
assert "".join(chunk["choices"][0]["text"] for chunk in chunks) == output_text
|
assert "".join(chunk["choices"][0]["text"] for chunk in chunks) == output_text
|
||||||
assert chunks[-1]["choices"][0]["finish_reason"] == "stop"
|
assert chunks[-1]["choices"][0]["finish_reason"] == "stop"
|
||||||
|
|
||||||
## Test basic completion until stop sequence
|
## Test basic completion until stop sequence
|
||||||
n = 0 # reset
|
mock_llama(llama, all_text)
|
||||||
completion = llama.create_completion(text, max_tokens=20, stop=["lazy"])
|
completion = llama.create_completion(text, max_tokens=20, stop=["lazy"])
|
||||||
assert completion["choices"][0]["text"] == " jumps over the "
|
assert completion["choices"][0]["text"] == " jumps over the "
|
||||||
assert completion["choices"][0]["finish_reason"] == "stop"
|
assert completion["choices"][0]["finish_reason"] == "stop"
|
||||||
|
|
||||||
## Test streaming completion until stop sequence
|
## Test streaming completion until stop sequence
|
||||||
n = 0 # reset
|
mock_llama(llama, all_text)
|
||||||
chunks = list(llama.create_completion(text, max_tokens=20, stream=True, stop=["lazy"]))
|
chunks = list(
|
||||||
|
llama.create_completion(text, max_tokens=20, stream=True, stop=["lazy"])
|
||||||
|
)
|
||||||
assert (
|
assert (
|
||||||
"".join(chunk["choices"][0]["text"] for chunk in chunks) == " jumps over the "
|
"".join(chunk["choices"][0]["text"] for chunk in chunks) == " jumps over the "
|
||||||
)
|
)
|
||||||
assert chunks[-1]["choices"][0]["finish_reason"] == "stop"
|
assert chunks[-1]["choices"][0]["finish_reason"] == "stop"
|
||||||
|
|
||||||
## Test basic completion until length
|
## Test basic completion until length
|
||||||
n = 0 # reset
|
mock_llama(llama, all_text)
|
||||||
completion = llama.create_completion(text, max_tokens=2)
|
completion = llama.create_completion(text, max_tokens=2)
|
||||||
assert completion["choices"][0]["text"] == " jumps"
|
assert completion["choices"][0]["text"] == " jumps"
|
||||||
assert completion["choices"][0]["finish_reason"] == "length"
|
assert completion["choices"][0]["finish_reason"] == "length"
|
||||||
|
|
||||||
## Test streaming completion until length
|
## Test streaming completion until length
|
||||||
n = 0 # reset
|
mock_llama(llama, all_text)
|
||||||
chunks = list(llama.create_completion(text, max_tokens=2, stream=True))
|
chunks = list(llama.create_completion(text, max_tokens=2, stream=True))
|
||||||
assert "".join(chunk["choices"][0]["text"] for chunk in chunks) == " jumps"
|
assert "".join(chunk["choices"][0]["text"] for chunk in chunks) == " jumps"
|
||||||
assert chunks[-1]["choices"][0]["finish_reason"] == "length"
|
assert chunks[-1]["choices"][0]["finish_reason"] == "length"
|
||||||
|
@ -131,44 +160,55 @@ def test_llama_pickle():
|
||||||
assert llama.detokenize(llama.tokenize(text)) == text
|
assert llama.detokenize(llama.tokenize(text)) == text
|
||||||
|
|
||||||
|
|
||||||
def test_utf8(monkeypatch):
|
def test_utf8(mock_llama, monkeypatch):
|
||||||
n_ctx = 512
|
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True, logits_all=True)
|
||||||
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True, n_ctx=n_ctx, logits_all=True)
|
n_ctx = llama.n_ctx()
|
||||||
n_vocab = llama.n_vocab()
|
n_vocab = llama.n_vocab()
|
||||||
|
|
||||||
|
output_text = "😀"
|
||||||
|
output_tokens = llama.tokenize(
|
||||||
|
output_text.encode("utf-8"), add_bos=True, special=True
|
||||||
|
)
|
||||||
|
token_eos = llama.token_eos()
|
||||||
|
n = 0
|
||||||
|
|
||||||
|
def reset():
|
||||||
|
nonlocal n
|
||||||
|
llama.reset()
|
||||||
|
n = 0
|
||||||
|
|
||||||
## Set up mock function
|
## Set up mock function
|
||||||
def mock_decode(*args, **kwargs):
|
def mock_decode(ctx: llama_cpp.llama_context_p, batch: llama_cpp.llama_batch):
|
||||||
|
nonlocal n
|
||||||
|
assert batch.n_tokens > 0
|
||||||
|
assert llama.n_tokens == n
|
||||||
|
n += batch.n_tokens
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
def mock_get_logits(*args, **kwargs):
|
def mock_get_logits(*args, **kwargs):
|
||||||
size = n_vocab * n_ctx
|
size = n_vocab * n_ctx
|
||||||
return (llama_cpp.c_float * size)()
|
return (llama_cpp.c_float * size)()
|
||||||
|
|
||||||
monkeypatch.setattr("llama_cpp.llama_cpp.llama_decode", mock_decode)
|
|
||||||
monkeypatch.setattr("llama_cpp.llama_cpp.llama_get_logits", mock_get_logits)
|
|
||||||
|
|
||||||
output_text = "😀"
|
|
||||||
output_tokens = llama.tokenize(output_text.encode("utf-8"))
|
|
||||||
token_eos = llama.token_eos()
|
|
||||||
n = 0
|
|
||||||
|
|
||||||
def mock_sample(*args, **kwargs):
|
def mock_sample(*args, **kwargs):
|
||||||
nonlocal n
|
nonlocal n
|
||||||
if n < len(output_tokens):
|
if n <= len(output_tokens):
|
||||||
n += 1
|
|
||||||
return output_tokens[n - 1]
|
return output_tokens[n - 1]
|
||||||
else:
|
else:
|
||||||
return token_eos
|
return token_eos
|
||||||
|
|
||||||
|
monkeypatch.setattr("llama_cpp.llama_cpp.llama_decode", mock_decode)
|
||||||
|
monkeypatch.setattr("llama_cpp.llama_cpp.llama_get_logits", mock_get_logits)
|
||||||
monkeypatch.setattr("llama_cpp.llama_cpp.llama_sample_token", mock_sample)
|
monkeypatch.setattr("llama_cpp.llama_cpp.llama_sample_token", mock_sample)
|
||||||
|
|
||||||
## Test basic completion with utf8 multibyte
|
## Test basic completion with utf8 multibyte
|
||||||
n = 0 # reset
|
# mock_llama(llama, output_text)
|
||||||
|
reset()
|
||||||
completion = llama.create_completion("", max_tokens=4)
|
completion = llama.create_completion("", max_tokens=4)
|
||||||
assert completion["choices"][0]["text"] == output_text
|
assert completion["choices"][0]["text"] == output_text
|
||||||
|
|
||||||
## Test basic completion with incomplete utf8 multibyte
|
## Test basic completion with incomplete utf8 multibyte
|
||||||
n = 0 # reset
|
# mock_llama(llama, output_text)
|
||||||
|
reset()
|
||||||
completion = llama.create_completion("", max_tokens=1)
|
completion = llama.create_completion("", max_tokens=1)
|
||||||
assert completion["choices"][0]["text"] == ""
|
assert completion["choices"][0]["text"] == ""
|
||||||
|
|
||||||
|
@ -196,5 +236,6 @@ def test_llama_server():
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def test_llama_cpp_version():
|
def test_llama_cpp_version():
|
||||||
assert llama_cpp.__version__
|
assert llama_cpp.__version__
|
||||||
|
|
Loading…
Reference in a new issue