diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 92ca67d..3e01d1c 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -845,7 +845,7 @@ class Llama: if len(prompt_tokens) >= llama_cpp.llama_n_ctx(self.ctx): raise ValueError( - f"Requested tokens exceed context window of {llama_cpp.llama_n_ctx(self.ctx)}" + f"Requested tokens ({len(prompt_tokens)}) exceed context window of {llama_cpp.llama_n_ctx(self.ctx)}" ) if max_tokens <= 0: diff --git a/llama_cpp/server/app.py b/llama_cpp/server/app.py index eaa6f44..62fba04 100644 --- a/llama_cpp/server/app.py +++ b/llama_cpp/server/app.py @@ -1,8 +1,9 @@ import json import multiprocessing +from re import compile, Match, Pattern from threading import Lock from functools import partial -from typing import Iterator, List, Optional, Union, Dict +from typing import Callable, Coroutine, Iterator, List, Optional, Union, Dict from typing_extensions import TypedDict, Literal import llama_cpp @@ -10,8 +11,10 @@ import llama_cpp import anyio from anyio.streams.memory import MemoryObjectSendStream from starlette.concurrency import run_in_threadpool, iterate_in_threadpool -from fastapi import Depends, FastAPI, APIRouter, Request +from fastapi import Depends, FastAPI, APIRouter, Request, Response from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse +from fastapi.routing import APIRoute from pydantic import BaseModel, Field from pydantic_settings import BaseSettings from sse_starlette.sse import EventSourceResponse @@ -92,7 +95,190 @@ class Settings(BaseSettings): ) -router = APIRouter() +class ErrorResponse(TypedDict): + """OpenAI style error response""" + + message: str + type: str + param: Optional[str] + code: Optional[str] + + +class ErrorResponseFormatters: + """Collection of formatters for error responses. + + Args: + request (Union[CreateCompletionRequest, CreateChatCompletionRequest]): + Request body + match (Match[str]): Match object from regex pattern + + Returns: + tuple[int, ErrorResponse]: Status code and error response + """ + + @staticmethod + def context_length_exceeded( + request: Union[ + "CreateCompletionRequest", "CreateChatCompletionRequest" + ], + match: Match[str], + ) -> tuple[int, ErrorResponse]: + """Formatter for context length exceeded error""" + + context_window = int(match.group(2)) + prompt_tokens = int(match.group(1)) + completion_tokens = request.max_tokens + if hasattr(request, "messages"): + # Chat completion + message = ( + "This model's maximum context length is {} tokens. " + "However, you requested {} tokens " + "({} in the messages, {} in the completion). " + "Please reduce the length of the messages or completion." + ) + else: + # Text completion + message = ( + "This model's maximum context length is {} tokens, " + "however you requested {} tokens " + "({} in your prompt; {} for the completion). " + "Please reduce your prompt; or completion length." + ) + return 400, ErrorResponse( + message=message.format( + context_window, + completion_tokens + prompt_tokens, + prompt_tokens, + completion_tokens, + ), + type="invalid_request_error", + param="messages", + code="context_length_exceeded", + ) + + @staticmethod + def model_not_found( + request: Union[ + "CreateCompletionRequest", "CreateChatCompletionRequest" + ], + match: Match[str], + ) -> tuple[int, ErrorResponse]: + """Formatter for model_not_found error""" + + model_path = str(match.group(1)) + message = f"The model `{model_path}` does not exist" + return 400, ErrorResponse( + message=message, + type="invalid_request_error", + param=None, + code="model_not_found", + ) + + +class RouteErrorHandler(APIRoute): + """Custom APIRoute that handles application errors and exceptions""" + + # key: regex pattern for original error message from llama_cpp + # value: formatter function + pattern_and_formatters: dict[ + "Pattern", + Callable[ + [ + Union["CreateCompletionRequest", "CreateChatCompletionRequest"], + Match[str], + ], + tuple[int, ErrorResponse], + ], + ] = { + compile( + r"Requested tokens \((\d+)\) exceed context window of (\d+)" + ): ErrorResponseFormatters.context_length_exceeded, + compile( + r"Model path does not exist: (.+)" + ): ErrorResponseFormatters.model_not_found, + } + + def error_message_wrapper( + self, + error: Exception, + body: Optional[ + Union[ + "CreateChatCompletionRequest", + "CreateCompletionRequest", + "CreateEmbeddingRequest", + ] + ] = None, + ) -> tuple[int, ErrorResponse]: + """Wraps error message in OpenAI style error response""" + + if body is not None and isinstance( + body, + ( + CreateCompletionRequest, + CreateChatCompletionRequest, + ), + ): + # When text completion or chat completion + for pattern, callback in self.pattern_and_formatters.items(): + match = pattern.search(str(error)) + if match is not None: + return callback(body, match) + + # Wrap other errors as internal server error + return 500, ErrorResponse( + message=str(error), + type="internal_server_error", + param=None, + code=None, + ) + + def get_route_handler( + self, + ) -> Callable[[Request], Coroutine[None, None, Response]]: + """Defines custom route handler that catches exceptions and formats + in OpenAI style error response""" + + original_route_handler = super().get_route_handler() + + async def custom_route_handler(request: Request) -> Response: + try: + return await original_route_handler(request) + except Exception as exc: + json_body = await request.json() + try: + if "messages" in json_body: + # Chat completion + body: Optional[ + Union[ + CreateChatCompletionRequest, + CreateCompletionRequest, + CreateEmbeddingRequest, + ] + ] = CreateChatCompletionRequest(**json_body) + elif "prompt" in json_body: + # Text completion + body = CreateCompletionRequest(**json_body) + else: + # Embedding + body = CreateEmbeddingRequest(**json_body) + except Exception: + # Invalid request body + body = None + + # Get proper error message from the exception + ( + status_code, + error_message, + ) = self.error_message_wrapper(error=exc, body=body) + return JSONResponse( + {"error": error_message}, + status_code=status_code, + ) + + return custom_route_handler + + +router = APIRouter(route_class=RouteErrorHandler) settings: Optional[Settings] = None llama: Optional[llama_cpp.Llama] = None @@ -179,10 +365,33 @@ def get_settings(): yield settings +async def get_event_publisher( + request: Request, + inner_send_chan: MemoryObjectSendStream, + iterator: Iterator, +): + async with inner_send_chan: + try: + async for chunk in iterate_in_threadpool(iterator): + await inner_send_chan.send(dict(data=json.dumps(chunk))) + if await request.is_disconnected(): + raise anyio.get_cancelled_exc_class()() + if settings.interrupt_requests and llama_outer_lock.locked(): + await inner_send_chan.send(dict(data="[DONE]")) + raise anyio.get_cancelled_exc_class()() + await inner_send_chan.send(dict(data="[DONE]")) + except anyio.get_cancelled_exc_class() as e: + print("disconnected") + with anyio.move_on_after(1, shield=True): + print( + f"Disconnected from client (via refresh/close) {request.client}" + ) + raise e + model_field = Field(description="The model to use for generating completions.", default=None) max_tokens_field = Field( - default=16, ge=1, le=2048, description="The maximum number of tokens to generate." + default=16, ge=1, description="The maximum number of tokens to generate." ) temperature_field = Field( @@ -370,35 +579,31 @@ async def create_completion( make_logit_bias_processor(llama, body.logit_bias, body.logit_bias_type), ]) - if body.stream: + iterator_or_completion: Union[llama_cpp.Completion, Iterator[ + llama_cpp.CompletionChunk + ]] = await run_in_threadpool(llama, **kwargs) + + if isinstance(iterator_or_completion, Iterator): + # EAFP: It's easier to ask for forgiveness than permission + first_response = await run_in_threadpool(next, iterator_or_completion) + + # If no exception was raised from first_response, we can assume that + # the iterator is valid and we can use it to stream the response. + def iterator() -> Iterator[llama_cpp.CompletionChunk]: + yield first_response + yield from iterator_or_completion + send_chan, recv_chan = anyio.create_memory_object_stream(10) - - async def event_publisher(inner_send_chan: MemoryObjectSendStream): - async with inner_send_chan: - try: - iterator: Iterator[llama_cpp.CompletionChunk] = await run_in_threadpool(llama, **kwargs) # type: ignore - async for chunk in iterate_in_threadpool(iterator): - await inner_send_chan.send(dict(data=json.dumps(chunk))) - if await request.is_disconnected(): - raise anyio.get_cancelled_exc_class()() - if settings.interrupt_requests and llama_outer_lock.locked(): - await inner_send_chan.send(dict(data="[DONE]")) - raise anyio.get_cancelled_exc_class()() - await inner_send_chan.send(dict(data="[DONE]")) - except anyio.get_cancelled_exc_class() as e: - print("disconnected") - with anyio.move_on_after(1, shield=True): - print( - f"Disconnected from client (via refresh/close) {request.client}" - ) - raise e - return EventSourceResponse( - recv_chan, data_sender_callable=partial(event_publisher, send_chan) - ) # type: ignore + recv_chan, data_sender_callable=partial( # type: ignore + get_event_publisher, + request=request, + inner_send_chan=send_chan, + iterator=iterator(), + ) + ) else: - completion: llama_cpp.Completion = await run_in_threadpool(llama, **kwargs) # type: ignore - return completion + return iterator_or_completion class CreateEmbeddingRequest(BaseModel): @@ -501,38 +706,31 @@ async def create_chat_completion( make_logit_bias_processor(llama, body.logit_bias, body.logit_bias_type), ]) - if body.stream: + iterator_or_completion: Union[llama_cpp.ChatCompletion, Iterator[ + llama_cpp.ChatCompletionChunk + ]] = await run_in_threadpool(llama.create_chat_completion, **kwargs) + + if isinstance(iterator_or_completion, Iterator): + # EAFP: It's easier to ask for forgiveness than permission + first_response = await run_in_threadpool(next, iterator_or_completion) + + # If no exception was raised from first_response, we can assume that + # the iterator is valid and we can use it to stream the response. + def iterator() -> Iterator[llama_cpp.ChatCompletionChunk]: + yield first_response + yield from iterator_or_completion + send_chan, recv_chan = anyio.create_memory_object_stream(10) - - async def event_publisher(inner_send_chan: MemoryObjectSendStream): - async with inner_send_chan: - try: - iterator: Iterator[llama_cpp.ChatCompletionChunk] = await run_in_threadpool(llama.create_chat_completion, **kwargs) # type: ignore - async for chat_chunk in iterate_in_threadpool(iterator): - await inner_send_chan.send(dict(data=json.dumps(chat_chunk))) - if await request.is_disconnected(): - raise anyio.get_cancelled_exc_class()() - if settings.interrupt_requests and llama_outer_lock.locked(): - await inner_send_chan.send(dict(data="[DONE]")) - raise anyio.get_cancelled_exc_class()() - await inner_send_chan.send(dict(data="[DONE]")) - except anyio.get_cancelled_exc_class() as e: - print("disconnected") - with anyio.move_on_after(1, shield=True): - print( - f"Disconnected from client (via refresh/close) {request.client}" - ) - raise e - return EventSourceResponse( - recv_chan, - data_sender_callable=partial(event_publisher, send_chan), - ) # type: ignore - else: - completion: llama_cpp.ChatCompletion = await run_in_threadpool( - llama.create_chat_completion, **kwargs # type: ignore + recv_chan, data_sender_callable=partial( # type: ignore + get_event_publisher, + request=request, + inner_send_chan=send_chan, + iterator=iterator(), + ) ) - return completion + else: + return iterator_or_completion class ModelData(TypedDict):