""" 限流中间件 """ import logging import time from collections import defaultdict from fastapi import Request, status from fastapi.responses import JSONResponse from starlette.middleware.base import BaseHTTPMiddleware from app.config import settings logger = logging.getLogger(__name__) class RateLimitMiddleware(BaseHTTPMiddleware): """简单的内存限流中间件""" def __init__(self, app): super().__init__(app) self.requests: dict = defaultdict(list) self.limit = settings.RATE_LIMIT_PER_MINUTE self.window = 60 # 1 分钟窗口 async def dispatch(self, request: Request, call_next): # 获取客户端 IP client_ip = request.client.host if request.client else "unknown" # 跳过限流的路径 skip_paths = ["/health", "/docs", "/redoc", "/openapi.json"] if any(request.url.path.startswith(path) for path in skip_paths): return await call_next(request) current_time = time.time() # 清理过期记录 self.requests[client_ip] = [ t for t in self.requests[client_ip] if current_time - t < self.window ] # 检查是否超限 if len(self.requests[client_ip]) >= self.limit: logger.warning(f"Rate limit exceeded for IP: {client_ip}") return JSONResponse( status_code=status.HTTP_429_TOO_MANY_REQUESTS, content={ "code": 429, "message": "Too many requests", "data": {"retry_after": self.window} } ) # 记录请求 self.requests[client_ip].append(current_time) # 继续处理请求 response = await call_next(request) # 添加限流头 remaining = self.limit - len(self.requests[client_ip]) response.headers["X-RateLimit-Limit"] = str(self.limit) response.headers["X-RateLimit-Remaining"] = str(remaining) response.headers["X-RateLimit-Reset"] = str(int(current_time + self.window)) return response