Add setting to control request interruption

This commit is contained in:
Andrei Betlen 2023-07-07 03:37:23 -04:00
parent cc542b4452
commit 57d8ec3899

View file

@ -85,6 +85,10 @@ class Settings(BaseSettings):
port: int = Field( port: int = Field(
default=8000, description="Listen port" default=8000, description="Listen port"
) )
interrupt_requests: bool = Field(
default=True,
description="Whether to interrupt requests when a new request is received.",
)
router = APIRouter() router = APIRouter()
@ -379,7 +383,7 @@ async def create_completion(
await inner_send_chan.send(dict(data=json.dumps(chunk))) await inner_send_chan.send(dict(data=json.dumps(chunk)))
if await request.is_disconnected(): if await request.is_disconnected():
raise anyio.get_cancelled_exc_class()() raise anyio.get_cancelled_exc_class()()
if llama_outer_lock.locked(): if settings.interrupt_requests and llama_outer_lock.locked():
await inner_send_chan.send(dict(data="[DONE]")) await inner_send_chan.send(dict(data="[DONE]"))
raise anyio.get_cancelled_exc_class()() raise anyio.get_cancelled_exc_class()()
await inner_send_chan.send(dict(data="[DONE]")) await inner_send_chan.send(dict(data="[DONE]"))
@ -486,6 +490,7 @@ async def create_chat_completion(
request: Request, request: Request,
body: CreateChatCompletionRequest, body: CreateChatCompletionRequest,
llama: llama_cpp.Llama = Depends(get_llama), llama: llama_cpp.Llama = Depends(get_llama),
settings: Settings = Depends(get_settings),
) -> Union[llama_cpp.ChatCompletion, EventSourceResponse]: ) -> Union[llama_cpp.ChatCompletion, EventSourceResponse]:
exclude = { exclude = {
"n", "n",
@ -511,7 +516,7 @@ async def create_chat_completion(
await inner_send_chan.send(dict(data=json.dumps(chat_chunk))) await inner_send_chan.send(dict(data=json.dumps(chat_chunk)))
if await request.is_disconnected(): if await request.is_disconnected():
raise anyio.get_cancelled_exc_class()() raise anyio.get_cancelled_exc_class()()
if llama_outer_lock.locked(): if settings.interrupt_requests and llama_outer_lock.locked():
await inner_send_chan.send(dict(data="[DONE]")) await inner_send_chan.send(dict(data="[DONE]"))
raise anyio.get_cancelled_exc_class()() raise anyio.get_cancelled_exc_class()()
await inner_send_chan.send(dict(data="[DONE]")) await inner_send_chan.send(dict(data="[DONE]"))