diff --git a/llama_cpp/server/app.py b/llama_cpp/server/app.py index 6923aec..911200c 100644 --- a/llama_cpp/server/app.py +++ b/llama_cpp/server/app.py @@ -12,12 +12,15 @@ import anyio from anyio.streams.memory import MemoryObjectSendStream from starlette.concurrency import run_in_threadpool, iterate_in_threadpool from fastapi import Depends, FastAPI, APIRouter, Request, Response +from fastapi.middleware import Middleware from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from fastapi.routing import APIRoute from pydantic import BaseModel, Field from pydantic_settings import BaseSettings from sse_starlette.sse import EventSourceResponse +from starlette_context import plugins +from starlette_context.middleware import RawContextMiddleware import numpy as np import numpy.typing as npt @@ -306,7 +309,17 @@ llama: Optional[llama_cpp.Llama] = None def create_app(settings: Optional[Settings] = None): if settings is None: settings = Settings() + + middleware = [ + Middleware( + RawContextMiddleware, + plugins=( + plugins.RequestIdPlugin(), + ) + ) + ] app = FastAPI( + middleware=middleware, title="🦙 llama.cpp Python API", version="0.0.1", ) diff --git a/pyproject.toml b/pyproject.toml index 61b5bec..0bc863e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,7 @@ server = [ "fastapi>=0.100.0", "pydantic-settings>=2.0.1", "sse-starlette>=1.6.1", + "starlette-context>=0.3.6,<0.4" ] test = [ "pytest>=7.4.0",