You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
68 lines
2.1 KiB
68 lines
2.1 KiB
"""
|
|
限流中间件
|
|
"""
|
|
import logging
|
|
import time
|
|
from collections import defaultdict
|
|
|
|
from fastapi import Request, status
|
|
from fastapi.responses import JSONResponse
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
|
|
from app.config import settings
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class RateLimitMiddleware(BaseHTTPMiddleware):
|
|
"""简单的内存限流中间件"""
|
|
|
|
def __init__(self, app):
|
|
super().__init__(app)
|
|
self.requests: dict = defaultdict(list)
|
|
self.limit = settings.RATE_LIMIT_PER_MINUTE
|
|
self.window = 60 # 1 分钟窗口
|
|
|
|
async def dispatch(self, request: Request, call_next):
|
|
# 获取客户端 IP
|
|
client_ip = request.client.host if request.client else "unknown"
|
|
|
|
# 跳过限流的路径
|
|
skip_paths = ["/health", "/docs", "/redoc", "/openapi.json"]
|
|
if any(request.url.path.startswith(path) for path in skip_paths):
|
|
return await call_next(request)
|
|
|
|
current_time = time.time()
|
|
|
|
# 清理过期记录
|
|
self.requests[client_ip] = [
|
|
t for t in self.requests[client_ip]
|
|
if current_time - t < self.window
|
|
]
|
|
|
|
# 检查是否超限
|
|
if len(self.requests[client_ip]) >= self.limit:
|
|
logger.warning(f"Rate limit exceeded for IP: {client_ip}")
|
|
return JSONResponse(
|
|
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
|
content={
|
|
"code": 429,
|
|
"message": "Too many requests",
|
|
"data": {"retry_after": self.window}
|
|
}
|
|
)
|
|
|
|
# 记录请求
|
|
self.requests[client_ip].append(current_time)
|
|
|
|
# 继续处理请求
|
|
response = await call_next(request)
|
|
|
|
# 添加限流头
|
|
remaining = self.limit - len(self.requests[client_ip])
|
|
response.headers["X-RateLimit-Limit"] = str(self.limit)
|
|
response.headers["X-RateLimit-Remaining"] = str(remaining)
|
|
response.headers["X-RateLimit-Reset"] = str(int(current_time + self.window))
|
|
|
|
return response
|