Un-skip tests
This commit is contained in:
parent
bf3d0dcb2c
commit
c088a2b3a7
1 changed files with 13 additions and 5 deletions
|
@ -1,4 +1,3 @@
|
|||
import pytest
|
||||
import llama_cpp
|
||||
|
||||
MODEL = "./vendor/llama.cpp/models/ggml-vocab.bin"
|
||||
|
@ -15,15 +14,20 @@ def test_llama():
|
|||
assert llama.detokenize(llama.tokenize(text)) == text
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="need to update sample mocking")
|
||||
# @pytest.mark.skip(reason="need to update sample mocking")
|
||||
def test_llama_patch(monkeypatch):
|
||||
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True)
|
||||
n_vocab = int(llama_cpp.llama_n_vocab(llama.ctx))
|
||||
|
||||
## Set up mock function
|
||||
def mock_eval(*args, **kwargs):
|
||||
return 0
|
||||
|
||||
def mock_get_logits(*args, **kwargs):
|
||||
return (llama_cpp.c_float * n_vocab)(*[llama_cpp.c_float(0) for _ in range(n_vocab)])
|
||||
|
||||
monkeypatch.setattr("llama_cpp.llama_cpp.llama_eval", mock_eval)
|
||||
monkeypatch.setattr("llama_cpp.llama_cpp.llama_get_logits", mock_get_logits)
|
||||
|
||||
output_text = " jumps over the lazy dog."
|
||||
output_tokens = llama.tokenize(output_text.encode("utf-8"))
|
||||
|
@ -38,7 +42,7 @@ def test_llama_patch(monkeypatch):
|
|||
else:
|
||||
return token_eos
|
||||
|
||||
monkeypatch.setattr("llama_cpp.llama_cpp.llama_sample_top_p_top_k", mock_sample)
|
||||
monkeypatch.setattr("llama_cpp.llama_cpp.llama_sample_token", mock_sample)
|
||||
|
||||
text = "The quick brown fox"
|
||||
|
||||
|
@ -97,15 +101,19 @@ def test_llama_pickle():
|
|||
|
||||
assert llama.detokenize(llama.tokenize(text)) == text
|
||||
|
||||
@pytest.mark.skip(reason="need to update sample mocking")
|
||||
def test_utf8(monkeypatch):
|
||||
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True)
|
||||
n_vocab = int(llama_cpp.llama_n_vocab(llama.ctx))
|
||||
|
||||
## Set up mock function
|
||||
def mock_eval(*args, **kwargs):
|
||||
return 0
|
||||
|
||||
def mock_get_logits(*args, **kwargs):
|
||||
return (llama_cpp.c_float * n_vocab)(*[llama_cpp.c_float(0) for _ in range(n_vocab)])
|
||||
|
||||
monkeypatch.setattr("llama_cpp.llama_cpp.llama_eval", mock_eval)
|
||||
monkeypatch.setattr("llama_cpp.llama_cpp.llama_get_logits", mock_get_logits)
|
||||
|
||||
output_text = "😀"
|
||||
output_tokens = llama.tokenize(output_text.encode("utf-8"))
|
||||
|
@ -120,7 +128,7 @@ def test_utf8(monkeypatch):
|
|||
else:
|
||||
return token_eos
|
||||
|
||||
monkeypatch.setattr("llama_cpp.llama_cpp.llama_sample_top_p_top_k", mock_sample)
|
||||
monkeypatch.setattr("llama_cpp.llama_cpp.llama_sample_token", mock_sample)
|
||||
|
||||
## Test basic completion with utf8 multibyte
|
||||
n = 0 # reset
|
||||
|
|
Loading…
Reference in a new issue