You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

354 lines
11 KiB

"""
数据库连接和模型模块
用于连接 MySQL 数据库并操作数据
"""
import os
from datetime import datetime
from typing import List, Optional, Dict, Any
from contextlib import contextmanager
from sqlalchemy import (
create_engine, Column, Integer, BigInteger, Float, String,
DateTime, Boolean, Text, ForeignKey, Index, UniqueConstraint
)
from sqlalchemy.orm import declarative_base, sessionmaker, Session
from sqlalchemy.pool import QueuePool
# 加载环境变量
from dotenv import load_dotenv
load_dotenv()
# 数据库配置
DB_HOST = os.getenv('DB_HOST', 'mysql')
DB_PORT = os.getenv('DB_PORT', '3306')
DB_NAME = os.getenv('DB_NAME', 'aguzhitou')
DB_USER = os.getenv('DB_USER', 'root')
DB_PASSWORD = os.getenv('DB_PASSWORD', '1qazse42W3')
DATABASE_URL = f"mysql+pymysql://{DB_USER}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_NAME}"
# 创建引擎
engine = create_engine(
DATABASE_URL,
poolclass=QueuePool,
pool_size=10,
max_overflow=20,
pool_pre_ping=True,
pool_recycle=3600,
)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base()
# ==================== 数据模型 ====================
class MarketIndex(Base):
"""市场指数"""
__tablename__ = "market_indices"
__table_args__ = (
Index('idx_market_indices_code', 'code'),
)
id = Column(Integer, primary_key=True, autoincrement=True)
name = Column(String(100), unique=True, nullable=False)
code = Column(String(50), unique=True, nullable=False)
current = Column(Float, default=0)
change = Column(Float, default=0)
changePercent = Column('change_percent', Float, default=0)
volume = Column(BigInteger, default=0)
turnover = Column(BigInteger, default=0)
updatedAt = Column('updated_at', DateTime, default=datetime.now, onupdate=datetime.now)
createdAt = Column('created_at', DateTime, default=datetime.now)
class Sector(Base):
"""板块信息"""
__tablename__ = "sectors"
__table_args__ = (
Index('idx_sectors_code', 'code'),
)
id = Column(String(36), primary_key=True)
name = Column(String(100), unique=True, nullable=False)
code = Column(String(50), unique=True, nullable=False)
updatedAt = Column('updated_at', DateTime, default=datetime.now, onupdate=datetime.now)
createdAt = Column('created_at', DateTime, default=datetime.now)
class SectorQuote(Base):
"""板块行情"""
__tablename__ = "sector_quotes"
__table_args__ = (
Index('idx_sector_quotes_code', 'sector_code'),
Index('idx_sector_quotes_time', 'quote_time'),
)
id = Column(Integer, primary_key=True, autoincrement=True)
sectorCode = Column('sector_code', String(50), nullable=False)
current = Column(Float, default=0)
change = Column(Float, default=0)
changePercent = Column('change_percent', Float, default=0)
volume = Column(BigInteger, default=0)
turnover = Column(BigInteger, default=0)
momentumScore = Column('momentum_score', Float, default=50)
rank = Column(Integer, default=0)
previousRank = Column('previous_rank', Integer, default=0)
quoteTime = Column('quote_time', DateTime, default=datetime.now)
class Stock(Base):
"""股票信息"""
__tablename__ = "stocks"
__table_args__ = (
Index('idx_stocks_code', 'code'),
Index('idx_stocks_sector', 'sector_code'),
)
id = Column(String(36), primary_key=True)
code = Column(String(50), unique=True, nullable=False)
name = Column(String(100), nullable=False)
sectorCode = Column('sector_code', String(50), nullable=True)
marketCap = Column('market_cap', BigInteger, nullable=True)
pe = Column(Float, nullable=True)
pb = Column(Float, nullable=True)
updatedAt = Column('updated_at', DateTime, default=datetime.now, onupdate=datetime.now)
createdAt = Column('created_at', DateTime, default=datetime.now)
class StockQuote(Base):
"""股票行情"""
__tablename__ = "stock_quotes"
__table_args__ = (
Index('idx_stock_quotes_code', 'stock_code'),
Index('idx_stock_quotes_time', 'quote_time'),
)
id = Column(Integer, primary_key=True, autoincrement=True)
stockCode = Column('stock_code', String(50), nullable=False)
price = Column(Float, default=0)
open = Column(Float, default=0)
high = Column(Float, default=0)
low = Column(Float, default=0)
preClose = Column('pre_close', Float, default=0)
volume = Column(BigInteger, default=0)
turnover = Column(BigInteger, default=0)
changePercent = Column('change_percent', Float, default=0)
turnoverRate = Column('turnover_rate', Float, nullable=True)
amplitude = Column(Float, nullable=True)
quoteTime = Column('quote_time', DateTime, default=datetime.now)
class StockKLine(Base):
"""股票K线数据"""
__tablename__ = "stock_klines"
__table_args__ = (
UniqueConstraint('stock_code', 'period', 'date', name='uk_stock_klines'),
Index('idx_stock_klines_code', 'stock_code'),
Index('idx_stock_klines_date', 'date'),
)
id = Column(Integer, primary_key=True, autoincrement=True)
stockCode = Column('stock_code', String(50), nullable=False)
period = Column(String(20), nullable=False) # day/week/month
date = Column(DateTime, nullable=False)
open = Column(Float, default=0)
high = Column(Float, default=0)
low = Column(Float, default=0)
close = Column(Float, default=0)
volume = Column(BigInteger, default=0)
ma5 = Column(Float, nullable=True)
ma10 = Column(Float, nullable=True)
ma20 = Column(Float, nullable=True)
ma30 = Column(Float, nullable=True)
ma60 = Column(Float, nullable=True)
class SectorKLine(Base):
"""板块K线数据"""
__tablename__ = "sector_klines"
__table_args__ = (
UniqueConstraint('sector_code', 'period', 'date', name='uk_sector_klines'),
Index('idx_sector_klines_code', 'sector_code'),
Index('idx_sector_klines_date', 'date'),
)
id = Column(Integer, primary_key=True, autoincrement=True)
sectorCode = Column('sector_code', String(50), nullable=False)
period = Column(String(20), nullable=False)
date = Column(DateTime, nullable=False)
open = Column(Float, default=0)
high = Column(Float, default=0)
low = Column(Float, default=0)
close = Column(Float, default=0)
volume = Column(BigInteger, default=0)
class HighLowStock(Base):
"""新高新低股票记录"""
__tablename__ = "high_low_stocks"
__table_args__ = (
Index('idx_high_low_code', 'stock_code'),
Index('idx_high_low_type', 'type'),
Index('idx_high_low_date', 'date'),
)
id = Column(Integer, primary_key=True, autoincrement=True)
stockCode = Column('stock_code', String(50), nullable=False)
type = Column(String(10), nullable=False) # high/low
price = Column(Float, default=0)
date = Column(DateTime, nullable=False)
daysToHighLow = Column('days_to_highlow', Integer, default=0)
createdAt = Column('created_at', DateTime, default=datetime.now)
class MomentumStock(Base):
"""动量股票推荐"""
__tablename__ = "momentum_stocks"
__table_args__ = (
Index('idx_momentum_code', 'stock_code'),
Index('idx_momentum_date', 'date'),
)
id = Column(Integer, primary_key=True, autoincrement=True)
stockCode = Column('stock_code', String(50), nullable=False)
momentumScore = Column('momentum_score', Float, default=0)
tags = Column(Text, nullable=True) # JSON string
volumeRatio = Column('volume_ratio', Float, default=0)
breakThrough = Column('break_through', Boolean, default=False)
date = Column(DateTime, nullable=False)
createdAt = Column('created_at', DateTime, default=datetime.now)
# ==================== 数据库操作 ====================
@contextmanager
def get_db():
"""获取数据库会话上下文管理器"""
db = SessionLocal()
try:
yield db
db.commit()
except Exception as e:
db.rollback()
raise e
finally:
db.close()
def init_db():
"""初始化数据库表"""
Base.metadata.create_all(bind=engine)
print("数据库表初始化完成")
def check_connection() -> bool:
"""检查数据库连接"""
try:
with engine.connect() as conn:
conn.execute("SELECT 1")
return True
except Exception as e:
print(f"数据库连接失败: {e}")
return False
# ==================== 数据操作函数 ====================
def upsert_stock(db: Session, code: str, name: str, **kwargs):
"""插入或更新股票信息"""
stock = db.query(Stock).filter(Stock.code == code).first()
if stock:
stock.name = name
for key, value in kwargs.items():
if hasattr(stock, key):
setattr(stock, key, value)
stock.updatedAt = datetime.now()
else:
import uuid
stock = Stock(
id=str(uuid.uuid4()),
code=code,
name=name,
**kwargs
)
db.add(stock)
return stock
def upsert_stock_quote(db: Session, stock_code: str, **data):
"""插入股票行情"""
quote = StockQuote(
stockCode=stock_code,
**data
)
db.add(quote)
return quote
def upsert_stock_kline(db: Session, stock_code: str, period: str, date: datetime, **data):
"""插入或更新K线数据"""
kline = db.query(StockKLine).filter(
StockKLine.stockCode == stock_code,
StockKLine.period == period,
StockKLine.date == date
).first()
if kline:
for key, value in data.items():
if hasattr(kline, key):
setattr(kline, key, value)
else:
kline = StockKLine(
stockCode=stock_code,
period=period,
date=date,
**data
)
db.add(kline)
return kline
def upsert_sector(db: Session, code: str, name: str):
"""插入或更新板块信息"""
sector = db.query(Sector).filter(Sector.code == code).first()
if sector:
sector.name = name
sector.updatedAt = datetime.now()
else:
import uuid
sector = Sector(
id=str(uuid.uuid4()),
code=code,
name=name
)
db.add(sector)
return sector
def upsert_market_index(db: Session, code: str, name: str, **data):
"""插入或更新市场指数"""
index = db.query(MarketIndex).filter(MarketIndex.code == code).first()
if index:
index.name = name
for key, value in data.items():
if hasattr(index, key):
setattr(index, key, value)
else:
index = MarketIndex(
code=code,
name=name,
**data
)
db.add(index)
return index
if __name__ == "__main__":
# 测试数据库连接
if check_connection():
init_db()
print("数据库初始化成功")
else:
print("数据库连接失败,请检查配置")