Un-skip tests

This commit is contained in:
Andrei Betlen 2023-05-01 15:46:03 -04:00
parent bf3d0dcb2c
commit c088a2b3a7

View file

@ -1,4 +1,3 @@
import pytest
import llama_cpp
MODEL = "./vendor/llama.cpp/models/ggml-vocab.bin"
@ -15,15 +14,20 @@ def test_llama():
assert llama.detokenize(llama.tokenize(text)) == text
@pytest.mark.skip(reason="need to update sample mocking")
# @pytest.mark.skip(reason="need to update sample mocking")
def test_llama_patch(monkeypatch):
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True)
n_vocab = int(llama_cpp.llama_n_vocab(llama.ctx))
## Set up mock function
def mock_eval(*args, **kwargs):
return 0
def mock_get_logits(*args, **kwargs):
return (llama_cpp.c_float * n_vocab)(*[llama_cpp.c_float(0) for _ in range(n_vocab)])
monkeypatch.setattr("llama_cpp.llama_cpp.llama_eval", mock_eval)
monkeypatch.setattr("llama_cpp.llama_cpp.llama_get_logits", mock_get_logits)
output_text = " jumps over the lazy dog."
output_tokens = llama.tokenize(output_text.encode("utf-8"))
@ -38,7 +42,7 @@ def test_llama_patch(monkeypatch):
else:
return token_eos
monkeypatch.setattr("llama_cpp.llama_cpp.llama_sample_top_p_top_k", mock_sample)
monkeypatch.setattr("llama_cpp.llama_cpp.llama_sample_token", mock_sample)
text = "The quick brown fox"
@ -97,15 +101,19 @@ def test_llama_pickle():
assert llama.detokenize(llama.tokenize(text)) == text
@pytest.mark.skip(reason="need to update sample mocking")
def test_utf8(monkeypatch):
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True)
n_vocab = int(llama_cpp.llama_n_vocab(llama.ctx))
## Set up mock function
def mock_eval(*args, **kwargs):
return 0
def mock_get_logits(*args, **kwargs):
return (llama_cpp.c_float * n_vocab)(*[llama_cpp.c_float(0) for _ in range(n_vocab)])
monkeypatch.setattr("llama_cpp.llama_cpp.llama_eval", mock_eval)
monkeypatch.setattr("llama_cpp.llama_cpp.llama_get_logits", mock_get_logits)
output_text = "😀"
output_tokens = llama.tokenize(output_text.encode("utf-8"))
@ -120,7 +128,7 @@ def test_utf8(monkeypatch):
else:
return token_eos
monkeypatch.setattr("llama_cpp.llama_cpp.llama_sample_top_p_top_k", mock_sample)
monkeypatch.setattr("llama_cpp.llama_cpp.llama_sample_token", mock_sample)
## Test basic completion with utf8 multibyte
n = 0 # reset