Implement openai api compatible authentication (#1010)
This commit is contained in:
parent
788394c096
commit
33cc623346
1 changed files with 36 additions and 2 deletions
|
@ -14,11 +14,12 @@ import llama_cpp
|
||||||
import anyio
|
import anyio
|
||||||
from anyio.streams.memory import MemoryObjectSendStream
|
from anyio.streams.memory import MemoryObjectSendStream
|
||||||
from starlette.concurrency import run_in_threadpool, iterate_in_threadpool
|
from starlette.concurrency import run_in_threadpool, iterate_in_threadpool
|
||||||
from fastapi import Depends, FastAPI, APIRouter, Request, Response
|
from fastapi import Depends, FastAPI, APIRouter, Request, Response, HTTPException, status
|
||||||
from fastapi.middleware import Middleware
|
from fastapi.middleware import Middleware
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
from fastapi.routing import APIRoute
|
from fastapi.routing import APIRoute
|
||||||
|
from fastapi.security import HTTPBearer
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from pydantic_settings import BaseSettings
|
from pydantic_settings import BaseSettings
|
||||||
from sse_starlette.sse import EventSourceResponse
|
from sse_starlette.sse import EventSourceResponse
|
||||||
|
@ -163,6 +164,10 @@ class Settings(BaseSettings):
|
||||||
default=True,
|
default=True,
|
||||||
description="Whether to interrupt requests when a new request is received.",
|
description="Whether to interrupt requests when a new request is received.",
|
||||||
)
|
)
|
||||||
|
api_key: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="API key for authentication. If set all requests need to be authenticated."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ErrorResponse(TypedDict):
|
class ErrorResponse(TypedDict):
|
||||||
|
@ -314,6 +319,9 @@ class RouteErrorHandler(APIRoute):
|
||||||
elapsed_time_ms = int((time.perf_counter() - start_sec) * 1000)
|
elapsed_time_ms = int((time.perf_counter() - start_sec) * 1000)
|
||||||
response.headers["openai-processing-ms"] = f"{elapsed_time_ms}"
|
response.headers["openai-processing-ms"] = f"{elapsed_time_ms}"
|
||||||
return response
|
return response
|
||||||
|
except HTTPException as unauthorized:
|
||||||
|
# api key check failed
|
||||||
|
raise unauthorized
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
json_body = await request.json()
|
json_body = await request.json()
|
||||||
try:
|
try:
|
||||||
|
@ -658,6 +666,27 @@ def _logit_bias_tokens_to_input_ids(
|
||||||
return to_bias
|
return to_bias
|
||||||
|
|
||||||
|
|
||||||
|
# Setup Bearer authentication scheme
|
||||||
|
bearer_scheme = HTTPBearer(auto_error=False)
|
||||||
|
|
||||||
|
|
||||||
|
async def authenticate(settings: Settings = Depends(get_settings), authorization: Optional[str] = Depends(bearer_scheme)):
|
||||||
|
# Skip API key check if it's not set in settings
|
||||||
|
if settings.api_key is None:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# check bearer credentials against the api_key
|
||||||
|
if authorization and authorization.credentials == settings.api_key:
|
||||||
|
# api key is valid
|
||||||
|
return authorization.credentials
|
||||||
|
|
||||||
|
# raise http error 401
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Invalid API key",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
"/v1/completions",
|
"/v1/completions",
|
||||||
summary="Completion"
|
summary="Completion"
|
||||||
|
@ -667,6 +696,7 @@ async def create_completion(
|
||||||
request: Request,
|
request: Request,
|
||||||
body: CreateCompletionRequest,
|
body: CreateCompletionRequest,
|
||||||
llama: llama_cpp.Llama = Depends(get_llama),
|
llama: llama_cpp.Llama = Depends(get_llama),
|
||||||
|
authenticated: str = Depends(authenticate),
|
||||||
) -> llama_cpp.Completion:
|
) -> llama_cpp.Completion:
|
||||||
if isinstance(body.prompt, list):
|
if isinstance(body.prompt, list):
|
||||||
assert len(body.prompt) <= 1
|
assert len(body.prompt) <= 1
|
||||||
|
@ -740,7 +770,9 @@ class CreateEmbeddingRequest(BaseModel):
|
||||||
summary="Embedding"
|
summary="Embedding"
|
||||||
)
|
)
|
||||||
async def create_embedding(
|
async def create_embedding(
|
||||||
request: CreateEmbeddingRequest, llama: llama_cpp.Llama = Depends(get_llama)
|
request: CreateEmbeddingRequest,
|
||||||
|
llama: llama_cpp.Llama = Depends(get_llama),
|
||||||
|
authenticated: str = Depends(authenticate),
|
||||||
):
|
):
|
||||||
return await run_in_threadpool(
|
return await run_in_threadpool(
|
||||||
llama.create_embedding, **request.model_dump(exclude={"user"})
|
llama.create_embedding, **request.model_dump(exclude={"user"})
|
||||||
|
@ -834,6 +866,7 @@ async def create_chat_completion(
|
||||||
body: CreateChatCompletionRequest,
|
body: CreateChatCompletionRequest,
|
||||||
llama: llama_cpp.Llama = Depends(get_llama),
|
llama: llama_cpp.Llama = Depends(get_llama),
|
||||||
settings: Settings = Depends(get_settings),
|
settings: Settings = Depends(get_settings),
|
||||||
|
authenticated: str = Depends(authenticate),
|
||||||
) -> llama_cpp.ChatCompletion:
|
) -> llama_cpp.ChatCompletion:
|
||||||
exclude = {
|
exclude = {
|
||||||
"n",
|
"n",
|
||||||
|
@ -895,6 +928,7 @@ class ModelList(TypedDict):
|
||||||
@router.get("/v1/models", summary="Models")
|
@router.get("/v1/models", summary="Models")
|
||||||
async def get_models(
|
async def get_models(
|
||||||
settings: Settings = Depends(get_settings),
|
settings: Settings = Depends(get_settings),
|
||||||
|
authenticated: str = Depends(authenticate),
|
||||||
) -> ModelList:
|
) -> ModelList:
|
||||||
assert llama is not None
|
assert llama is not None
|
||||||
return {
|
return {
|
||||||
|
|
Loading…
Reference in a new issue