Use async routes

This commit is contained in:
Andrei Betlen 2023-05-27 09:12:58 -04:00
parent c2b59a5f59
commit 80066f0b80

View file

@ -1,12 +1,16 @@
import json
import multiprocessing
from threading import Lock
from typing import List, Optional, Union, Iterator, Dict
from functools import partial
from typing import Iterator, List, Optional, Union, Dict
from typing_extensions import TypedDict, Literal
import llama_cpp
from fastapi import Depends, FastAPI, APIRouter
import anyio
from anyio.streams.memory import MemoryObjectSendStream
from starlette.concurrency import run_in_threadpool, iterate_in_threadpool
from fastapi import Depends, FastAPI, APIRouter, Request
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, BaseSettings, Field, create_model_from_typeddict
from sse_starlette.sse import EventSourceResponse
@ -241,34 +245,48 @@ CreateCompletionResponse = create_model_from_typeddict(llama_cpp.Completion)
"/v1/completions",
response_model=CreateCompletionResponse,
)
def create_completion(
request: CreateCompletionRequest, llama: llama_cpp.Llama = Depends(get_llama)
async def create_completion(
request: Request,
body: CreateCompletionRequest,
llama: llama_cpp.Llama = Depends(get_llama),
):
if isinstance(request.prompt, list):
assert len(request.prompt) <= 1
request.prompt = request.prompt[0] if len(request.prompt) > 0 else ""
if isinstance(body.prompt, list):
assert len(body.prompt) <= 1
body.prompt = body.prompt[0] if len(body.prompt) > 0 else ""
completion_or_chunks = llama(
**request.dict(
exclude = {
"n",
"best_of",
"logit_bias",
"user",
}
)
)
if request.stream:
kwargs = body.dict(exclude=exclude)
if body.stream:
send_chan, recv_chan = anyio.create_memory_object_stream(10)
async def server_sent_events(
chunks: Iterator[llama_cpp.CompletionChunk],
):
for chunk in chunks:
yield dict(data=json.dumps(chunk))
async def event_publisher(inner_send_chan: MemoryObjectSendStream):
async with inner_send_chan:
try:
iterator: Iterator[llama_cpp.CompletionChunk] = await run_in_threadpool(llama, **kwargs) # type: ignore
async for chunk in iterate_in_threadpool(iterator):
await inner_send_chan.send(dict(data=json.dumps(chunk)))
if await request.is_disconnected():
raise anyio.get_cancelled_exc_class()()
await inner_send_chan.send(dict(data="[DONE]"))
except anyio.get_cancelled_exc_class() as e:
print("disconnected")
with anyio.move_on_after(1, shield=True):
print(
f"Disconnected from client (via refresh/close) {request.client}"
)
await inner_send_chan.send(dict(closing=True))
raise e
chunks: Iterator[llama_cpp.CompletionChunk] = completion_or_chunks # type: ignore
return EventSourceResponse(server_sent_events(chunks))
completion: llama_cpp.Completion = completion_or_chunks # type: ignore
return EventSourceResponse(
recv_chan, data_sender_callable=partial(event_publisher, send_chan)
)
else:
completion: llama_cpp.Completion = await run_in_threadpool(llama, **kwargs) # type: ignore
return completion
@ -292,10 +310,12 @@ CreateEmbeddingResponse = create_model_from_typeddict(llama_cpp.Embedding)
"/v1/embeddings",
response_model=CreateEmbeddingResponse,
)
def create_embedding(
async def create_embedding(
request: CreateEmbeddingRequest, llama: llama_cpp.Llama = Depends(get_llama)
):
return llama.create_embedding(**request.dict(exclude={"user"}))
return await run_in_threadpool(
llama.create_embedding, **request.dict(exclude={"user"})
)
class ChatCompletionRequestMessage(BaseModel):
@ -349,35 +369,46 @@ CreateChatCompletionResponse = create_model_from_typeddict(llama_cpp.ChatComplet
"/v1/chat/completions",
response_model=CreateChatCompletionResponse,
)
def create_chat_completion(
request: CreateChatCompletionRequest,
async def create_chat_completion(
request: Request,
body: CreateChatCompletionRequest,
llama: llama_cpp.Llama = Depends(get_llama),
) -> Union[llama_cpp.ChatCompletion, EventSourceResponse]:
completion_or_chunks = llama.create_chat_completion(
**request.dict(
exclude = {
"n",
"logit_bias",
"user",
}
),
kwargs = body.dict(exclude=exclude)
if body.stream:
send_chan, recv_chan = anyio.create_memory_object_stream(10)
async def event_publisher(inner_send_chan: MemoryObjectSendStream):
async with inner_send_chan:
try:
iterator: Iterator[llama_cpp.ChatCompletionChunk] = await run_in_threadpool(llama.create_chat_completion, **kwargs) # type: ignore
async for chat_chunk in iterate_in_threadpool(iterator):
await inner_send_chan.send(dict(data=json.dumps(chat_chunk)))
if await request.is_disconnected():
raise anyio.get_cancelled_exc_class()()
await inner_send_chan.send(dict(data="[DONE]"))
except anyio.get_cancelled_exc_class() as e:
print("disconnected")
with anyio.move_on_after(1, shield=True):
print(
f"Disconnected from client (via refresh/close) {request.client}"
)
if request.stream:
async def server_sent_events(
chat_chunks: Iterator[llama_cpp.ChatCompletionChunk],
):
for chat_chunk in chat_chunks:
yield dict(data=json.dumps(chat_chunk))
yield dict(data="[DONE]")
chunks: Iterator[llama_cpp.ChatCompletionChunk] = completion_or_chunks # type: ignore
await inner_send_chan.send(dict(closing=True))
raise e
return EventSourceResponse(
server_sent_events(chunks),
recv_chan,
data_sender_callable=partial(event_publisher, send_chan),
)
else:
completion: llama_cpp.ChatCompletion = await run_in_threadpool(
llama.create_chat_completion, **kwargs # type: ignore
)
completion: llama_cpp.ChatCompletion = completion_or_chunks # type: ignore
return completion
@ -397,7 +428,7 @@ GetModelResponse = create_model_from_typeddict(ModelList)
@router.get("/v1/models", response_model=GetModelResponse)
def get_models(
async def get_models(
settings: Settings = Depends(get_settings),
llama: llama_cpp.Llama = Depends(get_llama),
) -> ModelList: