Add interruptible streaming requests for llama-cpp-python server. Closes #183

This commit is contained in:
Andrei Betlen 2023-07-07 03:04:17 -04:00
parent 98ae4e58a3
commit 4c7cdcca00
2 changed files with 29 additions and 6 deletions

View file

@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased] ## [Unreleased]
## [Added]
- (server) Streaming requests can are now interrupted pre-maturely when a concurrent request is made. Can be controlled with the `interrupt_requests` setting.
## [0.1.68] ## [0.1.68]
## [Added] ## [Added]

View file

@ -146,12 +146,27 @@ def create_app(settings: Optional[Settings] = None):
return app return app
llama_lock = Lock() llama_outer_lock = Lock()
llama_inner_lock = Lock()
def get_llama(): def get_llama():
with llama_lock: # NOTE: This double lock allows the currently streaming llama model to
yield llama # check if any other requests are pending in the same thread and cancel
# the stream if so.
llama_outer_lock.acquire()
release_outer_lock = True
try:
llama_inner_lock.acquire()
try:
llama_outer_lock.release()
release_outer_lock = False
yield llama
finally:
llama_inner_lock.release()
finally:
if release_outer_lock:
llama_outer_lock.release()
def get_settings(): def get_settings():
@ -364,6 +379,9 @@ async def create_completion(
await inner_send_chan.send(dict(data=json.dumps(chunk))) await inner_send_chan.send(dict(data=json.dumps(chunk)))
if await request.is_disconnected(): if await request.is_disconnected():
raise anyio.get_cancelled_exc_class()() raise anyio.get_cancelled_exc_class()()
if llama_outer_lock.locked():
await inner_send_chan.send(dict(data="[DONE]"))
raise anyio.get_cancelled_exc_class()()
await inner_send_chan.send(dict(data="[DONE]")) await inner_send_chan.send(dict(data="[DONE]"))
except anyio.get_cancelled_exc_class() as e: except anyio.get_cancelled_exc_class() as e:
print("disconnected") print("disconnected")
@ -371,7 +389,6 @@ async def create_completion(
print( print(
f"Disconnected from client (via refresh/close) {request.client}" f"Disconnected from client (via refresh/close) {request.client}"
) )
await inner_send_chan.send(dict(closing=True))
raise e raise e
return EventSourceResponse( return EventSourceResponse(
@ -494,6 +511,9 @@ async def create_chat_completion(
await inner_send_chan.send(dict(data=json.dumps(chat_chunk))) await inner_send_chan.send(dict(data=json.dumps(chat_chunk)))
if await request.is_disconnected(): if await request.is_disconnected():
raise anyio.get_cancelled_exc_class()() raise anyio.get_cancelled_exc_class()()
if llama_outer_lock.locked():
await inner_send_chan.send(dict(data="[DONE]"))
raise anyio.get_cancelled_exc_class()()
await inner_send_chan.send(dict(data="[DONE]")) await inner_send_chan.send(dict(data="[DONE]"))
except anyio.get_cancelled_exc_class() as e: except anyio.get_cancelled_exc_class() as e:
print("disconnected") print("disconnected")
@ -501,7 +521,6 @@ async def create_chat_completion(
print( print(
f"Disconnected from client (via refresh/close) {request.client}" f"Disconnected from client (via refresh/close) {request.client}"
) )
await inner_send_chan.send(dict(closing=True))
raise e raise e
return EventSourceResponse( return EventSourceResponse(
@ -533,8 +552,8 @@ GetModelResponse = create_model_from_typeddict(ModelList)
@router.get("/v1/models", response_model=GetModelResponse) @router.get("/v1/models", response_model=GetModelResponse)
async def get_models( async def get_models(
settings: Settings = Depends(get_settings), settings: Settings = Depends(get_settings),
llama: llama_cpp.Llama = Depends(get_llama),
) -> ModelList: ) -> ModelList:
assert llama is not None
return { return {
"object": "list", "object": "list",
"data": [ "data": [