diff --git a/llama_cpp/server/app.py b/llama_cpp/server/app.py index 882c902..ea9dec4 100644 --- a/llama_cpp/server/app.py +++ b/llama_cpp/server/app.py @@ -1,12 +1,16 @@ import json import multiprocessing from threading import Lock -from typing import List, Optional, Union, Iterator, Dict +from functools import partial +from typing import Iterator, List, Optional, Union, Dict from typing_extensions import TypedDict, Literal import llama_cpp -from fastapi import Depends, FastAPI, APIRouter +import anyio +from anyio.streams.memory import MemoryObjectSendStream +from starlette.concurrency import run_in_threadpool, iterate_in_threadpool +from fastapi import Depends, FastAPI, APIRouter, Request from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, BaseSettings, Field, create_model_from_typeddict from sse_starlette.sse import EventSourceResponse @@ -241,35 +245,49 @@ CreateCompletionResponse = create_model_from_typeddict(llama_cpp.Completion) "/v1/completions", response_model=CreateCompletionResponse, ) -def create_completion( - request: CreateCompletionRequest, llama: llama_cpp.Llama = Depends(get_llama) +async def create_completion( + request: Request, + body: CreateCompletionRequest, + llama: llama_cpp.Llama = Depends(get_llama), ): - if isinstance(request.prompt, list): - assert len(request.prompt) <= 1 - request.prompt = request.prompt[0] if len(request.prompt) > 0 else "" + if isinstance(body.prompt, list): + assert len(body.prompt) <= 1 + body.prompt = body.prompt[0] if len(body.prompt) > 0 else "" - completion_or_chunks = llama( - **request.dict( - exclude={ - "n", - "best_of", - "logit_bias", - "user", - } + exclude = { + "n", + "best_of", + "logit_bias", + "user", + } + kwargs = body.dict(exclude=exclude) + if body.stream: + send_chan, recv_chan = anyio.create_memory_object_stream(10) + + async def event_publisher(inner_send_chan: MemoryObjectSendStream): + async with inner_send_chan: + try: + iterator: Iterator[llama_cpp.CompletionChunk] = await run_in_threadpool(llama, **kwargs) # type: ignore + async for chunk in iterate_in_threadpool(iterator): + await inner_send_chan.send(dict(data=json.dumps(chunk))) + if await request.is_disconnected(): + raise anyio.get_cancelled_exc_class()() + await inner_send_chan.send(dict(data="[DONE]")) + except anyio.get_cancelled_exc_class() as e: + print("disconnected") + with anyio.move_on_after(1, shield=True): + print( + f"Disconnected from client (via refresh/close) {request.client}" + ) + await inner_send_chan.send(dict(closing=True)) + raise e + + return EventSourceResponse( + recv_chan, data_sender_callable=partial(event_publisher, send_chan) ) - ) - if request.stream: - - async def server_sent_events( - chunks: Iterator[llama_cpp.CompletionChunk], - ): - for chunk in chunks: - yield dict(data=json.dumps(chunk)) - - chunks: Iterator[llama_cpp.CompletionChunk] = completion_or_chunks # type: ignore - return EventSourceResponse(server_sent_events(chunks)) - completion: llama_cpp.Completion = completion_or_chunks # type: ignore - return completion + else: + completion: llama_cpp.Completion = await run_in_threadpool(llama, **kwargs) # type: ignore + return completion class CreateEmbeddingRequest(BaseModel): @@ -292,10 +310,12 @@ CreateEmbeddingResponse = create_model_from_typeddict(llama_cpp.Embedding) "/v1/embeddings", response_model=CreateEmbeddingResponse, ) -def create_embedding( +async def create_embedding( request: CreateEmbeddingRequest, llama: llama_cpp.Llama = Depends(get_llama) ): - return llama.create_embedding(**request.dict(exclude={"user"})) + return await run_in_threadpool( + llama.create_embedding, **request.dict(exclude={"user"}) + ) class ChatCompletionRequestMessage(BaseModel): @@ -349,36 +369,47 @@ CreateChatCompletionResponse = create_model_from_typeddict(llama_cpp.ChatComplet "/v1/chat/completions", response_model=CreateChatCompletionResponse, ) -def create_chat_completion( - request: CreateChatCompletionRequest, +async def create_chat_completion( + request: Request, + body: CreateChatCompletionRequest, llama: llama_cpp.Llama = Depends(get_llama), ) -> Union[llama_cpp.ChatCompletion, EventSourceResponse]: - completion_or_chunks = llama.create_chat_completion( - **request.dict( - exclude={ - "n", - "logit_bias", - "user", - } - ), - ) + exclude = { + "n", + "logit_bias", + "user", + } + kwargs = body.dict(exclude=exclude) + if body.stream: + send_chan, recv_chan = anyio.create_memory_object_stream(10) - if request.stream: - - async def server_sent_events( - chat_chunks: Iterator[llama_cpp.ChatCompletionChunk], - ): - for chat_chunk in chat_chunks: - yield dict(data=json.dumps(chat_chunk)) - yield dict(data="[DONE]") - - chunks: Iterator[llama_cpp.ChatCompletionChunk] = completion_or_chunks # type: ignore + async def event_publisher(inner_send_chan: MemoryObjectSendStream): + async with inner_send_chan: + try: + iterator: Iterator[llama_cpp.ChatCompletionChunk] = await run_in_threadpool(llama.create_chat_completion, **kwargs) # type: ignore + async for chat_chunk in iterate_in_threadpool(iterator): + await inner_send_chan.send(dict(data=json.dumps(chat_chunk))) + if await request.is_disconnected(): + raise anyio.get_cancelled_exc_class()() + await inner_send_chan.send(dict(data="[DONE]")) + except anyio.get_cancelled_exc_class() as e: + print("disconnected") + with anyio.move_on_after(1, shield=True): + print( + f"Disconnected from client (via refresh/close) {request.client}" + ) + await inner_send_chan.send(dict(closing=True)) + raise e return EventSourceResponse( - server_sent_events(chunks), + recv_chan, + data_sender_callable=partial(event_publisher, send_chan), ) - completion: llama_cpp.ChatCompletion = completion_or_chunks # type: ignore - return completion + else: + completion: llama_cpp.ChatCompletion = await run_in_threadpool( + llama.create_chat_completion, **kwargs # type: ignore + ) + return completion class ModelData(TypedDict): @@ -397,7 +428,7 @@ GetModelResponse = create_model_from_typeddict(ModelList) @router.get("/v1/models", response_model=GetModelResponse) -def get_models( +async def get_models( settings: Settings = Depends(get_settings), llama: llama_cpp.Llama = Depends(get_llama), ) -> ModelList: