You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

58 lines
1.8 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

"""
认证中间件
"""
import logging
from typing import Optional
from fastapi import Request, HTTPException, status
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
from app.config import settings
from app.services.auth_service import decode_token
logger = logging.getLogger(__name__)
class AuthMiddleware(BaseHTTPMiddleware):
"""JWT 认证中间件"""
async def dispatch(self, request: Request, call_next):
# 跳过认证的路径
skip_paths = [
"/",
"/health",
"/docs",
"/redoc",
"/openapi.json",
f"{settings.API_PREFIX}/auth/login",
]
# 检查是否需要跳过认证
if any(request.url.path.startswith(path) for path in skip_paths):
return await call_next(request)
# 获取 Authorization header
auth_header: Optional[str] = request.headers.get("Authorization")
if not auth_header:
# 对于需要认证的 API如果没有 token继续处理但标记为未认证
# 具体的权限检查在各个路由中进行
return await call_next(request)
# 验证 token
try:
if auth_header.startswith("Bearer "):
token = auth_header[7:]
payload = decode_token(token)
if payload:
# 将用户信息添加到 request state
request.state.user_id = payload.get("user_id")
request.state.username = payload.get("sub")
request.state.token_payload = payload
except Exception as e:
logger.warning(f"Token validation failed: {e}")
return await call_next(request)