From 33cc623346d8c24ac4283f5b79d7154b821f62a7 Mon Sep 17 00:00:00 2001 From: docmeth02 Date: Thu, 21 Dec 2023 19:44:49 +0100 Subject: [PATCH] Implement openai api compatible authentication (#1010) --- llama_cpp/server/app.py | 38 ++++++++++++++++++++++++++++++++++++-- 1 file changed, 36 insertions(+), 2 deletions(-) diff --git a/llama_cpp/server/app.py b/llama_cpp/server/app.py index fa39047..db9705f 100644 --- a/llama_cpp/server/app.py +++ b/llama_cpp/server/app.py @@ -14,11 +14,12 @@ import llama_cpp import anyio from anyio.streams.memory import MemoryObjectSendStream from starlette.concurrency import run_in_threadpool, iterate_in_threadpool -from fastapi import Depends, FastAPI, APIRouter, Request, Response +from fastapi import Depends, FastAPI, APIRouter, Request, Response, HTTPException, status from fastapi.middleware import Middleware from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from fastapi.routing import APIRoute +from fastapi.security import HTTPBearer from pydantic import BaseModel, Field from pydantic_settings import BaseSettings from sse_starlette.sse import EventSourceResponse @@ -163,6 +164,10 @@ class Settings(BaseSettings): default=True, description="Whether to interrupt requests when a new request is received.", ) + api_key: Optional[str] = Field( + default=None, + description="API key for authentication. If set all requests need to be authenticated." + ) class ErrorResponse(TypedDict): @@ -314,6 +319,9 @@ class RouteErrorHandler(APIRoute): elapsed_time_ms = int((time.perf_counter() - start_sec) * 1000) response.headers["openai-processing-ms"] = f"{elapsed_time_ms}" return response + except HTTPException as unauthorized: + # api key check failed + raise unauthorized except Exception as exc: json_body = await request.json() try: @@ -658,6 +666,27 @@ def _logit_bias_tokens_to_input_ids( return to_bias +# Setup Bearer authentication scheme +bearer_scheme = HTTPBearer(auto_error=False) + + +async def authenticate(settings: Settings = Depends(get_settings), authorization: Optional[str] = Depends(bearer_scheme)): + # Skip API key check if it's not set in settings + if settings.api_key is None: + return True + + # check bearer credentials against the api_key + if authorization and authorization.credentials == settings.api_key: + # api key is valid + return authorization.credentials + + # raise http error 401 + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid API key", + ) + + @router.post( "/v1/completions", summary="Completion" @@ -667,6 +696,7 @@ async def create_completion( request: Request, body: CreateCompletionRequest, llama: llama_cpp.Llama = Depends(get_llama), + authenticated: str = Depends(authenticate), ) -> llama_cpp.Completion: if isinstance(body.prompt, list): assert len(body.prompt) <= 1 @@ -740,7 +770,9 @@ class CreateEmbeddingRequest(BaseModel): summary="Embedding" ) async def create_embedding( - request: CreateEmbeddingRequest, llama: llama_cpp.Llama = Depends(get_llama) + request: CreateEmbeddingRequest, + llama: llama_cpp.Llama = Depends(get_llama), + authenticated: str = Depends(authenticate), ): return await run_in_threadpool( llama.create_embedding, **request.model_dump(exclude={"user"}) @@ -834,6 +866,7 @@ async def create_chat_completion( body: CreateChatCompletionRequest, llama: llama_cpp.Llama = Depends(get_llama), settings: Settings = Depends(get_settings), + authenticated: str = Depends(authenticate), ) -> llama_cpp.ChatCompletion: exclude = { "n", @@ -895,6 +928,7 @@ class ModelList(TypedDict): @router.get("/v1/models", summary="Models") async def get_models( settings: Settings = Depends(get_settings), + authenticated: str = Depends(authenticate), ) -> ModelList: assert llama is not None return {