Implement openai api compatible authentication (#1010)
This commit is contained in:
parent
788394c096
commit
33cc623346
1 changed files with 36 additions and 2 deletions
|
@ -14,11 +14,12 @@ import llama_cpp
|
|||
import anyio
|
||||
from anyio.streams.memory import MemoryObjectSendStream
|
||||
from starlette.concurrency import run_in_threadpool, iterate_in_threadpool
|
||||
from fastapi import Depends, FastAPI, APIRouter, Request, Response
|
||||
from fastapi import Depends, FastAPI, APIRouter, Request, Response, HTTPException, status
|
||||
from fastapi.middleware import Middleware
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.routing import APIRoute
|
||||
from fastapi.security import HTTPBearer
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic_settings import BaseSettings
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
|
@ -163,6 +164,10 @@ class Settings(BaseSettings):
|
|||
default=True,
|
||||
description="Whether to interrupt requests when a new request is received.",
|
||||
)
|
||||
api_key: Optional[str] = Field(
|
||||
default=None,
|
||||
description="API key for authentication. If set all requests need to be authenticated."
|
||||
)
|
||||
|
||||
|
||||
class ErrorResponse(TypedDict):
|
||||
|
@ -314,6 +319,9 @@ class RouteErrorHandler(APIRoute):
|
|||
elapsed_time_ms = int((time.perf_counter() - start_sec) * 1000)
|
||||
response.headers["openai-processing-ms"] = f"{elapsed_time_ms}"
|
||||
return response
|
||||
except HTTPException as unauthorized:
|
||||
# api key check failed
|
||||
raise unauthorized
|
||||
except Exception as exc:
|
||||
json_body = await request.json()
|
||||
try:
|
||||
|
@ -658,6 +666,27 @@ def _logit_bias_tokens_to_input_ids(
|
|||
return to_bias
|
||||
|
||||
|
||||
# Setup Bearer authentication scheme
|
||||
bearer_scheme = HTTPBearer(auto_error=False)
|
||||
|
||||
|
||||
async def authenticate(settings: Settings = Depends(get_settings), authorization: Optional[str] = Depends(bearer_scheme)):
|
||||
# Skip API key check if it's not set in settings
|
||||
if settings.api_key is None:
|
||||
return True
|
||||
|
||||
# check bearer credentials against the api_key
|
||||
if authorization and authorization.credentials == settings.api_key:
|
||||
# api key is valid
|
||||
return authorization.credentials
|
||||
|
||||
# raise http error 401
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid API key",
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/completions",
|
||||
summary="Completion"
|
||||
|
@ -667,6 +696,7 @@ async def create_completion(
|
|||
request: Request,
|
||||
body: CreateCompletionRequest,
|
||||
llama: llama_cpp.Llama = Depends(get_llama),
|
||||
authenticated: str = Depends(authenticate),
|
||||
) -> llama_cpp.Completion:
|
||||
if isinstance(body.prompt, list):
|
||||
assert len(body.prompt) <= 1
|
||||
|
@ -740,7 +770,9 @@ class CreateEmbeddingRequest(BaseModel):
|
|||
summary="Embedding"
|
||||
)
|
||||
async def create_embedding(
|
||||
request: CreateEmbeddingRequest, llama: llama_cpp.Llama = Depends(get_llama)
|
||||
request: CreateEmbeddingRequest,
|
||||
llama: llama_cpp.Llama = Depends(get_llama),
|
||||
authenticated: str = Depends(authenticate),
|
||||
):
|
||||
return await run_in_threadpool(
|
||||
llama.create_embedding, **request.model_dump(exclude={"user"})
|
||||
|
@ -834,6 +866,7 @@ async def create_chat_completion(
|
|||
body: CreateChatCompletionRequest,
|
||||
llama: llama_cpp.Llama = Depends(get_llama),
|
||||
settings: Settings = Depends(get_settings),
|
||||
authenticated: str = Depends(authenticate),
|
||||
) -> llama_cpp.ChatCompletion:
|
||||
exclude = {
|
||||
"n",
|
||||
|
@ -895,6 +928,7 @@ class ModelList(TypedDict):
|
|||
@router.get("/v1/models", summary="Models")
|
||||
async def get_models(
|
||||
settings: Settings = Depends(get_settings),
|
||||
authenticated: str = Depends(authenticate),
|
||||
) -> ModelList:
|
||||
assert llama is not None
|
||||
return {
|
||||
|
|
Loading…
Reference in a new issue