llama_cpp server: slight refactor to init_llama function
Define an init_llama function that starts llama with supplied settings instead of just doing it in the global context of app.py This allows the test to be less brittle by not needing to mess with os.environ, then importing the app
This commit is contained in:
parent
6d8db9d017
commit
efe8e6f879
3 changed files with 31 additions and 26 deletions
|
@ -24,9 +24,10 @@ Then visit http://localhost:8000/docs to see the interactive API docs.
|
||||||
import os
|
import os
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|
||||||
from llama_cpp.server.app import app
|
from llama_cpp.server.app import app, init_llama
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
init_llama()
|
||||||
|
|
||||||
uvicorn.run(
|
uvicorn.run(
|
||||||
app, host=os.getenv("HOST", "localhost"), port=int(os.getenv("PORT", 8000))
|
app, host=os.getenv("HOST", "localhost"), port=int(os.getenv("PORT", 8000))
|
||||||
|
|
|
@ -13,7 +13,7 @@ from sse_starlette.sse import EventSourceResponse
|
||||||
|
|
||||||
|
|
||||||
class Settings(BaseSettings):
|
class Settings(BaseSettings):
|
||||||
model: str = os.environ["MODEL"]
|
model: str = os.environ.get("MODEL", "null")
|
||||||
n_ctx: int = 2048
|
n_ctx: int = 2048
|
||||||
n_batch: int = 512
|
n_batch: int = 512
|
||||||
n_threads: int = max((os.cpu_count() or 2) // 2, 1)
|
n_threads: int = max((os.cpu_count() or 2) // 2, 1)
|
||||||
|
@ -38,8 +38,13 @@ app.add_middleware(
|
||||||
allow_methods=["*"],
|
allow_methods=["*"],
|
||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
settings = Settings()
|
|
||||||
llama = llama_cpp.Llama(
|
llama: llama_cpp.Llama = None
|
||||||
|
def init_llama(settings: Settings = None):
|
||||||
|
if settings is None:
|
||||||
|
settings = Settings()
|
||||||
|
global llama
|
||||||
|
llama = llama_cpp.Llama(
|
||||||
settings.model,
|
settings.model,
|
||||||
f16_kv=settings.f16_kv,
|
f16_kv=settings.f16_kv,
|
||||||
use_mlock=settings.use_mlock,
|
use_mlock=settings.use_mlock,
|
||||||
|
@ -51,18 +56,16 @@ llama = llama_cpp.Llama(
|
||||||
n_ctx=settings.n_ctx,
|
n_ctx=settings.n_ctx,
|
||||||
last_n_tokens_size=settings.last_n_tokens_size,
|
last_n_tokens_size=settings.last_n_tokens_size,
|
||||||
vocab_only=settings.vocab_only,
|
vocab_only=settings.vocab_only,
|
||||||
)
|
)
|
||||||
if settings.cache:
|
if settings.cache:
|
||||||
cache = llama_cpp.LlamaCache()
|
cache = llama_cpp.LlamaCache()
|
||||||
llama.set_cache(cache)
|
llama.set_cache(cache)
|
||||||
|
|
||||||
llama_lock = Lock()
|
llama_lock = Lock()
|
||||||
|
|
||||||
|
|
||||||
def get_llama():
|
def get_llama():
|
||||||
with llama_lock:
|
with llama_lock:
|
||||||
yield llama
|
yield llama
|
||||||
|
|
||||||
|
|
||||||
class CreateCompletionRequest(BaseModel):
|
class CreateCompletionRequest(BaseModel):
|
||||||
prompt: Union[str, List[str]]
|
prompt: Union[str, List[str]]
|
||||||
suffix: Optional[str] = Field(None)
|
suffix: Optional[str] = Field(None)
|
||||||
|
|
|
@ -132,10 +132,11 @@ def test_utf8(monkeypatch):
|
||||||
|
|
||||||
def test_llama_server():
|
def test_llama_server():
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
import os
|
from llama_cpp.server.app import app, init_llama, Settings
|
||||||
os.environ["MODEL"] = MODEL
|
s = Settings()
|
||||||
os.environ["VOCAB_ONLY"] = "true"
|
s.model = MODEL
|
||||||
from llama_cpp.server.app import app
|
s.vocab_only = True
|
||||||
|
init_llama(s)
|
||||||
client = TestClient(app)
|
client = TestClient(app)
|
||||||
response = client.get("/v1/models")
|
response = client.get("/v1/models")
|
||||||
assert response.json() == {
|
assert response.json() == {
|
||||||
|
|
Loading…
Reference in a new issue