From 4daf77e546da78d10b8f6969cddea03d4508ff21 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Wed, 13 Sep 2023 21:23:23 -0400 Subject: [PATCH] Format --- llama_cpp/server/app.py | 50 ++++++++++++++++++----------------------- 1 file changed, 22 insertions(+), 28 deletions(-) diff --git a/llama_cpp/server/app.py b/llama_cpp/server/app.py index f4551b8..9e29555 100644 --- a/llama_cpp/server/app.py +++ b/llama_cpp/server/app.py @@ -144,10 +144,8 @@ class ErrorResponseFormatters: @staticmethod def context_length_exceeded( - request: Union[ - "CreateCompletionRequest", "CreateChatCompletionRequest" - ], - match, # type: Match[str] # type: ignore + request: Union["CreateCompletionRequest", "CreateChatCompletionRequest"], + match, # type: Match[str] # type: ignore ) -> Tuple[int, ErrorResponse]: """Formatter for context length exceeded error""" @@ -184,10 +182,8 @@ class ErrorResponseFormatters: @staticmethod def model_not_found( - request: Union[ - "CreateCompletionRequest", "CreateChatCompletionRequest" - ], - match # type: Match[str] # type: ignore + request: Union["CreateCompletionRequest", "CreateChatCompletionRequest"], + match, # type: Match[str] # type: ignore ) -> Tuple[int, ErrorResponse]: """Formatter for model_not_found error""" @@ -315,12 +311,7 @@ def create_app(settings: Optional[Settings] = None): settings = Settings() middleware = [ - Middleware( - RawContextMiddleware, - plugins=( - plugins.RequestIdPlugin(), - ) - ) + Middleware(RawContextMiddleware, plugins=(plugins.RequestIdPlugin(),)) ] app = FastAPI( middleware=middleware, @@ -426,12 +417,13 @@ async def get_event_publisher( except anyio.get_cancelled_exc_class() as e: print("disconnected") with anyio.move_on_after(1, shield=True): - print( - f"Disconnected from client (via refresh/close) {request.client}" - ) + print(f"Disconnected from client (via refresh/close) {request.client}") raise e -model_field = Field(description="The model to use for generating completions.", default=None) + +model_field = Field( + description="The model to use for generating completions.", default=None +) max_tokens_field = Field( default=16, ge=1, description="The maximum number of tokens to generate." @@ -625,9 +617,9 @@ async def create_completion( ] ) - iterator_or_completion: Union[llama_cpp.Completion, Iterator[ - llama_cpp.CompletionChunk - ]] = await run_in_threadpool(llama, **kwargs) + iterator_or_completion: Union[ + llama_cpp.Completion, Iterator[llama_cpp.CompletionChunk] + ] = await run_in_threadpool(llama, **kwargs) if isinstance(iterator_or_completion, Iterator): # EAFP: It's easier to ask for forgiveness than permission @@ -641,12 +633,13 @@ async def create_completion( send_chan, recv_chan = anyio.create_memory_object_stream(10) return EventSourceResponse( - recv_chan, data_sender_callable=partial( # type: ignore + recv_chan, + data_sender_callable=partial( # type: ignore get_event_publisher, request=request, inner_send_chan=send_chan, iterator=iterator(), - ) + ), ) else: return iterator_or_completion @@ -762,9 +755,9 @@ async def create_chat_completion( ] ) - iterator_or_completion: Union[llama_cpp.ChatCompletion, Iterator[ - llama_cpp.ChatCompletionChunk - ]] = await run_in_threadpool(llama.create_chat_completion, **kwargs) + iterator_or_completion: Union[ + llama_cpp.ChatCompletion, Iterator[llama_cpp.ChatCompletionChunk] + ] = await run_in_threadpool(llama.create_chat_completion, **kwargs) if isinstance(iterator_or_completion, Iterator): # EAFP: It's easier to ask for forgiveness than permission @@ -778,12 +771,13 @@ async def create_chat_completion( send_chan, recv_chan = anyio.create_memory_object_stream(10) return EventSourceResponse( - recv_chan, data_sender_callable=partial( # type: ignore + recv_chan, + data_sender_callable=partial( # type: ignore get_event_publisher, request=request, inner_send_chan=send_chan, iterator=iterator(), - ) + ), ) else: return iterator_or_completion