You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

227 lines
8.8 KiB

"""
AmazingData 数据服务平台 - 实时订阅 API
"""
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session
from typing import Optional, List
from datetime import datetime
from backend.models.database import get_db
from backend.models.schemas import (
BaseResponse, SubscribeRequest, SubscribeResponse, TaskStatus
)
from backend.models.tables import SubscriptionTask, User
from backend.auth.dependencies import get_current_user
from backend.services.data_service import data_service
router = APIRouter()
# 存储运行中的订阅任务
active_subscriptions = {}
@router.post("/subscribe", response_model=BaseResponse)
async def subscribe_kline(
request: SubscribeRequest,
db: Session = Depends(get_db),
current_user: Optional[User] = Depends(get_current_user)
):
"""启动实时K线订阅任务"""
try:
import AmazingData as ad
import threading
import os
import json
# 创建任务记录
task_name = request.task_name or f"subscribe_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
task = SubscriptionTask(
task_name=task_name,
codes=[c.value if hasattr(c, 'value') else c for c in request.codes],
periods=[p.value if hasattr(p, 'value') else p for p in request.periods],
save_path=request.save_path,
duration=request.duration,
save_interval=request.save_interval,
status="running",
started_at=datetime.utcnow(),
created_by=current_user.username if current_user else "anonymous"
)
db.add(task)
db.commit()
db.refresh(task)
# 启动订阅线程
def subscribe_worker():
"""订阅工作线程"""
try:
# 登录
ad.login(
username=data_service.AMAZING_DATA_USERNAME if hasattr(data_service, 'AMAZING_DATA_USERNAME') else "11200008169",
password=data_service.AMAZING_DATA_PASSWORD if hasattr(data_service, 'AMAZING_DATA_PASSWORD') else "11200008169@2026",
host=data_service.AMAZING_DATA_HOST if hasattr(data_service, 'AMAZING_DATA_HOST') else "140.206.44.234",
port=data_service.AMAZING_DATA_PORT if hasattr(data_service, 'AMAZING_DATA_PORT') else 8600
)
save_path = request.save_path or "./data/realtime"
os.makedirs(save_path, exist_ok=True)
# 周期映射
period_map = {
"min1": ad.constant.Period.min1.value if hasattr(ad.constant, 'Period') else 1,
"min5": ad.constant.Period.min5.value if hasattr(ad.constant, 'Period') else 5,
"min15": ad.constant.Period.min15.value if hasattr(ad.constant, 'Period') else 15,
"min30": ad.constant.Period.min30.value if hasattr(ad.constant, 'Period') else 30,
"min60": ad.constant.Period.min60.value if hasattr(ad.constant, 'Period') else 60,
}
# 为每个品种和周期创建订阅
for code in request.codes:
code_val = code.value if hasattr(code, 'value') else code
for period in request.periods:
period_val = period.value if hasattr(period, 'value') else period
period_value = period_map.get(period_val, 5)
def on_data(data):
"""数据回调"""
try:
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
filename = f"{code_val.replace('.', '_')}_{period_val}_{timestamp}.json"
filepath = os.path.join(save_path, filename)
result = {
"code": code_val,
"period": period_val,
"timestamp": datetime.now().isoformat(),
"data": data
}
with open(filepath, 'w', encoding='utf-8') as f:
json.dump(result, f, ensure_ascii=False, indent=2)
except Exception as e:
print(f"Save data error: {e}")
# 订阅K线
ad.subscribe_kline(
code=code_val,
period=period_value,
callback=on_data
)
# 保持运行
import time
if request.duration > 0:
time.sleep(request.duration)
else:
# 无限运行,直到被停止
while task.status == "running":
time.sleep(1)
# 取消订阅
for code in request.codes:
code_val = code.value if hasattr(code, 'value') else code
ad.unsubscribe_kline(code=code_val)
ad.logout(data_service.AMAZING_DATA_USERNAME if hasattr(data_service, 'AMAZING_DATA_USERNAME') else "11200008169")
# 更新任务状态
task.status = "stopped"
task.stopped_at = datetime.utcnow()
db.commit()
except Exception as e:
task.status = "error"
task.stopped_at = datetime.utcnow()
db.commit()
print(f"Subscribe error: {e}")
thread = threading.Thread(target=subscribe_worker, daemon=True)
thread.start()
active_subscriptions[task.id] = thread
return BaseResponse(
data={
"task_id": task.id,
"task_name": task_name,
"status": "running",
"message": "Subscription started"
}
)
except Exception as e:
return BaseResponse(code=500, message=str(e))
@router.post("/stop/{task_id}", response_model=BaseResponse)
async def stop_subscription(
task_id: int,
db: Session = Depends(get_db),
current_user: Optional[User] = Depends(get_current_user)
):
"""停止订阅任务"""
task = db.query(SubscriptionTask).filter(SubscriptionTask.id == task_id).first()
if not task:
raise HTTPException(status_code=404, detail="Task not found")
task.status = "stopped"
task.stopped_at = datetime.utcnow()
db.commit()
return BaseResponse(message="Subscription stopped")
@router.get("/tasks", response_model=BaseResponse)
async def list_subscription_tasks(
status: Optional[str] = None,
db: Session = Depends(get_db),
current_user: Optional[User] = Depends(get_current_user)
):
"""列出订阅任务"""
query = db.query(SubscriptionTask)
if status:
query = query.filter(SubscriptionTask.status == status)
tasks = query.order_by(SubscriptionTask.created_at.desc()).all()
return BaseResponse(data={
"tasks": [
{
"id": t.id,
"task_name": t.task_name,
"codes": t.codes,
"periods": t.periods,
"status": t.status,
"started_at": t.started_at.isoformat() if t.started_at else None,
"stopped_at": t.stopped_at.isoformat() if t.stopped_at else None,
"created_at": t.created_at.isoformat() if t.created_at else None
}
for t in tasks
],
"total": len(tasks)
})
@router.get("/tasks/{task_id}", response_model=BaseResponse)
async def get_subscription_task(
task_id: int,
db: Session = Depends(get_db),
current_user: Optional[User] = Depends(get_current_user)
):
"""获取订阅任务详情"""
task = db.query(SubscriptionTask).filter(SubscriptionTask.id == task_id).first()
if not task:
raise HTTPException(status_code=404, detail="Task not found")
return BaseResponse(data={
"id": task.id,
"task_name": task.task_name,
"codes": task.codes,
"periods": task.periods,
"save_path": task.save_path,
"duration": task.duration,
"save_interval": task.save_interval,
"status": task.status,
"started_at": task.started_at.isoformat() if task.started_at else None,
"stopped_at": task.stopped_at.isoformat() if task.stopped_at else None,
"created_at": task.created_at.isoformat() if task.created_at else None
})