You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
136 lines
3.7 KiB
136 lines
3.7 KiB
"""
|
|
认证服务
|
|
"""
|
|
from datetime import timedelta
|
|
from sqlalchemy.orm import Session
|
|
from fastapi import HTTPException, status
|
|
|
|
from app.models.user import User
|
|
from app.core.security import verify_password, get_password_hash, create_access_token
|
|
from app.config import settings
|
|
|
|
|
|
class AuthService:
|
|
"""认证服务"""
|
|
|
|
@staticmethod
|
|
def authenticate_user(db: Session, username: str, password: str) -> User:
|
|
"""
|
|
验证用户凭据
|
|
|
|
Args:
|
|
db: 数据库会话
|
|
username: 用户名
|
|
password: 密码
|
|
|
|
Returns:
|
|
用户对象
|
|
|
|
Raises:
|
|
HTTPException: 认证失败
|
|
"""
|
|
user = db.query(User).filter(User.username == username).first()
|
|
|
|
if not user:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="用户名或密码错误"
|
|
)
|
|
|
|
if not verify_password(password, user.password_hash):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="用户名或密码错误"
|
|
)
|
|
|
|
if not user.is_active:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail="用户已被禁用"
|
|
)
|
|
|
|
return user
|
|
|
|
@staticmethod
|
|
def create_user_token(user: User) -> dict:
|
|
"""
|
|
创建用户访问令牌
|
|
|
|
Args:
|
|
user: 用户对象
|
|
|
|
Returns:
|
|
Token信息
|
|
"""
|
|
access_token_expires = timedelta(hours=settings.ACCESS_TOKEN_EXPIRE_HOURS)
|
|
access_token = create_access_token(
|
|
data={"sub": user.username},
|
|
expires_delta=access_token_expires
|
|
)
|
|
|
|
return {
|
|
"access_token": access_token,
|
|
"token_type": "bearer",
|
|
"expires_in": settings.ACCESS_TOKEN_EXPIRE_HOURS * 3600
|
|
}
|
|
|
|
@staticmethod
|
|
def get_user_by_username(db: Session, username: str) -> User:
|
|
"""通过用户名获取用户"""
|
|
return db.query(User).filter(User.username == username).first()
|
|
|
|
@staticmethod
|
|
def create_user(db: Session, username: str, password: str, is_superuser: bool = False) -> User:
|
|
"""
|
|
创建新用户
|
|
|
|
Args:
|
|
db: 数据库会话
|
|
username: 用户名
|
|
password: 密码
|
|
is_superuser: 是否超级用户
|
|
|
|
Returns:
|
|
新用户对象
|
|
"""
|
|
# 检查用户名是否已存在
|
|
existing_user = db.query(User).filter(User.username == username).first()
|
|
if existing_user:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="用户名已存在"
|
|
)
|
|
|
|
# 创建新用户
|
|
user = User(
|
|
username=username,
|
|
password_hash=get_password_hash(password),
|
|
is_superuser=is_superuser
|
|
)
|
|
|
|
db.add(user)
|
|
db.commit()
|
|
db.refresh(user)
|
|
|
|
return user
|
|
|
|
@staticmethod
|
|
def change_password(db: Session, user: User, old_password: str, new_password: str):
|
|
"""
|
|
修改密码
|
|
|
|
Args:
|
|
db: 数据库会话
|
|
user: 用户对象
|
|
old_password: 旧密码
|
|
new_password: 新密码
|
|
"""
|
|
if not verify_password(old_password, user.password_hash):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="旧密码错误"
|
|
)
|
|
|
|
user.password_hash = get_password_hash(new_password)
|
|
db.commit()
|