You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
227 lines
8.8 KiB
227 lines
8.8 KiB
|
2 months ago
|
"""
|
||
|
|
AmazingData 数据服务平台 - 实时订阅 API
|
||
|
|
"""
|
||
|
|
|
||
|
|
from fastapi import APIRouter, Depends, HTTPException
|
||
|
|
from sqlalchemy.orm import Session
|
||
|
|
from typing import Optional, List
|
||
|
|
from datetime import datetime
|
||
|
|
from backend.models.database import get_db
|
||
|
|
from backend.models.schemas import (
|
||
|
|
BaseResponse, SubscribeRequest, SubscribeResponse, TaskStatus
|
||
|
|
)
|
||
|
|
from backend.models.tables import SubscriptionTask, User
|
||
|
|
from backend.auth.dependencies import get_current_user
|
||
|
|
from backend.services.data_service import data_service
|
||
|
|
|
||
|
|
router = APIRouter()
|
||
|
|
|
||
|
|
# 存储运行中的订阅任务
|
||
|
|
active_subscriptions = {}
|
||
|
|
|
||
|
|
|
||
|
|
@router.post("/subscribe", response_model=BaseResponse)
|
||
|
|
async def subscribe_kline(
|
||
|
|
request: SubscribeRequest,
|
||
|
|
db: Session = Depends(get_db),
|
||
|
|
current_user: Optional[User] = Depends(get_current_user)
|
||
|
|
):
|
||
|
|
"""启动实时K线订阅任务"""
|
||
|
|
try:
|
||
|
|
import AmazingData as ad
|
||
|
|
import threading
|
||
|
|
import os
|
||
|
|
import json
|
||
|
|
|
||
|
|
# 创建任务记录
|
||
|
|
task_name = request.task_name or f"subscribe_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
||
|
|
|
||
|
|
task = SubscriptionTask(
|
||
|
|
task_name=task_name,
|
||
|
|
codes=[c.value if hasattr(c, 'value') else c for c in request.codes],
|
||
|
|
periods=[p.value if hasattr(p, 'value') else p for p in request.periods],
|
||
|
|
save_path=request.save_path,
|
||
|
|
duration=request.duration,
|
||
|
|
save_interval=request.save_interval,
|
||
|
|
status="running",
|
||
|
|
started_at=datetime.utcnow(),
|
||
|
|
created_by=current_user.username if current_user else "anonymous"
|
||
|
|
)
|
||
|
|
db.add(task)
|
||
|
|
db.commit()
|
||
|
|
db.refresh(task)
|
||
|
|
|
||
|
|
# 启动订阅线程
|
||
|
|
def subscribe_worker():
|
||
|
|
"""订阅工作线程"""
|
||
|
|
try:
|
||
|
|
# 登录
|
||
|
|
ad.login(
|
||
|
|
username=data_service.AMAZING_DATA_USERNAME if hasattr(data_service, 'AMAZING_DATA_USERNAME') else "11200008169",
|
||
|
|
password=data_service.AMAZING_DATA_PASSWORD if hasattr(data_service, 'AMAZING_DATA_PASSWORD') else "11200008169@2026",
|
||
|
|
host=data_service.AMAZING_DATA_HOST if hasattr(data_service, 'AMAZING_DATA_HOST') else "140.206.44.234",
|
||
|
|
port=data_service.AMAZING_DATA_PORT if hasattr(data_service, 'AMAZING_DATA_PORT') else 8600
|
||
|
|
)
|
||
|
|
|
||
|
|
save_path = request.save_path or "./data/realtime"
|
||
|
|
os.makedirs(save_path, exist_ok=True)
|
||
|
|
|
||
|
|
# 周期映射
|
||
|
|
period_map = {
|
||
|
|
"min1": ad.constant.Period.min1.value if hasattr(ad.constant, 'Period') else 1,
|
||
|
|
"min5": ad.constant.Period.min5.value if hasattr(ad.constant, 'Period') else 5,
|
||
|
|
"min15": ad.constant.Period.min15.value if hasattr(ad.constant, 'Period') else 15,
|
||
|
|
"min30": ad.constant.Period.min30.value if hasattr(ad.constant, 'Period') else 30,
|
||
|
|
"min60": ad.constant.Period.min60.value if hasattr(ad.constant, 'Period') else 60,
|
||
|
|
}
|
||
|
|
|
||
|
|
# 为每个品种和周期创建订阅
|
||
|
|
for code in request.codes:
|
||
|
|
code_val = code.value if hasattr(code, 'value') else code
|
||
|
|
for period in request.periods:
|
||
|
|
period_val = period.value if hasattr(period, 'value') else period
|
||
|
|
period_value = period_map.get(period_val, 5)
|
||
|
|
|
||
|
|
def on_data(data):
|
||
|
|
"""数据回调"""
|
||
|
|
try:
|
||
|
|
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||
|
|
filename = f"{code_val.replace('.', '_')}_{period_val}_{timestamp}.json"
|
||
|
|
filepath = os.path.join(save_path, filename)
|
||
|
|
|
||
|
|
result = {
|
||
|
|
"code": code_val,
|
||
|
|
"period": period_val,
|
||
|
|
"timestamp": datetime.now().isoformat(),
|
||
|
|
"data": data
|
||
|
|
}
|
||
|
|
|
||
|
|
with open(filepath, 'w', encoding='utf-8') as f:
|
||
|
|
json.dump(result, f, ensure_ascii=False, indent=2)
|
||
|
|
except Exception as e:
|
||
|
|
print(f"Save data error: {e}")
|
||
|
|
|
||
|
|
# 订阅K线
|
||
|
|
ad.subscribe_kline(
|
||
|
|
code=code_val,
|
||
|
|
period=period_value,
|
||
|
|
callback=on_data
|
||
|
|
)
|
||
|
|
|
||
|
|
# 保持运行
|
||
|
|
import time
|
||
|
|
if request.duration > 0:
|
||
|
|
time.sleep(request.duration)
|
||
|
|
else:
|
||
|
|
# 无限运行,直到被停止
|
||
|
|
while task.status == "running":
|
||
|
|
time.sleep(1)
|
||
|
|
|
||
|
|
# 取消订阅
|
||
|
|
for code in request.codes:
|
||
|
|
code_val = code.value if hasattr(code, 'value') else code
|
||
|
|
ad.unsubscribe_kline(code=code_val)
|
||
|
|
|
||
|
|
ad.logout(data_service.AMAZING_DATA_USERNAME if hasattr(data_service, 'AMAZING_DATA_USERNAME') else "11200008169")
|
||
|
|
|
||
|
|
# 更新任务状态
|
||
|
|
task.status = "stopped"
|
||
|
|
task.stopped_at = datetime.utcnow()
|
||
|
|
db.commit()
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
task.status = "error"
|
||
|
|
task.stopped_at = datetime.utcnow()
|
||
|
|
db.commit()
|
||
|
|
print(f"Subscribe error: {e}")
|
||
|
|
|
||
|
|
thread = threading.Thread(target=subscribe_worker, daemon=True)
|
||
|
|
thread.start()
|
||
|
|
active_subscriptions[task.id] = thread
|
||
|
|
|
||
|
|
return BaseResponse(
|
||
|
|
data={
|
||
|
|
"task_id": task.id,
|
||
|
|
"task_name": task_name,
|
||
|
|
"status": "running",
|
||
|
|
"message": "Subscription started"
|
||
|
|
}
|
||
|
|
)
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
return BaseResponse(code=500, message=str(e))
|
||
|
|
|
||
|
|
|
||
|
|
@router.post("/stop/{task_id}", response_model=BaseResponse)
|
||
|
|
async def stop_subscription(
|
||
|
|
task_id: int,
|
||
|
|
db: Session = Depends(get_db),
|
||
|
|
current_user: Optional[User] = Depends(get_current_user)
|
||
|
|
):
|
||
|
|
"""停止订阅任务"""
|
||
|
|
task = db.query(SubscriptionTask).filter(SubscriptionTask.id == task_id).first()
|
||
|
|
if not task:
|
||
|
|
raise HTTPException(status_code=404, detail="Task not found")
|
||
|
|
|
||
|
|
task.status = "stopped"
|
||
|
|
task.stopped_at = datetime.utcnow()
|
||
|
|
db.commit()
|
||
|
|
|
||
|
|
return BaseResponse(message="Subscription stopped")
|
||
|
|
|
||
|
|
|
||
|
|
@router.get("/tasks", response_model=BaseResponse)
|
||
|
|
async def list_subscription_tasks(
|
||
|
|
status: Optional[str] = None,
|
||
|
|
db: Session = Depends(get_db),
|
||
|
|
current_user: Optional[User] = Depends(get_current_user)
|
||
|
|
):
|
||
|
|
"""列出订阅任务"""
|
||
|
|
query = db.query(SubscriptionTask)
|
||
|
|
if status:
|
||
|
|
query = query.filter(SubscriptionTask.status == status)
|
||
|
|
|
||
|
|
tasks = query.order_by(SubscriptionTask.created_at.desc()).all()
|
||
|
|
|
||
|
|
return BaseResponse(data={
|
||
|
|
"tasks": [
|
||
|
|
{
|
||
|
|
"id": t.id,
|
||
|
|
"task_name": t.task_name,
|
||
|
|
"codes": t.codes,
|
||
|
|
"periods": t.periods,
|
||
|
|
"status": t.status,
|
||
|
|
"started_at": t.started_at.isoformat() if t.started_at else None,
|
||
|
|
"stopped_at": t.stopped_at.isoformat() if t.stopped_at else None,
|
||
|
|
"created_at": t.created_at.isoformat() if t.created_at else None
|
||
|
|
}
|
||
|
|
for t in tasks
|
||
|
|
],
|
||
|
|
"total": len(tasks)
|
||
|
|
})
|
||
|
|
|
||
|
|
|
||
|
|
@router.get("/tasks/{task_id}", response_model=BaseResponse)
|
||
|
|
async def get_subscription_task(
|
||
|
|
task_id: int,
|
||
|
|
db: Session = Depends(get_db),
|
||
|
|
current_user: Optional[User] = Depends(get_current_user)
|
||
|
|
):
|
||
|
|
"""获取订阅任务详情"""
|
||
|
|
task = db.query(SubscriptionTask).filter(SubscriptionTask.id == task_id).first()
|
||
|
|
if not task:
|
||
|
|
raise HTTPException(status_code=404, detail="Task not found")
|
||
|
|
|
||
|
|
return BaseResponse(data={
|
||
|
|
"id": task.id,
|
||
|
|
"task_name": task.task_name,
|
||
|
|
"codes": task.codes,
|
||
|
|
"periods": task.periods,
|
||
|
|
"save_path": task.save_path,
|
||
|
|
"duration": task.duration,
|
||
|
|
"save_interval": task.save_interval,
|
||
|
|
"status": task.status,
|
||
|
|
"started_at": task.started_at.isoformat() if task.started_at else None,
|
||
|
|
"stopped_at": task.stopped_at.isoformat() if task.stopped_at else None,
|
||
|
|
"created_at": task.created_at.isoformat() if task.created_at else None
|
||
|
|
})
|