You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

89 lines
2.4 KiB

"""
数据库会话管理
"""
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker, Session
from typing import Generator
from app.config import settings
from app.db.base import Base
# 确保使用SQLite作为默认数据库
database_url = settings.DATABASE_URL or "sqlite:///./amazing_data.db"
# 创建数据库引擎
try:
if database_url.startswith("sqlite"):
engine = create_engine(
database_url,
connect_args={"check_same_thread": False},
echo=settings.DEBUG
)
else:
engine = create_engine(
database_url,
pool_pre_ping=True,
pool_size=10,
max_overflow=20,
echo=settings.DEBUG
)
# 测试连接
with engine.connect() as conn:
pass
except Exception as e:
print(f"数据库连接失败: {e}")
print("使用SQLite作为备选数据库...")
# 使用SQLite作为备选
engine = create_engine(
"sqlite:///./amazing_data.db",
connect_args={"check_same_thread": False},
echo=settings.DEBUG
)
# 创建会话工厂
SessionLocal = sessionmaker(
autocommit=False,
autoflush=False,
bind=engine
)
def get_db() -> Generator[Session, None, None]:
"""获取数据库会话的依赖函数"""
db = SessionLocal()
try:
yield db
finally:
db.close()
def init_db() -> None:
"""初始化数据库表"""
from app.models import user, config, stock, future, realtime, finance, cache, test
from app.models.user import User
from app.models.stock_basic import StockBasic, IndexBasic, IndexTrade
from app.core.security import get_password_hash
print("开始创建数据库表...")
Base.metadata.create_all(bind=engine)
print("数据库表创建完成")
db = SessionLocal()
try:
existing_admin = db.query(User).filter(User.username == "admin").first()
if not existing_admin:
admin_user = User(
username="admin",
password_hash=get_password_hash("admin123"),
is_active=True,
is_superuser=True
)
db.add(admin_user)
db.commit()
print("Default admin user created (admin/admin123)")
except Exception as e:
print(f"Error creating default user: {e}")
db.rollback()
finally:
db.close()