tests: avoid constantly reallocating logits

This commit is contained in:
Andrei Betlen 2023-11-22 04:31:05 -05:00
parent 0a7e05bc10
commit 0ea244499e

View file

@ -47,16 +47,22 @@ def test_llama_cpp_tokenization():
@pytest.fixture @pytest.fixture
def mock_llama(monkeypatch): def mock_llama(monkeypatch):
def setup_mock(llama: llama_cpp.Llama, output_text: str): def setup_mock(llama: llama_cpp.Llama, output_text: str):
n_ctx = llama.n_ctx()
n_vocab = llama.n_vocab() n_vocab = llama.n_vocab()
output_tokens = llama.tokenize( output_tokens = llama.tokenize(
output_text.encode("utf-8"), add_bos=True, special=True output_text.encode("utf-8"), add_bos=True, special=True
) )
logits = (llama_cpp.c_float * (n_vocab * n_ctx))(-100.0)
for i in range(n_ctx):
output_idx = i + 1 # logits for first tokens predict second token
if output_idx < len(output_tokens):
logits[i * n_vocab + output_tokens[output_idx]] = 100.0
else:
logits[i * n_vocab + llama.token_eos()] = 100.0
n = 0 n = 0
last_n_tokens = 0 last_n_tokens = 0
def mock_decode(ctx: llama_cpp.llama_context_p, batch: llama_cpp.llama_batch): def mock_decode(ctx: llama_cpp.llama_context_p, batch: llama_cpp.llama_batch):
nonlocal n
nonlocal last_n_tokens
# Test some basic invariants of this mocking technique # Test some basic invariants of this mocking technique
assert ctx == llama._ctx.ctx, "context does not match mock_llama" assert ctx == llama._ctx.ctx, "context does not match mock_llama"
assert batch.n_tokens > 0, "no tokens in batch" assert batch.n_tokens > 0, "no tokens in batch"
@ -70,26 +76,22 @@ def mock_llama(monkeypatch):
batch.n_tokens - 1 batch.n_tokens - 1
], "logits not allocated for last token" ], "logits not allocated for last token"
# Update the mock context state # Update the mock context state
nonlocal n
nonlocal last_n_tokens
n = max(batch.pos[i] for i in range(batch.n_tokens)) + 1 n = max(batch.pos[i] for i in range(batch.n_tokens)) + 1
last_n_tokens = batch.n_tokens last_n_tokens = batch.n_tokens
return 0 return 0
def mock_get_logits(*args, **kwargs): def mock_get_logits(ctx: llama_cpp.llama_context_p):
nonlocal n # Test some basic invariants of this mocking technique
nonlocal last_n_tokens assert ctx == llama._ctx.ctx, "context does not match mock_llama"
assert n > 0, "mock_llama_decode not called" assert n > 0, "mock_llama_decode not called"
assert last_n_tokens > 0, "mock_llama_decode not called" assert last_n_tokens > 0, "mock_llama_decode not called"
logits = (llama_cpp.c_float * (last_n_tokens * n_vocab))(-100.0) # Return view of logits for last_n_tokens
for logits_idx, output_idx in enumerate( return (llama_cpp.c_float * (last_n_tokens * n_vocab)).from_address(
range(n - last_n_tokens + 1, n + 1) ctypes.addressof(logits)
): + (n - last_n_tokens) * n_vocab * ctypes.sizeof(llama_cpp.c_float)
if output_idx < len(output_tokens): )
logits[
logits_idx * last_n_tokens + output_tokens[output_idx]
] = 100.0
else:
logits[logits_idx * last_n_tokens + llama.token_eos()] = 100.0
return logits
monkeypatch.setattr("llama_cpp.llama_cpp.llama_decode", mock_decode) monkeypatch.setattr("llama_cpp.llama_cpp.llama_decode", mock_decode)
monkeypatch.setattr("llama_cpp.llama_cpp.llama_get_logits", mock_get_logits) monkeypatch.setattr("llama_cpp.llama_cpp.llama_get_logits", mock_get_logits)