feat: 一键检测所有数据的缺失情况-支持指定任务ID参数;增加websocket传输实时进度

master
Lxy 1 month ago
parent 7a2094b738
commit 4db22a6f77

@ -1,7 +1,7 @@
# API v1模块
from fastapi import APIRouter
from app.api.v1 import auth, configs, base_data, stock, future, realtime, finance, cache, test, data_import, index
from app.api.v1 import auth, configs, base_data, stock, future, realtime, finance, cache, test, data_import, index, ws
api_router = APIRouter(prefix="/api/v1")
@ -16,3 +16,4 @@ api_router.include_router(finance.router, prefix="/finance", tags=["财务数据
api_router.include_router(cache.router, prefix="/cache", tags=["缓存管理"])
api_router.include_router(test.router, prefix="/test", tags=["测试中心"])
api_router.include_router(data_import.router, prefix="/import", tags=["数据导入"])
api_router.include_router(ws.router, tags=["WebSocket进度"])

@ -190,6 +190,7 @@ async def cache_all_missing_data(
@router.post("/detect-all-missing", response_model=ResponseModel)
async def detect_all_missing_data(
request: AllDataRequest,
task_id: str = Query(None, description="WebSocket任务ID"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
@ -204,7 +205,8 @@ async def detect_all_missing_data(
request.period_type,
start,
end,
request.contract_type
request.contract_type,
task_id
)
return ResponseModel(data=result)

@ -0,0 +1,33 @@
"""
WebSocket进度路由
"""
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query
from app.core.progress_manager import progress_manager
import logging
logger = logging.getLogger(__name__)
router = APIRouter()
@router.websocket("/progress/{task_id}")
async def websocket_progress(
websocket: WebSocket,
task_id: str
):
"""WebSocket进度推送"""
await progress_manager.connect(websocket, task_id)
try:
while True:
data = await websocket.receive_text()
if data == "ping":
await websocket.send_text("pong")
elif data == "close":
break
except WebSocketDisconnect:
logger.info(f"WebSocket断开连接: task_id={task_id}")
except Exception as e:
logger.error(f"WebSocket错误: {e}")
finally:
await progress_manager.disconnect(websocket, task_id)

@ -0,0 +1,105 @@
"""
进度管理器 - WebSocket实时进度推送
"""
import asyncio
import json
import logging
from typing import Dict, Set, Optional
from datetime import datetime
from fastapi import WebSocket
logger = logging.getLogger(__name__)
class ProgressManager:
"""进度管理器"""
_instance = None
_lock = asyncio.Lock()
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._connections: Dict[str, Set[WebSocket]] = {}
cls._instance._progress_data: Dict[str, Dict] = {}
return cls._instance
async def connect(self, websocket: WebSocket, task_id: str):
"""连接WebSocket"""
await websocket.accept()
if task_id not in self._connections:
self._connections[task_id] = set()
self._connections[task_id].add(websocket)
if task_id in self._progress_data:
await websocket.send_json(self._progress_data[task_id])
logger.info(f"WebSocket连接: task_id={task_id}")
async def disconnect(self, websocket: WebSocket, task_id: str):
"""断开WebSocket连接"""
if task_id in self._connections:
self._connections[task_id].discard(websocket)
if not self._connections[task_id]:
del self._connections[task_id]
logger.info(f"WebSocket断开: task_id={task_id}")
async def update_progress(self, task_id: str, progress_data: Dict):
"""更新进度并推送"""
progress_data["timestamp"] = datetime.utcnow().isoformat()
self._progress_data[task_id] = progress_data
if task_id in self._connections:
disconnected = set()
for websocket in self._connections[task_id]:
try:
await websocket.send_json(progress_data)
except Exception as e:
logger.warning(f"WebSocket发送失败: {e}")
disconnected.add(websocket)
for ws in disconnected:
self._connections[task_id].discard(ws)
async def complete_task(self, task_id: str, result: Dict):
"""完成任务"""
result["status"] = "completed"
result["timestamp"] = datetime.utcnow().isoformat()
self._progress_data[task_id] = result
if task_id in self._connections:
for websocket in self._connections[task_id]:
try:
await websocket.send_json(result)
except Exception as e:
logger.warning(f"WebSocket发送失败: {e}")
async def fail_task(self, task_id: str, error: str):
"""任务失败"""
result = {
"status": "failed",
"error": error,
"timestamp": datetime.utcnow().isoformat()
}
self._progress_data[task_id] = result
if task_id in self._connections:
for websocket in self._connections[task_id]:
try:
await websocket.send_json(result)
except Exception as e:
logger.warning(f"WebSocket发送失败: {e}")
def get_progress(self, task_id: str) -> Optional[Dict]:
"""获取进度数据"""
return self._progress_data.get(task_id)
def clear_task(self, task_id: str):
"""清除任务数据"""
if task_id in self._progress_data:
del self._progress_data[task_id]
if task_id in self._connections:
del self._connections[task_id]
progress_manager = ProgressManager()

@ -1,6 +1,7 @@
"""
缓存管理服务
"""
import asyncio
import logging
from typing import List, Dict, Optional
from datetime import date, datetime
@ -17,6 +18,7 @@ from app.services.sdk_manager import sdk_manager
from app.utils.date_utils import parse_date, format_date, get_market_from_code
from app.config import settings
from app.core.redis_client import redis_client
from app.core.progress_manager import progress_manager
logger = logging.getLogger(__name__)
@ -166,7 +168,8 @@ class CacheService:
period_type: str,
start_date: date,
end_date: date,
contract_type: str = "all"
contract_type: str = "all",
task_id: str = None
) -> Dict:
"""
一键检测所有数据的缺失情况
@ -177,11 +180,11 @@ class CacheService:
start_date: 开始日期
end_date: 结束日期
contract_type: 合约类型 (all, main) - 仅对期货有效
task_id: WebSocket任务ID
Returns:
检测结果字典
"""
# 获取所有代码
code_list = self.get_all_codes(security_type, contract_type)
if not code_list:
@ -189,7 +192,23 @@ class CacheService:
logger.info(f"获取到{len(code_list)}{security_type}代码")
# 创建检测任务
ws_task_id = task_id or f"detect_{security_type}_{datetime.utcnow().strftime('%Y%m%d%H%M%S')}"
def push_progress(progress, status, **kwargs):
try:
loop = asyncio.get_event_loop()
if loop.is_running():
asyncio.create_task(progress_manager.update_progress(ws_task_id, {
"progress": progress,
"status": status,
"total_count": len(code_list),
**kwargs
}))
except RuntimeError:
pass
push_progress(0, "starting", message="开始检测...")
task = CacheTask(
task_name=f"一键检测所有数据 - {security_type} - {contract_type} - {len(code_list)}个代码",
task_type="detect_all_missing",
@ -206,8 +225,9 @@ class CacheService:
self.db.commit()
self.db.refresh(task)
push_progress(5, "running", message="获取交易日历...")
try:
# 获取交易日历
market = "CFE" if security_type == "future" else "SH"
trading_days = self.base_service.get_trading_calendar(market, start_date, end_date)
expected_count = len(trading_days)
@ -216,7 +236,6 @@ class CacheService:
complete_codes = []
error_count = 0
# 统计每个交易日的缺失情况
daily_stats = {}
for td in trading_days:
daily_stats[format_date(td)] = {
@ -225,44 +244,114 @@ class CacheService:
"missing": 0
}
for i, code in enumerate(code_list):
try:
# 查询实际数据量
if security_type == "stock" and period_type == "daily":
records = self.db.query(StockKlineDaily).filter(
and_(
StockKlineDaily.code == code,
StockKlineDaily.trade_date >= start_date,
StockKlineDaily.trade_date <= end_date
)
).all()
actual_count = len(records)
push_progress(10, "running", message="查询数据库统计...")
# 更新每日统计
for r in records:
date_key = format_date(r.trade_date)
if date_key in daily_stats:
daily_stats[date_key]["actual"] += 1
if security_type == "stock" and period_type == "daily":
from sqlalchemy import func
elif security_type == "future" and period_type == "daily":
records = self.db.query(FutureKlineDaily).filter(
and_(
FutureKlineDaily.code == code,
FutureKlineDaily.trade_date >= start_date,
FutureKlineDaily.trade_date <= end_date
)
).all()
actual_count = len(records)
code_count_query = self.db.query(
StockKlineDaily.code,
func.count(StockKlineDaily.id).label('count')
).filter(
and_(
StockKlineDaily.trade_date >= start_date,
StockKlineDaily.trade_date <= end_date
)
).group_by(StockKlineDaily.code).all()
code_counts = {r.code: r.count for r in code_count_query}
date_count_query = self.db.query(
func.date(StockKlineDaily.trade_date).label('trade_date'),
func.count(StockKlineDaily.id).label('count')
).filter(
and_(
StockKlineDaily.trade_date >= start_date,
StockKlineDaily.trade_date <= end_date
)
).group_by(func.date(StockKlineDaily.trade_date)).all()
for r in date_count_query:
date_key = format_date(r.trade_date) if hasattr(r.trade_date, 'strftime') else str(r.trade_date)
if date_key in daily_stats:
daily_stats[date_key]["actual"] = r.count
push_progress(20, "running", message="分析数据完整性...")
for r in records:
date_key = format_date(r.trade_date)
if date_key in daily_stats:
daily_stats[date_key]["actual"] += 1
for i, code in enumerate(code_list):
actual_count = code_counts.get(code, 0)
is_missing = actual_count < expected_count
if is_missing:
missing_codes.append({
"code": code,
"actual_count": actual_count,
"expected_count": expected_count,
"missing_count": expected_count - actual_count,
"missing_ratio": (expected_count - actual_count) / expected_count if expected_count > 0 else 0
})
detail = CacheTaskDetail(
task_id=task.id,
code=code,
trade_date=start_date,
expected_count=expected_count,
actual_count=actual_count,
is_missing=True,
status="pending"
)
self.db.add(detail)
else:
actual_count = 0
complete_codes.append(code)
if (i + 1) % 500 == 0 or i == len(code_list) - 1:
task.success_count = len(missing_codes) + len(complete_codes)
task.error_count = error_count
task.progress = min(100, int((i + 1) / len(code_list) * 100))
self.db.commit()
push_progress(
20 + int((i + 1) / len(code_list) * 70),
"running",
processed=len(missing_codes) + len(complete_codes),
missing=len(missing_codes),
complete=len(complete_codes)
)
elif security_type == "future" and period_type == "daily":
from sqlalchemy import func
code_count_query = self.db.query(
FutureKlineDaily.code,
func.count(FutureKlineDaily.id).label('count')
).filter(
and_(
FutureKlineDaily.trade_date >= start_date,
FutureKlineDaily.trade_date <= end_date
)
).group_by(FutureKlineDaily.code).all()
code_counts = {r.code: r.count for r in code_count_query}
# 判断是否缺失
date_count_query = self.db.query(
func.date(FutureKlineDaily.trade_date).label('trade_date'),
func.count(FutureKlineDaily.id).label('count')
).filter(
and_(
FutureKlineDaily.trade_date >= start_date,
FutureKlineDaily.trade_date <= end_date
)
).group_by(func.date(FutureKlineDaily.trade_date)).all()
for r in date_count_query:
date_key = format_date(r.trade_date) if hasattr(r.trade_date, 'strftime') else str(r.trade_date)
if date_key in daily_stats:
daily_stats[date_key]["actual"] = r.count
push_progress(20, "running", message="分析数据完整性...")
for i, code in enumerate(code_list):
actual_count = code_counts.get(code, 0)
is_missing = actual_count < expected_count
if is_missing:
@ -287,31 +376,59 @@ class CacheService:
else:
complete_codes.append(code)
except Exception as e:
logger.error(f"检测{code}缺失数据失败: {str(e)}")
error_count += 1
if (i + 1) % 500 == 0 or i == len(code_list) - 1:
task.success_count = len(missing_codes) + len(complete_codes)
task.error_count = error_count
task.progress = min(100, int((i + 1) / len(code_list) * 100))
self.db.commit()
push_progress(
20 + int((i + 1) / len(code_list) * 70),
"running",
processed=len(missing_codes) + len(complete_codes),
missing=len(missing_codes),
complete=len(complete_codes)
)
else:
for i, code in enumerate(code_list):
actual_count = 0
is_missing = True
missing_codes.append({
"code": code,
"actual_count": 0,
"expected_count": expected_count,
"missing_count": expected_count,
"missing_ratio": 1.0
})
detail = CacheTaskDetail(
task_id=task.id,
code=code,
trade_date=start_date,
status="failed",
error_message=str(e)
expected_count=expected_count,
actual_count=0,
is_missing=True,
status="pending"
)
self.db.add(detail)
# 每100个代码更新一次进度
if (i + 1) % 100 == 0 or i == len(code_list) - 1:
task.success_count = len(missing_codes) + len(complete_codes)
task.error_count = error_count
task.progress = min(100, int((i + 1) / len(code_list) * 100))
self.db.commit()
if (i + 1) % 500 == 0 or i == len(code_list) - 1:
task.success_count = len(missing_codes)
task.error_count = error_count
task.progress = min(100, int((i + 1) / len(code_list) * 100))
self.db.commit()
push_progress(
20 + int((i + 1) / len(code_list) * 70),
"running",
processed=len(missing_codes),
missing=len(missing_codes)
)
# 计算每日缺失数
for date_key in daily_stats:
daily_stats[date_key]["missing"] = daily_stats[date_key]["expected"] - daily_stats[date_key]["actual"]
# 保存缺失代码列表到任务记录
missing_code_list = [m["code"] for m in missing_codes]
task.code_list = ",".join(missing_code_list[:500]) if missing_code_list else ""
@ -321,10 +438,18 @@ class CacheService:
task.completed_at = datetime.utcnow()
self.db.commit()
push_progress(100, "completed",
message="检测完成",
complete_count=len(complete_codes),
missing_count=len(missing_codes),
error_count=error_count
)
logger.info(f"检测完成: 完整{len(complete_codes)}个, 缺失{len(missing_codes)}个, 错误{error_count}")
return {
"task_id": task.id,
"ws_task_id": ws_task_id,
"task_name": task.task_name,
"status": task.status,
"progress": float(task.progress),
@ -349,8 +474,11 @@ class CacheService:
self.db.commit()
logger.error(f"一键检测缺失数据任务失败: {str(e)}")
push_progress(100, "failed", error=str(e))
return {
"task_id": task.id,
"ws_task_id": ws_task_id,
"task_name": task.task_name,
"status": task.status,
"error_message": str(e)

@ -1,4 +1,4 @@
import request from '@/utils/request'
import request, { cacheRequest } from '@/utils/request'
export const detectMissingData = (data: {
security_type: string
@ -7,7 +7,7 @@ export const detectMissingData = (data: {
end_date: string
code_list: string[]
}) => {
return request.post('/cache/detect-missing', data)
return cacheRequest.post('/cache/detect-missing', data)
}
export const batchCacheData = (data: {
@ -17,7 +17,7 @@ export const batchCacheData = (data: {
end_date: string
code_list: string[]
}) => {
return request.post('/cache/batch-cache', data)
return cacheRequest.post('/cache/batch-cache', data)
}
export const detectAllMissingData = (data: {
@ -26,8 +26,10 @@ export const detectAllMissingData = (data: {
contract_type?: string
start_date: string
end_date: string
task_id?: string
}) => {
return request.post('/cache/detect-all-missing', data)
const params = data.task_id ? { task_id: data.task_id } : {}
return cacheRequest.post('/cache/detect-all-missing', data, { params })
}
export const cacheAllMissingData = (data: {
@ -37,7 +39,7 @@ export const cacheAllMissingData = (data: {
start_date: string
end_date: string
}) => {
return request.post('/cache/cache-all-missing', data)
return cacheRequest.post('/cache/cache-all-missing', data)
}
export const getCacheTasks = (params?: { page?: number; page_size?: number }) => {
@ -74,5 +76,5 @@ export const fillMissingData = (data: {
start_date: string
end_date: string
}) => {
return request.post('/cache/fill-missing', data)
return cacheRequest.post('/cache/fill-missing', data)
}

@ -2,13 +2,16 @@ import axios from 'axios'
import { ElMessage } from 'element-plus'
import { useUserStore } from '@/store/user'
// 创建axios实例
const request = axios.create({
baseURL: '/api/v1',
timeout: 30000
})
// 请求拦截器
const cacheRequest = axios.create({
baseURL: '/api/v1',
timeout: 300000
})
request.interceptors.request.use(
(config) => {
const userStore = useUserStore()
@ -22,7 +25,19 @@ request.interceptors.request.use(
}
)
// 响应拦截器
cacheRequest.interceptors.request.use(
(config) => {
const userStore = useUserStore()
if (userStore.token) {
config.headers.Authorization = `Bearer ${userStore.token}`
}
return config
},
(error) => {
return Promise.reject(error)
}
)
request.interceptors.response.use(
(response) => {
const res = response.data
@ -30,7 +45,31 @@ request.interceptors.response.use(
if (res.code !== 200) {
ElMessage.error(res.message || '请求失败')
// 401未授权跳转到登录页
if (res.code === 401) {
const userStore = useUserStore()
userStore.logout()
window.location.href = '/login'
}
return Promise.reject(new Error(res.message))
}
return res
},
(error) => {
const message = error.response?.data?.message || error.message || '网络错误'
ElMessage.error(message)
return Promise.reject(error)
}
)
cacheRequest.interceptors.response.use(
(response) => {
const res = response.data
if (res.code !== 200) {
ElMessage.error(res.message || '请求失败')
if (res.code === 401) {
const userStore = useUserStore()
userStore.logout()
@ -50,3 +89,4 @@ request.interceptors.response.use(
)
export default request
export { cacheRequest }

@ -69,6 +69,32 @@
</el-form>
</el-card>
<!-- 实时进度显示 -->
<el-card class="progress-card" v-if="wsProgress.status">
<template #header>
<div style="display: flex; justify-content: space-between; align-items: center;">
<span>实时进度</span>
<el-tag :type="getProgressTagType(wsProgress.status)">
{{ wsProgress.status === 'starting' ? '启动中' :
wsProgress.status === 'running' ? '进行中' :
wsProgress.status === 'completed' ? '已完成' : '失败' }}
</el-tag>
</div>
</template>
<el-progress
:percentage="wsProgress.progress || 0"
:status="getProgressStatus(wsProgress.status)"
:stroke-width="20"
:text-inside="true"
/>
<el-descriptions :column="4" border style="margin-top: 15px;">
<el-descriptions-item label="当前状态">{{ wsProgress.message || '-' }}</el-descriptions-item>
<el-descriptions-item label="总数">{{ wsProgress.total_count || 0 }}</el-descriptions-item>
<el-descriptions-item label="已处理">{{ wsProgress.processed || 0 }}</el-descriptions-item>
<el-descriptions-item label="缺失">{{ wsProgress.missing || 0 }}</el-descriptions-item>
</el-descriptions>
</el-card>
<!-- 检测结果汇总 -->
<el-card class="summary-card" v-if="detectResult">
<template #header>
@ -240,6 +266,17 @@ const codeInput = ref('000001.SZ\n600000.SH')
const detectResult = ref<any>(null)
const batchDetectResult = ref<any[]>([])
const cacheTask = ref<any>(null)
const wsProgress = reactive({
status: '',
progress: 0,
message: '',
total_count: 0,
processed: 0,
missing: 0,
complete: 0
})
let ws: WebSocket | null = null
const hasMissing = computed(() => batchDetectResult.value.some(r => r.missingCount > 0))
@ -287,6 +324,13 @@ function getProgressStatus(status: string) {
return undefined
}
function getProgressTagType(status: string) {
if (status === 'completed') return 'success'
if (status === 'running') return 'warning'
if (status === 'starting') return 'info'
return 'danger'
}
const parseCodes = () => {
return codeInput.value
.split(/[\n,]/)
@ -294,6 +338,56 @@ const parseCodes = () => {
.filter(c => c.length > 0)
}
const connectWebSocket = (taskId: string) => {
const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:'
const host = window.location.host || 'localhost:3000'
ws = new WebSocket(`${protocol}//${host}/api/v1/progress/${taskId}`)
ws.onopen = () => {
console.log('WebSocket连接成功')
}
ws.onmessage = (event) => {
try {
const data = JSON.parse(event.data)
wsProgress.status = data.status
wsProgress.progress = data.progress || 0
wsProgress.message = data.message || ''
wsProgress.total_count = data.total_count || 0
wsProgress.processed = data.processed || 0
wsProgress.missing = data.missing || 0
wsProgress.complete = data.complete || 0
if (data.status === 'completed' || data.status === 'failed') {
if (ws) {
ws.close()
ws = null
}
}
} catch (e) {
console.error('WebSocket消息解析失败', e)
}
}
ws.onerror = (error) => {
console.error('WebSocket错误', error)
}
ws.onclose = () => {
console.log('WebSocket连接关闭')
ws = null
}
}
const closeWebSocket = () => {
if (ws) {
ws.close()
ws = null
}
wsProgress.status = ''
wsProgress.progress = 0
}
const pollTaskProgress = async (taskId: number) => {
const poll = async () => {
if (!cacheTask.value || cacheTask.value.status === 'running' || cacheTask.value.status === 'pending') {
@ -316,13 +410,18 @@ const pollTaskProgress = async (taskId: number) => {
const handleDetectAll = async () => {
detectingAll.value = true
detectResult.value = null
const taskId = `detect_${form.securityType}_${Date.now()}`
connectWebSocket(taskId)
try {
const res: any = await detectAllMissingData({
security_type: form.securityType,
period_type: form.periodType,
contract_type: form.contractType,
start_date: form.startDate,
end_date: form.endDate
end_date: form.endDate,
task_id: taskId
})
if (res.data) {

@ -13,7 +13,7 @@ export default defineConfig({
port: 3000,
proxy: {
'/api': {
target: 'http://localhost:8001',
target: 'http://localhost:8000',
changeOrigin: true
}
}

Loading…
Cancel
Save