Fix tests
This commit is contained in:
parent
6f0b0b1b84
commit
e32ecb0516
1 changed files with 16 additions and 12 deletions
|
@ -1,4 +1,7 @@
|
||||||
|
import ctypes
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
import llama_cpp
|
import llama_cpp
|
||||||
|
|
||||||
MODEL = "./vendor/llama.cpp/models/ggml-vocab-llama.gguf"
|
MODEL = "./vendor/llama.cpp/models/ggml-vocab-llama.gguf"
|
||||||
|
@ -36,19 +39,20 @@ def test_llama_cpp_tokenization():
|
||||||
|
|
||||||
|
|
||||||
def test_llama_patch(monkeypatch):
|
def test_llama_patch(monkeypatch):
|
||||||
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True)
|
n_ctx = 128
|
||||||
|
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True, n_ctx=n_ctx)
|
||||||
n_vocab = llama_cpp.llama_n_vocab(llama._model.model)
|
n_vocab = llama_cpp.llama_n_vocab(llama._model.model)
|
||||||
|
assert n_vocab == 32000
|
||||||
|
|
||||||
## Set up mock function
|
## Set up mock function
|
||||||
def mock_eval(*args, **kwargs):
|
def mock_decode(*args, **kwargs):
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
def mock_get_logits(*args, **kwargs):
|
def mock_get_logits(*args, **kwargs):
|
||||||
return (llama_cpp.c_float * n_vocab)(
|
size = n_vocab * n_ctx
|
||||||
*[llama_cpp.c_float(0) for _ in range(n_vocab)]
|
return (llama_cpp.c_float * size)()
|
||||||
)
|
|
||||||
|
|
||||||
monkeypatch.setattr("llama_cpp.llama_cpp.llama_decode", mock_eval)
|
monkeypatch.setattr("llama_cpp.llama_cpp.llama_decode", mock_decode)
|
||||||
monkeypatch.setattr("llama_cpp.llama_cpp.llama_get_logits", mock_get_logits)
|
monkeypatch.setattr("llama_cpp.llama_cpp.llama_get_logits", mock_get_logits)
|
||||||
|
|
||||||
output_text = " jumps over the lazy dog."
|
output_text = " jumps over the lazy dog."
|
||||||
|
@ -126,19 +130,19 @@ def test_llama_pickle():
|
||||||
|
|
||||||
|
|
||||||
def test_utf8(monkeypatch):
|
def test_utf8(monkeypatch):
|
||||||
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True)
|
n_ctx = 512
|
||||||
|
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True, n_ctx=n_ctx, logits_all=True)
|
||||||
n_vocab = llama.n_vocab()
|
n_vocab = llama.n_vocab()
|
||||||
|
|
||||||
## Set up mock function
|
## Set up mock function
|
||||||
def mock_eval(*args, **kwargs):
|
def mock_decode(*args, **kwargs):
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
def mock_get_logits(*args, **kwargs):
|
def mock_get_logits(*args, **kwargs):
|
||||||
return (llama_cpp.c_float * n_vocab)(
|
size = n_vocab * n_ctx
|
||||||
*[llama_cpp.c_float(0) for _ in range(n_vocab)]
|
return (llama_cpp.c_float * size)()
|
||||||
)
|
|
||||||
|
|
||||||
monkeypatch.setattr("llama_cpp.llama_cpp.llama_decode", mock_eval)
|
monkeypatch.setattr("llama_cpp.llama_cpp.llama_decode", mock_decode)
|
||||||
monkeypatch.setattr("llama_cpp.llama_cpp.llama_get_logits", mock_get_logits)
|
monkeypatch.setattr("llama_cpp.llama_cpp.llama_get_logits", mock_get_logits)
|
||||||
|
|
||||||
output_text = "😀"
|
output_text = "😀"
|
||||||
|
|
Loading…
Reference in a new issue