Update fastapi server example

This commit is contained in:
Andrei Betlen 2023-04-05 14:44:26 -04:00
parent 6de2f24aca
commit e1b5b9bb04

View file

@ -13,7 +13,8 @@ Then visit http://localhost:8000/docs to see the interactive API docs.
""" """
import os import os
import json import json
from typing import List, Optional, Literal, Union, Iterator from typing import List, Optional, Literal, Union, Iterator, Dict
from typing_extensions import TypedDict
import llama_cpp import llama_cpp
@ -64,13 +65,24 @@ class CreateCompletionRequest(BaseModel):
max_tokens: int = 16 max_tokens: int = 16
temperature: float = 0.8 temperature: float = 0.8
top_p: float = 0.95 top_p: float = 0.95
logprobs: Optional[int] = Field(None)
echo: bool = False echo: bool = False
stop: List[str] = [] stop: List[str] = []
repeat_penalty: float = 1.1
top_k: int = 40
stream: bool = False stream: bool = False
# ignored or currently unsupported
model: Optional[str] = Field(None)
n: Optional[int] = 1
logprobs: Optional[int] = Field(None)
presence_penalty: Optional[float] = 0
frequency_penalty: Optional[float] = 0
best_of: Optional[int] = 1
logit_bias: Optional[Dict[str, float]] = Field(None)
user: Optional[str] = Field(None)
# llama.cpp specific parameters
top_k: int = 40
repeat_penalty: float = 1.1
class Config: class Config:
schema_extra = { schema_extra = {
"example": { "example": {
@ -91,7 +103,20 @@ def create_completion(request: CreateCompletionRequest):
if request.stream: if request.stream:
chunks: Iterator[llama_cpp.CompletionChunk] = llama(**request.dict()) # type: ignore chunks: Iterator[llama_cpp.CompletionChunk] = llama(**request.dict()) # type: ignore
return EventSourceResponse(dict(data=json.dumps(chunk)) for chunk in chunks) return EventSourceResponse(dict(data=json.dumps(chunk)) for chunk in chunks)
return llama(**request.dict()) return llama(
**request.dict(
exclude={
"model",
"n",
"logprobs",
"frequency_penalty",
"presence_penalty",
"best_of",
"logit_bias",
"user",
}
)
)
class CreateEmbeddingRequest(BaseModel): class CreateEmbeddingRequest(BaseModel):
@ -132,6 +157,16 @@ class CreateChatCompletionRequest(BaseModel):
stream: bool = False stream: bool = False
stop: List[str] = [] stop: List[str] = []
max_tokens: int = 128 max_tokens: int = 128
# ignored or currently unsupported
model: Optional[str] = Field(None)
n: Optional[int] = 1
presence_penalty: Optional[float] = 0
frequency_penalty: Optional[float] = 0
logit_bias: Optional[Dict[str, float]] = Field(None)
user: Optional[str] = Field(None)
# llama.cpp specific parameters
repeat_penalty: float = 1.1 repeat_penalty: float = 1.1
class Config: class Config:
@ -160,7 +195,16 @@ async def create_chat_completion(
request: CreateChatCompletionRequest, request: CreateChatCompletionRequest,
) -> Union[llama_cpp.ChatCompletion, EventSourceResponse]: ) -> Union[llama_cpp.ChatCompletion, EventSourceResponse]:
completion_or_chunks = llama.create_chat_completion( completion_or_chunks = llama.create_chat_completion(
**request.dict(exclude={"model"}), **request.dict(
exclude={
"model",
"n",
"presence_penalty",
"frequency_penalty",
"logit_bias",
"user",
}
),
) )
if request.stream: if request.stream:
@ -179,3 +223,40 @@ async def create_chat_completion(
) )
completion: llama_cpp.ChatCompletion = completion_or_chunks # type: ignore completion: llama_cpp.ChatCompletion = completion_or_chunks # type: ignore
return completion return completion
class ModelData(TypedDict):
id: str
object: Literal["model"]
owned_by: str
permissions: List[str]
class ModelList(TypedDict):
object: Literal["list"]
data: List[ModelData]
GetModelResponse = create_model_from_typeddict(ModelList)
@app.get("/v1/models", response_model=GetModelResponse)
def get_models() -> ModelList:
return {
"object": "list",
"data": [
{
"id": llama.model_path,
"object": "model",
"owned_by": "me",
"permissions": [],
}
],
}
if __name__ == "__main__":
import os
import uvicorn
uvicorn.run(app, host=os.getenv("HOST", "localhost"), port=os.getenv("PORT", 8000))