Refactor server to use factory
This commit is contained in:
parent
dd9ad1c759
commit
9eafc4c49a
3 changed files with 47 additions and 31 deletions
|
@ -24,10 +24,10 @@ Then visit http://localhost:8000/docs to see the interactive API docs.
|
||||||
import os
|
import os
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|
||||||
from llama_cpp.server.app import app, init_llama
|
from llama_cpp.server.app import create_app
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
init_llama()
|
app = create_app()
|
||||||
|
|
||||||
uvicorn.run(
|
uvicorn.run(
|
||||||
app, host=os.getenv("HOST", "localhost"), port=int(os.getenv("PORT", 8000))
|
app, host=os.getenv("HOST", "localhost"), port=int(os.getenv("PORT", 8000))
|
||||||
|
|
|
@ -2,18 +2,18 @@ import os
|
||||||
import json
|
import json
|
||||||
from threading import Lock
|
from threading import Lock
|
||||||
from typing import List, Optional, Union, Iterator, Dict
|
from typing import List, Optional, Union, Iterator, Dict
|
||||||
from typing_extensions import TypedDict, Literal
|
from typing_extensions import TypedDict, Literal, Annotated
|
||||||
|
|
||||||
import llama_cpp
|
import llama_cpp
|
||||||
|
|
||||||
from fastapi import Depends, FastAPI
|
from fastapi import Depends, FastAPI, APIRouter
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from pydantic import BaseModel, BaseSettings, Field, create_model_from_typeddict
|
from pydantic import BaseModel, BaseSettings, Field, create_model_from_typeddict
|
||||||
from sse_starlette.sse import EventSourceResponse
|
from sse_starlette.sse import EventSourceResponse
|
||||||
|
|
||||||
|
|
||||||
class Settings(BaseSettings):
|
class Settings(BaseSettings):
|
||||||
model: str = os.environ.get("MODEL", "null")
|
model: str
|
||||||
n_ctx: int = 2048
|
n_ctx: int = 2048
|
||||||
n_batch: int = 512
|
n_batch: int = 512
|
||||||
n_threads: int = max((os.cpu_count() or 2) // 2, 1)
|
n_threads: int = max((os.cpu_count() or 2) // 2, 1)
|
||||||
|
@ -27,6 +27,14 @@ class Settings(BaseSettings):
|
||||||
vocab_only: bool = False
|
vocab_only: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
llama: Optional[llama_cpp.Llama] = None
|
||||||
|
|
||||||
|
|
||||||
|
def create_app(settings: Optional[Settings] = None):
|
||||||
|
if settings is None:
|
||||||
|
settings = Settings()
|
||||||
app = FastAPI(
|
app = FastAPI(
|
||||||
title="🦙 llama.cpp Python API",
|
title="🦙 llama.cpp Python API",
|
||||||
version="0.0.1",
|
version="0.0.1",
|
||||||
|
@ -38,14 +46,10 @@ app.add_middleware(
|
||||||
allow_methods=["*"],
|
allow_methods=["*"],
|
||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
|
app.include_router(router)
|
||||||
llama: llama_cpp.Llama = None
|
|
||||||
def init_llama(settings: Settings = None):
|
|
||||||
if settings is None:
|
|
||||||
settings = Settings()
|
|
||||||
global llama
|
global llama
|
||||||
llama = llama_cpp.Llama(
|
llama = llama_cpp.Llama(
|
||||||
settings.model,
|
model_path=settings.model,
|
||||||
f16_kv=settings.f16_kv,
|
f16_kv=settings.f16_kv,
|
||||||
use_mlock=settings.use_mlock,
|
use_mlock=settings.use_mlock,
|
||||||
use_mmap=settings.use_mmap,
|
use_mmap=settings.use_mmap,
|
||||||
|
@ -60,12 +64,17 @@ def init_llama(settings: Settings = None):
|
||||||
if settings.cache:
|
if settings.cache:
|
||||||
cache = llama_cpp.LlamaCache()
|
cache = llama_cpp.LlamaCache()
|
||||||
llama.set_cache(cache)
|
llama.set_cache(cache)
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
llama_lock = Lock()
|
llama_lock = Lock()
|
||||||
|
|
||||||
|
|
||||||
def get_llama():
|
def get_llama():
|
||||||
with llama_lock:
|
with llama_lock:
|
||||||
yield llama
|
yield llama
|
||||||
|
|
||||||
|
|
||||||
class CreateCompletionRequest(BaseModel):
|
class CreateCompletionRequest(BaseModel):
|
||||||
prompt: Union[str, List[str]]
|
prompt: Union[str, List[str]]
|
||||||
suffix: Optional[str] = Field(None)
|
suffix: Optional[str] = Field(None)
|
||||||
|
@ -102,7 +111,7 @@ class CreateCompletionRequest(BaseModel):
|
||||||
CreateCompletionResponse = create_model_from_typeddict(llama_cpp.Completion)
|
CreateCompletionResponse = create_model_from_typeddict(llama_cpp.Completion)
|
||||||
|
|
||||||
|
|
||||||
@app.post(
|
@router.post(
|
||||||
"/v1/completions",
|
"/v1/completions",
|
||||||
response_model=CreateCompletionResponse,
|
response_model=CreateCompletionResponse,
|
||||||
)
|
)
|
||||||
|
@ -148,7 +157,7 @@ class CreateEmbeddingRequest(BaseModel):
|
||||||
CreateEmbeddingResponse = create_model_from_typeddict(llama_cpp.Embedding)
|
CreateEmbeddingResponse = create_model_from_typeddict(llama_cpp.Embedding)
|
||||||
|
|
||||||
|
|
||||||
@app.post(
|
@router.post(
|
||||||
"/v1/embeddings",
|
"/v1/embeddings",
|
||||||
response_model=CreateEmbeddingResponse,
|
response_model=CreateEmbeddingResponse,
|
||||||
)
|
)
|
||||||
|
@ -202,7 +211,7 @@ class CreateChatCompletionRequest(BaseModel):
|
||||||
CreateChatCompletionResponse = create_model_from_typeddict(llama_cpp.ChatCompletion)
|
CreateChatCompletionResponse = create_model_from_typeddict(llama_cpp.ChatCompletion)
|
||||||
|
|
||||||
|
|
||||||
@app.post(
|
@router.post(
|
||||||
"/v1/chat/completions",
|
"/v1/chat/completions",
|
||||||
response_model=CreateChatCompletionResponse,
|
response_model=CreateChatCompletionResponse,
|
||||||
)
|
)
|
||||||
|
@ -256,7 +265,7 @@ class ModelList(TypedDict):
|
||||||
GetModelResponse = create_model_from_typeddict(ModelList)
|
GetModelResponse = create_model_from_typeddict(ModelList)
|
||||||
|
|
||||||
|
|
||||||
@app.get("/v1/models", response_model=GetModelResponse)
|
@router.get("/v1/models", response_model=GetModelResponse)
|
||||||
def get_models() -> ModelList:
|
def get_models() -> ModelList:
|
||||||
return {
|
return {
|
||||||
"object": "list",
|
"object": "list",
|
||||||
|
|
|
@ -24,7 +24,9 @@ def test_llama_patch(monkeypatch):
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
def mock_get_logits(*args, **kwargs):
|
def mock_get_logits(*args, **kwargs):
|
||||||
return (llama_cpp.c_float * n_vocab)(*[llama_cpp.c_float(0) for _ in range(n_vocab)])
|
return (llama_cpp.c_float * n_vocab)(
|
||||||
|
*[llama_cpp.c_float(0) for _ in range(n_vocab)]
|
||||||
|
)
|
||||||
|
|
||||||
monkeypatch.setattr("llama_cpp.llama_cpp.llama_eval", mock_eval)
|
monkeypatch.setattr("llama_cpp.llama_cpp.llama_eval", mock_eval)
|
||||||
monkeypatch.setattr("llama_cpp.llama_cpp.llama_get_logits", mock_get_logits)
|
monkeypatch.setattr("llama_cpp.llama_cpp.llama_get_logits", mock_get_logits)
|
||||||
|
@ -88,6 +90,7 @@ def test_llama_patch(monkeypatch):
|
||||||
def test_llama_pickle():
|
def test_llama_pickle():
|
||||||
import pickle
|
import pickle
|
||||||
import tempfile
|
import tempfile
|
||||||
|
|
||||||
fp = tempfile.TemporaryFile()
|
fp = tempfile.TemporaryFile()
|
||||||
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True)
|
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True)
|
||||||
pickle.dump(llama, fp)
|
pickle.dump(llama, fp)
|
||||||
|
@ -101,6 +104,7 @@ def test_llama_pickle():
|
||||||
|
|
||||||
assert llama.detokenize(llama.tokenize(text)) == text
|
assert llama.detokenize(llama.tokenize(text)) == text
|
||||||
|
|
||||||
|
|
||||||
def test_utf8(monkeypatch):
|
def test_utf8(monkeypatch):
|
||||||
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True)
|
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True)
|
||||||
n_vocab = int(llama_cpp.llama_n_vocab(llama.ctx))
|
n_vocab = int(llama_cpp.llama_n_vocab(llama.ctx))
|
||||||
|
@ -110,7 +114,9 @@ def test_utf8(monkeypatch):
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
def mock_get_logits(*args, **kwargs):
|
def mock_get_logits(*args, **kwargs):
|
||||||
return (llama_cpp.c_float * n_vocab)(*[llama_cpp.c_float(0) for _ in range(n_vocab)])
|
return (llama_cpp.c_float * n_vocab)(
|
||||||
|
*[llama_cpp.c_float(0) for _ in range(n_vocab)]
|
||||||
|
)
|
||||||
|
|
||||||
monkeypatch.setattr("llama_cpp.llama_cpp.llama_eval", mock_eval)
|
monkeypatch.setattr("llama_cpp.llama_cpp.llama_eval", mock_eval)
|
||||||
monkeypatch.setattr("llama_cpp.llama_cpp.llama_get_logits", mock_get_logits)
|
monkeypatch.setattr("llama_cpp.llama_cpp.llama_get_logits", mock_get_logits)
|
||||||
|
@ -143,11 +149,12 @@ def test_utf8(monkeypatch):
|
||||||
|
|
||||||
def test_llama_server():
|
def test_llama_server():
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
from llama_cpp.server.app import app, init_llama, Settings
|
from llama_cpp.server.app import create_app, Settings
|
||||||
s = Settings()
|
|
||||||
s.model = MODEL
|
settings = Settings()
|
||||||
s.vocab_only = True
|
settings.model = MODEL
|
||||||
init_llama(s)
|
settings.vocab_only = True
|
||||||
|
app = create_app(settings)
|
||||||
client = TestClient(app)
|
client = TestClient(app)
|
||||||
response = client.get("/v1/models")
|
response = client.get("/v1/models")
|
||||||
assert response.json() == {
|
assert response.json() == {
|
||||||
|
|
Loading…
Reference in a new issue