Merge pull request #481 from c0sogi/main
Added `RouteErrorHandler` for server
This commit is contained in:
commit
365d9a4367
2 changed files with 258 additions and 60 deletions
|
@ -845,7 +845,7 @@ class Llama:
|
||||||
|
|
||||||
if len(prompt_tokens) >= llama_cpp.llama_n_ctx(self.ctx):
|
if len(prompt_tokens) >= llama_cpp.llama_n_ctx(self.ctx):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Requested tokens exceed context window of {llama_cpp.llama_n_ctx(self.ctx)}"
|
f"Requested tokens ({len(prompt_tokens)}) exceed context window of {llama_cpp.llama_n_ctx(self.ctx)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if max_tokens <= 0:
|
if max_tokens <= 0:
|
||||||
|
|
|
@ -1,8 +1,9 @@
|
||||||
import json
|
import json
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
|
from re import compile, Match, Pattern
|
||||||
from threading import Lock
|
from threading import Lock
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Iterator, List, Optional, Union, Dict
|
from typing import Callable, Coroutine, Iterator, List, Optional, Union, Dict
|
||||||
from typing_extensions import TypedDict, Literal
|
from typing_extensions import TypedDict, Literal
|
||||||
|
|
||||||
import llama_cpp
|
import llama_cpp
|
||||||
|
@ -10,8 +11,10 @@ import llama_cpp
|
||||||
import anyio
|
import anyio
|
||||||
from anyio.streams.memory import MemoryObjectSendStream
|
from anyio.streams.memory import MemoryObjectSendStream
|
||||||
from starlette.concurrency import run_in_threadpool, iterate_in_threadpool
|
from starlette.concurrency import run_in_threadpool, iterate_in_threadpool
|
||||||
from fastapi import Depends, FastAPI, APIRouter, Request
|
from fastapi import Depends, FastAPI, APIRouter, Request, Response
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
from fastapi.routing import APIRoute
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from pydantic_settings import BaseSettings
|
from pydantic_settings import BaseSettings
|
||||||
from sse_starlette.sse import EventSourceResponse
|
from sse_starlette.sse import EventSourceResponse
|
||||||
|
@ -94,7 +97,190 @@ class Settings(BaseSettings):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
router = APIRouter()
|
class ErrorResponse(TypedDict):
|
||||||
|
"""OpenAI style error response"""
|
||||||
|
|
||||||
|
message: str
|
||||||
|
type: str
|
||||||
|
param: Optional[str]
|
||||||
|
code: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
|
class ErrorResponseFormatters:
|
||||||
|
"""Collection of formatters for error responses.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request (Union[CreateCompletionRequest, CreateChatCompletionRequest]):
|
||||||
|
Request body
|
||||||
|
match (Match[str]): Match object from regex pattern
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[int, ErrorResponse]: Status code and error response
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def context_length_exceeded(
|
||||||
|
request: Union[
|
||||||
|
"CreateCompletionRequest", "CreateChatCompletionRequest"
|
||||||
|
],
|
||||||
|
match: Match[str],
|
||||||
|
) -> tuple[int, ErrorResponse]:
|
||||||
|
"""Formatter for context length exceeded error"""
|
||||||
|
|
||||||
|
context_window = int(match.group(2))
|
||||||
|
prompt_tokens = int(match.group(1))
|
||||||
|
completion_tokens = request.max_tokens
|
||||||
|
if hasattr(request, "messages"):
|
||||||
|
# Chat completion
|
||||||
|
message = (
|
||||||
|
"This model's maximum context length is {} tokens. "
|
||||||
|
"However, you requested {} tokens "
|
||||||
|
"({} in the messages, {} in the completion). "
|
||||||
|
"Please reduce the length of the messages or completion."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Text completion
|
||||||
|
message = (
|
||||||
|
"This model's maximum context length is {} tokens, "
|
||||||
|
"however you requested {} tokens "
|
||||||
|
"({} in your prompt; {} for the completion). "
|
||||||
|
"Please reduce your prompt; or completion length."
|
||||||
|
)
|
||||||
|
return 400, ErrorResponse(
|
||||||
|
message=message.format(
|
||||||
|
context_window,
|
||||||
|
completion_tokens + prompt_tokens,
|
||||||
|
prompt_tokens,
|
||||||
|
completion_tokens,
|
||||||
|
),
|
||||||
|
type="invalid_request_error",
|
||||||
|
param="messages",
|
||||||
|
code="context_length_exceeded",
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def model_not_found(
|
||||||
|
request: Union[
|
||||||
|
"CreateCompletionRequest", "CreateChatCompletionRequest"
|
||||||
|
],
|
||||||
|
match: Match[str],
|
||||||
|
) -> tuple[int, ErrorResponse]:
|
||||||
|
"""Formatter for model_not_found error"""
|
||||||
|
|
||||||
|
model_path = str(match.group(1))
|
||||||
|
message = f"The model `{model_path}` does not exist"
|
||||||
|
return 400, ErrorResponse(
|
||||||
|
message=message,
|
||||||
|
type="invalid_request_error",
|
||||||
|
param=None,
|
||||||
|
code="model_not_found",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RouteErrorHandler(APIRoute):
|
||||||
|
"""Custom APIRoute that handles application errors and exceptions"""
|
||||||
|
|
||||||
|
# key: regex pattern for original error message from llama_cpp
|
||||||
|
# value: formatter function
|
||||||
|
pattern_and_formatters: dict[
|
||||||
|
"Pattern",
|
||||||
|
Callable[
|
||||||
|
[
|
||||||
|
Union["CreateCompletionRequest", "CreateChatCompletionRequest"],
|
||||||
|
Match[str],
|
||||||
|
],
|
||||||
|
tuple[int, ErrorResponse],
|
||||||
|
],
|
||||||
|
] = {
|
||||||
|
compile(
|
||||||
|
r"Requested tokens \((\d+)\) exceed context window of (\d+)"
|
||||||
|
): ErrorResponseFormatters.context_length_exceeded,
|
||||||
|
compile(
|
||||||
|
r"Model path does not exist: (.+)"
|
||||||
|
): ErrorResponseFormatters.model_not_found,
|
||||||
|
}
|
||||||
|
|
||||||
|
def error_message_wrapper(
|
||||||
|
self,
|
||||||
|
error: Exception,
|
||||||
|
body: Optional[
|
||||||
|
Union[
|
||||||
|
"CreateChatCompletionRequest",
|
||||||
|
"CreateCompletionRequest",
|
||||||
|
"CreateEmbeddingRequest",
|
||||||
|
]
|
||||||
|
] = None,
|
||||||
|
) -> tuple[int, ErrorResponse]:
|
||||||
|
"""Wraps error message in OpenAI style error response"""
|
||||||
|
|
||||||
|
if body is not None and isinstance(
|
||||||
|
body,
|
||||||
|
(
|
||||||
|
CreateCompletionRequest,
|
||||||
|
CreateChatCompletionRequest,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
# When text completion or chat completion
|
||||||
|
for pattern, callback in self.pattern_and_formatters.items():
|
||||||
|
match = pattern.search(str(error))
|
||||||
|
if match is not None:
|
||||||
|
return callback(body, match)
|
||||||
|
|
||||||
|
# Wrap other errors as internal server error
|
||||||
|
return 500, ErrorResponse(
|
||||||
|
message=str(error),
|
||||||
|
type="internal_server_error",
|
||||||
|
param=None,
|
||||||
|
code=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_route_handler(
|
||||||
|
self,
|
||||||
|
) -> Callable[[Request], Coroutine[None, None, Response]]:
|
||||||
|
"""Defines custom route handler that catches exceptions and formats
|
||||||
|
in OpenAI style error response"""
|
||||||
|
|
||||||
|
original_route_handler = super().get_route_handler()
|
||||||
|
|
||||||
|
async def custom_route_handler(request: Request) -> Response:
|
||||||
|
try:
|
||||||
|
return await original_route_handler(request)
|
||||||
|
except Exception as exc:
|
||||||
|
json_body = await request.json()
|
||||||
|
try:
|
||||||
|
if "messages" in json_body:
|
||||||
|
# Chat completion
|
||||||
|
body: Optional[
|
||||||
|
Union[
|
||||||
|
CreateChatCompletionRequest,
|
||||||
|
CreateCompletionRequest,
|
||||||
|
CreateEmbeddingRequest,
|
||||||
|
]
|
||||||
|
] = CreateChatCompletionRequest(**json_body)
|
||||||
|
elif "prompt" in json_body:
|
||||||
|
# Text completion
|
||||||
|
body = CreateCompletionRequest(**json_body)
|
||||||
|
else:
|
||||||
|
# Embedding
|
||||||
|
body = CreateEmbeddingRequest(**json_body)
|
||||||
|
except Exception:
|
||||||
|
# Invalid request body
|
||||||
|
body = None
|
||||||
|
|
||||||
|
# Get proper error message from the exception
|
||||||
|
(
|
||||||
|
status_code,
|
||||||
|
error_message,
|
||||||
|
) = self.error_message_wrapper(error=exc, body=body)
|
||||||
|
return JSONResponse(
|
||||||
|
{"error": error_message},
|
||||||
|
status_code=status_code,
|
||||||
|
)
|
||||||
|
|
||||||
|
return custom_route_handler
|
||||||
|
|
||||||
|
|
||||||
|
router = APIRouter(route_class=RouteErrorHandler)
|
||||||
|
|
||||||
settings: Optional[Settings] = None
|
settings: Optional[Settings] = None
|
||||||
llama: Optional[llama_cpp.Llama] = None
|
llama: Optional[llama_cpp.Llama] = None
|
||||||
|
@ -183,10 +369,33 @@ def get_settings():
|
||||||
yield settings
|
yield settings
|
||||||
|
|
||||||
|
|
||||||
|
async def get_event_publisher(
|
||||||
|
request: Request,
|
||||||
|
inner_send_chan: MemoryObjectSendStream,
|
||||||
|
iterator: Iterator,
|
||||||
|
):
|
||||||
|
async with inner_send_chan:
|
||||||
|
try:
|
||||||
|
async for chunk in iterate_in_threadpool(iterator):
|
||||||
|
await inner_send_chan.send(dict(data=json.dumps(chunk)))
|
||||||
|
if await request.is_disconnected():
|
||||||
|
raise anyio.get_cancelled_exc_class()()
|
||||||
|
if settings.interrupt_requests and llama_outer_lock.locked():
|
||||||
|
await inner_send_chan.send(dict(data="[DONE]"))
|
||||||
|
raise anyio.get_cancelled_exc_class()()
|
||||||
|
await inner_send_chan.send(dict(data="[DONE]"))
|
||||||
|
except anyio.get_cancelled_exc_class() as e:
|
||||||
|
print("disconnected")
|
||||||
|
with anyio.move_on_after(1, shield=True):
|
||||||
|
print(
|
||||||
|
f"Disconnected from client (via refresh/close) {request.client}"
|
||||||
|
)
|
||||||
|
raise e
|
||||||
|
|
||||||
model_field = Field(description="The model to use for generating completions.", default=None)
|
model_field = Field(description="The model to use for generating completions.", default=None)
|
||||||
|
|
||||||
max_tokens_field = Field(
|
max_tokens_field = Field(
|
||||||
default=16, ge=1, le=2048, description="The maximum number of tokens to generate."
|
default=16, ge=1, description="The maximum number of tokens to generate."
|
||||||
)
|
)
|
||||||
|
|
||||||
temperature_field = Field(
|
temperature_field = Field(
|
||||||
|
@ -374,35 +583,31 @@ async def create_completion(
|
||||||
make_logit_bias_processor(llama, body.logit_bias, body.logit_bias_type),
|
make_logit_bias_processor(llama, body.logit_bias, body.logit_bias_type),
|
||||||
])
|
])
|
||||||
|
|
||||||
if body.stream:
|
iterator_or_completion: Union[llama_cpp.Completion, Iterator[
|
||||||
|
llama_cpp.CompletionChunk
|
||||||
|
]] = await run_in_threadpool(llama, **kwargs)
|
||||||
|
|
||||||
|
if isinstance(iterator_or_completion, Iterator):
|
||||||
|
# EAFP: It's easier to ask for forgiveness than permission
|
||||||
|
first_response = await run_in_threadpool(next, iterator_or_completion)
|
||||||
|
|
||||||
|
# If no exception was raised from first_response, we can assume that
|
||||||
|
# the iterator is valid and we can use it to stream the response.
|
||||||
|
def iterator() -> Iterator[llama_cpp.CompletionChunk]:
|
||||||
|
yield first_response
|
||||||
|
yield from iterator_or_completion
|
||||||
|
|
||||||
send_chan, recv_chan = anyio.create_memory_object_stream(10)
|
send_chan, recv_chan = anyio.create_memory_object_stream(10)
|
||||||
|
|
||||||
async def event_publisher(inner_send_chan: MemoryObjectSendStream):
|
|
||||||
async with inner_send_chan:
|
|
||||||
try:
|
|
||||||
iterator: Iterator[llama_cpp.CompletionChunk] = await run_in_threadpool(llama, **kwargs) # type: ignore
|
|
||||||
async for chunk in iterate_in_threadpool(iterator):
|
|
||||||
await inner_send_chan.send(dict(data=json.dumps(chunk)))
|
|
||||||
if await request.is_disconnected():
|
|
||||||
raise anyio.get_cancelled_exc_class()()
|
|
||||||
if settings.interrupt_requests and llama_outer_lock.locked():
|
|
||||||
await inner_send_chan.send(dict(data="[DONE]"))
|
|
||||||
raise anyio.get_cancelled_exc_class()()
|
|
||||||
await inner_send_chan.send(dict(data="[DONE]"))
|
|
||||||
except anyio.get_cancelled_exc_class() as e:
|
|
||||||
print("disconnected")
|
|
||||||
with anyio.move_on_after(1, shield=True):
|
|
||||||
print(
|
|
||||||
f"Disconnected from client (via refresh/close) {request.client}"
|
|
||||||
)
|
|
||||||
raise e
|
|
||||||
|
|
||||||
return EventSourceResponse(
|
return EventSourceResponse(
|
||||||
recv_chan, data_sender_callable=partial(event_publisher, send_chan)
|
recv_chan, data_sender_callable=partial( # type: ignore
|
||||||
) # type: ignore
|
get_event_publisher,
|
||||||
|
request=request,
|
||||||
|
inner_send_chan=send_chan,
|
||||||
|
iterator=iterator(),
|
||||||
|
)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
completion: llama_cpp.Completion = await run_in_threadpool(llama, **kwargs) # type: ignore
|
return iterator_or_completion
|
||||||
return completion
|
|
||||||
|
|
||||||
|
|
||||||
class CreateEmbeddingRequest(BaseModel):
|
class CreateEmbeddingRequest(BaseModel):
|
||||||
|
@ -505,38 +710,31 @@ async def create_chat_completion(
|
||||||
make_logit_bias_processor(llama, body.logit_bias, body.logit_bias_type),
|
make_logit_bias_processor(llama, body.logit_bias, body.logit_bias_type),
|
||||||
])
|
])
|
||||||
|
|
||||||
if body.stream:
|
iterator_or_completion: Union[llama_cpp.ChatCompletion, Iterator[
|
||||||
|
llama_cpp.ChatCompletionChunk
|
||||||
|
]] = await run_in_threadpool(llama.create_chat_completion, **kwargs)
|
||||||
|
|
||||||
|
if isinstance(iterator_or_completion, Iterator):
|
||||||
|
# EAFP: It's easier to ask for forgiveness than permission
|
||||||
|
first_response = await run_in_threadpool(next, iterator_or_completion)
|
||||||
|
|
||||||
|
# If no exception was raised from first_response, we can assume that
|
||||||
|
# the iterator is valid and we can use it to stream the response.
|
||||||
|
def iterator() -> Iterator[llama_cpp.ChatCompletionChunk]:
|
||||||
|
yield first_response
|
||||||
|
yield from iterator_or_completion
|
||||||
|
|
||||||
send_chan, recv_chan = anyio.create_memory_object_stream(10)
|
send_chan, recv_chan = anyio.create_memory_object_stream(10)
|
||||||
|
|
||||||
async def event_publisher(inner_send_chan: MemoryObjectSendStream):
|
|
||||||
async with inner_send_chan:
|
|
||||||
try:
|
|
||||||
iterator: Iterator[llama_cpp.ChatCompletionChunk] = await run_in_threadpool(llama.create_chat_completion, **kwargs) # type: ignore
|
|
||||||
async for chat_chunk in iterate_in_threadpool(iterator):
|
|
||||||
await inner_send_chan.send(dict(data=json.dumps(chat_chunk)))
|
|
||||||
if await request.is_disconnected():
|
|
||||||
raise anyio.get_cancelled_exc_class()()
|
|
||||||
if settings.interrupt_requests and llama_outer_lock.locked():
|
|
||||||
await inner_send_chan.send(dict(data="[DONE]"))
|
|
||||||
raise anyio.get_cancelled_exc_class()()
|
|
||||||
await inner_send_chan.send(dict(data="[DONE]"))
|
|
||||||
except anyio.get_cancelled_exc_class() as e:
|
|
||||||
print("disconnected")
|
|
||||||
with anyio.move_on_after(1, shield=True):
|
|
||||||
print(
|
|
||||||
f"Disconnected from client (via refresh/close) {request.client}"
|
|
||||||
)
|
|
||||||
raise e
|
|
||||||
|
|
||||||
return EventSourceResponse(
|
return EventSourceResponse(
|
||||||
recv_chan,
|
recv_chan, data_sender_callable=partial( # type: ignore
|
||||||
data_sender_callable=partial(event_publisher, send_chan),
|
get_event_publisher,
|
||||||
) # type: ignore
|
request=request,
|
||||||
else:
|
inner_send_chan=send_chan,
|
||||||
completion: llama_cpp.ChatCompletion = await run_in_threadpool(
|
iterator=iterator(),
|
||||||
llama.create_chat_completion, **kwargs # type: ignore
|
)
|
||||||
)
|
)
|
||||||
return completion
|
else:
|
||||||
|
return iterator_or_completion
|
||||||
|
|
||||||
|
|
||||||
class ModelData(TypedDict):
|
class ModelData(TypedDict):
|
||||||
|
|
Loading…
Reference in a new issue