Use async routes

This commit is contained in:
Andrei Betlen 2023-05-27 09:12:58 -04:00
parent c2b59a5f59
commit 80066f0b80

View file

@ -1,12 +1,16 @@
import json import json
import multiprocessing import multiprocessing
from threading import Lock from threading import Lock
from typing import List, Optional, Union, Iterator, Dict from functools import partial
from typing import Iterator, List, Optional, Union, Dict
from typing_extensions import TypedDict, Literal from typing_extensions import TypedDict, Literal
import llama_cpp import llama_cpp
from fastapi import Depends, FastAPI, APIRouter import anyio
from anyio.streams.memory import MemoryObjectSendStream
from starlette.concurrency import run_in_threadpool, iterate_in_threadpool
from fastapi import Depends, FastAPI, APIRouter, Request
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, BaseSettings, Field, create_model_from_typeddict from pydantic import BaseModel, BaseSettings, Field, create_model_from_typeddict
from sse_starlette.sse import EventSourceResponse from sse_starlette.sse import EventSourceResponse
@ -241,35 +245,49 @@ CreateCompletionResponse = create_model_from_typeddict(llama_cpp.Completion)
"/v1/completions", "/v1/completions",
response_model=CreateCompletionResponse, response_model=CreateCompletionResponse,
) )
def create_completion( async def create_completion(
request: CreateCompletionRequest, llama: llama_cpp.Llama = Depends(get_llama) request: Request,
body: CreateCompletionRequest,
llama: llama_cpp.Llama = Depends(get_llama),
): ):
if isinstance(request.prompt, list): if isinstance(body.prompt, list):
assert len(request.prompt) <= 1 assert len(body.prompt) <= 1
request.prompt = request.prompt[0] if len(request.prompt) > 0 else "" body.prompt = body.prompt[0] if len(body.prompt) > 0 else ""
completion_or_chunks = llama( exclude = {
**request.dict( "n",
exclude={ "best_of",
"n", "logit_bias",
"best_of", "user",
"logit_bias", }
"user", kwargs = body.dict(exclude=exclude)
} if body.stream:
send_chan, recv_chan = anyio.create_memory_object_stream(10)
async def event_publisher(inner_send_chan: MemoryObjectSendStream):
async with inner_send_chan:
try:
iterator: Iterator[llama_cpp.CompletionChunk] = await run_in_threadpool(llama, **kwargs) # type: ignore
async for chunk in iterate_in_threadpool(iterator):
await inner_send_chan.send(dict(data=json.dumps(chunk)))
if await request.is_disconnected():
raise anyio.get_cancelled_exc_class()()
await inner_send_chan.send(dict(data="[DONE]"))
except anyio.get_cancelled_exc_class() as e:
print("disconnected")
with anyio.move_on_after(1, shield=True):
print(
f"Disconnected from client (via refresh/close) {request.client}"
)
await inner_send_chan.send(dict(closing=True))
raise e
return EventSourceResponse(
recv_chan, data_sender_callable=partial(event_publisher, send_chan)
) )
) else:
if request.stream: completion: llama_cpp.Completion = await run_in_threadpool(llama, **kwargs) # type: ignore
return completion
async def server_sent_events(
chunks: Iterator[llama_cpp.CompletionChunk],
):
for chunk in chunks:
yield dict(data=json.dumps(chunk))
chunks: Iterator[llama_cpp.CompletionChunk] = completion_or_chunks # type: ignore
return EventSourceResponse(server_sent_events(chunks))
completion: llama_cpp.Completion = completion_or_chunks # type: ignore
return completion
class CreateEmbeddingRequest(BaseModel): class CreateEmbeddingRequest(BaseModel):
@ -292,10 +310,12 @@ CreateEmbeddingResponse = create_model_from_typeddict(llama_cpp.Embedding)
"/v1/embeddings", "/v1/embeddings",
response_model=CreateEmbeddingResponse, response_model=CreateEmbeddingResponse,
) )
def create_embedding( async def create_embedding(
request: CreateEmbeddingRequest, llama: llama_cpp.Llama = Depends(get_llama) request: CreateEmbeddingRequest, llama: llama_cpp.Llama = Depends(get_llama)
): ):
return llama.create_embedding(**request.dict(exclude={"user"})) return await run_in_threadpool(
llama.create_embedding, **request.dict(exclude={"user"})
)
class ChatCompletionRequestMessage(BaseModel): class ChatCompletionRequestMessage(BaseModel):
@ -349,36 +369,47 @@ CreateChatCompletionResponse = create_model_from_typeddict(llama_cpp.ChatComplet
"/v1/chat/completions", "/v1/chat/completions",
response_model=CreateChatCompletionResponse, response_model=CreateChatCompletionResponse,
) )
def create_chat_completion( async def create_chat_completion(
request: CreateChatCompletionRequest, request: Request,
body: CreateChatCompletionRequest,
llama: llama_cpp.Llama = Depends(get_llama), llama: llama_cpp.Llama = Depends(get_llama),
) -> Union[llama_cpp.ChatCompletion, EventSourceResponse]: ) -> Union[llama_cpp.ChatCompletion, EventSourceResponse]:
completion_or_chunks = llama.create_chat_completion( exclude = {
**request.dict( "n",
exclude={ "logit_bias",
"n", "user",
"logit_bias", }
"user", kwargs = body.dict(exclude=exclude)
} if body.stream:
), send_chan, recv_chan = anyio.create_memory_object_stream(10)
)
if request.stream: async def event_publisher(inner_send_chan: MemoryObjectSendStream):
async with inner_send_chan:
async def server_sent_events( try:
chat_chunks: Iterator[llama_cpp.ChatCompletionChunk], iterator: Iterator[llama_cpp.ChatCompletionChunk] = await run_in_threadpool(llama.create_chat_completion, **kwargs) # type: ignore
): async for chat_chunk in iterate_in_threadpool(iterator):
for chat_chunk in chat_chunks: await inner_send_chan.send(dict(data=json.dumps(chat_chunk)))
yield dict(data=json.dumps(chat_chunk)) if await request.is_disconnected():
yield dict(data="[DONE]") raise anyio.get_cancelled_exc_class()()
await inner_send_chan.send(dict(data="[DONE]"))
chunks: Iterator[llama_cpp.ChatCompletionChunk] = completion_or_chunks # type: ignore except anyio.get_cancelled_exc_class() as e:
print("disconnected")
with anyio.move_on_after(1, shield=True):
print(
f"Disconnected from client (via refresh/close) {request.client}"
)
await inner_send_chan.send(dict(closing=True))
raise e
return EventSourceResponse( return EventSourceResponse(
server_sent_events(chunks), recv_chan,
data_sender_callable=partial(event_publisher, send_chan),
) )
completion: llama_cpp.ChatCompletion = completion_or_chunks # type: ignore else:
return completion completion: llama_cpp.ChatCompletion = await run_in_threadpool(
llama.create_chat_completion, **kwargs # type: ignore
)
return completion
class ModelData(TypedDict): class ModelData(TypedDict):
@ -397,7 +428,7 @@ GetModelResponse = create_model_from_typeddict(ModelList)
@router.get("/v1/models", response_model=GetModelResponse) @router.get("/v1/models", response_model=GetModelResponse)
def get_models( async def get_models(
settings: Settings = Depends(get_settings), settings: Settings = Depends(get_settings),
llama: llama_cpp.Llama = Depends(get_llama), llama: llama_cpp.Llama = Depends(get_llama),
) -> ModelList: ) -> ModelList: