diff --git a/llama_cpp/server/__main__.py b/llama_cpp/server/__main__.py index dd4767f..f57d68c 100644 --- a/llama_cpp/server/__main__.py +++ b/llama_cpp/server/__main__.py @@ -24,9 +24,10 @@ Then visit http://localhost:8000/docs to see the interactive API docs. import os import uvicorn -from llama_cpp.server.app import app +from llama_cpp.server.app import app, init_llama if __name__ == "__main__": + init_llama() uvicorn.run( app, host=os.getenv("HOST", "localhost"), port=int(os.getenv("PORT", 8000)) diff --git a/llama_cpp/server/app.py b/llama_cpp/server/app.py index 2c50fcb..92b023c 100644 --- a/llama_cpp/server/app.py +++ b/llama_cpp/server/app.py @@ -13,7 +13,7 @@ from sse_starlette.sse import EventSourceResponse class Settings(BaseSettings): - model: str = os.environ["MODEL"] + model: str = os.environ.get("MODEL", "null") n_ctx: int = 2048 n_batch: int = 512 n_threads: int = max((os.cpu_count() or 2) // 2, 1) @@ -38,31 +38,34 @@ app.add_middleware( allow_methods=["*"], allow_headers=["*"], ) -settings = Settings() -llama = llama_cpp.Llama( - settings.model, - f16_kv=settings.f16_kv, - use_mlock=settings.use_mlock, - use_mmap=settings.use_mmap, - embedding=settings.embedding, - logits_all=settings.logits_all, - n_threads=settings.n_threads, - n_batch=settings.n_batch, - n_ctx=settings.n_ctx, - last_n_tokens_size=settings.last_n_tokens_size, - vocab_only=settings.vocab_only, -) -if settings.cache: - cache = llama_cpp.LlamaCache() - llama.set_cache(cache) + +llama: llama_cpp.Llama = None +def init_llama(settings: Settings = None): + if settings is None: + settings = Settings() + global llama + llama = llama_cpp.Llama( + settings.model, + f16_kv=settings.f16_kv, + use_mlock=settings.use_mlock, + use_mmap=settings.use_mmap, + embedding=settings.embedding, + logits_all=settings.logits_all, + n_threads=settings.n_threads, + n_batch=settings.n_batch, + n_ctx=settings.n_ctx, + last_n_tokens_size=settings.last_n_tokens_size, + vocab_only=settings.vocab_only, + ) + if settings.cache: + cache = llama_cpp.LlamaCache() + llama.set_cache(cache) + llama_lock = Lock() - - def get_llama(): with llama_lock: yield llama - class CreateCompletionRequest(BaseModel): prompt: Union[str, List[str]] suffix: Optional[str] = Field(None) diff --git a/tests/test_llama.py b/tests/test_llama.py index 9110286..c3f69cc 100644 --- a/tests/test_llama.py +++ b/tests/test_llama.py @@ -132,10 +132,11 @@ def test_utf8(monkeypatch): def test_llama_server(): from fastapi.testclient import TestClient - import os - os.environ["MODEL"] = MODEL - os.environ["VOCAB_ONLY"] = "true" - from llama_cpp.server.app import app + from llama_cpp.server.app import app, init_llama, Settings + s = Settings() + s.model = MODEL + s.vocab_only = True + init_llama(s) client = TestClient(app) response = client.get("/v1/models") assert response.json() == {