tests: avoid constantly reallocating logits
This commit is contained in:
parent
0a7e05bc10
commit
0ea244499e
1 changed files with 18 additions and 16 deletions
|
@ -47,16 +47,22 @@ def test_llama_cpp_tokenization():
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_llama(monkeypatch):
|
def mock_llama(monkeypatch):
|
||||||
def setup_mock(llama: llama_cpp.Llama, output_text: str):
|
def setup_mock(llama: llama_cpp.Llama, output_text: str):
|
||||||
|
n_ctx = llama.n_ctx()
|
||||||
n_vocab = llama.n_vocab()
|
n_vocab = llama.n_vocab()
|
||||||
output_tokens = llama.tokenize(
|
output_tokens = llama.tokenize(
|
||||||
output_text.encode("utf-8"), add_bos=True, special=True
|
output_text.encode("utf-8"), add_bos=True, special=True
|
||||||
)
|
)
|
||||||
|
logits = (llama_cpp.c_float * (n_vocab * n_ctx))(-100.0)
|
||||||
|
for i in range(n_ctx):
|
||||||
|
output_idx = i + 1 # logits for first tokens predict second token
|
||||||
|
if output_idx < len(output_tokens):
|
||||||
|
logits[i * n_vocab + output_tokens[output_idx]] = 100.0
|
||||||
|
else:
|
||||||
|
logits[i * n_vocab + llama.token_eos()] = 100.0
|
||||||
n = 0
|
n = 0
|
||||||
last_n_tokens = 0
|
last_n_tokens = 0
|
||||||
|
|
||||||
def mock_decode(ctx: llama_cpp.llama_context_p, batch: llama_cpp.llama_batch):
|
def mock_decode(ctx: llama_cpp.llama_context_p, batch: llama_cpp.llama_batch):
|
||||||
nonlocal n
|
|
||||||
nonlocal last_n_tokens
|
|
||||||
# Test some basic invariants of this mocking technique
|
# Test some basic invariants of this mocking technique
|
||||||
assert ctx == llama._ctx.ctx, "context does not match mock_llama"
|
assert ctx == llama._ctx.ctx, "context does not match mock_llama"
|
||||||
assert batch.n_tokens > 0, "no tokens in batch"
|
assert batch.n_tokens > 0, "no tokens in batch"
|
||||||
|
@ -70,26 +76,22 @@ def mock_llama(monkeypatch):
|
||||||
batch.n_tokens - 1
|
batch.n_tokens - 1
|
||||||
], "logits not allocated for last token"
|
], "logits not allocated for last token"
|
||||||
# Update the mock context state
|
# Update the mock context state
|
||||||
|
nonlocal n
|
||||||
|
nonlocal last_n_tokens
|
||||||
n = max(batch.pos[i] for i in range(batch.n_tokens)) + 1
|
n = max(batch.pos[i] for i in range(batch.n_tokens)) + 1
|
||||||
last_n_tokens = batch.n_tokens
|
last_n_tokens = batch.n_tokens
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
def mock_get_logits(*args, **kwargs):
|
def mock_get_logits(ctx: llama_cpp.llama_context_p):
|
||||||
nonlocal n
|
# Test some basic invariants of this mocking technique
|
||||||
nonlocal last_n_tokens
|
assert ctx == llama._ctx.ctx, "context does not match mock_llama"
|
||||||
assert n > 0, "mock_llama_decode not called"
|
assert n > 0, "mock_llama_decode not called"
|
||||||
assert last_n_tokens > 0, "mock_llama_decode not called"
|
assert last_n_tokens > 0, "mock_llama_decode not called"
|
||||||
logits = (llama_cpp.c_float * (last_n_tokens * n_vocab))(-100.0)
|
# Return view of logits for last_n_tokens
|
||||||
for logits_idx, output_idx in enumerate(
|
return (llama_cpp.c_float * (last_n_tokens * n_vocab)).from_address(
|
||||||
range(n - last_n_tokens + 1, n + 1)
|
ctypes.addressof(logits)
|
||||||
):
|
+ (n - last_n_tokens) * n_vocab * ctypes.sizeof(llama_cpp.c_float)
|
||||||
if output_idx < len(output_tokens):
|
)
|
||||||
logits[
|
|
||||||
logits_idx * last_n_tokens + output_tokens[output_idx]
|
|
||||||
] = 100.0
|
|
||||||
else:
|
|
||||||
logits[logits_idx * last_n_tokens + llama.token_eos()] = 100.0
|
|
||||||
return logits
|
|
||||||
|
|
||||||
monkeypatch.setattr("llama_cpp.llama_cpp.llama_decode", mock_decode)
|
monkeypatch.setattr("llama_cpp.llama_cpp.llama_decode", mock_decode)
|
||||||
monkeypatch.setattr("llama_cpp.llama_cpp.llama_get_logits", mock_get_logits)
|
monkeypatch.setattr("llama_cpp.llama_cpp.llama_get_logits", mock_get_logits)
|
||||||
|
|
Loading…
Add table
Reference in a new issue