You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

231 lines
7.3 KiB

"""数据同步工具 - 对应Go的cmd/sync/main.go"""
import asyncio
import os
import sys
from datetime import datetime, timedelta
from argparse import ArgumentParser
# 添加项目根目录到路径
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from app.adapters import AKShareAdapter
from app.repositories import SessionLocal
from app.repositories.stock_repository import StockRepository
from app.repositories.futures_repository import FuturesRepository
from app.models import Symbol, SymbolType, TradeCalData
from app.core.logger import info, error
def parse_date(date_str: str) -> datetime:
"""解析日期字符串"""
return datetime.strptime(date_str, "%Y%m%d")
def format_date(date: datetime) -> str:
"""格式化日期为字符串"""
return date.strftime("%Y%m%d")
def is_stock(symbol: str) -> bool:
"""判断是否为股票代码"""
return symbol.endswith(".SH") or symbol.endswith(".SZ") or symbol.endswith(".BJ")
async def sync_stocks(adapter: AKShareAdapter, db):
"""同步股票基础信息"""
info("Syncing stock basic info...")
try:
symbols_data = await adapter.fetch_symbols("stock")
repo = StockRepository(db)
symbols = []
for d in symbols_data:
list_date = None
if d.list_date:
try:
list_date = datetime.strptime(d.list_date, "%Y%m%d")
except:
pass
symbols.append(Symbol(
symbol_id=d.symbol_id,
symbol_type=SymbolType.STOCK,
exchange=d.exchange,
name=d.name,
list_date=list_date,
status="active"
))
repo.save_symbols(symbols)
info(f"Synced {len(symbols)} stocks")
except Exception as e:
error(f"Failed to sync stocks: {e}")
raise
async def sync_futures(adapter: AKShareAdapter, db):
"""同步期货基础信息"""
info("Syncing futures basic info...")
try:
symbols_data = await adapter.fetch_symbols("futures")
repo = FuturesRepository(db)
symbols = []
for d in symbols_data:
list_date = None
delist_date = None
if d.list_date:
try:
list_date = datetime.strptime(d.list_date, "%Y%m%d")
except:
pass
if d.delist_date:
try:
delist_date = datetime.strptime(d.delist_date, "%Y%m%d")
except:
pass
status = "active"
if delist_date and datetime.now() > delist_date:
status = "expired"
symbols.append(Symbol(
symbol_id=d.symbol_id,
symbol_type=SymbolType.FUTURES,
exchange=d.exchange,
name=d.name,
underlying=d.underlying,
contract_month=d.contract_month,
list_date=list_date,
delist_date=delist_date,
status=status
))
repo.save_symbols(symbols)
info(f"Synced {len(symbols)} futures")
except Exception as e:
error(f"Failed to sync futures: {e}")
raise
async def sync_calendar(adapter: AKShareAdapter, db, start: str, end: str):
"""同步交易日历"""
info(f"Syncing trading calendar from {start} to {end}...")
try:
# 同步股票交易日历(上交所)
stock_data = await adapter.fetch_trading_calendar("SH", start, end)
stock_repo = StockRepository(db)
stock_dates = [
TradeCalData(date=d.date, is_trading_day=d.is_trading_day)
for d in stock_data
]
stock_repo.save_trading_calendar(stock_dates)
# 同步期货交易日历
futures_repo = FuturesRepository(db)
futures_repo.save_trading_calendar(stock_dates)
info(f"Synced {len(stock_dates)} calendar days")
except Exception as e:
error(f"Failed to sync calendar: {e}")
raise
async def sync_klines(adapter: AKShareAdapter, db, symbol: str, start: str, end: str, freq: str):
"""同步K线数据"""
info(f"Syncing {freq} klines for {symbol} from {start} to {end}...")
try:
# 获取K线数据
klines_data = await adapter.fetch_klines(symbol, start, end, freq)
# 转换为KLineItem并保存
from app.models import KLineItem
items = [
KLineItem(
time=datetime.fromtimestamp(d.time),
open=d.open,
high=d.high,
low=d.low,
close=d.close,
volume=d.volume,
amount=d.amount,
open_interest=d.open_interest if d.open_interest > 0 else None
)
for d in klines_data
]
# 判断股票还是期货并保存
from app.models import Frequency
if is_stock(symbol):
repo = StockRepository(db)
# 为每个item设置symbol
for item in items:
item.symbol = symbol
repo.save_klines(Frequency(freq), items)
else:
repo = FuturesRepository(db)
repo.save_klines(Frequency(freq), symbol, items)
info(f"Synced {len(items)} klines")
except Exception as e:
error(f"Failed to sync klines: {e}")
raise
async def main():
"""主函数"""
parser = ArgumentParser(description="Market Data Sync Tool")
parser.add_argument(
"--type", "-t",
required=True,
choices=["stocks", "futures", "calendar", "klines"],
help="同步类型"
)
parser.add_argument("--start", "-s", help="开始日期 YYYYMMDD")
parser.add_argument("--end", "-e", help="结束日期 YYYYMMDD")
parser.add_argument("--symbol", help="标的代码klines类型需要")
parser.add_argument("--freq", "-f", default="1d", help="K线周期")
args = parser.parse_args()
# 初始化适配器AKShare 无需 token
adapter = AKShareAdapter()
await adapter.connect({"timeout": 30})
# 创建数据库会话
db = SessionLocal()
try:
if args.type == "stocks":
await sync_stocks(adapter, db)
elif args.type == "futures":
await sync_futures(adapter, db)
elif args.type == "calendar":
# 设置默认日期范围
start = args.start or (datetime.now() - timedelta(days=30)).strftime("%Y%m%d")
end = args.end or (datetime.now() + timedelta(days=180)).strftime("%Y%m%d")
await sync_calendar(adapter, db, start, end)
elif args.type == "klines":
if not args.symbol:
error("symbol is required for klines sync")
sys.exit(1)
start = args.start or (datetime.now() - timedelta(days=7)).strftime("%Y%m%d")
end = args.end or datetime.now().strftime("%Y%m%d")
await sync_klines(adapter, db, args.symbol, start, end, args.freq)
finally:
db.close()
await adapter.close()
if __name__ == "__main__":
asyncio.run(main())