llama_cpp server: slight refactor to init_llama function

Define an init_llama function that starts llama with supplied settings instead of just doing it in the global context of app.py

This allows the test to be less brittle by not needing to mess with os.environ, then importing the app
This commit is contained in:
Lucas Doyle 2023-04-28 23:47:36 -07:00
parent 6d8db9d017
commit efe8e6f879
3 changed files with 31 additions and 26 deletions

View file

@ -24,9 +24,10 @@ Then visit http://localhost:8000/docs to see the interactive API docs.
import os
import uvicorn
from llama_cpp.server.app import app
from llama_cpp.server.app import app, init_llama
if __name__ == "__main__":
init_llama()
uvicorn.run(
app, host=os.getenv("HOST", "localhost"), port=int(os.getenv("PORT", 8000))

View file

@ -13,7 +13,7 @@ from sse_starlette.sse import EventSourceResponse
class Settings(BaseSettings):
model: str = os.environ["MODEL"]
model: str = os.environ.get("MODEL", "null")
n_ctx: int = 2048
n_batch: int = 512
n_threads: int = max((os.cpu_count() or 2) // 2, 1)
@ -38,7 +38,12 @@ app.add_middleware(
allow_methods=["*"],
allow_headers=["*"],
)
llama: llama_cpp.Llama = None
def init_llama(settings: Settings = None):
if settings is None:
settings = Settings()
global llama
llama = llama_cpp.Llama(
settings.model,
f16_kv=settings.f16_kv,
@ -55,14 +60,12 @@ llama = llama_cpp.Llama(
if settings.cache:
cache = llama_cpp.LlamaCache()
llama.set_cache(cache)
llama_lock = Lock()
def get_llama():
with llama_lock:
yield llama
class CreateCompletionRequest(BaseModel):
prompt: Union[str, List[str]]
suffix: Optional[str] = Field(None)

View file

@ -132,10 +132,11 @@ def test_utf8(monkeypatch):
def test_llama_server():
from fastapi.testclient import TestClient
import os
os.environ["MODEL"] = MODEL
os.environ["VOCAB_ONLY"] = "true"
from llama_cpp.server.app import app
from llama_cpp.server.app import app, init_llama, Settings
s = Settings()
s.model = MODEL
s.vocab_only = True
init_llama(s)
client = TestClient(app)
response = client.get("/v1/models")
assert response.json() == {