Add basic tests. Closes #24
This commit is contained in:
parent
51dbcf2693
commit
c3972b61ae
3 changed files with 167 additions and 1 deletions
88
poetry.lock
generated
88
poetry.lock
generated
|
@ -1,5 +1,24 @@
|
||||||
# This file is automatically @generated by Poetry 1.4.1 and should not be changed by hand.
|
# This file is automatically @generated by Poetry 1.4.1 and should not be changed by hand.
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "attrs"
|
||||||
|
version = "22.2.0"
|
||||||
|
description = "Classes Without Boilerplate"
|
||||||
|
category = "dev"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.6"
|
||||||
|
files = [
|
||||||
|
{file = "attrs-22.2.0-py3-none-any.whl", hash = "sha256:29e95c7f6778868dbd49170f98f8818f78f3dc5e0e37c0b1f474e3561b240836"},
|
||||||
|
{file = "attrs-22.2.0.tar.gz", hash = "sha256:c9227bfc2f01993c03f68db37d1d15c9690188323c067c641f1a35ca58185f99"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
cov = ["attrs[tests]", "coverage-enable-subprocess", "coverage[toml] (>=5.3)"]
|
||||||
|
dev = ["attrs[docs,tests]"]
|
||||||
|
docs = ["furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphinxcontrib-towncrier", "towncrier", "zope.interface"]
|
||||||
|
tests = ["attrs[tests-no-zope]", "zope.interface"]
|
||||||
|
tests-no-zope = ["cloudpickle", "cloudpickle", "hypothesis", "hypothesis", "mypy (>=0.971,<0.990)", "mypy (>=0.971,<0.990)", "pympler", "pympler", "pytest (>=4.3.0)", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-mypy-plugins", "pytest-xdist[psutil]", "pytest-xdist[psutil]"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "black"
|
name = "black"
|
||||||
version = "23.1.0"
|
version = "23.1.0"
|
||||||
|
@ -328,6 +347,21 @@ files = [
|
||||||
{file = "docutils-0.19.tar.gz", hash = "sha256:33995a6753c30b7f577febfc2c50411fec6aac7f7ffeb7c4cfe5991072dcf9e6"},
|
{file = "docutils-0.19.tar.gz", hash = "sha256:33995a6753c30b7f577febfc2c50411fec6aac7f7ffeb7c4cfe5991072dcf9e6"},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "exceptiongroup"
|
||||||
|
version = "1.1.1"
|
||||||
|
description = "Backport of PEP 654 (exception groups)"
|
||||||
|
category = "dev"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.7"
|
||||||
|
files = [
|
||||||
|
{file = "exceptiongroup-1.1.1-py3-none-any.whl", hash = "sha256:232c37c63e4f682982c8b6459f33a8981039e5fb8756b2074364e5055c498c9e"},
|
||||||
|
{file = "exceptiongroup-1.1.1.tar.gz", hash = "sha256:d484c3090ba2889ae2928419117447a14daf3c1231d5e30d0aae34f354f01785"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
test = ["pytest (>=6)"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "ghp-import"
|
name = "ghp-import"
|
||||||
version = "2.1.0"
|
version = "2.1.0"
|
||||||
|
@ -415,6 +449,18 @@ zipp = {version = ">=3.1.0", markers = "python_version < \"3.10\""}
|
||||||
docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"]
|
docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"]
|
||||||
testing = ["flake8 (<5)", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-flake8", "pytest-mypy (>=0.9.1)"]
|
testing = ["flake8 (<5)", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-flake8", "pytest-mypy (>=0.9.1)"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "iniconfig"
|
||||||
|
version = "2.0.0"
|
||||||
|
description = "brain-dead simple config-ini parsing"
|
||||||
|
category = "dev"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.7"
|
||||||
|
files = [
|
||||||
|
{file = "iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374"},
|
||||||
|
{file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"},
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "jaraco-classes"
|
name = "jaraco-classes"
|
||||||
version = "3.2.3"
|
version = "3.2.3"
|
||||||
|
@ -821,6 +867,22 @@ files = [
|
||||||
docs = ["furo (>=2022.12.7)", "proselint (>=0.13)", "sphinx (>=6.1.3)", "sphinx-autodoc-typehints (>=1.22,!=1.23.4)"]
|
docs = ["furo (>=2022.12.7)", "proselint (>=0.13)", "sphinx (>=6.1.3)", "sphinx-autodoc-typehints (>=1.22,!=1.23.4)"]
|
||||||
test = ["appdirs (==1.4.4)", "covdefaults (>=2.2.2)", "pytest (>=7.2.1)", "pytest-cov (>=4)", "pytest-mock (>=3.10)"]
|
test = ["appdirs (==1.4.4)", "covdefaults (>=2.2.2)", "pytest (>=7.2.1)", "pytest-cov (>=4)", "pytest-mock (>=3.10)"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "pluggy"
|
||||||
|
version = "1.0.0"
|
||||||
|
description = "plugin and hook calling mechanisms for python"
|
||||||
|
category = "dev"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.6"
|
||||||
|
files = [
|
||||||
|
{file = "pluggy-1.0.0-py2.py3-none-any.whl", hash = "sha256:74134bbf457f031a36d68416e1509f34bd5ccc019f0bcc952c7b909d06b37bd3"},
|
||||||
|
{file = "pluggy-1.0.0.tar.gz", hash = "sha256:4224373bacce55f955a878bf9cfa763c1e360858e330072059e10bad68531159"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
dev = ["pre-commit", "tox"]
|
||||||
|
testing = ["pytest", "pytest-benchmark"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pycparser"
|
name = "pycparser"
|
||||||
version = "2.21"
|
version = "2.21"
|
||||||
|
@ -864,6 +926,30 @@ files = [
|
||||||
markdown = ">=3.2"
|
markdown = ">=3.2"
|
||||||
pyyaml = "*"
|
pyyaml = "*"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "pytest"
|
||||||
|
version = "7.2.2"
|
||||||
|
description = "pytest: simple powerful testing with Python"
|
||||||
|
category = "dev"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.7"
|
||||||
|
files = [
|
||||||
|
{file = "pytest-7.2.2-py3-none-any.whl", hash = "sha256:130328f552dcfac0b1cec75c12e3f005619dc5f874f0a06e8ff7263f0ee6225e"},
|
||||||
|
{file = "pytest-7.2.2.tar.gz", hash = "sha256:c99ab0c73aceb050f68929bc93af19ab6db0558791c6a0715723abe9d0ade9d4"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
attrs = ">=19.2.0"
|
||||||
|
colorama = {version = "*", markers = "sys_platform == \"win32\""}
|
||||||
|
exceptiongroup = {version = ">=1.0.0rc8", markers = "python_version < \"3.11\""}
|
||||||
|
iniconfig = "*"
|
||||||
|
packaging = "*"
|
||||||
|
pluggy = ">=0.12,<2.0"
|
||||||
|
tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""}
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
testing = ["argcomplete", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "xmlschema"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "python-dateutil"
|
name = "python-dateutil"
|
||||||
version = "2.8.2"
|
version = "2.8.2"
|
||||||
|
@ -1281,4 +1367,4 @@ testing = ["big-O", "flake8 (<5)", "jaraco.functools", "jaraco.itertools", "more
|
||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = "^3.8.1"
|
python-versions = "^3.8.1"
|
||||||
content-hash = "cffaf5e2e66ade4f429d0e938277d4fa2c4878ca7338c3c4f91721a7d3aff91b"
|
content-hash = "cc9babcdfdc3679a4d84f68912408a005619a576947b059146ed1b428850ece9"
|
||||||
|
|
|
@ -23,6 +23,7 @@ twine = "^4.0.2"
|
||||||
mkdocs = "^1.4.2"
|
mkdocs = "^1.4.2"
|
||||||
mkdocstrings = {extras = ["python"], version = "^0.20.0"}
|
mkdocstrings = {extras = ["python"], version = "^0.20.0"}
|
||||||
mkdocs-material = "^9.1.4"
|
mkdocs-material = "^9.1.4"
|
||||||
|
pytest = "^7.2.2"
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = [
|
requires = [
|
||||||
|
|
79
tests/test_llama.py
Normal file
79
tests/test_llama.py
Normal file
|
@ -0,0 +1,79 @@
|
||||||
|
import llama_cpp
|
||||||
|
|
||||||
|
MODEL = "./vendor/llama.cpp/models/ggml-vocab.bin"
|
||||||
|
|
||||||
|
|
||||||
|
def test_llama():
|
||||||
|
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True)
|
||||||
|
|
||||||
|
assert llama
|
||||||
|
assert llama.ctx is not None
|
||||||
|
|
||||||
|
text = b"Hello World"
|
||||||
|
|
||||||
|
assert llama.detokenize(llama.tokenize(text)) == text
|
||||||
|
|
||||||
|
|
||||||
|
def test_llama_patch(monkeypatch):
|
||||||
|
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True)
|
||||||
|
|
||||||
|
## Set up mock function
|
||||||
|
def mock_eval(*args, **kwargs):
|
||||||
|
return 0
|
||||||
|
|
||||||
|
monkeypatch.setattr("llama_cpp.llama_cpp.llama_eval", mock_eval)
|
||||||
|
|
||||||
|
output_text = " jumps over the lazy dog."
|
||||||
|
output_tokens = llama.tokenize(output_text.encode("utf-8"))
|
||||||
|
token_eos = llama.token_eos()
|
||||||
|
n = 0
|
||||||
|
|
||||||
|
def mock_sample(*args, **kwargs):
|
||||||
|
nonlocal n
|
||||||
|
if n < len(output_tokens):
|
||||||
|
n += 1
|
||||||
|
return output_tokens[n - 1]
|
||||||
|
else:
|
||||||
|
return token_eos
|
||||||
|
|
||||||
|
monkeypatch.setattr("llama_cpp.llama_cpp.llama_sample_top_p_top_k", mock_sample)
|
||||||
|
|
||||||
|
text = "The quick brown fox"
|
||||||
|
|
||||||
|
## Test basic completion until eos
|
||||||
|
n = 0 # reset
|
||||||
|
completion = llama.create_completion(text, max_tokens=20)
|
||||||
|
assert completion["choices"][0]["text"] == output_text
|
||||||
|
assert completion["choices"][0]["finish_reason"] == "stop"
|
||||||
|
|
||||||
|
## Test streaming completion until eos
|
||||||
|
n = 0 # reset
|
||||||
|
chunks = llama.create_completion(text, max_tokens=20, stream=True)
|
||||||
|
assert "".join(chunk["choices"][0]["text"] for chunk in chunks) == output_text
|
||||||
|
assert completion["choices"][0]["finish_reason"] == "stop"
|
||||||
|
|
||||||
|
## Test basic completion until stop sequence
|
||||||
|
n = 0 # reset
|
||||||
|
completion = llama.create_completion(text, max_tokens=20, stop=["lazy"])
|
||||||
|
assert completion["choices"][0]["text"] == " jumps over the "
|
||||||
|
assert completion["choices"][0]["finish_reason"] == "stop"
|
||||||
|
|
||||||
|
## Test streaming completion until stop sequence
|
||||||
|
n = 0 # reset
|
||||||
|
chunks = llama.create_completion(text, max_tokens=20, stream=True, stop=["lazy"])
|
||||||
|
assert (
|
||||||
|
"".join(chunk["choices"][0]["text"] for chunk in chunks) == " jumps over the "
|
||||||
|
)
|
||||||
|
assert completion["choices"][0]["finish_reason"] == "stop"
|
||||||
|
|
||||||
|
## Test basic completion until length
|
||||||
|
n = 0 # reset
|
||||||
|
completion = llama.create_completion(text, max_tokens=2)
|
||||||
|
assert completion["choices"][0]["text"] == " j"
|
||||||
|
assert completion["choices"][0]["finish_reason"] == "length"
|
||||||
|
|
||||||
|
## Test streaming completion until length
|
||||||
|
n = 0 # reset
|
||||||
|
chunks = llama.create_completion(text, max_tokens=2, stream=True)
|
||||||
|
assert "".join(chunk["choices"][0]["text"] for chunk in chunks) == " j"
|
||||||
|
assert completion["choices"][0]["finish_reason"] == "length"
|
Loading…
Reference in a new issue