You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

136 lines
3.7 KiB

"""
认证服务
"""
from datetime import timedelta
from sqlalchemy.orm import Session
from fastapi import HTTPException, status
from app.models.user import User
from app.core.security import verify_password, get_password_hash, create_access_token
from app.config import settings
class AuthService:
"""认证服务"""
@staticmethod
def authenticate_user(db: Session, username: str, password: str) -> User:
"""
验证用户凭据
Args:
db: 数据库会话
username: 用户名
password: 密码
Returns:
用户对象
Raises:
HTTPException: 认证失败
"""
user = db.query(User).filter(User.username == username).first()
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="用户名或密码错误"
)
if not verify_password(password, user.password_hash):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="用户名或密码错误"
)
if not user.is_active:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="用户已被禁用"
)
return user
@staticmethod
def create_user_token(user: User) -> dict:
"""
创建用户访问令牌
Args:
user: 用户对象
Returns:
Token信息
"""
access_token_expires = timedelta(hours=settings.ACCESS_TOKEN_EXPIRE_HOURS)
access_token = create_access_token(
data={"sub": user.username},
expires_delta=access_token_expires
)
return {
"access_token": access_token,
"token_type": "bearer",
"expires_in": settings.ACCESS_TOKEN_EXPIRE_HOURS * 3600
}
@staticmethod
def get_user_by_username(db: Session, username: str) -> User:
"""通过用户名获取用户"""
return db.query(User).filter(User.username == username).first()
@staticmethod
def create_user(db: Session, username: str, password: str, is_superuser: bool = False) -> User:
"""
创建新用户
Args:
db: 数据库会话
username: 用户名
password: 密码
is_superuser: 是否超级用户
Returns:
新用户对象
"""
# 检查用户名是否已存在
existing_user = db.query(User).filter(User.username == username).first()
if existing_user:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="用户名已存在"
)
# 创建新用户
user = User(
username=username,
password_hash=get_password_hash(password),
is_superuser=is_superuser
)
db.add(user)
db.commit()
db.refresh(user)
return user
@staticmethod
def change_password(db: Session, user: User, old_password: str, new_password: str):
"""
修改密码
Args:
db: 数据库会话
user: 用户对象
old_password: 旧密码
new_password: 新密码
"""
if not verify_password(old_password, user.password_hash):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="旧密码错误"
)
user.password_hash = get_password_hash(new_password)
db.commit()