Merge pull request #125 from Stonelinks/app-server-module-importable

Make app server module importable
This commit is contained in:
Andrei 2023-05-01 11:31:08 -04:00 committed by GitHub
commit 79ba9ed98d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 401 additions and 270 deletions

View file

View file

@ -5,283 +5,29 @@ To run this example:
```bash ```bash
pip install fastapi uvicorn sse-starlette pip install fastapi uvicorn sse-starlette
export MODEL=../models/7B/... export MODEL=../models/7B/...
uvicorn fastapi_server_chat:app --reload ```
Then run:
```
uvicorn llama_cpp.server.app:app --reload
```
or
```
python3 -m llama_cpp.server
``` ```
Then visit http://localhost:8000/docs to see the interactive API docs. Then visit http://localhost:8000/docs to see the interactive API docs.
""" """
import os import os
import json import uvicorn
from threading import Lock
from typing import List, Optional, Literal, Union, Iterator, Dict
from typing_extensions import TypedDict
import llama_cpp
from fastapi import Depends, FastAPI
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, BaseSettings, Field, create_model_from_typeddict
from sse_starlette.sse import EventSourceResponse
class Settings(BaseSettings):
model: str
n_ctx: int = 2048
n_batch: int = 512
n_threads: int = max((os.cpu_count() or 2) // 2, 1)
f16_kv: bool = True
use_mlock: bool = False # This causes a silent failure on platforms that don't support mlock (e.g. Windows) took forever to figure out...
use_mmap: bool = True
embedding: bool = True
last_n_tokens_size: int = 64
logits_all: bool = False
cache: bool = False # WARNING: This is an experimental feature
app = FastAPI(
title="🦙 llama.cpp Python API",
version="0.0.1",
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
settings = Settings()
llama = llama_cpp.Llama(
settings.model,
f16_kv=settings.f16_kv,
use_mlock=settings.use_mlock,
use_mmap=settings.use_mmap,
embedding=settings.embedding,
logits_all=settings.logits_all,
n_threads=settings.n_threads,
n_batch=settings.n_batch,
n_ctx=settings.n_ctx,
last_n_tokens_size=settings.last_n_tokens_size,
)
if settings.cache:
cache = llama_cpp.LlamaCache()
llama.set_cache(cache)
llama_lock = Lock()
def get_llama():
with llama_lock:
yield llama
class CreateCompletionRequest(BaseModel):
prompt: Union[str, List[str]]
suffix: Optional[str] = Field(None)
max_tokens: int = 16
temperature: float = 0.8
top_p: float = 0.95
echo: bool = False
stop: Optional[List[str]] = []
stream: bool = False
# ignored or currently unsupported
model: Optional[str] = Field(None)
n: Optional[int] = 1
logprobs: Optional[int] = Field(None)
presence_penalty: Optional[float] = 0
frequency_penalty: Optional[float] = 0
best_of: Optional[int] = 1
logit_bias: Optional[Dict[str, float]] = Field(None)
user: Optional[str] = Field(None)
# llama.cpp specific parameters
top_k: int = 40
repeat_penalty: float = 1.1
class Config:
schema_extra = {
"example": {
"prompt": "\n\n### Instructions:\nWhat is the capital of France?\n\n### Response:\n",
"stop": ["\n", "###"],
}
}
CreateCompletionResponse = create_model_from_typeddict(llama_cpp.Completion)
@app.post(
"/v1/completions",
response_model=CreateCompletionResponse,
)
def create_completion(
request: CreateCompletionRequest, llama: llama_cpp.Llama = Depends(get_llama)
):
if isinstance(request.prompt, list):
request.prompt = "".join(request.prompt)
completion_or_chunks = llama(
**request.dict(
exclude={
"model",
"n",
"frequency_penalty",
"presence_penalty",
"best_of",
"logit_bias",
"user",
}
)
)
if request.stream:
chunks: Iterator[llama_cpp.CompletionChunk] = completion_or_chunks # type: ignore
return EventSourceResponse(dict(data=json.dumps(chunk)) for chunk in chunks)
completion: llama_cpp.Completion = completion_or_chunks # type: ignore
return completion
class CreateEmbeddingRequest(BaseModel):
model: Optional[str]
input: str
user: Optional[str]
class Config:
schema_extra = {
"example": {
"input": "The food was delicious and the waiter...",
}
}
CreateEmbeddingResponse = create_model_from_typeddict(llama_cpp.Embedding)
@app.post(
"/v1/embeddings",
response_model=CreateEmbeddingResponse,
)
def create_embedding(
request: CreateEmbeddingRequest, llama: llama_cpp.Llama = Depends(get_llama)
):
return llama.create_embedding(**request.dict(exclude={"model", "user"}))
class ChatCompletionRequestMessage(BaseModel):
role: Union[Literal["system"], Literal["user"], Literal["assistant"]]
content: str
user: Optional[str] = None
class CreateChatCompletionRequest(BaseModel):
model: Optional[str]
messages: List[ChatCompletionRequestMessage]
temperature: float = 0.8
top_p: float = 0.95
stream: bool = False
stop: Optional[List[str]] = []
max_tokens: int = 128
# ignored or currently unsupported
model: Optional[str] = Field(None)
n: Optional[int] = 1
presence_penalty: Optional[float] = 0
frequency_penalty: Optional[float] = 0
logit_bias: Optional[Dict[str, float]] = Field(None)
user: Optional[str] = Field(None)
# llama.cpp specific parameters
repeat_penalty: float = 1.1
class Config:
schema_extra = {
"example": {
"messages": [
ChatCompletionRequestMessage(
role="system", content="You are a helpful assistant."
),
ChatCompletionRequestMessage(
role="user", content="What is the capital of France?"
),
]
}
}
CreateChatCompletionResponse = create_model_from_typeddict(llama_cpp.ChatCompletion)
@app.post(
"/v1/chat/completions",
response_model=CreateChatCompletionResponse,
)
def create_chat_completion(
request: CreateChatCompletionRequest,
llama: llama_cpp.Llama = Depends(get_llama),
) -> Union[llama_cpp.ChatCompletion, EventSourceResponse]:
completion_or_chunks = llama.create_chat_completion(
**request.dict(
exclude={
"model",
"n",
"presence_penalty",
"frequency_penalty",
"logit_bias",
"user",
}
),
)
if request.stream:
async def server_sent_events(
chat_chunks: Iterator[llama_cpp.ChatCompletionChunk],
):
for chat_chunk in chat_chunks:
yield dict(data=json.dumps(chat_chunk))
yield dict(data="[DONE]")
chunks: Iterator[llama_cpp.ChatCompletionChunk] = completion_or_chunks # type: ignore
return EventSourceResponse(
server_sent_events(chunks),
)
completion: llama_cpp.ChatCompletion = completion_or_chunks # type: ignore
return completion
class ModelData(TypedDict):
id: str
object: Literal["model"]
owned_by: str
permissions: List[str]
class ModelList(TypedDict):
object: Literal["list"]
data: List[ModelData]
GetModelResponse = create_model_from_typeddict(ModelList)
@app.get("/v1/models", response_model=GetModelResponse)
def get_models() -> ModelList:
return {
"object": "list",
"data": [
{
"id": llama.model_path,
"object": "model",
"owned_by": "me",
"permissions": [],
}
],
}
from llama_cpp.server.app import app, init_llama
if __name__ == "__main__": if __name__ == "__main__":
import os init_llama()
import uvicorn
uvicorn.run( uvicorn.run(
app, host=os.getenv("HOST", "localhost"), port=int(os.getenv("PORT", 8000)) app, host=os.getenv("HOST", "localhost"), port=int(os.getenv("PORT", 8000))

271
llama_cpp/server/app.py Normal file
View file

@ -0,0 +1,271 @@
import os
import json
from threading import Lock
from typing import List, Optional, Literal, Union, Iterator, Dict
from typing_extensions import TypedDict
import llama_cpp
from fastapi import Depends, FastAPI
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, BaseSettings, Field, create_model_from_typeddict
from sse_starlette.sse import EventSourceResponse
class Settings(BaseSettings):
model: str = os.environ.get("MODEL", "null")
n_ctx: int = 2048
n_batch: int = 512
n_threads: int = max((os.cpu_count() or 2) // 2, 1)
f16_kv: bool = True
use_mlock: bool = False # This causes a silent failure on platforms that don't support mlock (e.g. Windows) took forever to figure out...
use_mmap: bool = True
embedding: bool = True
last_n_tokens_size: int = 64
logits_all: bool = False
cache: bool = False # WARNING: This is an experimental feature
vocab_only: bool = False
app = FastAPI(
title="🦙 llama.cpp Python API",
version="0.0.1",
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
llama: llama_cpp.Llama = None
def init_llama(settings: Settings = None):
if settings is None:
settings = Settings()
global llama
llama = llama_cpp.Llama(
settings.model,
f16_kv=settings.f16_kv,
use_mlock=settings.use_mlock,
use_mmap=settings.use_mmap,
embedding=settings.embedding,
logits_all=settings.logits_all,
n_threads=settings.n_threads,
n_batch=settings.n_batch,
n_ctx=settings.n_ctx,
last_n_tokens_size=settings.last_n_tokens_size,
vocab_only=settings.vocab_only,
)
if settings.cache:
cache = llama_cpp.LlamaCache()
llama.set_cache(cache)
llama_lock = Lock()
def get_llama():
with llama_lock:
yield llama
class CreateCompletionRequest(BaseModel):
prompt: Union[str, List[str]]
suffix: Optional[str] = Field(None)
max_tokens: int = 16
temperature: float = 0.8
top_p: float = 0.95
echo: bool = False
stop: Optional[List[str]] = []
stream: bool = False
# ignored or currently unsupported
model: Optional[str] = Field(None)
n: Optional[int] = 1
logprobs: Optional[int] = Field(None)
presence_penalty: Optional[float] = 0
frequency_penalty: Optional[float] = 0
best_of: Optional[int] = 1
logit_bias: Optional[Dict[str, float]] = Field(None)
user: Optional[str] = Field(None)
# llama.cpp specific parameters
top_k: int = 40
repeat_penalty: float = 1.1
class Config:
schema_extra = {
"example": {
"prompt": "\n\n### Instructions:\nWhat is the capital of France?\n\n### Response:\n",
"stop": ["\n", "###"],
}
}
CreateCompletionResponse = create_model_from_typeddict(llama_cpp.Completion)
@app.post(
"/v1/completions",
response_model=CreateCompletionResponse,
)
def create_completion(
request: CreateCompletionRequest, llama: llama_cpp.Llama = Depends(get_llama)
):
if isinstance(request.prompt, list):
request.prompt = "".join(request.prompt)
completion_or_chunks = llama(
**request.dict(
exclude={
"model",
"n",
"frequency_penalty",
"presence_penalty",
"best_of",
"logit_bias",
"user",
}
)
)
if request.stream:
chunks: Iterator[llama_cpp.CompletionChunk] = completion_or_chunks # type: ignore
return EventSourceResponse(dict(data=json.dumps(chunk)) for chunk in chunks)
completion: llama_cpp.Completion = completion_or_chunks # type: ignore
return completion
class CreateEmbeddingRequest(BaseModel):
model: Optional[str]
input: str
user: Optional[str]
class Config:
schema_extra = {
"example": {
"input": "The food was delicious and the waiter...",
}
}
CreateEmbeddingResponse = create_model_from_typeddict(llama_cpp.Embedding)
@app.post(
"/v1/embeddings",
response_model=CreateEmbeddingResponse,
)
def create_embedding(
request: CreateEmbeddingRequest, llama: llama_cpp.Llama = Depends(get_llama)
):
return llama.create_embedding(**request.dict(exclude={"model", "user"}))
class ChatCompletionRequestMessage(BaseModel):
role: Union[Literal["system"], Literal["user"], Literal["assistant"]]
content: str
user: Optional[str] = None
class CreateChatCompletionRequest(BaseModel):
model: Optional[str]
messages: List[ChatCompletionRequestMessage]
temperature: float = 0.8
top_p: float = 0.95
stream: bool = False
stop: Optional[List[str]] = []
max_tokens: int = 128
# ignored or currently unsupported
model: Optional[str] = Field(None)
n: Optional[int] = 1
presence_penalty: Optional[float] = 0
frequency_penalty: Optional[float] = 0
logit_bias: Optional[Dict[str, float]] = Field(None)
user: Optional[str] = Field(None)
# llama.cpp specific parameters
repeat_penalty: float = 1.1
class Config:
schema_extra = {
"example": {
"messages": [
ChatCompletionRequestMessage(
role="system", content="You are a helpful assistant."
),
ChatCompletionRequestMessage(
role="user", content="What is the capital of France?"
),
]
}
}
CreateChatCompletionResponse = create_model_from_typeddict(llama_cpp.ChatCompletion)
@app.post(
"/v1/chat/completions",
response_model=CreateChatCompletionResponse,
)
def create_chat_completion(
request: CreateChatCompletionRequest,
llama: llama_cpp.Llama = Depends(get_llama),
) -> Union[llama_cpp.ChatCompletion, EventSourceResponse]:
completion_or_chunks = llama.create_chat_completion(
**request.dict(
exclude={
"model",
"n",
"presence_penalty",
"frequency_penalty",
"logit_bias",
"user",
}
),
)
if request.stream:
async def server_sent_events(
chat_chunks: Iterator[llama_cpp.ChatCompletionChunk],
):
for chat_chunk in chat_chunks:
yield dict(data=json.dumps(chat_chunk))
yield dict(data="[DONE]")
chunks: Iterator[llama_cpp.ChatCompletionChunk] = completion_or_chunks # type: ignore
return EventSourceResponse(
server_sent_events(chunks),
)
completion: llama_cpp.ChatCompletion = completion_or_chunks # type: ignore
return completion
class ModelData(TypedDict):
id: str
object: Literal["model"]
owned_by: str
permissions: List[str]
class ModelList(TypedDict):
object: Literal["list"]
data: List[ModelData]
GetModelResponse = create_model_from_typeddict(ModelList)
@app.get("/v1/models", response_model=GetModelResponse)
def get_models() -> ModelList:
return {
"object": "list",
"data": [
{
"id": llama.model_path,
"object": "model",
"owned_by": "me",
"permissions": [],
}
],
}

95
poetry.lock generated
View file

@ -1,4 +1,25 @@
# This file is automatically @generated by Poetry 1.4.1 and should not be changed by hand. # This file is automatically @generated by Poetry 1.4.2 and should not be changed by hand.
[[package]]
name = "anyio"
version = "3.6.2"
description = "High level compatibility layer for multiple asynchronous event loop implementations"
category = "dev"
optional = false
python-versions = ">=3.6.2"
files = [
{file = "anyio-3.6.2-py3-none-any.whl", hash = "sha256:fbbe32bd270d2a2ef3ed1c5d45041250284e31fc0a4df4a5a6071842051a51e3"},
{file = "anyio-3.6.2.tar.gz", hash = "sha256:25ea0d673ae30af41a0c442f81cf3b38c7e79fdc7b60335a4c14e05eb0947421"},
]
[package.dependencies]
idna = ">=2.8"
sniffio = ">=1.1"
[package.extras]
doc = ["packaging", "sphinx-autodoc-typehints (>=1.2.0)", "sphinx-rtd-theme"]
test = ["contextlib2", "coverage[toml] (>=4.5)", "hypothesis (>=4.0)", "mock (>=4)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "uvloop (<0.15)", "uvloop (>=0.15)"]
trio = ["trio (>=0.16,<0.22)"]
[[package]] [[package]]
name = "attrs" name = "attrs"
@ -398,6 +419,64 @@ colorama = ">=0.4"
[package.extras] [package.extras]
async = ["aiofiles (>=0.7,<1.0)"] async = ["aiofiles (>=0.7,<1.0)"]
[[package]]
name = "h11"
version = "0.14.0"
description = "A pure-Python, bring-your-own-I/O implementation of HTTP/1.1"
category = "dev"
optional = false
python-versions = ">=3.7"
files = [
{file = "h11-0.14.0-py3-none-any.whl", hash = "sha256:e3fe4ac4b851c468cc8363d500db52c2ead036020723024a109d37346efaa761"},
{file = "h11-0.14.0.tar.gz", hash = "sha256:8f19fbbe99e72420ff35c00b27a34cb9937e902a8b810e2c88300c6f0a3b699d"},
]
[[package]]
name = "httpcore"
version = "0.17.0"
description = "A minimal low-level HTTP client."
category = "dev"
optional = false
python-versions = ">=3.7"
files = [
{file = "httpcore-0.17.0-py3-none-any.whl", hash = "sha256:0fdfea45e94f0c9fd96eab9286077f9ff788dd186635ae61b312693e4d943599"},
{file = "httpcore-0.17.0.tar.gz", hash = "sha256:cc045a3241afbf60ce056202301b4d8b6af08845e3294055eb26b09913ef903c"},
]
[package.dependencies]
anyio = ">=3.0,<5.0"
certifi = "*"
h11 = ">=0.13,<0.15"
sniffio = ">=1.0.0,<2.0.0"
[package.extras]
http2 = ["h2 (>=3,<5)"]
socks = ["socksio (>=1.0.0,<2.0.0)"]
[[package]]
name = "httpx"
version = "0.24.0"
description = "The next generation HTTP client."
category = "dev"
optional = false
python-versions = ">=3.7"
files = [
{file = "httpx-0.24.0-py3-none-any.whl", hash = "sha256:447556b50c1921c351ea54b4fe79d91b724ed2b027462ab9a329465d147d5a4e"},
{file = "httpx-0.24.0.tar.gz", hash = "sha256:507d676fc3e26110d41df7d35ebd8b3b8585052450f4097401c9be59d928c63e"},
]
[package.dependencies]
certifi = "*"
httpcore = ">=0.15.0,<0.18.0"
idna = "*"
sniffio = "*"
[package.extras]
brotli = ["brotli", "brotlicffi"]
cli = ["click (>=8.0.0,<9.0.0)", "pygments (>=2.0.0,<3.0.0)", "rich (>=10,<14)"]
http2 = ["h2 (>=3,<5)"]
socks = ["socksio (>=1.0.0,<2.0.0)"]
[[package]] [[package]]
name = "idna" name = "idna"
version = "3.4" version = "3.4"
@ -1232,6 +1311,18 @@ files = [
{file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"}, {file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"},
] ]
[[package]]
name = "sniffio"
version = "1.3.0"
description = "Sniff out which async library your code is running under"
category = "dev"
optional = false
python-versions = ">=3.7"
files = [
{file = "sniffio-1.3.0-py3-none-any.whl", hash = "sha256:eecefdce1e5bbfb7ad2eeaabf7c1eeb404d7757c379bd1f7e5cce9d8bf425384"},
{file = "sniffio-1.3.0.tar.gz", hash = "sha256:e60305c5e5d314f5389259b7f22aaa33d8f7dee49763119234af3755c55b9101"},
]
[[package]] [[package]]
name = "tomli" name = "tomli"
version = "2.0.1" version = "2.0.1"
@ -1367,4 +1458,4 @@ testing = ["big-O", "flake8 (<5)", "jaraco.functools", "jaraco.itertools", "more
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = "^3.8.1" python-versions = "^3.8.1"
content-hash = "cc9babcdfdc3679a4d84f68912408a005619a576947b059146ed1b428850ece9" content-hash = "aa15e57300668bd23c051b4cd87bec4c1a58dcccd2f2b4767579fea7f2c5fa41"

View file

@ -24,6 +24,7 @@ mkdocs = "^1.4.2"
mkdocstrings = {extras = ["python"], version = "^0.20.0"} mkdocstrings = {extras = ["python"], version = "^0.20.0"}
mkdocs-material = "^9.1.4" mkdocs-material = "^9.1.4"
pytest = "^7.2.2" pytest = "^7.2.2"
httpx = "^0.24.0"
[build-system] [build-system]
requires = [ requires = [

View file

@ -128,3 +128,25 @@ def test_utf8(monkeypatch):
n = 0 # reset n = 0 # reset
completion = llama.create_completion("", max_tokens=1) completion = llama.create_completion("", max_tokens=1)
assert completion["choices"][0]["text"] == "" assert completion["choices"][0]["text"] == ""
def test_llama_server():
from fastapi.testclient import TestClient
from llama_cpp.server.app import app, init_llama, Settings
s = Settings()
s.model = MODEL
s.vocab_only = True
init_llama(s)
client = TestClient(app)
response = client.get("/v1/models")
assert response.json() == {
"object": "list",
"data": [
{
"id": MODEL,
"object": "model",
"owned_by": "me",
"permissions": [],
}
],
}