Use async routes

This commit is contained in:
Andrei Betlen 2023-05-27 09:12:58 -04:00
parent c2b59a5f59
commit 80066f0b80

View file

@ -1,12 +1,16 @@
import json import json
import multiprocessing import multiprocessing
from threading import Lock from threading import Lock
from typing import List, Optional, Union, Iterator, Dict from functools import partial
from typing import Iterator, List, Optional, Union, Dict
from typing_extensions import TypedDict, Literal from typing_extensions import TypedDict, Literal
import llama_cpp import llama_cpp
from fastapi import Depends, FastAPI, APIRouter import anyio
from anyio.streams.memory import MemoryObjectSendStream
from starlette.concurrency import run_in_threadpool, iterate_in_threadpool
from fastapi import Depends, FastAPI, APIRouter, Request
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, BaseSettings, Field, create_model_from_typeddict from pydantic import BaseModel, BaseSettings, Field, create_model_from_typeddict
from sse_starlette.sse import EventSourceResponse from sse_starlette.sse import EventSourceResponse
@ -241,34 +245,48 @@ CreateCompletionResponse = create_model_from_typeddict(llama_cpp.Completion)
"/v1/completions", "/v1/completions",
response_model=CreateCompletionResponse, response_model=CreateCompletionResponse,
) )
def create_completion( async def create_completion(
request: CreateCompletionRequest, llama: llama_cpp.Llama = Depends(get_llama) request: Request,
body: CreateCompletionRequest,
llama: llama_cpp.Llama = Depends(get_llama),
): ):
if isinstance(request.prompt, list): if isinstance(body.prompt, list):
assert len(request.prompt) <= 1 assert len(body.prompt) <= 1
request.prompt = request.prompt[0] if len(request.prompt) > 0 else "" body.prompt = body.prompt[0] if len(body.prompt) > 0 else ""
completion_or_chunks = llama(
**request.dict(
exclude = { exclude = {
"n", "n",
"best_of", "best_of",
"logit_bias", "logit_bias",
"user", "user",
} }
) kwargs = body.dict(exclude=exclude)
) if body.stream:
if request.stream: send_chan, recv_chan = anyio.create_memory_object_stream(10)
async def server_sent_events( async def event_publisher(inner_send_chan: MemoryObjectSendStream):
chunks: Iterator[llama_cpp.CompletionChunk], async with inner_send_chan:
): try:
for chunk in chunks: iterator: Iterator[llama_cpp.CompletionChunk] = await run_in_threadpool(llama, **kwargs) # type: ignore
yield dict(data=json.dumps(chunk)) async for chunk in iterate_in_threadpool(iterator):
await inner_send_chan.send(dict(data=json.dumps(chunk)))
if await request.is_disconnected():
raise anyio.get_cancelled_exc_class()()
await inner_send_chan.send(dict(data="[DONE]"))
except anyio.get_cancelled_exc_class() as e:
print("disconnected")
with anyio.move_on_after(1, shield=True):
print(
f"Disconnected from client (via refresh/close) {request.client}"
)
await inner_send_chan.send(dict(closing=True))
raise e
chunks: Iterator[llama_cpp.CompletionChunk] = completion_or_chunks # type: ignore return EventSourceResponse(
return EventSourceResponse(server_sent_events(chunks)) recv_chan, data_sender_callable=partial(event_publisher, send_chan)
completion: llama_cpp.Completion = completion_or_chunks # type: ignore )
else:
completion: llama_cpp.Completion = await run_in_threadpool(llama, **kwargs) # type: ignore
return completion return completion
@ -292,10 +310,12 @@ CreateEmbeddingResponse = create_model_from_typeddict(llama_cpp.Embedding)
"/v1/embeddings", "/v1/embeddings",
response_model=CreateEmbeddingResponse, response_model=CreateEmbeddingResponse,
) )
def create_embedding( async def create_embedding(
request: CreateEmbeddingRequest, llama: llama_cpp.Llama = Depends(get_llama) request: CreateEmbeddingRequest, llama: llama_cpp.Llama = Depends(get_llama)
): ):
return llama.create_embedding(**request.dict(exclude={"user"})) return await run_in_threadpool(
llama.create_embedding, **request.dict(exclude={"user"})
)
class ChatCompletionRequestMessage(BaseModel): class ChatCompletionRequestMessage(BaseModel):
@ -349,35 +369,46 @@ CreateChatCompletionResponse = create_model_from_typeddict(llama_cpp.ChatComplet
"/v1/chat/completions", "/v1/chat/completions",
response_model=CreateChatCompletionResponse, response_model=CreateChatCompletionResponse,
) )
def create_chat_completion( async def create_chat_completion(
request: CreateChatCompletionRequest, request: Request,
body: CreateChatCompletionRequest,
llama: llama_cpp.Llama = Depends(get_llama), llama: llama_cpp.Llama = Depends(get_llama),
) -> Union[llama_cpp.ChatCompletion, EventSourceResponse]: ) -> Union[llama_cpp.ChatCompletion, EventSourceResponse]:
completion_or_chunks = llama.create_chat_completion(
**request.dict(
exclude = { exclude = {
"n", "n",
"logit_bias", "logit_bias",
"user", "user",
} }
), kwargs = body.dict(exclude=exclude)
if body.stream:
send_chan, recv_chan = anyio.create_memory_object_stream(10)
async def event_publisher(inner_send_chan: MemoryObjectSendStream):
async with inner_send_chan:
try:
iterator: Iterator[llama_cpp.ChatCompletionChunk] = await run_in_threadpool(llama.create_chat_completion, **kwargs) # type: ignore
async for chat_chunk in iterate_in_threadpool(iterator):
await inner_send_chan.send(dict(data=json.dumps(chat_chunk)))
if await request.is_disconnected():
raise anyio.get_cancelled_exc_class()()
await inner_send_chan.send(dict(data="[DONE]"))
except anyio.get_cancelled_exc_class() as e:
print("disconnected")
with anyio.move_on_after(1, shield=True):
print(
f"Disconnected from client (via refresh/close) {request.client}"
) )
await inner_send_chan.send(dict(closing=True))
if request.stream: raise e
async def server_sent_events(
chat_chunks: Iterator[llama_cpp.ChatCompletionChunk],
):
for chat_chunk in chat_chunks:
yield dict(data=json.dumps(chat_chunk))
yield dict(data="[DONE]")
chunks: Iterator[llama_cpp.ChatCompletionChunk] = completion_or_chunks # type: ignore
return EventSourceResponse( return EventSourceResponse(
server_sent_events(chunks), recv_chan,
data_sender_callable=partial(event_publisher, send_chan),
)
else:
completion: llama_cpp.ChatCompletion = await run_in_threadpool(
llama.create_chat_completion, **kwargs # type: ignore
) )
completion: llama_cpp.ChatCompletion = completion_or_chunks # type: ignore
return completion return completion
@ -397,7 +428,7 @@ GetModelResponse = create_model_from_typeddict(ModelList)
@router.get("/v1/models", response_model=GetModelResponse) @router.get("/v1/models", response_model=GetModelResponse)
def get_models( async def get_models(
settings: Settings = Depends(get_settings), settings: Settings = Depends(get_settings),
llama: llama_cpp.Llama = Depends(get_llama), llama: llama_cpp.Llama = Depends(get_llama),
) -> ModelList: ) -> ModelList: