Add interruptible streaming requests for llama-cpp-python server. Closes #183
This commit is contained in:
parent
98ae4e58a3
commit
4c7cdcca00
2 changed files with 29 additions and 6 deletions
|
@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||||
|
|
||||||
## [Unreleased]
|
## [Unreleased]
|
||||||
|
|
||||||
|
## [Added]
|
||||||
|
|
||||||
|
- (server) Streaming requests can are now interrupted pre-maturely when a concurrent request is made. Can be controlled with the `interrupt_requests` setting.
|
||||||
|
|
||||||
## [0.1.68]
|
## [0.1.68]
|
||||||
|
|
||||||
## [Added]
|
## [Added]
|
||||||
|
|
|
@ -146,12 +146,27 @@ def create_app(settings: Optional[Settings] = None):
|
||||||
return app
|
return app
|
||||||
|
|
||||||
|
|
||||||
llama_lock = Lock()
|
llama_outer_lock = Lock()
|
||||||
|
llama_inner_lock = Lock()
|
||||||
|
|
||||||
|
|
||||||
def get_llama():
|
def get_llama():
|
||||||
with llama_lock:
|
# NOTE: This double lock allows the currently streaming llama model to
|
||||||
yield llama
|
# check if any other requests are pending in the same thread and cancel
|
||||||
|
# the stream if so.
|
||||||
|
llama_outer_lock.acquire()
|
||||||
|
release_outer_lock = True
|
||||||
|
try:
|
||||||
|
llama_inner_lock.acquire()
|
||||||
|
try:
|
||||||
|
llama_outer_lock.release()
|
||||||
|
release_outer_lock = False
|
||||||
|
yield llama
|
||||||
|
finally:
|
||||||
|
llama_inner_lock.release()
|
||||||
|
finally:
|
||||||
|
if release_outer_lock:
|
||||||
|
llama_outer_lock.release()
|
||||||
|
|
||||||
|
|
||||||
def get_settings():
|
def get_settings():
|
||||||
|
@ -364,6 +379,9 @@ async def create_completion(
|
||||||
await inner_send_chan.send(dict(data=json.dumps(chunk)))
|
await inner_send_chan.send(dict(data=json.dumps(chunk)))
|
||||||
if await request.is_disconnected():
|
if await request.is_disconnected():
|
||||||
raise anyio.get_cancelled_exc_class()()
|
raise anyio.get_cancelled_exc_class()()
|
||||||
|
if llama_outer_lock.locked():
|
||||||
|
await inner_send_chan.send(dict(data="[DONE]"))
|
||||||
|
raise anyio.get_cancelled_exc_class()()
|
||||||
await inner_send_chan.send(dict(data="[DONE]"))
|
await inner_send_chan.send(dict(data="[DONE]"))
|
||||||
except anyio.get_cancelled_exc_class() as e:
|
except anyio.get_cancelled_exc_class() as e:
|
||||||
print("disconnected")
|
print("disconnected")
|
||||||
|
@ -371,7 +389,6 @@ async def create_completion(
|
||||||
print(
|
print(
|
||||||
f"Disconnected from client (via refresh/close) {request.client}"
|
f"Disconnected from client (via refresh/close) {request.client}"
|
||||||
)
|
)
|
||||||
await inner_send_chan.send(dict(closing=True))
|
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
return EventSourceResponse(
|
return EventSourceResponse(
|
||||||
|
@ -494,6 +511,9 @@ async def create_chat_completion(
|
||||||
await inner_send_chan.send(dict(data=json.dumps(chat_chunk)))
|
await inner_send_chan.send(dict(data=json.dumps(chat_chunk)))
|
||||||
if await request.is_disconnected():
|
if await request.is_disconnected():
|
||||||
raise anyio.get_cancelled_exc_class()()
|
raise anyio.get_cancelled_exc_class()()
|
||||||
|
if llama_outer_lock.locked():
|
||||||
|
await inner_send_chan.send(dict(data="[DONE]"))
|
||||||
|
raise anyio.get_cancelled_exc_class()()
|
||||||
await inner_send_chan.send(dict(data="[DONE]"))
|
await inner_send_chan.send(dict(data="[DONE]"))
|
||||||
except anyio.get_cancelled_exc_class() as e:
|
except anyio.get_cancelled_exc_class() as e:
|
||||||
print("disconnected")
|
print("disconnected")
|
||||||
|
@ -501,7 +521,6 @@ async def create_chat_completion(
|
||||||
print(
|
print(
|
||||||
f"Disconnected from client (via refresh/close) {request.client}"
|
f"Disconnected from client (via refresh/close) {request.client}"
|
||||||
)
|
)
|
||||||
await inner_send_chan.send(dict(closing=True))
|
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
return EventSourceResponse(
|
return EventSourceResponse(
|
||||||
|
@ -533,8 +552,8 @@ GetModelResponse = create_model_from_typeddict(ModelList)
|
||||||
@router.get("/v1/models", response_model=GetModelResponse)
|
@router.get("/v1/models", response_model=GetModelResponse)
|
||||||
async def get_models(
|
async def get_models(
|
||||||
settings: Settings = Depends(get_settings),
|
settings: Settings = Depends(get_settings),
|
||||||
llama: llama_cpp.Llama = Depends(get_llama),
|
|
||||||
) -> ModelList:
|
) -> ModelList:
|
||||||
|
assert llama is not None
|
||||||
return {
|
return {
|
||||||
"object": "list",
|
"object": "list",
|
||||||
"data": [
|
"data": [
|
||||||
|
|
Loading…
Add table
Reference in a new issue