diff --git a/tests/test_llama.py b/tests/test_llama.py index 5448743..23c7e86 100644 --- a/tests/test_llama.py +++ b/tests/test_llama.py @@ -1,4 +1,7 @@ +import ctypes + import pytest + import llama_cpp MODEL = "./vendor/llama.cpp/models/ggml-vocab-llama.gguf" @@ -36,19 +39,20 @@ def test_llama_cpp_tokenization(): def test_llama_patch(monkeypatch): - llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True) + n_ctx = 128 + llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True, n_ctx=n_ctx) n_vocab = llama_cpp.llama_n_vocab(llama._model.model) + assert n_vocab == 32000 ## Set up mock function - def mock_eval(*args, **kwargs): + def mock_decode(*args, **kwargs): return 0 def mock_get_logits(*args, **kwargs): - return (llama_cpp.c_float * n_vocab)( - *[llama_cpp.c_float(0) for _ in range(n_vocab)] - ) + size = n_vocab * n_ctx + return (llama_cpp.c_float * size)() - monkeypatch.setattr("llama_cpp.llama_cpp.llama_decode", mock_eval) + monkeypatch.setattr("llama_cpp.llama_cpp.llama_decode", mock_decode) monkeypatch.setattr("llama_cpp.llama_cpp.llama_get_logits", mock_get_logits) output_text = " jumps over the lazy dog." @@ -126,19 +130,19 @@ def test_llama_pickle(): def test_utf8(monkeypatch): - llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True) + n_ctx = 512 + llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True, n_ctx=n_ctx, logits_all=True) n_vocab = llama.n_vocab() ## Set up mock function - def mock_eval(*args, **kwargs): + def mock_decode(*args, **kwargs): return 0 def mock_get_logits(*args, **kwargs): - return (llama_cpp.c_float * n_vocab)( - *[llama_cpp.c_float(0) for _ in range(n_vocab)] - ) + size = n_vocab * n_ctx + return (llama_cpp.c_float * size)() - monkeypatch.setattr("llama_cpp.llama_cpp.llama_decode", mock_eval) + monkeypatch.setattr("llama_cpp.llama_cpp.llama_decode", mock_decode) monkeypatch.setattr("llama_cpp.llama_cpp.llama_get_logits", mock_get_logits) output_text = "😀"