You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
254 lines
7.2 KiB
254 lines
7.2 KiB
"""
|
|
认证 API 路由
|
|
"""
|
|
from datetime import timedelta
|
|
from typing import Annotated
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException, status
|
|
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
|
|
|
from app.config import settings
|
|
from app.schemas import (
|
|
LoginRequest,
|
|
TokenResponse,
|
|
RefreshTokenRequest,
|
|
ResponseData,
|
|
APIKeyCreate,
|
|
APIKeyResponse
|
|
)
|
|
from app.services.auth_service import (
|
|
authenticate_user,
|
|
create_access_token,
|
|
create_refresh_token,
|
|
decode_token,
|
|
generate_api_key,
|
|
hash_api_key,
|
|
get_password_hash
|
|
)
|
|
from app.models import User, APIKey
|
|
from app.db.init_db import SQLiteSessionLocal, get_sqlite_db
|
|
from sqlalchemy.orm import Session
|
|
|
|
router = APIRouter()
|
|
|
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl=f"{settings.API_PREFIX}/auth/login")
|
|
|
|
|
|
async def get_current_user(
|
|
token: Annotated[str, Depends(oauth2_scheme)],
|
|
db: Session = Depends(get_sqlite_db)
|
|
) -> User:
|
|
"""获取当前用户"""
|
|
payload = decode_token(token)
|
|
if payload is None:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid or expired token",
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|
)
|
|
|
|
username: str = payload.get("sub")
|
|
if username is None:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid token payload",
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|
)
|
|
|
|
user = db.query(User).filter(User.username == username).first()
|
|
if user is None or not user.is_active:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="User not found or inactive",
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|
)
|
|
|
|
return user
|
|
|
|
|
|
@router.post("/login", response_model=ResponseData)
|
|
async def login(form_data: Annotated[OAuth2PasswordRequestForm, Depends()]):
|
|
"""
|
|
用户登录
|
|
|
|
- **username**: 用户名
|
|
- **password**: 密码
|
|
"""
|
|
user = authenticate_user(form_data.username, form_data.password)
|
|
if not user:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Incorrect username or password",
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|
)
|
|
|
|
access_token = create_access_token(
|
|
data={"sub": user.username, "user_id": user.id}
|
|
)
|
|
refresh_token = create_refresh_token(
|
|
data={"sub": user.username, "user_id": user.id}
|
|
)
|
|
|
|
return ResponseData(
|
|
code=0,
|
|
message="success",
|
|
data={
|
|
"access_token": access_token,
|
|
"refresh_token": refresh_token,
|
|
"token_type": "Bearer",
|
|
"expires_in": settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60
|
|
}
|
|
)
|
|
|
|
|
|
@router.post("/refresh", response_model=ResponseData)
|
|
async def refresh_token(
|
|
request: RefreshTokenRequest,
|
|
db: Session = Depends(get_sqlite_db)
|
|
):
|
|
"""刷新访问令牌"""
|
|
payload = decode_token(request.refresh_token)
|
|
if payload is None or payload.get("type") != "refresh":
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid or expired refresh token",
|
|
)
|
|
|
|
username = payload.get("sub")
|
|
user = db.query(User).filter(User.username == username).first()
|
|
if not user or not user.is_active:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="User not found or inactive",
|
|
)
|
|
|
|
access_token = create_access_token(
|
|
data={"sub": user.username, "user_id": user.id}
|
|
)
|
|
|
|
return ResponseData(
|
|
code=0,
|
|
message="success",
|
|
data={
|
|
"access_token": access_token,
|
|
"expires_in": settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60
|
|
}
|
|
)
|
|
|
|
|
|
@router.get("/me", response_model=ResponseData)
|
|
async def get_current_user_info(
|
|
current_user: Annotated[User, Depends(get_current_user)]
|
|
):
|
|
"""获取当前用户信息"""
|
|
return ResponseData(
|
|
code=0,
|
|
message="success",
|
|
data={
|
|
"id": current_user.id,
|
|
"username": current_user.username,
|
|
"email": current_user.email,
|
|
"role": current_user.role,
|
|
"is_active": current_user.is_active,
|
|
"created_at": current_user.created_at.isoformat()
|
|
}
|
|
)
|
|
|
|
|
|
@router.post("/api-key", response_model=ResponseData)
|
|
async def create_api_key(
|
|
request: APIKeyCreate,
|
|
current_user: Annotated[User, Depends(get_current_user)],
|
|
db: Session = Depends(get_sqlite_db)
|
|
):
|
|
"""创建 API Key"""
|
|
api_key = generate_api_key()
|
|
key_hash = hash_api_key(api_key)
|
|
|
|
from datetime import datetime, timedelta
|
|
expires_at = None
|
|
if request.expires_days:
|
|
expires_at = datetime.utcnow() + timedelta(days=request.expires_days)
|
|
|
|
db_api_key = APIKey(
|
|
user_id=current_user.id,
|
|
key_hash=key_hash,
|
|
name=request.name,
|
|
permissions=str(request.permissions) if request.permissions else None,
|
|
expires_at=expires_at
|
|
)
|
|
db.add(db_api_key)
|
|
db.commit()
|
|
db.refresh(db_api_key)
|
|
|
|
return ResponseData(
|
|
code=0,
|
|
message="success",
|
|
data={
|
|
"id": db_api_key.id,
|
|
"name": db_api_key.name,
|
|
"key": api_key, # 仅返回一次
|
|
"permissions": request.permissions,
|
|
"expires_at": db_api_key.expires_at.isoformat() if db_api_key.expires_at else None,
|
|
"is_active": db_api_key.is_active,
|
|
"created_at": db_api_key.created_at.isoformat()
|
|
}
|
|
)
|
|
|
|
|
|
@router.get("/api-keys", response_model=ResponseData)
|
|
async def list_api_keys(
|
|
current_user: Annotated[User, Depends(get_current_user)],
|
|
db: Session = Depends(get_sqlite_db)
|
|
):
|
|
"""获取当前用户的 API Key 列表"""
|
|
api_keys = db.query(APIKey).filter(
|
|
APIKey.user_id == current_user.id,
|
|
APIKey.is_active == True
|
|
).all()
|
|
|
|
return ResponseData(
|
|
code=0,
|
|
message="success",
|
|
data=[
|
|
{
|
|
"id": ak.id,
|
|
"name": ak.name,
|
|
"permissions": ak.permissions,
|
|
"expires_at": ak.expires_at.isoformat() if ak.expires_at else None,
|
|
"is_active": ak.is_active,
|
|
"created_at": ak.created_at.isoformat(),
|
|
"last_used_at": ak.last_used_at.isoformat() if ak.last_used_at else None
|
|
}
|
|
for ak in api_keys
|
|
]
|
|
)
|
|
|
|
|
|
@router.delete("/api-key/{key_id}", response_model=ResponseData)
|
|
async def revoke_api_key(
|
|
key_id: int,
|
|
current_user: Annotated[User, Depends(get_current_user)],
|
|
db: Session = Depends(get_sqlite_db)
|
|
):
|
|
"""撤销 API Key"""
|
|
api_key = db.query(APIKey).filter(
|
|
APIKey.id == key_id,
|
|
APIKey.user_id == current_user.id
|
|
).first()
|
|
|
|
if not api_key:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="API Key not found",
|
|
)
|
|
|
|
api_key.is_active = False
|
|
db.commit()
|
|
|
|
return ResponseData(
|
|
code=0,
|
|
message="success",
|
|
data={"id": key_id, "status": "revoked"}
|
|
)
|