llama.cpp/llama_cpp/server/app.py
Felipe Lorenz c139f8b5d5
feat: Add endpoints for tokenize, detokenize and count tokens (#1136)
* Add endpoint to count tokens

* Add tokenize and detokenize endpoints

* Change response key to tokens for tokenize endpoint

* Fix dependency bug

* Cleanup

* Remove example added by mistake

* Move tokenize, detokenize, and count to Extras namespace. Tag existing endpoints

---------

Co-authored-by: Andrei Betlen <abetlen@gmail.com>
2024-03-08 21:09:00 -05:00

476 lines
14 KiB
Python

from __future__ import annotations
import os
import json
from threading import Lock
from functools import partial
from typing import Iterator, List, Optional, Union, Dict
import llama_cpp
import anyio
from anyio.streams.memory import MemoryObjectSendStream
from starlette.concurrency import run_in_threadpool, iterate_in_threadpool
from fastapi import (
Depends,
FastAPI,
APIRouter,
Request,
HTTPException,
status,
)
from fastapi.middleware import Middleware
from fastapi.middleware.cors import CORSMiddleware
from fastapi.security import HTTPBearer
from sse_starlette.sse import EventSourceResponse
from starlette_context.plugins import RequestIdPlugin # type: ignore
from starlette_context.middleware import RawContextMiddleware
from llama_cpp.server.model import (
LlamaProxy,
)
from llama_cpp.server.settings import (
ConfigFileSettings,
Settings,
ModelSettings,
ServerSettings,
)
from llama_cpp.server.types import (
CreateCompletionRequest,
CreateEmbeddingRequest,
CreateChatCompletionRequest,
ModelList,
TokenizeInputRequest,
TokenizeInputResponse,
TokenizeInputCountResponse,
DetokenizeInputRequest,
DetokenizeInputResponse,
)
from llama_cpp.server.errors import RouteErrorHandler
router = APIRouter(route_class=RouteErrorHandler)
_server_settings: Optional[ServerSettings] = None
def set_server_settings(server_settings: ServerSettings):
global _server_settings
_server_settings = server_settings
def get_server_settings():
yield _server_settings
_llama_proxy: Optional[LlamaProxy] = None
llama_outer_lock = Lock()
llama_inner_lock = Lock()
def set_llama_proxy(model_settings: List[ModelSettings]):
global _llama_proxy
_llama_proxy = LlamaProxy(models=model_settings)
def get_llama_proxy():
# NOTE: This double lock allows the currently streaming llama model to
# check if any other requests are pending in the same thread and cancel
# the stream if so.
llama_outer_lock.acquire()
release_outer_lock = True
try:
llama_inner_lock.acquire()
try:
llama_outer_lock.release()
release_outer_lock = False
yield _llama_proxy
finally:
llama_inner_lock.release()
finally:
if release_outer_lock:
llama_outer_lock.release()
def create_app(
settings: Settings | None = None,
server_settings: ServerSettings | None = None,
model_settings: List[ModelSettings] | None = None,
):
config_file = os.environ.get("CONFIG_FILE", None)
if config_file is not None:
if not os.path.exists(config_file):
raise ValueError(f"Config file {config_file} not found!")
with open(config_file, "rb") as f:
config_file_settings = ConfigFileSettings.model_validate_json(f.read())
server_settings = ServerSettings.model_validate(config_file_settings)
model_settings = config_file_settings.models
if server_settings is None and model_settings is None:
if settings is None:
settings = Settings()
server_settings = ServerSettings.model_validate(settings)
model_settings = [ModelSettings.model_validate(settings)]
assert (
server_settings is not None and model_settings is not None
), "server_settings and model_settings must be provided together"
set_server_settings(server_settings)
middleware = [Middleware(RawContextMiddleware, plugins=(RequestIdPlugin(),))]
app = FastAPI(
middleware=middleware,
title="🦙 llama.cpp Python API",
version=llama_cpp.__version__,
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.include_router(router)
assert model_settings is not None
set_llama_proxy(model_settings=model_settings)
return app
async def get_event_publisher(
request: Request,
inner_send_chan: MemoryObjectSendStream,
iterator: Iterator,
):
async with inner_send_chan:
try:
async for chunk in iterate_in_threadpool(iterator):
await inner_send_chan.send(dict(data=json.dumps(chunk)))
if await request.is_disconnected():
raise anyio.get_cancelled_exc_class()()
if (
next(get_server_settings()).interrupt_requests
and llama_outer_lock.locked()
):
await inner_send_chan.send(dict(data="[DONE]"))
raise anyio.get_cancelled_exc_class()()
await inner_send_chan.send(dict(data="[DONE]"))
except anyio.get_cancelled_exc_class() as e:
print("disconnected")
with anyio.move_on_after(1, shield=True):
print(f"Disconnected from client (via refresh/close) {request.client}")
raise e
def _logit_bias_tokens_to_input_ids(
llama: llama_cpp.Llama,
logit_bias: Dict[str, float],
) -> Dict[str, float]:
to_bias: Dict[str, float] = {}
for token, score in logit_bias.items():
token = token.encode("utf-8")
for input_id in llama.tokenize(token, add_bos=False, special=True):
to_bias[str(input_id)] = score
return to_bias
# Setup Bearer authentication scheme
bearer_scheme = HTTPBearer(auto_error=False)
async def authenticate(
settings: Settings = Depends(get_server_settings),
authorization: Optional[str] = Depends(bearer_scheme),
):
# Skip API key check if it's not set in settings
if settings.api_key is None:
return True
# check bearer credentials against the api_key
if authorization and authorization.credentials == settings.api_key:
# api key is valid
return authorization.credentials
# raise http error 401
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid API key",
)
openai_v1_tag = "OpenAI V1"
@router.post(
"/v1/completions",
summary="Completion",
dependencies=[Depends(authenticate)],
response_model=Union[
llama_cpp.CreateCompletionResponse,
str,
],
responses={
"200": {
"description": "Successful Response",
"content": {
"application/json": {
"schema": {
"anyOf": [
{"$ref": "#/components/schemas/CreateCompletionResponse"}
],
"title": "Completion response, when stream=False",
}
},
"text/event-stream": {
"schema": {
"type": "string",
"title": "Server Side Streaming response, when stream=True. "
+ "See SSE format: https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format", # noqa: E501
"example": """data: {... see CreateCompletionResponse ...} \\n\\n data: ... \\n\\n ... data: [DONE]""",
}
},
},
}
},
tags=[openai_v1_tag],
)
@router.post(
"/v1/engines/copilot-codex/completions",
include_in_schema=False,
dependencies=[Depends(authenticate)],
tags=[openai_v1_tag],
)
async def create_completion(
request: Request,
body: CreateCompletionRequest,
llama_proxy: LlamaProxy = Depends(get_llama_proxy),
) -> llama_cpp.Completion:
if isinstance(body.prompt, list):
assert len(body.prompt) <= 1
body.prompt = body.prompt[0] if len(body.prompt) > 0 else ""
llama = llama_proxy(
body.model
if request.url.path != "/v1/engines/copilot-codex/completions"
else "copilot-codex"
)
exclude = {
"n",
"best_of",
"logit_bias_type",
"user",
}
kwargs = body.model_dump(exclude=exclude)
if body.logit_bias is not None:
kwargs["logit_bias"] = (
_logit_bias_tokens_to_input_ids(llama, body.logit_bias)
if body.logit_bias_type == "tokens"
else body.logit_bias
)
if body.grammar is not None:
kwargs["grammar"] = llama_cpp.LlamaGrammar.from_string(body.grammar)
iterator_or_completion: Union[
llama_cpp.CreateCompletionResponse,
Iterator[llama_cpp.CreateCompletionStreamResponse],
] = await run_in_threadpool(llama, **kwargs)
if isinstance(iterator_or_completion, Iterator):
# EAFP: It's easier to ask for forgiveness than permission
first_response = await run_in_threadpool(next, iterator_or_completion)
# If no exception was raised from first_response, we can assume that
# the iterator is valid and we can use it to stream the response.
def iterator() -> Iterator[llama_cpp.CreateCompletionStreamResponse]:
yield first_response
yield from iterator_or_completion
send_chan, recv_chan = anyio.create_memory_object_stream(10)
return EventSourceResponse(
recv_chan,
data_sender_callable=partial( # type: ignore
get_event_publisher,
request=request,
inner_send_chan=send_chan,
iterator=iterator(),
),
sep="\n",
)
else:
return iterator_or_completion
@router.post(
"/v1/embeddings",
summary="Embedding",
dependencies=[Depends(authenticate)],
tags=[openai_v1_tag],
)
async def create_embedding(
request: CreateEmbeddingRequest,
llama_proxy: LlamaProxy = Depends(get_llama_proxy),
):
return await run_in_threadpool(
llama_proxy(request.model).create_embedding,
**request.model_dump(exclude={"user"}),
)
@router.post(
"/v1/chat/completions",
summary="Chat",
dependencies=[Depends(authenticate)],
response_model=Union[llama_cpp.ChatCompletion, str],
responses={
"200": {
"description": "Successful Response",
"content": {
"application/json": {
"schema": {
"anyOf": [
{
"$ref": "#/components/schemas/CreateChatCompletionResponse"
}
],
"title": "Completion response, when stream=False",
}
},
"text/event-stream": {
"schema": {
"type": "string",
"title": "Server Side Streaming response, when stream=True"
+ "See SSE format: https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format", # noqa: E501
"example": """data: {... see CreateChatCompletionResponse ...} \\n\\n data: ... \\n\\n ... data: [DONE]""",
}
},
},
}
},
tags=[openai_v1_tag],
)
async def create_chat_completion(
request: Request,
body: CreateChatCompletionRequest,
llama_proxy: LlamaProxy = Depends(get_llama_proxy),
) -> llama_cpp.ChatCompletion:
exclude = {
"n",
"logit_bias_type",
"user",
}
kwargs = body.model_dump(exclude=exclude)
llama = llama_proxy(body.model)
if body.logit_bias is not None:
kwargs["logit_bias"] = (
_logit_bias_tokens_to_input_ids(llama, body.logit_bias)
if body.logit_bias_type == "tokens"
else body.logit_bias
)
if body.grammar is not None:
kwargs["grammar"] = llama_cpp.LlamaGrammar.from_string(body.grammar)
iterator_or_completion: Union[
llama_cpp.ChatCompletion, Iterator[llama_cpp.ChatCompletionChunk]
] = await run_in_threadpool(llama.create_chat_completion, **kwargs)
if isinstance(iterator_or_completion, Iterator):
# EAFP: It's easier to ask for forgiveness than permission
first_response = await run_in_threadpool(next, iterator_or_completion)
# If no exception was raised from first_response, we can assume that
# the iterator is valid and we can use it to stream the response.
def iterator() -> Iterator[llama_cpp.ChatCompletionChunk]:
yield first_response
yield from iterator_or_completion
send_chan, recv_chan = anyio.create_memory_object_stream(10)
return EventSourceResponse(
recv_chan,
data_sender_callable=partial( # type: ignore
get_event_publisher,
request=request,
inner_send_chan=send_chan,
iterator=iterator(),
),
sep="\n",
)
else:
return iterator_or_completion
@router.get(
"/v1/models",
summary="Models",
dependencies=[Depends(authenticate)],
tags=[openai_v1_tag],
)
async def get_models(
llama_proxy: LlamaProxy = Depends(get_llama_proxy),
) -> ModelList:
return {
"object": "list",
"data": [
{
"id": model_alias,
"object": "model",
"owned_by": "me",
"permissions": [],
}
for model_alias in llama_proxy
],
}
extras_tag = "Extras"
@router.post(
"/extras/tokenize",
summary="Tokenize",
dependencies=[Depends(authenticate)],
tags=[extras_tag],
)
async def tokenize(
body: TokenizeInputRequest,
llama_proxy: LlamaProxy = Depends(get_llama_proxy),
) -> TokenizeInputResponse:
tokens = llama_proxy(body.model).tokenize(body.input.encode("utf-8"), special=True)
return {"tokens": tokens}
@router.post(
"/extras/tokenize/count",
summary="Tokenize Count",
dependencies=[Depends(authenticate)],
tags=[extras_tag],
)
async def count_query_tokens(
body: TokenizeInputRequest,
llama_proxy: LlamaProxy = Depends(get_llama_proxy),
) -> TokenizeInputCountResponse:
tokens = llama_proxy(body.model).tokenize(body.input.encode("utf-8"), special=True)
return {"count": len(tokens)}
@router.post(
"/extras/detokenize",
summary="Detokenize",
dependencies=[Depends(authenticate)],
tags=[extras_tag],
)
async def detokenize(
body: DetokenizeInputRequest,
llama_proxy: LlamaProxy = Depends(get_llama_proxy),
) -> DetokenizeInputResponse:
text = llama_proxy(body.model).detokenize(body.tokens).decode("utf-8")
return {"text": text}