Refactor server to use factory
This commit is contained in:
parent
dd9ad1c759
commit
9eafc4c49a
3 changed files with 47 additions and 31 deletions
|
@ -24,10 +24,10 @@ Then visit http://localhost:8000/docs to see the interactive API docs.
|
|||
import os
|
||||
import uvicorn
|
||||
|
||||
from llama_cpp.server.app import app, init_llama
|
||||
from llama_cpp.server.app import create_app
|
||||
|
||||
if __name__ == "__main__":
|
||||
init_llama()
|
||||
app = create_app()
|
||||
|
||||
uvicorn.run(
|
||||
app, host=os.getenv("HOST", "localhost"), port=int(os.getenv("PORT", 8000))
|
||||
|
|
|
@ -2,18 +2,18 @@ import os
|
|||
import json
|
||||
from threading import Lock
|
||||
from typing import List, Optional, Union, Iterator, Dict
|
||||
from typing_extensions import TypedDict, Literal
|
||||
from typing_extensions import TypedDict, Literal, Annotated
|
||||
|
||||
import llama_cpp
|
||||
|
||||
from fastapi import Depends, FastAPI
|
||||
from fastapi import Depends, FastAPI, APIRouter
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel, BaseSettings, Field, create_model_from_typeddict
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
model: str = os.environ.get("MODEL", "null")
|
||||
model: str
|
||||
n_ctx: int = 2048
|
||||
n_batch: int = 512
|
||||
n_threads: int = max((os.cpu_count() or 2) // 2, 1)
|
||||
|
@ -27,6 +27,14 @@ class Settings(BaseSettings):
|
|||
vocab_only: bool = False
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
llama: Optional[llama_cpp.Llama] = None
|
||||
|
||||
|
||||
def create_app(settings: Optional[Settings] = None):
|
||||
if settings is None:
|
||||
settings = Settings()
|
||||
app = FastAPI(
|
||||
title="🦙 llama.cpp Python API",
|
||||
version="0.0.1",
|
||||
|
@ -38,14 +46,10 @@ app.add_middleware(
|
|||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
llama: llama_cpp.Llama = None
|
||||
def init_llama(settings: Settings = None):
|
||||
if settings is None:
|
||||
settings = Settings()
|
||||
app.include_router(router)
|
||||
global llama
|
||||
llama = llama_cpp.Llama(
|
||||
settings.model,
|
||||
model_path=settings.model,
|
||||
f16_kv=settings.f16_kv,
|
||||
use_mlock=settings.use_mlock,
|
||||
use_mmap=settings.use_mmap,
|
||||
|
@ -60,12 +64,17 @@ def init_llama(settings: Settings = None):
|
|||
if settings.cache:
|
||||
cache = llama_cpp.LlamaCache()
|
||||
llama.set_cache(cache)
|
||||
return app
|
||||
|
||||
|
||||
llama_lock = Lock()
|
||||
|
||||
|
||||
def get_llama():
|
||||
with llama_lock:
|
||||
yield llama
|
||||
|
||||
|
||||
class CreateCompletionRequest(BaseModel):
|
||||
prompt: Union[str, List[str]]
|
||||
suffix: Optional[str] = Field(None)
|
||||
|
@ -102,7 +111,7 @@ class CreateCompletionRequest(BaseModel):
|
|||
CreateCompletionResponse = create_model_from_typeddict(llama_cpp.Completion)
|
||||
|
||||
|
||||
@app.post(
|
||||
@router.post(
|
||||
"/v1/completions",
|
||||
response_model=CreateCompletionResponse,
|
||||
)
|
||||
|
@ -148,7 +157,7 @@ class CreateEmbeddingRequest(BaseModel):
|
|||
CreateEmbeddingResponse = create_model_from_typeddict(llama_cpp.Embedding)
|
||||
|
||||
|
||||
@app.post(
|
||||
@router.post(
|
||||
"/v1/embeddings",
|
||||
response_model=CreateEmbeddingResponse,
|
||||
)
|
||||
|
@ -202,7 +211,7 @@ class CreateChatCompletionRequest(BaseModel):
|
|||
CreateChatCompletionResponse = create_model_from_typeddict(llama_cpp.ChatCompletion)
|
||||
|
||||
|
||||
@app.post(
|
||||
@router.post(
|
||||
"/v1/chat/completions",
|
||||
response_model=CreateChatCompletionResponse,
|
||||
)
|
||||
|
@ -256,7 +265,7 @@ class ModelList(TypedDict):
|
|||
GetModelResponse = create_model_from_typeddict(ModelList)
|
||||
|
||||
|
||||
@app.get("/v1/models", response_model=GetModelResponse)
|
||||
@router.get("/v1/models", response_model=GetModelResponse)
|
||||
def get_models() -> ModelList:
|
||||
return {
|
||||
"object": "list",
|
||||
|
|
|
@ -24,7 +24,9 @@ def test_llama_patch(monkeypatch):
|
|||
return 0
|
||||
|
||||
def mock_get_logits(*args, **kwargs):
|
||||
return (llama_cpp.c_float * n_vocab)(*[llama_cpp.c_float(0) for _ in range(n_vocab)])
|
||||
return (llama_cpp.c_float * n_vocab)(
|
||||
*[llama_cpp.c_float(0) for _ in range(n_vocab)]
|
||||
)
|
||||
|
||||
monkeypatch.setattr("llama_cpp.llama_cpp.llama_eval", mock_eval)
|
||||
monkeypatch.setattr("llama_cpp.llama_cpp.llama_get_logits", mock_get_logits)
|
||||
|
@ -88,6 +90,7 @@ def test_llama_patch(monkeypatch):
|
|||
def test_llama_pickle():
|
||||
import pickle
|
||||
import tempfile
|
||||
|
||||
fp = tempfile.TemporaryFile()
|
||||
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True)
|
||||
pickle.dump(llama, fp)
|
||||
|
@ -101,6 +104,7 @@ def test_llama_pickle():
|
|||
|
||||
assert llama.detokenize(llama.tokenize(text)) == text
|
||||
|
||||
|
||||
def test_utf8(monkeypatch):
|
||||
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True)
|
||||
n_vocab = int(llama_cpp.llama_n_vocab(llama.ctx))
|
||||
|
@ -110,7 +114,9 @@ def test_utf8(monkeypatch):
|
|||
return 0
|
||||
|
||||
def mock_get_logits(*args, **kwargs):
|
||||
return (llama_cpp.c_float * n_vocab)(*[llama_cpp.c_float(0) for _ in range(n_vocab)])
|
||||
return (llama_cpp.c_float * n_vocab)(
|
||||
*[llama_cpp.c_float(0) for _ in range(n_vocab)]
|
||||
)
|
||||
|
||||
monkeypatch.setattr("llama_cpp.llama_cpp.llama_eval", mock_eval)
|
||||
monkeypatch.setattr("llama_cpp.llama_cpp.llama_get_logits", mock_get_logits)
|
||||
|
@ -143,11 +149,12 @@ def test_utf8(monkeypatch):
|
|||
|
||||
def test_llama_server():
|
||||
from fastapi.testclient import TestClient
|
||||
from llama_cpp.server.app import app, init_llama, Settings
|
||||
s = Settings()
|
||||
s.model = MODEL
|
||||
s.vocab_only = True
|
||||
init_llama(s)
|
||||
from llama_cpp.server.app import create_app, Settings
|
||||
|
||||
settings = Settings()
|
||||
settings.model = MODEL
|
||||
settings.vocab_only = True
|
||||
app = create_app(settings)
|
||||
client = TestClient(app)
|
||||
response = client.get("/v1/models")
|
||||
assert response.json() == {
|
||||
|
|
Loading…
Reference in a new issue