You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
354 lines
11 KiB
354 lines
11 KiB
"""
|
|
数据库连接和模型模块
|
|
用于连接 MySQL 数据库并操作数据
|
|
"""
|
|
import os
|
|
from datetime import datetime
|
|
from typing import List, Optional, Dict, Any
|
|
from contextlib import contextmanager
|
|
|
|
from sqlalchemy import (
|
|
create_engine, Column, Integer, BigInteger, Float, String,
|
|
DateTime, Boolean, Text, ForeignKey, Index, UniqueConstraint
|
|
)
|
|
from sqlalchemy.orm import declarative_base, sessionmaker, Session
|
|
from sqlalchemy.pool import QueuePool
|
|
|
|
# 加载环境变量
|
|
from dotenv import load_dotenv
|
|
load_dotenv()
|
|
|
|
# 数据库配置
|
|
DB_HOST = os.getenv('DB_HOST', 'mysql')
|
|
DB_PORT = os.getenv('DB_PORT', '3306')
|
|
DB_NAME = os.getenv('DB_NAME', 'aguzhitou')
|
|
DB_USER = os.getenv('DB_USER', 'root')
|
|
DB_PASSWORD = os.getenv('DB_PASSWORD', '1qazse42W3')
|
|
|
|
DATABASE_URL = f"mysql+pymysql://{DB_USER}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_NAME}"
|
|
|
|
# 创建引擎
|
|
engine = create_engine(
|
|
DATABASE_URL,
|
|
poolclass=QueuePool,
|
|
pool_size=10,
|
|
max_overflow=20,
|
|
pool_pre_ping=True,
|
|
pool_recycle=3600,
|
|
)
|
|
|
|
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
|
Base = declarative_base()
|
|
|
|
|
|
# ==================== 数据模型 ====================
|
|
|
|
class MarketIndex(Base):
|
|
"""市场指数"""
|
|
__tablename__ = "market_indices"
|
|
__table_args__ = (
|
|
Index('idx_market_indices_code', 'code'),
|
|
)
|
|
|
|
id = Column(Integer, primary_key=True, autoincrement=True)
|
|
name = Column(String(100), unique=True, nullable=False)
|
|
code = Column(String(50), unique=True, nullable=False)
|
|
current = Column(Float, default=0)
|
|
change = Column(Float, default=0)
|
|
changePercent = Column('change_percent', Float, default=0)
|
|
volume = Column(BigInteger, default=0)
|
|
turnover = Column(BigInteger, default=0)
|
|
updatedAt = Column('updated_at', DateTime, default=datetime.now, onupdate=datetime.now)
|
|
createdAt = Column('created_at', DateTime, default=datetime.now)
|
|
|
|
|
|
class Sector(Base):
|
|
"""板块信息"""
|
|
__tablename__ = "sectors"
|
|
__table_args__ = (
|
|
Index('idx_sectors_code', 'code'),
|
|
)
|
|
|
|
id = Column(String(36), primary_key=True)
|
|
name = Column(String(100), unique=True, nullable=False)
|
|
code = Column(String(50), unique=True, nullable=False)
|
|
updatedAt = Column('updated_at', DateTime, default=datetime.now, onupdate=datetime.now)
|
|
createdAt = Column('created_at', DateTime, default=datetime.now)
|
|
|
|
|
|
class SectorQuote(Base):
|
|
"""板块行情"""
|
|
__tablename__ = "sector_quotes"
|
|
__table_args__ = (
|
|
Index('idx_sector_quotes_code', 'sector_code'),
|
|
Index('idx_sector_quotes_time', 'quote_time'),
|
|
)
|
|
|
|
id = Column(Integer, primary_key=True, autoincrement=True)
|
|
sectorCode = Column('sector_code', String(50), nullable=False)
|
|
current = Column(Float, default=0)
|
|
change = Column(Float, default=0)
|
|
changePercent = Column('change_percent', Float, default=0)
|
|
volume = Column(BigInteger, default=0)
|
|
turnover = Column(BigInteger, default=0)
|
|
momentumScore = Column('momentum_score', Float, default=50)
|
|
rank = Column(Integer, default=0)
|
|
previousRank = Column('previous_rank', Integer, default=0)
|
|
quoteTime = Column('quote_time', DateTime, default=datetime.now)
|
|
|
|
|
|
class Stock(Base):
|
|
"""股票信息"""
|
|
__tablename__ = "stocks"
|
|
__table_args__ = (
|
|
Index('idx_stocks_code', 'code'),
|
|
Index('idx_stocks_sector', 'sector_code'),
|
|
)
|
|
|
|
id = Column(String(36), primary_key=True)
|
|
code = Column(String(50), unique=True, nullable=False)
|
|
name = Column(String(100), nullable=False)
|
|
sectorCode = Column('sector_code', String(50), nullable=True)
|
|
marketCap = Column('market_cap', BigInteger, nullable=True)
|
|
pe = Column(Float, nullable=True)
|
|
pb = Column(Float, nullable=True)
|
|
updatedAt = Column('updated_at', DateTime, default=datetime.now, onupdate=datetime.now)
|
|
createdAt = Column('created_at', DateTime, default=datetime.now)
|
|
|
|
|
|
class StockQuote(Base):
|
|
"""股票行情"""
|
|
__tablename__ = "stock_quotes"
|
|
__table_args__ = (
|
|
Index('idx_stock_quotes_code', 'stock_code'),
|
|
Index('idx_stock_quotes_time', 'quote_time'),
|
|
)
|
|
|
|
id = Column(Integer, primary_key=True, autoincrement=True)
|
|
stockCode = Column('stock_code', String(50), nullable=False)
|
|
price = Column(Float, default=0)
|
|
open = Column(Float, default=0)
|
|
high = Column(Float, default=0)
|
|
low = Column(Float, default=0)
|
|
preClose = Column('pre_close', Float, default=0)
|
|
volume = Column(BigInteger, default=0)
|
|
turnover = Column(BigInteger, default=0)
|
|
changePercent = Column('change_percent', Float, default=0)
|
|
turnoverRate = Column('turnover_rate', Float, nullable=True)
|
|
amplitude = Column(Float, nullable=True)
|
|
quoteTime = Column('quote_time', DateTime, default=datetime.now)
|
|
|
|
|
|
class StockKLine(Base):
|
|
"""股票K线数据"""
|
|
__tablename__ = "stock_klines"
|
|
__table_args__ = (
|
|
UniqueConstraint('stock_code', 'period', 'date', name='uk_stock_klines'),
|
|
Index('idx_stock_klines_code', 'stock_code'),
|
|
Index('idx_stock_klines_date', 'date'),
|
|
)
|
|
|
|
id = Column(Integer, primary_key=True, autoincrement=True)
|
|
stockCode = Column('stock_code', String(50), nullable=False)
|
|
period = Column(String(20), nullable=False) # day/week/month
|
|
date = Column(DateTime, nullable=False)
|
|
open = Column(Float, default=0)
|
|
high = Column(Float, default=0)
|
|
low = Column(Float, default=0)
|
|
close = Column(Float, default=0)
|
|
volume = Column(BigInteger, default=0)
|
|
ma5 = Column(Float, nullable=True)
|
|
ma10 = Column(Float, nullable=True)
|
|
ma20 = Column(Float, nullable=True)
|
|
ma30 = Column(Float, nullable=True)
|
|
ma60 = Column(Float, nullable=True)
|
|
|
|
|
|
class SectorKLine(Base):
|
|
"""板块K线数据"""
|
|
__tablename__ = "sector_klines"
|
|
__table_args__ = (
|
|
UniqueConstraint('sector_code', 'period', 'date', name='uk_sector_klines'),
|
|
Index('idx_sector_klines_code', 'sector_code'),
|
|
Index('idx_sector_klines_date', 'date'),
|
|
)
|
|
|
|
id = Column(Integer, primary_key=True, autoincrement=True)
|
|
sectorCode = Column('sector_code', String(50), nullable=False)
|
|
period = Column(String(20), nullable=False)
|
|
date = Column(DateTime, nullable=False)
|
|
open = Column(Float, default=0)
|
|
high = Column(Float, default=0)
|
|
low = Column(Float, default=0)
|
|
close = Column(Float, default=0)
|
|
volume = Column(BigInteger, default=0)
|
|
|
|
|
|
class HighLowStock(Base):
|
|
"""新高新低股票记录"""
|
|
__tablename__ = "high_low_stocks"
|
|
__table_args__ = (
|
|
Index('idx_high_low_code', 'stock_code'),
|
|
Index('idx_high_low_type', 'type'),
|
|
Index('idx_high_low_date', 'date'),
|
|
)
|
|
|
|
id = Column(Integer, primary_key=True, autoincrement=True)
|
|
stockCode = Column('stock_code', String(50), nullable=False)
|
|
type = Column(String(10), nullable=False) # high/low
|
|
price = Column(Float, default=0)
|
|
date = Column(DateTime, nullable=False)
|
|
daysToHighLow = Column('days_to_highlow', Integer, default=0)
|
|
createdAt = Column('created_at', DateTime, default=datetime.now)
|
|
|
|
|
|
class MomentumStock(Base):
|
|
"""动量股票推荐"""
|
|
__tablename__ = "momentum_stocks"
|
|
__table_args__ = (
|
|
Index('idx_momentum_code', 'stock_code'),
|
|
Index('idx_momentum_date', 'date'),
|
|
)
|
|
|
|
id = Column(Integer, primary_key=True, autoincrement=True)
|
|
stockCode = Column('stock_code', String(50), nullable=False)
|
|
momentumScore = Column('momentum_score', Float, default=0)
|
|
tags = Column(Text, nullable=True) # JSON string
|
|
volumeRatio = Column('volume_ratio', Float, default=0)
|
|
breakThrough = Column('break_through', Boolean, default=False)
|
|
date = Column(DateTime, nullable=False)
|
|
createdAt = Column('created_at', DateTime, default=datetime.now)
|
|
|
|
|
|
# ==================== 数据库操作 ====================
|
|
|
|
@contextmanager
|
|
def get_db():
|
|
"""获取数据库会话上下文管理器"""
|
|
db = SessionLocal()
|
|
try:
|
|
yield db
|
|
db.commit()
|
|
except Exception as e:
|
|
db.rollback()
|
|
raise e
|
|
finally:
|
|
db.close()
|
|
|
|
|
|
def init_db():
|
|
"""初始化数据库表"""
|
|
Base.metadata.create_all(bind=engine)
|
|
print("数据库表初始化完成")
|
|
|
|
|
|
def check_connection() -> bool:
|
|
"""检查数据库连接"""
|
|
try:
|
|
with engine.connect() as conn:
|
|
conn.execute("SELECT 1")
|
|
return True
|
|
except Exception as e:
|
|
print(f"数据库连接失败: {e}")
|
|
return False
|
|
|
|
|
|
# ==================== 数据操作函数 ====================
|
|
|
|
def upsert_stock(db: Session, code: str, name: str, **kwargs):
|
|
"""插入或更新股票信息"""
|
|
stock = db.query(Stock).filter(Stock.code == code).first()
|
|
if stock:
|
|
stock.name = name
|
|
for key, value in kwargs.items():
|
|
if hasattr(stock, key):
|
|
setattr(stock, key, value)
|
|
stock.updatedAt = datetime.now()
|
|
else:
|
|
import uuid
|
|
stock = Stock(
|
|
id=str(uuid.uuid4()),
|
|
code=code,
|
|
name=name,
|
|
**kwargs
|
|
)
|
|
db.add(stock)
|
|
return stock
|
|
|
|
|
|
def upsert_stock_quote(db: Session, stock_code: str, **data):
|
|
"""插入股票行情"""
|
|
quote = StockQuote(
|
|
stockCode=stock_code,
|
|
**data
|
|
)
|
|
db.add(quote)
|
|
return quote
|
|
|
|
|
|
def upsert_stock_kline(db: Session, stock_code: str, period: str, date: datetime, **data):
|
|
"""插入或更新K线数据"""
|
|
kline = db.query(StockKLine).filter(
|
|
StockKLine.stockCode == stock_code,
|
|
StockKLine.period == period,
|
|
StockKLine.date == date
|
|
).first()
|
|
|
|
if kline:
|
|
for key, value in data.items():
|
|
if hasattr(kline, key):
|
|
setattr(kline, key, value)
|
|
else:
|
|
kline = StockKLine(
|
|
stockCode=stock_code,
|
|
period=period,
|
|
date=date,
|
|
**data
|
|
)
|
|
db.add(kline)
|
|
return kline
|
|
|
|
|
|
def upsert_sector(db: Session, code: str, name: str):
|
|
"""插入或更新板块信息"""
|
|
sector = db.query(Sector).filter(Sector.code == code).first()
|
|
if sector:
|
|
sector.name = name
|
|
sector.updatedAt = datetime.now()
|
|
else:
|
|
import uuid
|
|
sector = Sector(
|
|
id=str(uuid.uuid4()),
|
|
code=code,
|
|
name=name
|
|
)
|
|
db.add(sector)
|
|
return sector
|
|
|
|
|
|
def upsert_market_index(db: Session, code: str, name: str, **data):
|
|
"""插入或更新市场指数"""
|
|
index = db.query(MarketIndex).filter(MarketIndex.code == code).first()
|
|
if index:
|
|
index.name = name
|
|
for key, value in data.items():
|
|
if hasattr(index, key):
|
|
setattr(index, key, value)
|
|
else:
|
|
index = MarketIndex(
|
|
code=code,
|
|
name=name,
|
|
**data
|
|
)
|
|
db.add(index)
|
|
return index
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# 测试数据库连接
|
|
if check_connection():
|
|
init_db()
|
|
print("数据库初始化成功")
|
|
else:
|
|
print("数据库连接失败,请检查配置")
|