Fix tests

This commit is contained in:
Andrei Betlen 2023-11-10 05:39:42 -05:00
parent 6f0b0b1b84
commit e32ecb0516

View file

@ -1,4 +1,7 @@
import ctypes
import pytest import pytest
import llama_cpp import llama_cpp
MODEL = "./vendor/llama.cpp/models/ggml-vocab-llama.gguf" MODEL = "./vendor/llama.cpp/models/ggml-vocab-llama.gguf"
@ -36,19 +39,20 @@ def test_llama_cpp_tokenization():
def test_llama_patch(monkeypatch): def test_llama_patch(monkeypatch):
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True) n_ctx = 128
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True, n_ctx=n_ctx)
n_vocab = llama_cpp.llama_n_vocab(llama._model.model) n_vocab = llama_cpp.llama_n_vocab(llama._model.model)
assert n_vocab == 32000
## Set up mock function ## Set up mock function
def mock_eval(*args, **kwargs): def mock_decode(*args, **kwargs):
return 0 return 0
def mock_get_logits(*args, **kwargs): def mock_get_logits(*args, **kwargs):
return (llama_cpp.c_float * n_vocab)( size = n_vocab * n_ctx
*[llama_cpp.c_float(0) for _ in range(n_vocab)] return (llama_cpp.c_float * size)()
)
monkeypatch.setattr("llama_cpp.llama_cpp.llama_decode", mock_eval) monkeypatch.setattr("llama_cpp.llama_cpp.llama_decode", mock_decode)
monkeypatch.setattr("llama_cpp.llama_cpp.llama_get_logits", mock_get_logits) monkeypatch.setattr("llama_cpp.llama_cpp.llama_get_logits", mock_get_logits)
output_text = " jumps over the lazy dog." output_text = " jumps over the lazy dog."
@ -126,19 +130,19 @@ def test_llama_pickle():
def test_utf8(monkeypatch): def test_utf8(monkeypatch):
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True) n_ctx = 512
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True, n_ctx=n_ctx, logits_all=True)
n_vocab = llama.n_vocab() n_vocab = llama.n_vocab()
## Set up mock function ## Set up mock function
def mock_eval(*args, **kwargs): def mock_decode(*args, **kwargs):
return 0 return 0
def mock_get_logits(*args, **kwargs): def mock_get_logits(*args, **kwargs):
return (llama_cpp.c_float * n_vocab)( size = n_vocab * n_ctx
*[llama_cpp.c_float(0) for _ in range(n_vocab)] return (llama_cpp.c_float * size)()
)
monkeypatch.setattr("llama_cpp.llama_cpp.llama_decode", mock_eval) monkeypatch.setattr("llama_cpp.llama_cpp.llama_decode", mock_decode)
monkeypatch.setattr("llama_cpp.llama_cpp.llama_get_logits", mock_get_logits) monkeypatch.setattr("llama_cpp.llama_cpp.llama_get_logits", mock_get_logits)
output_text = "😀" output_text = "😀"