llama.cpp/llama_cpp/server/app.py

560 lines
18 KiB
Python
Raw Normal View History

from __future__ import annotations
import os
import json
from threading import Lock
2023-05-27 09:12:58 -04:00
from functools import partial
from typing import Iterator, List, Optional, Union, Dict
import llama_cpp
2023-05-27 09:12:58 -04:00
import anyio
from anyio.streams.memory import MemoryObjectSendStream
from starlette.concurrency import run_in_threadpool, iterate_in_threadpool
2024-03-19 10:52:53 -04:00
from fastapi import Depends, FastAPI, APIRouter, Request, HTTPException, status, Body
from fastapi.middleware import Middleware
from fastapi.middleware.cors import CORSMiddleware
from fastapi.security import HTTPBearer
from sse_starlette.sse import EventSourceResponse
from starlette_context.plugins import RequestIdPlugin # type: ignore
from starlette_context.middleware import RawContextMiddleware
from llama_cpp.server.model import (
LlamaProxy,
)
from llama_cpp.server.settings import (
ConfigFileSettings,
Settings,
ModelSettings,
ServerSettings,
)
from llama_cpp.server.types import (
CreateCompletionRequest,
CreateEmbeddingRequest,
CreateChatCompletionRequest,
ModelList,
TokenizeInputRequest,
TokenizeInputResponse,
TokenizeInputCountResponse,
DetokenizeInputRequest,
DetokenizeInputResponse,
)
from llama_cpp.server.errors import RouteErrorHandler
2023-07-16 14:57:39 +09:00
router = APIRouter(route_class=RouteErrorHandler)
2023-07-16 14:57:39 +09:00
_server_settings: Optional[ServerSettings] = None
2023-07-16 14:57:39 +09:00
def set_server_settings(server_settings: ServerSettings):
global _server_settings
_server_settings = server_settings
2023-07-16 14:57:39 +09:00
2023-05-01 22:38:46 -04:00
def get_server_settings():
yield _server_settings
2023-05-01 22:38:46 -04:00
_llama_proxy: Optional[LlamaProxy] = None
llama_outer_lock = Lock()
llama_inner_lock = Lock()
2023-05-01 22:38:46 -04:00
def set_llama_proxy(model_settings: List[ModelSettings]):
global _llama_proxy
_llama_proxy = LlamaProxy(models=model_settings)
def get_llama_proxy():
# NOTE: This double lock allows the currently streaming llama model to
# check if any other requests are pending in the same thread and cancel
# the stream if so.
llama_outer_lock.acquire()
release_outer_lock = True
try:
llama_inner_lock.acquire()
try:
llama_outer_lock.release()
release_outer_lock = False
yield _llama_proxy
finally:
llama_inner_lock.release()
finally:
if release_outer_lock:
llama_outer_lock.release()
2023-05-07 02:52:20 -04:00
_ping_message_factory = None
def set_ping_message_factory(factory):
global _ping_message_factory
_ping_message_factory = factory
def create_app(
settings: Settings | None = None,
server_settings: ServerSettings | None = None,
model_settings: List[ModelSettings] | None = None,
):
config_file = os.environ.get("CONFIG_FILE", None)
if config_file is not None:
if not os.path.exists(config_file):
raise ValueError(f"Config file {config_file} not found!")
with open(config_file, "rb") as f:
# Check if yaml file
if config_file.endswith(".yaml") or config_file.endswith(".yml"):
import yaml
config_file_settings = ConfigFileSettings.model_validate_json(
json.dumps(yaml.safe_load(f))
)
else:
config_file_settings = ConfigFileSettings.model_validate_json(f.read())
server_settings = ServerSettings.model_validate(config_file_settings)
model_settings = config_file_settings.models
if server_settings is None and model_settings is None:
if settings is None:
settings = Settings()
server_settings = ServerSettings.model_validate(settings)
model_settings = [ModelSettings.model_validate(settings)]
assert (
server_settings is not None and model_settings is not None
), "server_settings and model_settings must be provided together"
set_server_settings(server_settings)
middleware = [Middleware(RawContextMiddleware, plugins=(RequestIdPlugin(),))]
app = FastAPI(
middleware=middleware,
title="🦙 llama.cpp Python API",
version=llama_cpp.__version__,
root_path=server_settings.root_path,
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.include_router(router)
assert model_settings is not None
set_llama_proxy(model_settings=model_settings)
if server_settings.disable_ping_events:
set_ping_message_factory(lambda: bytes())
return app
2023-07-16 14:57:39 +09:00
async def get_event_publisher(
request: Request,
inner_send_chan: MemoryObjectSendStream,
iterator: Iterator,
):
async with inner_send_chan:
try:
async for chunk in iterate_in_threadpool(iterator):
await inner_send_chan.send(dict(data=json.dumps(chunk)))
if await request.is_disconnected():
raise anyio.get_cancelled_exc_class()()
if (
next(get_server_settings()).interrupt_requests
and llama_outer_lock.locked()
):
2023-07-16 14:57:39 +09:00
await inner_send_chan.send(dict(data="[DONE]"))
raise anyio.get_cancelled_exc_class()()
await inner_send_chan.send(dict(data="[DONE]"))
except anyio.get_cancelled_exc_class() as e:
print("disconnected")
with anyio.move_on_after(1, shield=True):
2023-09-13 21:23:23 -04:00
print(f"Disconnected from client (via refresh/close) {request.client}")
2023-07-16 14:57:39 +09:00
raise e
2023-09-13 21:23:23 -04:00
def _logit_bias_tokens_to_input_ids(
llama: llama_cpp.Llama,
logit_bias: Dict[str, float],
) -> Dict[str, float]:
to_bias: Dict[str, float] = {}
for token, score in logit_bias.items():
token = token.encode("utf-8")
for input_id in llama.tokenize(token, add_bos=False, special=True):
to_bias[str(input_id)] = score
return to_bias
# Setup Bearer authentication scheme
bearer_scheme = HTTPBearer(auto_error=False)
async def authenticate(
settings: Settings = Depends(get_server_settings),
authorization: Optional[str] = Depends(bearer_scheme),
):
# Skip API key check if it's not set in settings
if settings.api_key is None:
return True
# check bearer credentials against the api_key
if authorization and authorization.credentials == settings.api_key:
# api key is valid
return authorization.credentials
# raise http error 401
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid API key",
)
openai_v1_tag = "OpenAI V1"
2023-05-01 22:38:46 -04:00
@router.post(
"/v1/completions",
summary="Completion",
2024-02-28 14:27:40 -05:00
dependencies=[Depends(authenticate)],
response_model=Union[
llama_cpp.CreateCompletionResponse,
str,
],
responses={
"200": {
"description": "Successful Response",
"content": {
"application/json": {
"schema": {
"anyOf": [
2024-02-28 14:27:40 -05:00
{"$ref": "#/components/schemas/CreateCompletionResponse"}
],
"title": "Completion response, when stream=False",
}
},
2024-02-28 14:27:40 -05:00
"text/event-stream": {
"schema": {
"type": "string",
"title": "Server Side Streaming response, when stream=True. "
+ "See SSE format: https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format", # noqa: E501
"example": """data: {... see CreateCompletionResponse ...} \\n\\n data: ... \\n\\n ... data: [DONE]""",
}
2024-02-28 14:27:40 -05:00
},
},
}
},
tags=[openai_v1_tag],
)
@router.post(
"/v1/engines/copilot-codex/completions",
include_in_schema=False,
dependencies=[Depends(authenticate)],
tags=[openai_v1_tag],
)
2023-05-27 09:12:58 -04:00
async def create_completion(
request: Request,
body: CreateCompletionRequest,
llama_proxy: LlamaProxy = Depends(get_llama_proxy),
2023-07-13 23:25:12 -04:00
) -> llama_cpp.Completion:
2023-05-27 09:12:58 -04:00
if isinstance(body.prompt, list):
assert len(body.prompt) <= 1
body.prompt = body.prompt[0] if len(body.prompt) > 0 else ""
llama = llama_proxy(
body.model
if request.url.path != "/v1/engines/copilot-codex/completions"
else "copilot-codex"
)
2023-05-27 09:12:58 -04:00
exclude = {
"n",
"best_of",
"logit_bias_type",
2023-05-27 09:12:58 -04:00
"user",
}
2023-07-13 23:25:12 -04:00
kwargs = body.model_dump(exclude=exclude)
if body.logit_bias is not None:
kwargs["logit_bias"] = (
_logit_bias_tokens_to_input_ids(llama, body.logit_bias)
if body.logit_bias_type == "tokens"
else body.logit_bias
2023-07-19 03:48:27 -04:00
)
if body.grammar is not None:
kwargs["grammar"] = llama_cpp.LlamaGrammar.from_string(body.grammar)
2023-09-13 21:23:23 -04:00
iterator_or_completion: Union[
2023-11-21 04:02:20 -05:00
llama_cpp.CreateCompletionResponse,
Iterator[llama_cpp.CreateCompletionStreamResponse],
2023-09-13 21:23:23 -04:00
] = await run_in_threadpool(llama, **kwargs)
2023-05-27 09:12:58 -04:00
2023-07-16 14:57:39 +09:00
if isinstance(iterator_or_completion, Iterator):
# EAFP: It's easier to ask for forgiveness than permission
first_response = await run_in_threadpool(next, iterator_or_completion)
2023-07-16 14:57:39 +09:00
# If no exception was raised from first_response, we can assume that
# the iterator is valid and we can use it to stream the response.
def iterator() -> Iterator[llama_cpp.CreateCompletionStreamResponse]:
2023-07-16 14:57:39 +09:00
yield first_response
yield from iterator_or_completion
send_chan, recv_chan = anyio.create_memory_object_stream(10)
2023-05-27 09:12:58 -04:00
return EventSourceResponse(
2023-09-13 21:23:23 -04:00
recv_chan,
data_sender_callable=partial( # type: ignore
2023-07-16 14:57:39 +09:00
get_event_publisher,
request=request,
inner_send_chan=send_chan,
iterator=iterator(),
2023-09-13 21:23:23 -04:00
),
2024-02-28 14:27:40 -05:00
sep="\n",
ping_message_factory=_ping_message_factory,
2023-07-16 14:57:39 +09:00
)
2023-05-27 09:12:58 -04:00
else:
2023-07-16 14:57:39 +09:00
return iterator_or_completion
2023-05-01 22:38:46 -04:00
@router.post(
"/v1/embeddings",
summary="Embedding",
dependencies=[Depends(authenticate)],
tags=[openai_v1_tag],
)
2023-05-27 09:12:58 -04:00
async def create_embedding(
request: CreateEmbeddingRequest,
llama_proxy: LlamaProxy = Depends(get_llama_proxy),
):
2023-05-27 09:12:58 -04:00
return await run_in_threadpool(
llama_proxy(request.model).create_embedding,
**request.model_dump(exclude={"user"}),
2023-05-27 09:12:58 -04:00
)
2023-05-01 22:38:46 -04:00
@router.post(
2024-02-28 14:27:40 -05:00
"/v1/chat/completions",
summary="Chat",
dependencies=[Depends(authenticate)],
response_model=Union[llama_cpp.ChatCompletion, str],
responses={
"200": {
"description": "Successful Response",
"content": {
"application/json": {
"schema": {
"anyOf": [
2024-02-28 14:27:40 -05:00
{
"$ref": "#/components/schemas/CreateChatCompletionResponse"
}
],
"title": "Completion response, when stream=False",
}
},
2024-02-28 14:27:40 -05:00
"text/event-stream": {
"schema": {
"type": "string",
"title": "Server Side Streaming response, when stream=True"
+ "See SSE format: https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format", # noqa: E501
"example": """data: {... see CreateChatCompletionResponse ...} \\n\\n data: ... \\n\\n ... data: [DONE]""",
}
2024-02-28 14:27:40 -05:00
},
},
}
},
tags=[openai_v1_tag],
)
2023-05-27 09:12:58 -04:00
async def create_chat_completion(
request: Request,
2024-03-19 10:52:53 -04:00
body: CreateChatCompletionRequest = Body(
openapi_examples={
"normal": {
"summary": "Chat Completion",
"value": {
"model": "gpt-3.5-turbo",
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What is the capital of France?"},
],
},
},
"json_mode": {
"summary": "JSON Mode",
"value": {
"model": "gpt-3.5-turbo",
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Who won the world series in 2020"},
],
"response_format": { "type": "json_object" }
},
},
"tool_calling": {
"summary": "Tool Calling",
"value": {
"model": "gpt-3.5-turbo",
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Extract Jason is 30 years old."},
],
"tools": [
{
"type": "function",
"function": {
"name": "User",
"description": "User record",
"parameters": {
"type": "object",
"properties": {
"name": {"type": "string"},
"age": {"type": "number"},
},
"required": ["name", "age"],
},
}
}
],
"tool_choice": {
"type": "function",
"function": {
"name": "User",
}
}
},
},
"logprobs": {
"summary": "Logprobs",
"value": {
"model": "gpt-3.5-turbo",
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What is the capital of France?"},
],
"logprobs": True,
"top_logprobs": 10
},
},
2024-03-19 10:52:53 -04:00
}
),
llama_proxy: LlamaProxy = Depends(get_llama_proxy),
2023-07-13 23:25:12 -04:00
) -> llama_cpp.ChatCompletion:
2023-05-27 09:12:58 -04:00
exclude = {
"n",
"logit_bias_type",
2023-05-27 09:12:58 -04:00
"user",
}
2023-07-13 23:25:12 -04:00
kwargs = body.model_dump(exclude=exclude)
llama = llama_proxy(body.model)
if body.logit_bias is not None:
kwargs["logit_bias"] = (
_logit_bias_tokens_to_input_ids(llama, body.logit_bias)
if body.logit_bias_type == "tokens"
else body.logit_bias
2023-07-19 03:48:27 -04:00
)
if body.grammar is not None:
kwargs["grammar"] = llama_cpp.LlamaGrammar.from_string(body.grammar)
2023-09-13 21:23:23 -04:00
iterator_or_completion: Union[
llama_cpp.ChatCompletion, Iterator[llama_cpp.ChatCompletionChunk]
] = await run_in_threadpool(llama.create_chat_completion, **kwargs)
2023-05-27 09:12:58 -04:00
2023-07-16 14:57:39 +09:00
if isinstance(iterator_or_completion, Iterator):
# EAFP: It's easier to ask for forgiveness than permission
first_response = await run_in_threadpool(next, iterator_or_completion)
# If no exception was raised from first_response, we can assume that
# the iterator is valid and we can use it to stream the response.
def iterator() -> Iterator[llama_cpp.ChatCompletionChunk]:
yield first_response
yield from iterator_or_completion
2023-07-16 14:57:39 +09:00
send_chan, recv_chan = anyio.create_memory_object_stream(10)
return EventSourceResponse(
2023-09-13 21:23:23 -04:00
recv_chan,
data_sender_callable=partial( # type: ignore
2023-07-16 14:57:39 +09:00
get_event_publisher,
request=request,
inner_send_chan=send_chan,
iterator=iterator(),
2023-09-13 21:23:23 -04:00
),
2024-02-28 14:27:40 -05:00
sep="\n",
ping_message_factory=_ping_message_factory,
)
2023-07-16 14:57:39 +09:00
else:
return iterator_or_completion
@router.get(
"/v1/models",
summary="Models",
dependencies=[Depends(authenticate)],
tags=[openai_v1_tag],
)
2023-05-27 09:12:58 -04:00
async def get_models(
llama_proxy: LlamaProxy = Depends(get_llama_proxy),
2023-05-07 20:17:52 -04:00
) -> ModelList:
return {
"object": "list",
"data": [
{
"id": model_alias,
"object": "model",
"owned_by": "me",
"permissions": [],
}
for model_alias in llama_proxy
],
}
extras_tag = "Extras"
@router.post(
"/extras/tokenize",
summary="Tokenize",
dependencies=[Depends(authenticate)],
tags=[extras_tag],
)
async def tokenize(
body: TokenizeInputRequest,
llama_proxy: LlamaProxy = Depends(get_llama_proxy),
) -> TokenizeInputResponse:
tokens = llama_proxy(body.model).tokenize(body.input.encode("utf-8"), special=True)
2024-03-23 17:14:15 -04:00
return TokenizeInputResponse(tokens=tokens)
@router.post(
"/extras/tokenize/count",
summary="Tokenize Count",
dependencies=[Depends(authenticate)],
tags=[extras_tag],
)
async def count_query_tokens(
body: TokenizeInputRequest,
llama_proxy: LlamaProxy = Depends(get_llama_proxy),
) -> TokenizeInputCountResponse:
tokens = llama_proxy(body.model).tokenize(body.input.encode("utf-8"), special=True)
2024-03-23 17:14:15 -04:00
return TokenizeInputCountResponse(count=len(tokens))
@router.post(
"/extras/detokenize",
summary="Detokenize",
dependencies=[Depends(authenticate)],
tags=[extras_tag],
)
async def detokenize(
body: DetokenizeInputRequest,
llama_proxy: LlamaProxy = Depends(get_llama_proxy),
) -> DetokenizeInputResponse:
text = llama_proxy(body.model).detokenize(body.tokens).decode("utf-8")
2024-03-23 17:14:15 -04:00
return DetokenizeInputResponse(text=text)