Use mock_llama for all tests
This commit is contained in:
parent
dbfaf53fe0
commit
d7388f1ffb
1 changed files with 3 additions and 40 deletions
|
@ -160,55 +160,18 @@ def test_llama_pickle():
|
||||||
assert llama.detokenize(llama.tokenize(text)) == text
|
assert llama.detokenize(llama.tokenize(text)) == text
|
||||||
|
|
||||||
|
|
||||||
def test_utf8(mock_llama, monkeypatch):
|
def test_utf8(mock_llama):
|
||||||
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True, logits_all=True)
|
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True, logits_all=True)
|
||||||
n_ctx = llama.n_ctx()
|
|
||||||
n_vocab = llama.n_vocab()
|
|
||||||
|
|
||||||
output_text = "😀"
|
output_text = "😀"
|
||||||
output_tokens = llama.tokenize(
|
|
||||||
output_text.encode("utf-8"), add_bos=True, special=True
|
|
||||||
)
|
|
||||||
token_eos = llama.token_eos()
|
|
||||||
n = 0
|
|
||||||
|
|
||||||
def reset():
|
|
||||||
nonlocal n
|
|
||||||
llama.reset()
|
|
||||||
n = 0
|
|
||||||
|
|
||||||
## Set up mock function
|
|
||||||
def mock_decode(ctx: llama_cpp.llama_context_p, batch: llama_cpp.llama_batch):
|
|
||||||
nonlocal n
|
|
||||||
assert batch.n_tokens > 0
|
|
||||||
assert llama.n_tokens == n
|
|
||||||
n += batch.n_tokens
|
|
||||||
return 0
|
|
||||||
|
|
||||||
def mock_get_logits(*args, **kwargs):
|
|
||||||
size = n_vocab * n_ctx
|
|
||||||
return (llama_cpp.c_float * size)()
|
|
||||||
|
|
||||||
def mock_sample(*args, **kwargs):
|
|
||||||
nonlocal n
|
|
||||||
if n <= len(output_tokens):
|
|
||||||
return output_tokens[n - 1]
|
|
||||||
else:
|
|
||||||
return token_eos
|
|
||||||
|
|
||||||
monkeypatch.setattr("llama_cpp.llama_cpp.llama_decode", mock_decode)
|
|
||||||
monkeypatch.setattr("llama_cpp.llama_cpp.llama_get_logits", mock_get_logits)
|
|
||||||
monkeypatch.setattr("llama_cpp.llama_cpp.llama_sample_token", mock_sample)
|
|
||||||
|
|
||||||
## Test basic completion with utf8 multibyte
|
## Test basic completion with utf8 multibyte
|
||||||
# mock_llama(llama, output_text)
|
mock_llama(llama, output_text)
|
||||||
reset()
|
|
||||||
completion = llama.create_completion("", max_tokens=4)
|
completion = llama.create_completion("", max_tokens=4)
|
||||||
assert completion["choices"][0]["text"] == output_text
|
assert completion["choices"][0]["text"] == output_text
|
||||||
|
|
||||||
## Test basic completion with incomplete utf8 multibyte
|
## Test basic completion with incomplete utf8 multibyte
|
||||||
# mock_llama(llama, output_text)
|
mock_llama(llama, output_text)
|
||||||
reset()
|
|
||||||
completion = llama.create_completion("", max_tokens=1)
|
completion = llama.create_completion("", max_tokens=1)
|
||||||
assert completion["choices"][0]["text"] == ""
|
assert completion["choices"][0]["text"] == ""
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue