You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

69 lines
1.7 KiB

"""
数据库会话管理
"""
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker, Session
from typing import Generator
from app.config import settings
from app.db.base import Base
# 创建数据库引擎
if settings.DATABASE_URL.startswith("sqlite"):
engine = create_engine(
settings.DATABASE_URL,
connect_args={"check_same_thread": False},
echo=settings.DEBUG
)
else:
engine = create_engine(
settings.DATABASE_URL,
pool_pre_ping=True,
pool_size=10,
max_overflow=20,
echo=settings.DEBUG
)
# 创建会话工厂
SessionLocal = sessionmaker(
autocommit=False,
autoflush=False,
bind=engine
)
def get_db() -> Generator[Session, None, None]:
"""获取数据库会话的依赖函数"""
db = SessionLocal()
try:
yield db
finally:
db.close()
def init_db() -> None:
"""初始化数据库表"""
from app.models import user, config, stock, future, realtime, finance, cache, test
from app.models.user import User
from app.core.security import get_password_hash
Base.metadata.create_all(bind=engine)
db = SessionLocal()
try:
existing_admin = db.query(User).filter(User.username == "admin").first()
if not existing_admin:
admin_user = User(
username="admin",
password_hash=get_password_hash("admin123"),
is_active=True,
is_superuser=True
)
db.add(admin_user)
db.commit()
print("Default admin user created (admin/admin123)")
except Exception as e:
print(f"Error creating default user: {e}")
db.rollback()
finally:
db.close()