From efe8e6f8795eb2f92db22b841a40ad41fb053fe1 Mon Sep 17 00:00:00 2001 From: Lucas Doyle Date: Fri, 28 Apr 2023 23:47:36 -0700 Subject: [PATCH] llama_cpp server: slight refactor to init_llama function Define an init_llama function that starts llama with supplied settings instead of just doing it in the global context of app.py This allows the test to be less brittle by not needing to mess with os.environ, then importing the app --- llama_cpp/server/__main__.py | 3 ++- llama_cpp/server/app.py | 45 +++++++++++++++++++----------------- tests/test_llama.py | 9 ++++---- 3 files changed, 31 insertions(+), 26 deletions(-) diff --git a/llama_cpp/server/__main__.py b/llama_cpp/server/__main__.py index dd4767f..f57d68c 100644 --- a/llama_cpp/server/__main__.py +++ b/llama_cpp/server/__main__.py @@ -24,9 +24,10 @@ Then visit http://localhost:8000/docs to see the interactive API docs. import os import uvicorn -from llama_cpp.server.app import app +from llama_cpp.server.app import app, init_llama if __name__ == "__main__": + init_llama() uvicorn.run( app, host=os.getenv("HOST", "localhost"), port=int(os.getenv("PORT", 8000)) diff --git a/llama_cpp/server/app.py b/llama_cpp/server/app.py index 2c50fcb..92b023c 100644 --- a/llama_cpp/server/app.py +++ b/llama_cpp/server/app.py @@ -13,7 +13,7 @@ from sse_starlette.sse import EventSourceResponse class Settings(BaseSettings): - model: str = os.environ["MODEL"] + model: str = os.environ.get("MODEL", "null") n_ctx: int = 2048 n_batch: int = 512 n_threads: int = max((os.cpu_count() or 2) // 2, 1) @@ -38,31 +38,34 @@ app.add_middleware( allow_methods=["*"], allow_headers=["*"], ) -settings = Settings() -llama = llama_cpp.Llama( - settings.model, - f16_kv=settings.f16_kv, - use_mlock=settings.use_mlock, - use_mmap=settings.use_mmap, - embedding=settings.embedding, - logits_all=settings.logits_all, - n_threads=settings.n_threads, - n_batch=settings.n_batch, - n_ctx=settings.n_ctx, - last_n_tokens_size=settings.last_n_tokens_size, - vocab_only=settings.vocab_only, -) -if settings.cache: - cache = llama_cpp.LlamaCache() - llama.set_cache(cache) + +llama: llama_cpp.Llama = None +def init_llama(settings: Settings = None): + if settings is None: + settings = Settings() + global llama + llama = llama_cpp.Llama( + settings.model, + f16_kv=settings.f16_kv, + use_mlock=settings.use_mlock, + use_mmap=settings.use_mmap, + embedding=settings.embedding, + logits_all=settings.logits_all, + n_threads=settings.n_threads, + n_batch=settings.n_batch, + n_ctx=settings.n_ctx, + last_n_tokens_size=settings.last_n_tokens_size, + vocab_only=settings.vocab_only, + ) + if settings.cache: + cache = llama_cpp.LlamaCache() + llama.set_cache(cache) + llama_lock = Lock() - - def get_llama(): with llama_lock: yield llama - class CreateCompletionRequest(BaseModel): prompt: Union[str, List[str]] suffix: Optional[str] = Field(None) diff --git a/tests/test_llama.py b/tests/test_llama.py index 9110286..c3f69cc 100644 --- a/tests/test_llama.py +++ b/tests/test_llama.py @@ -132,10 +132,11 @@ def test_utf8(monkeypatch): def test_llama_server(): from fastapi.testclient import TestClient - import os - os.environ["MODEL"] = MODEL - os.environ["VOCAB_ONLY"] = "true" - from llama_cpp.server.app import app + from llama_cpp.server.app import app, init_llama, Settings + s = Settings() + s.model = MODEL + s.vocab_only = True + init_llama(s) client = TestClient(app) response = client.get("/v1/models") assert response.json() == {