You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

231 lines
7.3 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

"""数据同步工具 - 对应Go的cmd/sync/main.go"""
import asyncio
import os
import sys
from datetime import datetime, timedelta
from argparse import ArgumentParser
# 添加项目根目录到路径
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from app.adapters import AKShareAdapter
from app.repositories import SessionLocal
from app.repositories.stock_repository import StockRepository
from app.repositories.futures_repository import FuturesRepository
from app.models import Symbol, SymbolType, TradeCalData
from app.core.logger import info, error
def parse_date(date_str: str) -> datetime:
"""解析日期字符串"""
return datetime.strptime(date_str, "%Y%m%d")
def format_date(date: datetime) -> str:
"""格式化日期为字符串"""
return date.strftime("%Y%m%d")
def is_stock(symbol: str) -> bool:
"""判断是否为股票代码"""
return symbol.endswith(".SH") or symbol.endswith(".SZ") or symbol.endswith(".BJ")
async def sync_stocks(adapter: AKShareAdapter, db):
"""同步股票基础信息"""
info("Syncing stock basic info...")
try:
symbols_data = await adapter.fetch_symbols("stock")
repo = StockRepository(db)
symbols = []
for d in symbols_data:
list_date = None
if d.list_date:
try:
list_date = datetime.strptime(d.list_date, "%Y%m%d")
except:
pass
symbols.append(Symbol(
symbol_id=d.symbol_id,
symbol_type=SymbolType.STOCK,
exchange=d.exchange,
name=d.name,
list_date=list_date,
status="active"
))
repo.save_symbols(symbols)
info(f"Synced {len(symbols)} stocks")
except Exception as e:
error(f"Failed to sync stocks: {e}")
raise
async def sync_futures(adapter: AKShareAdapter, db):
"""同步期货基础信息"""
info("Syncing futures basic info...")
try:
symbols_data = await adapter.fetch_symbols("futures")
repo = FuturesRepository(db)
symbols = []
for d in symbols_data:
list_date = None
delist_date = None
if d.list_date:
try:
list_date = datetime.strptime(d.list_date, "%Y%m%d")
except:
pass
if d.delist_date:
try:
delist_date = datetime.strptime(d.delist_date, "%Y%m%d")
except:
pass
status = "active"
if delist_date and datetime.now() > delist_date:
status = "expired"
symbols.append(Symbol(
symbol_id=d.symbol_id,
symbol_type=SymbolType.FUTURES,
exchange=d.exchange,
name=d.name,
underlying=d.underlying,
contract_month=d.contract_month,
list_date=list_date,
delist_date=delist_date,
status=status
))
repo.save_symbols(symbols)
info(f"Synced {len(symbols)} futures")
except Exception as e:
error(f"Failed to sync futures: {e}")
raise
async def sync_calendar(adapter: AKShareAdapter, db, start: str, end: str):
"""同步交易日历"""
info(f"Syncing trading calendar from {start} to {end}...")
try:
# 同步股票交易日历(上交所)
stock_data = await adapter.fetch_trading_calendar("SH", start, end)
stock_repo = StockRepository(db)
stock_dates = [
TradeCalData(date=d.date, is_trading_day=d.is_trading_day)
for d in stock_data
]
stock_repo.save_trading_calendar(stock_dates)
# 同步期货交易日历
futures_repo = FuturesRepository(db)
futures_repo.save_trading_calendar(stock_dates)
info(f"Synced {len(stock_dates)} calendar days")
except Exception as e:
error(f"Failed to sync calendar: {e}")
raise
async def sync_klines(adapter: AKShareAdapter, db, symbol: str, start: str, end: str, freq: str):
"""同步K线数据"""
info(f"Syncing {freq} klines for {symbol} from {start} to {end}...")
try:
# 获取K线数据
klines_data = await adapter.fetch_klines(symbol, start, end, freq)
# 转换为KLineItem并保存
from app.models import KLineItem
items = [
KLineItem(
time=datetime.fromtimestamp(d.time),
open=d.open,
high=d.high,
low=d.low,
close=d.close,
volume=d.volume,
amount=d.amount,
open_interest=d.open_interest if d.open_interest > 0 else None
)
for d in klines_data
]
# 判断股票还是期货并保存
from app.models import Frequency
if is_stock(symbol):
repo = StockRepository(db)
# 为每个item设置symbol
for item in items:
item.symbol = symbol
repo.save_klines(Frequency(freq), items)
else:
repo = FuturesRepository(db)
repo.save_klines(Frequency(freq), symbol, items)
info(f"Synced {len(items)} klines")
except Exception as e:
error(f"Failed to sync klines: {e}")
raise
async def main():
"""主函数"""
parser = ArgumentParser(description="Market Data Sync Tool")
parser.add_argument(
"--type", "-t",
required=True,
choices=["stocks", "futures", "calendar", "klines"],
help="同步类型"
)
parser.add_argument("--start", "-s", help="开始日期 YYYYMMDD")
parser.add_argument("--end", "-e", help="结束日期 YYYYMMDD")
parser.add_argument("--symbol", help="标的代码klines类型需要")
parser.add_argument("--freq", "-f", default="1d", help="K线周期")
args = parser.parse_args()
# 初始化适配器AKShare 无需 token
adapter = AKShareAdapter()
await adapter.connect({"timeout": 30})
# 创建数据库会话
db = SessionLocal()
try:
if args.type == "stocks":
await sync_stocks(adapter, db)
elif args.type == "futures":
await sync_futures(adapter, db)
elif args.type == "calendar":
# 设置默认日期范围
start = args.start or (datetime.now() - timedelta(days=30)).strftime("%Y%m%d")
end = args.end or (datetime.now() + timedelta(days=180)).strftime("%Y%m%d")
await sync_calendar(adapter, db, start, end)
elif args.type == "klines":
if not args.symbol:
error("symbol is required for klines sync")
sys.exit(1)
start = args.start or (datetime.now() - timedelta(days=7)).strftime("%Y%m%d")
end = args.end or datetime.now().strftime("%Y%m%d")
await sync_klines(adapter, db, args.symbol, start, end, args.freq)
finally:
db.close()
await adapter.close()
if __name__ == "__main__":
asyncio.run(main())