From 4c7cdcca00f63896a95e09a11f424237e224bc72 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Fri, 7 Jul 2023 03:04:17 -0400 Subject: [PATCH] Add interruptible streaming requests for llama-cpp-python server. Closes #183 --- CHANGELOG.md | 4 ++++ llama_cpp/server/app.py | 31 +++++++++++++++++++++++++------ 2 files changed, 29 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c6cfaab..11251c6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [Added] + +- (server) Streaming requests can are now interrupted pre-maturely when a concurrent request is made. Can be controlled with the `interrupt_requests` setting. + ## [0.1.68] ## [Added] diff --git a/llama_cpp/server/app.py b/llama_cpp/server/app.py index ef319c7..b9d5771 100644 --- a/llama_cpp/server/app.py +++ b/llama_cpp/server/app.py @@ -146,12 +146,27 @@ def create_app(settings: Optional[Settings] = None): return app -llama_lock = Lock() +llama_outer_lock = Lock() +llama_inner_lock = Lock() def get_llama(): - with llama_lock: - yield llama + # NOTE: This double lock allows the currently streaming llama model to + # check if any other requests are pending in the same thread and cancel + # the stream if so. + llama_outer_lock.acquire() + release_outer_lock = True + try: + llama_inner_lock.acquire() + try: + llama_outer_lock.release() + release_outer_lock = False + yield llama + finally: + llama_inner_lock.release() + finally: + if release_outer_lock: + llama_outer_lock.release() def get_settings(): @@ -364,6 +379,9 @@ async def create_completion( await inner_send_chan.send(dict(data=json.dumps(chunk))) if await request.is_disconnected(): raise anyio.get_cancelled_exc_class()() + if llama_outer_lock.locked(): + await inner_send_chan.send(dict(data="[DONE]")) + raise anyio.get_cancelled_exc_class()() await inner_send_chan.send(dict(data="[DONE]")) except anyio.get_cancelled_exc_class() as e: print("disconnected") @@ -371,7 +389,6 @@ async def create_completion( print( f"Disconnected from client (via refresh/close) {request.client}" ) - await inner_send_chan.send(dict(closing=True)) raise e return EventSourceResponse( @@ -494,6 +511,9 @@ async def create_chat_completion( await inner_send_chan.send(dict(data=json.dumps(chat_chunk))) if await request.is_disconnected(): raise anyio.get_cancelled_exc_class()() + if llama_outer_lock.locked(): + await inner_send_chan.send(dict(data="[DONE]")) + raise anyio.get_cancelled_exc_class()() await inner_send_chan.send(dict(data="[DONE]")) except anyio.get_cancelled_exc_class() as e: print("disconnected") @@ -501,7 +521,6 @@ async def create_chat_completion( print( f"Disconnected from client (via refresh/close) {request.client}" ) - await inner_send_chan.send(dict(closing=True)) raise e return EventSourceResponse( @@ -533,8 +552,8 @@ GetModelResponse = create_model_from_typeddict(ModelList) @router.get("/v1/models", response_model=GetModelResponse) async def get_models( settings: Settings = Depends(get_settings), - llama: llama_cpp.Llama = Depends(get_llama), ) -> ModelList: + assert llama is not None return { "object": "list", "data": [