Added RouteErrorHandler
for server
This commit is contained in:
parent
6d8892fe64
commit
1551ba10bd
2 changed files with 258 additions and 60 deletions
|
@ -845,7 +845,7 @@ class Llama:
|
|||
|
||||
if len(prompt_tokens) >= llama_cpp.llama_n_ctx(self.ctx):
|
||||
raise ValueError(
|
||||
f"Requested tokens exceed context window of {llama_cpp.llama_n_ctx(self.ctx)}"
|
||||
f"Requested tokens ({len(prompt_tokens)}) exceed context window of {llama_cpp.llama_n_ctx(self.ctx)}"
|
||||
)
|
||||
|
||||
if max_tokens <= 0:
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
import json
|
||||
import multiprocessing
|
||||
from re import compile, Match, Pattern
|
||||
from threading import Lock
|
||||
from functools import partial
|
||||
from typing import Iterator, List, Optional, Union, Dict
|
||||
from typing import Callable, Coroutine, Iterator, List, Optional, Union, Dict
|
||||
from typing_extensions import TypedDict, Literal
|
||||
|
||||
import llama_cpp
|
||||
|
@ -10,8 +11,10 @@ import llama_cpp
|
|||
import anyio
|
||||
from anyio.streams.memory import MemoryObjectSendStream
|
||||
from starlette.concurrency import run_in_threadpool, iterate_in_threadpool
|
||||
from fastapi import Depends, FastAPI, APIRouter, Request
|
||||
from fastapi import Depends, FastAPI, APIRouter, Request, Response
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.routing import APIRoute
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic_settings import BaseSettings
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
|
@ -92,7 +95,190 @@ class Settings(BaseSettings):
|
|||
)
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
class ErrorResponse(TypedDict):
|
||||
"""OpenAI style error response"""
|
||||
|
||||
message: str
|
||||
type: str
|
||||
param: Optional[str]
|
||||
code: Optional[str]
|
||||
|
||||
|
||||
class ErrorResponseFormatters:
|
||||
"""Collection of formatters for error responses.
|
||||
|
||||
Args:
|
||||
request (Union[CreateCompletionRequest, CreateChatCompletionRequest]):
|
||||
Request body
|
||||
match (Match[str]): Match object from regex pattern
|
||||
|
||||
Returns:
|
||||
tuple[int, ErrorResponse]: Status code and error response
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def context_length_exceeded(
|
||||
request: Union[
|
||||
"CreateCompletionRequest", "CreateChatCompletionRequest"
|
||||
],
|
||||
match: Match[str],
|
||||
) -> tuple[int, ErrorResponse]:
|
||||
"""Formatter for context length exceeded error"""
|
||||
|
||||
context_window = int(match.group(2))
|
||||
prompt_tokens = int(match.group(1))
|
||||
completion_tokens = request.max_tokens
|
||||
if hasattr(request, "messages"):
|
||||
# Chat completion
|
||||
message = (
|
||||
"This model's maximum context length is {} tokens. "
|
||||
"However, you requested {} tokens "
|
||||
"({} in the messages, {} in the completion). "
|
||||
"Please reduce the length of the messages or completion."
|
||||
)
|
||||
else:
|
||||
# Text completion
|
||||
message = (
|
||||
"This model's maximum context length is {} tokens, "
|
||||
"however you requested {} tokens "
|
||||
"({} in your prompt; {} for the completion). "
|
||||
"Please reduce your prompt; or completion length."
|
||||
)
|
||||
return 400, ErrorResponse(
|
||||
message=message.format(
|
||||
context_window,
|
||||
completion_tokens + prompt_tokens,
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
),
|
||||
type="invalid_request_error",
|
||||
param="messages",
|
||||
code="context_length_exceeded",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def model_not_found(
|
||||
request: Union[
|
||||
"CreateCompletionRequest", "CreateChatCompletionRequest"
|
||||
],
|
||||
match: Match[str],
|
||||
) -> tuple[int, ErrorResponse]:
|
||||
"""Formatter for model_not_found error"""
|
||||
|
||||
model_path = str(match.group(1))
|
||||
message = f"The model `{model_path}` does not exist"
|
||||
return 400, ErrorResponse(
|
||||
message=message,
|
||||
type="invalid_request_error",
|
||||
param=None,
|
||||
code="model_not_found",
|
||||
)
|
||||
|
||||
|
||||
class RouteErrorHandler(APIRoute):
|
||||
"""Custom APIRoute that handles application errors and exceptions"""
|
||||
|
||||
# key: regex pattern for original error message from llama_cpp
|
||||
# value: formatter function
|
||||
pattern_and_formatters: dict[
|
||||
"Pattern",
|
||||
Callable[
|
||||
[
|
||||
Union["CreateCompletionRequest", "CreateChatCompletionRequest"],
|
||||
Match[str],
|
||||
],
|
||||
tuple[int, ErrorResponse],
|
||||
],
|
||||
] = {
|
||||
compile(
|
||||
r"Requested tokens \((\d+)\) exceed context window of (\d+)"
|
||||
): ErrorResponseFormatters.context_length_exceeded,
|
||||
compile(
|
||||
r"Model path does not exist: (.+)"
|
||||
): ErrorResponseFormatters.model_not_found,
|
||||
}
|
||||
|
||||
def error_message_wrapper(
|
||||
self,
|
||||
error: Exception,
|
||||
body: Optional[
|
||||
Union[
|
||||
"CreateChatCompletionRequest",
|
||||
"CreateCompletionRequest",
|
||||
"CreateEmbeddingRequest",
|
||||
]
|
||||
] = None,
|
||||
) -> tuple[int, ErrorResponse]:
|
||||
"""Wraps error message in OpenAI style error response"""
|
||||
|
||||
if body is not None and isinstance(
|
||||
body,
|
||||
(
|
||||
CreateCompletionRequest,
|
||||
CreateChatCompletionRequest,
|
||||
),
|
||||
):
|
||||
# When text completion or chat completion
|
||||
for pattern, callback in self.pattern_and_formatters.items():
|
||||
match = pattern.search(str(error))
|
||||
if match is not None:
|
||||
return callback(body, match)
|
||||
|
||||
# Wrap other errors as internal server error
|
||||
return 500, ErrorResponse(
|
||||
message=str(error),
|
||||
type="internal_server_error",
|
||||
param=None,
|
||||
code=None,
|
||||
)
|
||||
|
||||
def get_route_handler(
|
||||
self,
|
||||
) -> Callable[[Request], Coroutine[None, None, Response]]:
|
||||
"""Defines custom route handler that catches exceptions and formats
|
||||
in OpenAI style error response"""
|
||||
|
||||
original_route_handler = super().get_route_handler()
|
||||
|
||||
async def custom_route_handler(request: Request) -> Response:
|
||||
try:
|
||||
return await original_route_handler(request)
|
||||
except Exception as exc:
|
||||
json_body = await request.json()
|
||||
try:
|
||||
if "messages" in json_body:
|
||||
# Chat completion
|
||||
body: Optional[
|
||||
Union[
|
||||
CreateChatCompletionRequest,
|
||||
CreateCompletionRequest,
|
||||
CreateEmbeddingRequest,
|
||||
]
|
||||
] = CreateChatCompletionRequest(**json_body)
|
||||
elif "prompt" in json_body:
|
||||
# Text completion
|
||||
body = CreateCompletionRequest(**json_body)
|
||||
else:
|
||||
# Embedding
|
||||
body = CreateEmbeddingRequest(**json_body)
|
||||
except Exception:
|
||||
# Invalid request body
|
||||
body = None
|
||||
|
||||
# Get proper error message from the exception
|
||||
(
|
||||
status_code,
|
||||
error_message,
|
||||
) = self.error_message_wrapper(error=exc, body=body)
|
||||
return JSONResponse(
|
||||
{"error": error_message},
|
||||
status_code=status_code,
|
||||
)
|
||||
|
||||
return custom_route_handler
|
||||
|
||||
|
||||
router = APIRouter(route_class=RouteErrorHandler)
|
||||
|
||||
settings: Optional[Settings] = None
|
||||
llama: Optional[llama_cpp.Llama] = None
|
||||
|
@ -179,10 +365,33 @@ def get_settings():
|
|||
yield settings
|
||||
|
||||
|
||||
async def get_event_publisher(
|
||||
request: Request,
|
||||
inner_send_chan: MemoryObjectSendStream,
|
||||
iterator: Iterator,
|
||||
):
|
||||
async with inner_send_chan:
|
||||
try:
|
||||
async for chunk in iterate_in_threadpool(iterator):
|
||||
await inner_send_chan.send(dict(data=json.dumps(chunk)))
|
||||
if await request.is_disconnected():
|
||||
raise anyio.get_cancelled_exc_class()()
|
||||
if settings.interrupt_requests and llama_outer_lock.locked():
|
||||
await inner_send_chan.send(dict(data="[DONE]"))
|
||||
raise anyio.get_cancelled_exc_class()()
|
||||
await inner_send_chan.send(dict(data="[DONE]"))
|
||||
except anyio.get_cancelled_exc_class() as e:
|
||||
print("disconnected")
|
||||
with anyio.move_on_after(1, shield=True):
|
||||
print(
|
||||
f"Disconnected from client (via refresh/close) {request.client}"
|
||||
)
|
||||
raise e
|
||||
|
||||
model_field = Field(description="The model to use for generating completions.", default=None)
|
||||
|
||||
max_tokens_field = Field(
|
||||
default=16, ge=1, le=2048, description="The maximum number of tokens to generate."
|
||||
default=16, ge=1, description="The maximum number of tokens to generate."
|
||||
)
|
||||
|
||||
temperature_field = Field(
|
||||
|
@ -370,35 +579,31 @@ async def create_completion(
|
|||
make_logit_bias_processor(llama, body.logit_bias, body.logit_bias_type),
|
||||
])
|
||||
|
||||
if body.stream:
|
||||
iterator_or_completion: Union[llama_cpp.Completion, Iterator[
|
||||
llama_cpp.CompletionChunk
|
||||
]] = await run_in_threadpool(llama, **kwargs)
|
||||
|
||||
if isinstance(iterator_or_completion, Iterator):
|
||||
# EAFP: It's easier to ask for forgiveness than permission
|
||||
first_response = await run_in_threadpool(next, iterator_or_completion)
|
||||
|
||||
# If no exception was raised from first_response, we can assume that
|
||||
# the iterator is valid and we can use it to stream the response.
|
||||
def iterator() -> Iterator[llama_cpp.CompletionChunk]:
|
||||
yield first_response
|
||||
yield from iterator_or_completion
|
||||
|
||||
send_chan, recv_chan = anyio.create_memory_object_stream(10)
|
||||
|
||||
async def event_publisher(inner_send_chan: MemoryObjectSendStream):
|
||||
async with inner_send_chan:
|
||||
try:
|
||||
iterator: Iterator[llama_cpp.CompletionChunk] = await run_in_threadpool(llama, **kwargs) # type: ignore
|
||||
async for chunk in iterate_in_threadpool(iterator):
|
||||
await inner_send_chan.send(dict(data=json.dumps(chunk)))
|
||||
if await request.is_disconnected():
|
||||
raise anyio.get_cancelled_exc_class()()
|
||||
if settings.interrupt_requests and llama_outer_lock.locked():
|
||||
await inner_send_chan.send(dict(data="[DONE]"))
|
||||
raise anyio.get_cancelled_exc_class()()
|
||||
await inner_send_chan.send(dict(data="[DONE]"))
|
||||
except anyio.get_cancelled_exc_class() as e:
|
||||
print("disconnected")
|
||||
with anyio.move_on_after(1, shield=True):
|
||||
print(
|
||||
f"Disconnected from client (via refresh/close) {request.client}"
|
||||
)
|
||||
raise e
|
||||
|
||||
return EventSourceResponse(
|
||||
recv_chan, data_sender_callable=partial(event_publisher, send_chan)
|
||||
) # type: ignore
|
||||
recv_chan, data_sender_callable=partial( # type: ignore
|
||||
get_event_publisher,
|
||||
request=request,
|
||||
inner_send_chan=send_chan,
|
||||
iterator=iterator(),
|
||||
)
|
||||
)
|
||||
else:
|
||||
completion: llama_cpp.Completion = await run_in_threadpool(llama, **kwargs) # type: ignore
|
||||
return completion
|
||||
return iterator_or_completion
|
||||
|
||||
|
||||
class CreateEmbeddingRequest(BaseModel):
|
||||
|
@ -501,38 +706,31 @@ async def create_chat_completion(
|
|||
make_logit_bias_processor(llama, body.logit_bias, body.logit_bias_type),
|
||||
])
|
||||
|
||||
if body.stream:
|
||||
iterator_or_completion: Union[llama_cpp.ChatCompletion, Iterator[
|
||||
llama_cpp.ChatCompletionChunk
|
||||
]] = await run_in_threadpool(llama.create_chat_completion, **kwargs)
|
||||
|
||||
if isinstance(iterator_or_completion, Iterator):
|
||||
# EAFP: It's easier to ask for forgiveness than permission
|
||||
first_response = await run_in_threadpool(next, iterator_or_completion)
|
||||
|
||||
# If no exception was raised from first_response, we can assume that
|
||||
# the iterator is valid and we can use it to stream the response.
|
||||
def iterator() -> Iterator[llama_cpp.ChatCompletionChunk]:
|
||||
yield first_response
|
||||
yield from iterator_or_completion
|
||||
|
||||
send_chan, recv_chan = anyio.create_memory_object_stream(10)
|
||||
|
||||
async def event_publisher(inner_send_chan: MemoryObjectSendStream):
|
||||
async with inner_send_chan:
|
||||
try:
|
||||
iterator: Iterator[llama_cpp.ChatCompletionChunk] = await run_in_threadpool(llama.create_chat_completion, **kwargs) # type: ignore
|
||||
async for chat_chunk in iterate_in_threadpool(iterator):
|
||||
await inner_send_chan.send(dict(data=json.dumps(chat_chunk)))
|
||||
if await request.is_disconnected():
|
||||
raise anyio.get_cancelled_exc_class()()
|
||||
if settings.interrupt_requests and llama_outer_lock.locked():
|
||||
await inner_send_chan.send(dict(data="[DONE]"))
|
||||
raise anyio.get_cancelled_exc_class()()
|
||||
await inner_send_chan.send(dict(data="[DONE]"))
|
||||
except anyio.get_cancelled_exc_class() as e:
|
||||
print("disconnected")
|
||||
with anyio.move_on_after(1, shield=True):
|
||||
print(
|
||||
f"Disconnected from client (via refresh/close) {request.client}"
|
||||
)
|
||||
raise e
|
||||
|
||||
return EventSourceResponse(
|
||||
recv_chan,
|
||||
data_sender_callable=partial(event_publisher, send_chan),
|
||||
) # type: ignore
|
||||
else:
|
||||
completion: llama_cpp.ChatCompletion = await run_in_threadpool(
|
||||
llama.create_chat_completion, **kwargs # type: ignore
|
||||
recv_chan, data_sender_callable=partial( # type: ignore
|
||||
get_event_publisher,
|
||||
request=request,
|
||||
inner_send_chan=send_chan,
|
||||
iterator=iterator(),
|
||||
)
|
||||
)
|
||||
return completion
|
||||
else:
|
||||
return iterator_or_completion
|
||||
|
||||
|
||||
class ModelData(TypedDict):
|
||||
|
|
Loading…
Reference in a new issue