tests: don't mock sampling functions
This commit is contained in:
parent
d7388f1ffb
commit
0a7e05bc10
1 changed files with 27 additions and 15 deletions
|
@ -47,7 +47,6 @@ def test_llama_cpp_tokenization():
|
|||
@pytest.fixture
|
||||
def mock_llama(monkeypatch):
|
||||
def setup_mock(llama: llama_cpp.Llama, output_text: str):
|
||||
llama.reset()
|
||||
n_vocab = llama.n_vocab()
|
||||
output_tokens = llama.tokenize(
|
||||
output_text.encode("utf-8"), add_bos=True, special=True
|
||||
|
@ -59,28 +58,41 @@ def mock_llama(monkeypatch):
|
|||
nonlocal n
|
||||
nonlocal last_n_tokens
|
||||
# Test some basic invariants of this mocking technique
|
||||
assert ctx == llama._ctx.ctx
|
||||
assert llama.n_tokens == n
|
||||
assert batch.n_tokens > 0
|
||||
n += batch.n_tokens
|
||||
assert ctx == llama._ctx.ctx, "context does not match mock_llama"
|
||||
assert batch.n_tokens > 0, "no tokens in batch"
|
||||
assert all(
|
||||
batch.n_seq_id[i] == 1 for i in range(batch.n_tokens)
|
||||
), "n_seq >1 not supported by mock_llama"
|
||||
assert all(
|
||||
batch.seq_id[i][0] == 0 for i in range(batch.n_tokens)
|
||||
), "n_seq >1 not supported by mock_llama"
|
||||
assert batch.logits[
|
||||
batch.n_tokens - 1
|
||||
], "logits not allocated for last token"
|
||||
# Update the mock context state
|
||||
n = max(batch.pos[i] for i in range(batch.n_tokens)) + 1
|
||||
last_n_tokens = batch.n_tokens
|
||||
return 0
|
||||
|
||||
def mock_get_logits(*args, **kwargs):
|
||||
nonlocal last_n_tokens
|
||||
size = n_vocab * last_n_tokens
|
||||
return (llama_cpp.c_float * size)()
|
||||
|
||||
def mock_sample(*args, **kwargs):
|
||||
nonlocal n
|
||||
if n < len(output_tokens):
|
||||
return output_tokens[n]
|
||||
else:
|
||||
return llama.token_eos()
|
||||
nonlocal last_n_tokens
|
||||
assert n > 0, "mock_llama_decode not called"
|
||||
assert last_n_tokens > 0, "mock_llama_decode not called"
|
||||
logits = (llama_cpp.c_float * (last_n_tokens * n_vocab))(-100.0)
|
||||
for logits_idx, output_idx in enumerate(
|
||||
range(n - last_n_tokens + 1, n + 1)
|
||||
):
|
||||
if output_idx < len(output_tokens):
|
||||
logits[
|
||||
logits_idx * last_n_tokens + output_tokens[output_idx]
|
||||
] = 100.0
|
||||
else:
|
||||
logits[logits_idx * last_n_tokens + llama.token_eos()] = 100.0
|
||||
return logits
|
||||
|
||||
monkeypatch.setattr("llama_cpp.llama_cpp.llama_decode", mock_decode)
|
||||
monkeypatch.setattr("llama_cpp.llama_cpp.llama_get_logits", mock_get_logits)
|
||||
monkeypatch.setattr("llama_cpp.llama_cpp.llama_sample_token", mock_sample)
|
||||
|
||||
return setup_mock
|
||||
|
||||
|
|
Loading…
Reference in a new issue