You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
|
|
|
|
|
"""
|
|
|
|
|
|
认证中间件
|
|
|
|
|
|
"""
|
|
|
|
|
|
import logging
|
|
|
|
|
|
from typing import Optional
|
|
|
|
|
|
|
|
|
|
|
|
from fastapi import Request, HTTPException, status
|
|
|
|
|
|
from fastapi.responses import JSONResponse
|
|
|
|
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
|
|
|
|
|
|
|
|
|
|
from app.config import settings
|
|
|
|
|
|
from app.services.auth_service import decode_token
|
|
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AuthMiddleware(BaseHTTPMiddleware):
|
|
|
|
|
|
"""JWT 认证中间件"""
|
|
|
|
|
|
|
|
|
|
|
|
async def dispatch(self, request: Request, call_next):
|
|
|
|
|
|
# 跳过认证的路径
|
|
|
|
|
|
skip_paths = [
|
|
|
|
|
|
"/",
|
|
|
|
|
|
"/health",
|
|
|
|
|
|
"/docs",
|
|
|
|
|
|
"/redoc",
|
|
|
|
|
|
"/openapi.json",
|
|
|
|
|
|
f"{settings.API_PREFIX}/auth/login",
|
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
# 检查是否需要跳过认证
|
|
|
|
|
|
if any(request.url.path.startswith(path) for path in skip_paths):
|
|
|
|
|
|
return await call_next(request)
|
|
|
|
|
|
|
|
|
|
|
|
# 获取 Authorization header
|
|
|
|
|
|
auth_header: Optional[str] = request.headers.get("Authorization")
|
|
|
|
|
|
|
|
|
|
|
|
if not auth_header:
|
|
|
|
|
|
# 对于需要认证的 API,如果没有 token,继续处理但标记为未认证
|
|
|
|
|
|
# 具体的权限检查在各个路由中进行
|
|
|
|
|
|
return await call_next(request)
|
|
|
|
|
|
|
|
|
|
|
|
# 验证 token
|
|
|
|
|
|
try:
|
|
|
|
|
|
if auth_header.startswith("Bearer "):
|
|
|
|
|
|
token = auth_header[7:]
|
|
|
|
|
|
payload = decode_token(token)
|
|
|
|
|
|
|
|
|
|
|
|
if payload:
|
|
|
|
|
|
# 将用户信息添加到 request state
|
|
|
|
|
|
request.state.user_id = payload.get("user_id")
|
|
|
|
|
|
request.state.username = payload.get("sub")
|
|
|
|
|
|
request.state.token_payload = payload
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.warning(f"Token validation failed: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
return await call_next(request)
|