Implement openai api compatible authentication (#1010)

This commit is contained in:
docmeth02 2023-12-21 19:44:49 +01:00 committed by GitHub
parent 788394c096
commit 33cc623346
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -14,11 +14,12 @@ import llama_cpp
import anyio import anyio
from anyio.streams.memory import MemoryObjectSendStream from anyio.streams.memory import MemoryObjectSendStream
from starlette.concurrency import run_in_threadpool, iterate_in_threadpool from starlette.concurrency import run_in_threadpool, iterate_in_threadpool
from fastapi import Depends, FastAPI, APIRouter, Request, Response from fastapi import Depends, FastAPI, APIRouter, Request, Response, HTTPException, status
from fastapi.middleware import Middleware from fastapi.middleware import Middleware
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from fastapi.routing import APIRoute from fastapi.routing import APIRoute
from fastapi.security import HTTPBearer
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from pydantic_settings import BaseSettings from pydantic_settings import BaseSettings
from sse_starlette.sse import EventSourceResponse from sse_starlette.sse import EventSourceResponse
@ -163,6 +164,10 @@ class Settings(BaseSettings):
default=True, default=True,
description="Whether to interrupt requests when a new request is received.", description="Whether to interrupt requests when a new request is received.",
) )
api_key: Optional[str] = Field(
default=None,
description="API key for authentication. If set all requests need to be authenticated."
)
class ErrorResponse(TypedDict): class ErrorResponse(TypedDict):
@ -314,6 +319,9 @@ class RouteErrorHandler(APIRoute):
elapsed_time_ms = int((time.perf_counter() - start_sec) * 1000) elapsed_time_ms = int((time.perf_counter() - start_sec) * 1000)
response.headers["openai-processing-ms"] = f"{elapsed_time_ms}" response.headers["openai-processing-ms"] = f"{elapsed_time_ms}"
return response return response
except HTTPException as unauthorized:
# api key check failed
raise unauthorized
except Exception as exc: except Exception as exc:
json_body = await request.json() json_body = await request.json()
try: try:
@ -658,6 +666,27 @@ def _logit_bias_tokens_to_input_ids(
return to_bias return to_bias
# Setup Bearer authentication scheme
bearer_scheme = HTTPBearer(auto_error=False)
async def authenticate(settings: Settings = Depends(get_settings), authorization: Optional[str] = Depends(bearer_scheme)):
# Skip API key check if it's not set in settings
if settings.api_key is None:
return True
# check bearer credentials against the api_key
if authorization and authorization.credentials == settings.api_key:
# api key is valid
return authorization.credentials
# raise http error 401
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid API key",
)
@router.post( @router.post(
"/v1/completions", "/v1/completions",
summary="Completion" summary="Completion"
@ -667,6 +696,7 @@ async def create_completion(
request: Request, request: Request,
body: CreateCompletionRequest, body: CreateCompletionRequest,
llama: llama_cpp.Llama = Depends(get_llama), llama: llama_cpp.Llama = Depends(get_llama),
authenticated: str = Depends(authenticate),
) -> llama_cpp.Completion: ) -> llama_cpp.Completion:
if isinstance(body.prompt, list): if isinstance(body.prompt, list):
assert len(body.prompt) <= 1 assert len(body.prompt) <= 1
@ -740,7 +770,9 @@ class CreateEmbeddingRequest(BaseModel):
summary="Embedding" summary="Embedding"
) )
async def create_embedding( async def create_embedding(
request: CreateEmbeddingRequest, llama: llama_cpp.Llama = Depends(get_llama) request: CreateEmbeddingRequest,
llama: llama_cpp.Llama = Depends(get_llama),
authenticated: str = Depends(authenticate),
): ):
return await run_in_threadpool( return await run_in_threadpool(
llama.create_embedding, **request.model_dump(exclude={"user"}) llama.create_embedding, **request.model_dump(exclude={"user"})
@ -834,6 +866,7 @@ async def create_chat_completion(
body: CreateChatCompletionRequest, body: CreateChatCompletionRequest,
llama: llama_cpp.Llama = Depends(get_llama), llama: llama_cpp.Llama = Depends(get_llama),
settings: Settings = Depends(get_settings), settings: Settings = Depends(get_settings),
authenticated: str = Depends(authenticate),
) -> llama_cpp.ChatCompletion: ) -> llama_cpp.ChatCompletion:
exclude = { exclude = {
"n", "n",
@ -895,6 +928,7 @@ class ModelList(TypedDict):
@router.get("/v1/models", summary="Models") @router.get("/v1/models", summary="Models")
async def get_models( async def get_models(
settings: Settings = Depends(get_settings), settings: Settings = Depends(get_settings),
authenticated: str = Depends(authenticate),
) -> ModelList: ) -> ModelList:
assert llama is not None assert llama is not None
return { return {