llama.cpp/tests/test_llama.py

164 lines
5.1 KiB
Python
Raw Normal View History

2023-04-05 07:23:15 +00:00
import llama_cpp
MODEL = "./vendor/llama.cpp/models/ggml-vocab.bin"
def test_llama():
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True)
assert llama
assert llama.ctx is not None
text = b"Hello World"
assert llama.detokenize(llama.tokenize(text)) == text
2023-05-01 19:46:03 +00:00
# @pytest.mark.skip(reason="need to update sample mocking")
2023-04-05 07:23:15 +00:00
def test_llama_patch(monkeypatch):
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True)
2023-05-01 19:46:03 +00:00
n_vocab = int(llama_cpp.llama_n_vocab(llama.ctx))
2023-04-05 07:23:15 +00:00
## Set up mock function
def mock_eval(*args, **kwargs):
return 0
2023-05-01 19:46:03 +00:00
def mock_get_logits(*args, **kwargs):
return (llama_cpp.c_float * n_vocab)(*[llama_cpp.c_float(0) for _ in range(n_vocab)])
2023-04-05 07:23:15 +00:00
monkeypatch.setattr("llama_cpp.llama_cpp.llama_eval", mock_eval)
2023-05-01 19:46:03 +00:00
monkeypatch.setattr("llama_cpp.llama_cpp.llama_get_logits", mock_get_logits)
2023-04-05 07:23:15 +00:00
output_text = " jumps over the lazy dog."
output_tokens = llama.tokenize(output_text.encode("utf-8"))
2023-04-05 07:23:15 +00:00
token_eos = llama.token_eos()
n = 0
def mock_sample(*args, **kwargs):
nonlocal n
if n < len(output_tokens):
n += 1
return output_tokens[n - 1]
else:
return token_eos
2023-05-01 19:46:03 +00:00
monkeypatch.setattr("llama_cpp.llama_cpp.llama_sample_token", mock_sample)
2023-04-05 07:23:15 +00:00
text = "The quick brown fox"
## Test basic completion until eos
n = 0 # reset
completion = llama.create_completion(text, max_tokens=20)
assert completion["choices"][0]["text"] == output_text
assert completion["choices"][0]["finish_reason"] == "stop"
## Test streaming completion until eos
n = 0 # reset
chunks = llama.create_completion(text, max_tokens=20, stream=True)
assert "".join(chunk["choices"][0]["text"] for chunk in chunks) == output_text
assert completion["choices"][0]["finish_reason"] == "stop"
## Test basic completion until stop sequence
n = 0 # reset
completion = llama.create_completion(text, max_tokens=20, stop=["lazy"])
assert completion["choices"][0]["text"] == " jumps over the "
assert completion["choices"][0]["finish_reason"] == "stop"
## Test streaming completion until stop sequence
n = 0 # reset
chunks = llama.create_completion(text, max_tokens=20, stream=True, stop=["lazy"])
assert (
"".join(chunk["choices"][0]["text"] for chunk in chunks) == " jumps over the "
)
assert completion["choices"][0]["finish_reason"] == "stop"
## Test basic completion until length
n = 0 # reset
completion = llama.create_completion(text, max_tokens=2)
assert completion["choices"][0]["text"] == " j"
assert completion["choices"][0]["finish_reason"] == "length"
## Test streaming completion until length
n = 0 # reset
chunks = llama.create_completion(text, max_tokens=2, stream=True)
assert "".join(chunk["choices"][0]["text"] for chunk in chunks) == " j"
assert completion["choices"][0]["finish_reason"] == "length"
def test_llama_pickle():
import pickle
import tempfile
fp = tempfile.TemporaryFile()
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True)
pickle.dump(llama, fp)
fp.seek(0)
llama = pickle.load(fp)
assert llama
assert llama.ctx is not None
text = b"Hello World"
assert llama.detokenize(llama.tokenize(text)) == text
def test_utf8(monkeypatch):
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True)
2023-05-01 19:46:03 +00:00
n_vocab = int(llama_cpp.llama_n_vocab(llama.ctx))
## Set up mock function
def mock_eval(*args, **kwargs):
return 0
2023-05-01 19:46:03 +00:00
def mock_get_logits(*args, **kwargs):
return (llama_cpp.c_float * n_vocab)(*[llama_cpp.c_float(0) for _ in range(n_vocab)])
monkeypatch.setattr("llama_cpp.llama_cpp.llama_eval", mock_eval)
2023-05-01 19:46:03 +00:00
monkeypatch.setattr("llama_cpp.llama_cpp.llama_get_logits", mock_get_logits)
output_text = "😀"
output_tokens = llama.tokenize(output_text.encode("utf-8"))
token_eos = llama.token_eos()
n = 0
def mock_sample(*args, **kwargs):
nonlocal n
if n < len(output_tokens):
n += 1
return output_tokens[n - 1]
else:
return token_eos
2023-05-01 19:46:03 +00:00
monkeypatch.setattr("llama_cpp.llama_cpp.llama_sample_token", mock_sample)
## Test basic completion with utf8 multibyte
n = 0 # reset
completion = llama.create_completion("", max_tokens=4)
assert completion["choices"][0]["text"] == output_text
## Test basic completion with incomplete utf8 multibyte
n = 0 # reset
completion = llama.create_completion("", max_tokens=1)
assert completion["choices"][0]["text"] == ""
2023-04-29 06:26:07 +00:00
def test_llama_server():
from fastapi.testclient import TestClient
from llama_cpp.server.app import app, init_llama, Settings
s = Settings()
s.model = MODEL
s.vocab_only = True
init_llama(s)
2023-04-29 06:26:07 +00:00
client = TestClient(app)
response = client.get("/v1/models")
assert response.json() == {
"object": "list",
"data": [
{
"id": MODEL,
"object": "model",
"owned_by": "me",
"permissions": [],
}
],
}