You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

68 lines
2.1 KiB

"""
限流中间件
"""
import logging
import time
from collections import defaultdict
from fastapi import Request, status
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
from app.config import settings
logger = logging.getLogger(__name__)
class RateLimitMiddleware(BaseHTTPMiddleware):
"""简单的内存限流中间件"""
def __init__(self, app):
super().__init__(app)
self.requests: dict = defaultdict(list)
self.limit = settings.RATE_LIMIT_PER_MINUTE
self.window = 60 # 1 分钟窗口
async def dispatch(self, request: Request, call_next):
# 获取客户端 IP
client_ip = request.client.host if request.client else "unknown"
# 跳过限流的路径
skip_paths = ["/health", "/docs", "/redoc", "/openapi.json"]
if any(request.url.path.startswith(path) for path in skip_paths):
return await call_next(request)
current_time = time.time()
# 清理过期记录
self.requests[client_ip] = [
t for t in self.requests[client_ip]
if current_time - t < self.window
]
# 检查是否超限
if len(self.requests[client_ip]) >= self.limit:
logger.warning(f"Rate limit exceeded for IP: {client_ip}")
return JSONResponse(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
content={
"code": 429,
"message": "Too many requests",
"data": {"retry_after": self.window}
}
)
# 记录请求
self.requests[client_ip].append(current_time)
# 继续处理请求
response = await call_next(request)
# 添加限流头
remaining = self.limit - len(self.requests[client_ip])
response.headers["X-RateLimit-Limit"] = str(self.limit)
response.headers["X-RateLimit-Remaining"] = str(remaining)
response.headers["X-RateLimit-Reset"] = str(int(current_time + self.window))
return response