Fix tests
This commit is contained in:
parent
6f0b0b1b84
commit
e32ecb0516
1 changed files with 16 additions and 12 deletions
|
@ -1,4 +1,7 @@
|
|||
import ctypes
|
||||
|
||||
import pytest
|
||||
|
||||
import llama_cpp
|
||||
|
||||
MODEL = "./vendor/llama.cpp/models/ggml-vocab-llama.gguf"
|
||||
|
@ -36,19 +39,20 @@ def test_llama_cpp_tokenization():
|
|||
|
||||
|
||||
def test_llama_patch(monkeypatch):
|
||||
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True)
|
||||
n_ctx = 128
|
||||
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True, n_ctx=n_ctx)
|
||||
n_vocab = llama_cpp.llama_n_vocab(llama._model.model)
|
||||
assert n_vocab == 32000
|
||||
|
||||
## Set up mock function
|
||||
def mock_eval(*args, **kwargs):
|
||||
def mock_decode(*args, **kwargs):
|
||||
return 0
|
||||
|
||||
def mock_get_logits(*args, **kwargs):
|
||||
return (llama_cpp.c_float * n_vocab)(
|
||||
*[llama_cpp.c_float(0) for _ in range(n_vocab)]
|
||||
)
|
||||
size = n_vocab * n_ctx
|
||||
return (llama_cpp.c_float * size)()
|
||||
|
||||
monkeypatch.setattr("llama_cpp.llama_cpp.llama_decode", mock_eval)
|
||||
monkeypatch.setattr("llama_cpp.llama_cpp.llama_decode", mock_decode)
|
||||
monkeypatch.setattr("llama_cpp.llama_cpp.llama_get_logits", mock_get_logits)
|
||||
|
||||
output_text = " jumps over the lazy dog."
|
||||
|
@ -126,19 +130,19 @@ def test_llama_pickle():
|
|||
|
||||
|
||||
def test_utf8(monkeypatch):
|
||||
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True)
|
||||
n_ctx = 512
|
||||
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True, n_ctx=n_ctx, logits_all=True)
|
||||
n_vocab = llama.n_vocab()
|
||||
|
||||
## Set up mock function
|
||||
def mock_eval(*args, **kwargs):
|
||||
def mock_decode(*args, **kwargs):
|
||||
return 0
|
||||
|
||||
def mock_get_logits(*args, **kwargs):
|
||||
return (llama_cpp.c_float * n_vocab)(
|
||||
*[llama_cpp.c_float(0) for _ in range(n_vocab)]
|
||||
)
|
||||
size = n_vocab * n_ctx
|
||||
return (llama_cpp.c_float * size)()
|
||||
|
||||
monkeypatch.setattr("llama_cpp.llama_cpp.llama_decode", mock_eval)
|
||||
monkeypatch.setattr("llama_cpp.llama_cpp.llama_decode", mock_decode)
|
||||
monkeypatch.setattr("llama_cpp.llama_cpp.llama_get_logits", mock_get_logits)
|
||||
|
||||
output_text = "😀"
|
||||
|
|
Loading…
Reference in a new issue