Refactor server to use factory

This commit is contained in:
Andrei Betlen 2023-05-01 22:38:46 -04:00
parent dd9ad1c759
commit 9eafc4c49a
3 changed files with 47 additions and 31 deletions

View file

@ -24,10 +24,10 @@ Then visit http://localhost:8000/docs to see the interactive API docs.
import os
import uvicorn
from llama_cpp.server.app import app, init_llama
from llama_cpp.server.app import create_app
if __name__ == "__main__":
init_llama()
app = create_app()
uvicorn.run(
app, host=os.getenv("HOST", "localhost"), port=int(os.getenv("PORT", 8000))

View file

@ -2,18 +2,18 @@ import os
import json
from threading import Lock
from typing import List, Optional, Union, Iterator, Dict
from typing_extensions import TypedDict, Literal
from typing_extensions import TypedDict, Literal, Annotated
import llama_cpp
from fastapi import Depends, FastAPI
from fastapi import Depends, FastAPI, APIRouter
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, BaseSettings, Field, create_model_from_typeddict
from sse_starlette.sse import EventSourceResponse
class Settings(BaseSettings):
model: str = os.environ.get("MODEL", "null")
model: str
n_ctx: int = 2048
n_batch: int = 512
n_threads: int = max((os.cpu_count() or 2) // 2, 1)
@ -27,6 +27,14 @@ class Settings(BaseSettings):
vocab_only: bool = False
router = APIRouter()
llama: Optional[llama_cpp.Llama] = None
def create_app(settings: Optional[Settings] = None):
if settings is None:
settings = Settings()
app = FastAPI(
title="🦙 llama.cpp Python API",
version="0.0.1",
@ -38,14 +46,10 @@ app.add_middleware(
allow_methods=["*"],
allow_headers=["*"],
)
llama: llama_cpp.Llama = None
def init_llama(settings: Settings = None):
if settings is None:
settings = Settings()
app.include_router(router)
global llama
llama = llama_cpp.Llama(
settings.model,
model_path=settings.model,
f16_kv=settings.f16_kv,
use_mlock=settings.use_mlock,
use_mmap=settings.use_mmap,
@ -60,12 +64,17 @@ def init_llama(settings: Settings = None):
if settings.cache:
cache = llama_cpp.LlamaCache()
llama.set_cache(cache)
return app
llama_lock = Lock()
def get_llama():
with llama_lock:
yield llama
class CreateCompletionRequest(BaseModel):
prompt: Union[str, List[str]]
suffix: Optional[str] = Field(None)
@ -102,7 +111,7 @@ class CreateCompletionRequest(BaseModel):
CreateCompletionResponse = create_model_from_typeddict(llama_cpp.Completion)
@app.post(
@router.post(
"/v1/completions",
response_model=CreateCompletionResponse,
)
@ -148,7 +157,7 @@ class CreateEmbeddingRequest(BaseModel):
CreateEmbeddingResponse = create_model_from_typeddict(llama_cpp.Embedding)
@app.post(
@router.post(
"/v1/embeddings",
response_model=CreateEmbeddingResponse,
)
@ -202,7 +211,7 @@ class CreateChatCompletionRequest(BaseModel):
CreateChatCompletionResponse = create_model_from_typeddict(llama_cpp.ChatCompletion)
@app.post(
@router.post(
"/v1/chat/completions",
response_model=CreateChatCompletionResponse,
)
@ -256,7 +265,7 @@ class ModelList(TypedDict):
GetModelResponse = create_model_from_typeddict(ModelList)
@app.get("/v1/models", response_model=GetModelResponse)
@router.get("/v1/models", response_model=GetModelResponse)
def get_models() -> ModelList:
return {
"object": "list",

View file

@ -24,7 +24,9 @@ def test_llama_patch(monkeypatch):
return 0
def mock_get_logits(*args, **kwargs):
return (llama_cpp.c_float * n_vocab)(*[llama_cpp.c_float(0) for _ in range(n_vocab)])
return (llama_cpp.c_float * n_vocab)(
*[llama_cpp.c_float(0) for _ in range(n_vocab)]
)
monkeypatch.setattr("llama_cpp.llama_cpp.llama_eval", mock_eval)
monkeypatch.setattr("llama_cpp.llama_cpp.llama_get_logits", mock_get_logits)
@ -88,6 +90,7 @@ def test_llama_patch(monkeypatch):
def test_llama_pickle():
import pickle
import tempfile
fp = tempfile.TemporaryFile()
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True)
pickle.dump(llama, fp)
@ -101,6 +104,7 @@ def test_llama_pickle():
assert llama.detokenize(llama.tokenize(text)) == text
def test_utf8(monkeypatch):
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True)
n_vocab = int(llama_cpp.llama_n_vocab(llama.ctx))
@ -110,7 +114,9 @@ def test_utf8(monkeypatch):
return 0
def mock_get_logits(*args, **kwargs):
return (llama_cpp.c_float * n_vocab)(*[llama_cpp.c_float(0) for _ in range(n_vocab)])
return (llama_cpp.c_float * n_vocab)(
*[llama_cpp.c_float(0) for _ in range(n_vocab)]
)
monkeypatch.setattr("llama_cpp.llama_cpp.llama_eval", mock_eval)
monkeypatch.setattr("llama_cpp.llama_cpp.llama_get_logits", mock_get_logits)
@ -143,11 +149,12 @@ def test_utf8(monkeypatch):
def test_llama_server():
from fastapi.testclient import TestClient
from llama_cpp.server.app import app, init_llama, Settings
s = Settings()
s.model = MODEL
s.vocab_only = True
init_llama(s)
from llama_cpp.server.app import create_app, Settings
settings = Settings()
settings.model = MODEL
settings.vocab_only = True
app = create_app(settings)
client = TestClient(app)
response = client.get("/v1/models")
assert response.json() == {