""" 数据库连接和模型模块 用于连接 MySQL 数据库并操作数据 """ import os from datetime import datetime from typing import List, Optional, Dict, Any from contextlib import contextmanager from sqlalchemy import ( create_engine, Column, Integer, BigInteger, Float, String, DateTime, Boolean, Text, ForeignKey, Index, UniqueConstraint ) from sqlalchemy.orm import declarative_base, sessionmaker, Session from sqlalchemy.pool import QueuePool # 加载环境变量 from dotenv import load_dotenv load_dotenv() # 数据库配置 DB_HOST = os.getenv('DB_HOST', 'mysql') DB_PORT = os.getenv('DB_PORT', '3306') DB_NAME = os.getenv('DB_NAME', 'aguzhitou') DB_USER = os.getenv('DB_USER', 'root') DB_PASSWORD = os.getenv('DB_PASSWORD', '1qazse42W3') DATABASE_URL = f"mysql+pymysql://{DB_USER}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_NAME}" # 创建引擎 engine = create_engine( DATABASE_URL, poolclass=QueuePool, pool_size=10, max_overflow=20, pool_pre_ping=True, pool_recycle=3600, ) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) Base = declarative_base() # ==================== 数据模型 ==================== class MarketIndex(Base): """市场指数""" __tablename__ = "market_indices" __table_args__ = ( Index('idx_market_indices_code', 'code'), ) id = Column(Integer, primary_key=True, autoincrement=True) name = Column(String(100), unique=True, nullable=False) code = Column(String(50), unique=True, nullable=False) current = Column(Float, default=0) change = Column(Float, default=0) changePercent = Column('change_percent', Float, default=0) volume = Column(BigInteger, default=0) turnover = Column(BigInteger, default=0) updatedAt = Column('updated_at', DateTime, default=datetime.now, onupdate=datetime.now) createdAt = Column('created_at', DateTime, default=datetime.now) class Sector(Base): """板块信息""" __tablename__ = "sectors" __table_args__ = ( Index('idx_sectors_code', 'code'), ) id = Column(String(36), primary_key=True) name = Column(String(100), unique=True, nullable=False) code = Column(String(50), unique=True, nullable=False) updatedAt = Column('updated_at', DateTime, default=datetime.now, onupdate=datetime.now) createdAt = Column('created_at', DateTime, default=datetime.now) class SectorQuote(Base): """板块行情""" __tablename__ = "sector_quotes" __table_args__ = ( Index('idx_sector_quotes_code', 'sector_code'), Index('idx_sector_quotes_time', 'quote_time'), ) id = Column(Integer, primary_key=True, autoincrement=True) sectorCode = Column('sector_code', String(50), nullable=False) current = Column(Float, default=0) change = Column(Float, default=0) changePercent = Column('change_percent', Float, default=0) volume = Column(BigInteger, default=0) turnover = Column(BigInteger, default=0) momentumScore = Column('momentum_score', Float, default=50) rank = Column(Integer, default=0) previousRank = Column('previous_rank', Integer, default=0) quoteTime = Column('quote_time', DateTime, default=datetime.now) class Stock(Base): """股票信息""" __tablename__ = "stocks" __table_args__ = ( Index('idx_stocks_code', 'code'), Index('idx_stocks_sector', 'sector_code'), ) id = Column(String(36), primary_key=True) code = Column(String(50), unique=True, nullable=False) name = Column(String(100), nullable=False) sectorCode = Column('sector_code', String(50), nullable=True) marketCap = Column('market_cap', BigInteger, nullable=True) pe = Column(Float, nullable=True) pb = Column(Float, nullable=True) updatedAt = Column('updated_at', DateTime, default=datetime.now, onupdate=datetime.now) createdAt = Column('created_at', DateTime, default=datetime.now) class StockQuote(Base): """股票行情""" __tablename__ = "stock_quotes" __table_args__ = ( Index('idx_stock_quotes_code', 'stock_code'), Index('idx_stock_quotes_time', 'quote_time'), ) id = Column(Integer, primary_key=True, autoincrement=True) stockCode = Column('stock_code', String(50), nullable=False) price = Column(Float, default=0) open = Column(Float, default=0) high = Column(Float, default=0) low = Column(Float, default=0) preClose = Column('pre_close', Float, default=0) volume = Column(BigInteger, default=0) turnover = Column(BigInteger, default=0) changePercent = Column('change_percent', Float, default=0) turnoverRate = Column('turnover_rate', Float, nullable=True) amplitude = Column(Float, nullable=True) quoteTime = Column('quote_time', DateTime, default=datetime.now) class StockKLine(Base): """股票K线数据""" __tablename__ = "stock_klines" __table_args__ = ( UniqueConstraint('stock_code', 'period', 'date', name='uk_stock_klines'), Index('idx_stock_klines_code', 'stock_code'), Index('idx_stock_klines_date', 'date'), ) id = Column(Integer, primary_key=True, autoincrement=True) stockCode = Column('stock_code', String(50), nullable=False) period = Column(String(20), nullable=False) # day/week/month date = Column(DateTime, nullable=False) open = Column(Float, default=0) high = Column(Float, default=0) low = Column(Float, default=0) close = Column(Float, default=0) volume = Column(BigInteger, default=0) ma5 = Column(Float, nullable=True) ma10 = Column(Float, nullable=True) ma20 = Column(Float, nullable=True) ma30 = Column(Float, nullable=True) ma60 = Column(Float, nullable=True) class SectorKLine(Base): """板块K线数据""" __tablename__ = "sector_klines" __table_args__ = ( UniqueConstraint('sector_code', 'period', 'date', name='uk_sector_klines'), Index('idx_sector_klines_code', 'sector_code'), Index('idx_sector_klines_date', 'date'), ) id = Column(Integer, primary_key=True, autoincrement=True) sectorCode = Column('sector_code', String(50), nullable=False) period = Column(String(20), nullable=False) date = Column(DateTime, nullable=False) open = Column(Float, default=0) high = Column(Float, default=0) low = Column(Float, default=0) close = Column(Float, default=0) volume = Column(BigInteger, default=0) class HighLowStock(Base): """新高新低股票记录""" __tablename__ = "high_low_stocks" __table_args__ = ( Index('idx_high_low_code', 'stock_code'), Index('idx_high_low_type', 'type'), Index('idx_high_low_date', 'date'), ) id = Column(Integer, primary_key=True, autoincrement=True) stockCode = Column('stock_code', String(50), nullable=False) type = Column(String(10), nullable=False) # high/low price = Column(Float, default=0) date = Column(DateTime, nullable=False) daysToHighLow = Column('days_to_highlow', Integer, default=0) createdAt = Column('created_at', DateTime, default=datetime.now) class MomentumStock(Base): """动量股票推荐""" __tablename__ = "momentum_stocks" __table_args__ = ( Index('idx_momentum_code', 'stock_code'), Index('idx_momentum_date', 'date'), ) id = Column(Integer, primary_key=True, autoincrement=True) stockCode = Column('stock_code', String(50), nullable=False) momentumScore = Column('momentum_score', Float, default=0) tags = Column(Text, nullable=True) # JSON string volumeRatio = Column('volume_ratio', Float, default=0) breakThrough = Column('break_through', Boolean, default=False) date = Column(DateTime, nullable=False) createdAt = Column('created_at', DateTime, default=datetime.now) # ==================== 数据库操作 ==================== @contextmanager def get_db(): """获取数据库会话上下文管理器""" db = SessionLocal() try: yield db db.commit() except Exception as e: db.rollback() raise e finally: db.close() def init_db(): """初始化数据库表""" Base.metadata.create_all(bind=engine) print("数据库表初始化完成") def check_connection() -> bool: """检查数据库连接""" try: with engine.connect() as conn: conn.execute("SELECT 1") return True except Exception as e: print(f"数据库连接失败: {e}") return False # ==================== 数据操作函数 ==================== def upsert_stock(db: Session, code: str, name: str, **kwargs): """插入或更新股票信息""" stock = db.query(Stock).filter(Stock.code == code).first() if stock: stock.name = name for key, value in kwargs.items(): if hasattr(stock, key): setattr(stock, key, value) stock.updatedAt = datetime.now() else: import uuid stock = Stock( id=str(uuid.uuid4()), code=code, name=name, **kwargs ) db.add(stock) return stock def upsert_stock_quote(db: Session, stock_code: str, **data): """插入股票行情""" quote = StockQuote( stockCode=stock_code, **data ) db.add(quote) return quote def upsert_stock_kline(db: Session, stock_code: str, period: str, date: datetime, **data): """插入或更新K线数据""" kline = db.query(StockKLine).filter( StockKLine.stockCode == stock_code, StockKLine.period == period, StockKLine.date == date ).first() if kline: for key, value in data.items(): if hasattr(kline, key): setattr(kline, key, value) else: kline = StockKLine( stockCode=stock_code, period=period, date=date, **data ) db.add(kline) return kline def upsert_sector(db: Session, code: str, name: str): """插入或更新板块信息""" sector = db.query(Sector).filter(Sector.code == code).first() if sector: sector.name = name sector.updatedAt = datetime.now() else: import uuid sector = Sector( id=str(uuid.uuid4()), code=code, name=name ) db.add(sector) return sector def upsert_market_index(db: Session, code: str, name: str, **data): """插入或更新市场指数""" index = db.query(MarketIndex).filter(MarketIndex.code == code).first() if index: index.name = name for key, value in data.items(): if hasattr(index, key): setattr(index, key, value) else: index = MarketIndex( code=code, name=name, **data ) db.add(index) return index if __name__ == "__main__": # 测试数据库连接 if check_connection(): init_db() print("数据库初始化成功") else: print("数据库连接失败,请检查配置")