Lucas Doyle efe8e6f879 llama_cpp server: slight refactor to init_llama function
Define an init_llama function that starts llama with supplied settings instead of just doing it in the global context of app.py

This allows the test to be less brittle by not needing to mess with os.environ, then importing the app
2023-04-29 11:42:23 -07:00

152 lines
4.5 KiB

import llama_cpp
MODEL = "./vendor/llama.cpp/models/ggml-vocab.bin"
def test_llama():
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True)
assert llama
assert llama.ctx is not None
text = b"Hello World"
assert llama.detokenize(llama.tokenize(text)) == text
def test_llama_patch(monkeypatch):
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True)
## Set up mock function
def mock_eval(*args, **kwargs):
return 0
monkeypatch.setattr("llama_cpp.llama_cpp.llama_eval", mock_eval)
output_text = " jumps over the lazy dog."
output_tokens = llama.tokenize(output_text.encode("utf-8"))
token_eos = llama.token_eos()
n = 0
def mock_sample(*args, **kwargs):
nonlocal n
if n < len(output_tokens):
n += 1
return output_tokens[n - 1]
return token_eos
monkeypatch.setattr("llama_cpp.llama_cpp.llama_sample_top_p_top_k", mock_sample)
text = "The quick brown fox"
## Test basic completion until eos
n = 0 # reset
completion = llama.create_completion(text, max_tokens=20)
assert completion["choices"][0]["text"] == output_text
assert completion["choices"][0]["finish_reason"] == "stop"
## Test streaming completion until eos
n = 0 # reset
chunks = llama.create_completion(text, max_tokens=20, stream=True)
assert "".join(chunk["choices"][0]["text"] for chunk in chunks) == output_text
assert completion["choices"][0]["finish_reason"] == "stop"
## Test basic completion until stop sequence
n = 0 # reset
completion = llama.create_completion(text, max_tokens=20, stop=["lazy"])
assert completion["choices"][0]["text"] == " jumps over the "
assert completion["choices"][0]["finish_reason"] == "stop"
## Test streaming completion until stop sequence
n = 0 # reset
chunks = llama.create_completion(text, max_tokens=20, stream=True, stop=["lazy"])
assert (
"".join(chunk["choices"][0]["text"] for chunk in chunks) == " jumps over the "
assert completion["choices"][0]["finish_reason"] == "stop"
## Test basic completion until length
n = 0 # reset
completion = llama.create_completion(text, max_tokens=2)
assert completion["choices"][0]["text"] == " j"
assert completion["choices"][0]["finish_reason"] == "length"
## Test streaming completion until length
n = 0 # reset
chunks = llama.create_completion(text, max_tokens=2, stream=True)
assert "".join(chunk["choices"][0]["text"] for chunk in chunks) == " j"
assert completion["choices"][0]["finish_reason"] == "length"
def test_llama_pickle():
import pickle
import tempfile
fp = tempfile.TemporaryFile()
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True)
pickle.dump(llama, fp)
llama = pickle.load(fp)
assert llama
assert llama.ctx is not None
text = b"Hello World"
assert llama.detokenize(llama.tokenize(text)) == text
def test_utf8(monkeypatch):
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True)
## Set up mock function
def mock_eval(*args, **kwargs):
return 0
monkeypatch.setattr("llama_cpp.llama_cpp.llama_eval", mock_eval)
output_text = "😀"
output_tokens = llama.tokenize(output_text.encode("utf-8"))
token_eos = llama.token_eos()
n = 0
def mock_sample(*args, **kwargs):
nonlocal n
if n < len(output_tokens):
n += 1
return output_tokens[n - 1]
return token_eos
monkeypatch.setattr("llama_cpp.llama_cpp.llama_sample_top_p_top_k", mock_sample)
## Test basic completion with utf8 multibyte
n = 0 # reset
completion = llama.create_completion("", max_tokens=4)
assert completion["choices"][0]["text"] == output_text
## Test basic completion with incomplete utf8 multibyte
n = 0 # reset
completion = llama.create_completion("", max_tokens=1)
assert completion["choices"][0]["text"] == ""
def test_llama_server():
from fastapi.testclient import TestClient
from llama_cpp.server.app import app, init_llama, Settings
s = Settings()
s.model = MODEL
s.vocab_only = True
client = TestClient(app)
response = client.get("/v1/models")
assert response.json() == {
"object": "list",
"data": [
"id": MODEL,
"object": "model",
"owned_by": "me",
"permissions": [],