fix: 一键缓存缺失数据修改方法(未测试);增加get_all_code redis缓存

master
Lxy 2 months ago
parent b54a0cee58
commit b0d0ef298b

@ -2,7 +2,7 @@
缓存管理路由 缓存管理路由
""" """
from typing import List from typing import List
from fastapi import APIRouter, Depends, Query from fastapi import APIRouter, Depends, Query, BackgroundTasks
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.db.session import get_db from app.db.session import get_db
@ -20,17 +20,85 @@ from app.utils.date_utils import parse_date, format_date
router = APIRouter() router = APIRouter()
@router.post("/detect-all-missing", response_model=ResponseModel) def run_cache_task(
async def detect_all_missing_data( security_type: str,
period_type: str,
start_date_str: str,
end_date_str: str,
contract_type: str,
task_id: int
):
"""后台执行缓存任务"""
from app.db.session import SessionLocal
from app.utils.date_utils import parse_date
db = SessionLocal()
try:
service = CacheService(db)
start_date = parse_date(start_date_str)
end_date = parse_date(end_date_str)
service._execute_cache_task(
task_id,
security_type,
period_type,
start_date,
end_date,
contract_type
)
except Exception as e:
import logging
logger = logging.getLogger(__name__)
logger.error(f"后台缓存任务失败: {str(e)}")
finally:
db.close()
def run_fill_missing_task(
security_type: str,
period_type: str,
start_date_str: str,
end_date_str: str,
missing_codes: List[str],
task_id: int
):
"""后台执行补齐缺失数据任务"""
from app.db.session import SessionLocal
from app.utils.date_utils import parse_date
db = SessionLocal()
try:
service = CacheService(db)
start_date = parse_date(start_date_str)
end_date = parse_date(end_date_str)
service._execute_fill_missing_task(
task_id,
security_type,
period_type,
start_date,
end_date,
missing_codes
)
except Exception as e:
import logging
logger = logging.getLogger(__name__)
logger.error(f"后台补齐任务失败: {str(e)}")
finally:
db.close()
@router.post("/fill-missing", response_model=ResponseModel)
async def fill_missing_data(
request: AllDataRequest, request: AllDataRequest,
background_tasks: BackgroundTasks,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user)
): ):
"""一键检测所有数据的缺失情况""" """一键补齐缺失数据(异步执行,只处理缺失代码)"""
service = CacheService(db) service = CacheService(db)
start = parse_date(request.start_date) start = parse_date(request.start_date)
end = parse_date(request.end_date) end = parse_date(request.end_date)
# 先执行检测获取缺失代码列表
result = service.detect_all_missing_data( result = service.detect_all_missing_data(
request.security_type, request.security_type,
request.period_type, request.period_type,
@ -39,21 +107,59 @@ async def detect_all_missing_data(
request.contract_type request.contract_type
) )
return ResponseModel(data=result) missing_codes = result.get("missing_code_list", [])
if not missing_codes:
return ResponseModel(data={
"task_id": None,
"message": "没有缺失数据需要补齐",
"missing_count": 0
})
# 创建补齐任务记录
task = service._create_fill_missing_task(
request.security_type,
request.period_type,
start,
end,
missing_codes
)
# 在后台执行补齐任务
background_tasks.add_task(
run_fill_missing_task,
request.security_type,
request.period_type,
request.start_date,
request.end_date,
missing_codes,
task.id
)
return ResponseModel(data={
"task_id": task.id,
"task_name": task.task_name,
"status": task.status,
"total_count": task.total_count,
"progress": float(task.progress),
"missing_count": len(missing_codes)
})
@router.post("/cache-all-missing", response_model=ResponseModel) @router.post("/cache-all-missing", response_model=ResponseModel)
async def cache_all_missing_data( async def cache_all_missing_data(
request: AllDataRequest, request: AllDataRequest,
background_tasks: BackgroundTasks,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user)
): ):
"""一键缓存所有缺失数据""" """一键缓存所有缺失数据(异步执行)"""
service = CacheService(db) service = CacheService(db)
start = parse_date(request.start_date) start = parse_date(request.start_date)
end = parse_date(request.end_date) end = parse_date(request.end_date)
task = service.cache_all_missing_data( # 创建任务记录
task = service._create_cache_task(
request.security_type, request.security_type,
request.period_type, request.period_type,
start, start,
@ -61,15 +167,48 @@ async def cache_all_missing_data(
request.contract_type request.contract_type
) )
# 在后台执行缓存任务
background_tasks.add_task(
run_cache_task,
request.security_type,
request.period_type,
request.start_date,
request.end_date,
request.contract_type,
task.id
)
return ResponseModel(data={ return ResponseModel(data={
"task_id": task.id, "task_id": task.id,
"task_name": task.task_name, "task_name": task.task_name,
"status": task.status, "status": task.status,
"total_count": task.total_count, "total_count": task.total_count,
"progress": task.progress "progress": float(task.progress)
}) })
@router.post("/detect-all-missing", response_model=ResponseModel)
async def detect_all_missing_data(
request: AllDataRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""一键检测所有数据的缺失情况"""
service = CacheService(db)
start = parse_date(request.start_date)
end = parse_date(request.end_date)
result = service.detect_all_missing_data(
request.security_type,
request.period_type,
start,
end,
request.contract_type
)
return ResponseModel(data=result)
@router.post("/detect-missing", response_model=ResponseModel) @router.post("/detect-missing", response_model=ResponseModel)
async def detect_missing_data( async def detect_missing_data(
request: DetectMissingRequest, request: DetectMissingRequest,

@ -0,0 +1,92 @@
"""
Redis客户端模块
"""
import redis
import json
import logging
from typing import Optional, List, Any
from app.config import settings
logger = logging.getLogger(__name__)
class RedisClient:
"""Redis客户端"""
_instance: Optional['RedisClient'] = None
_client: Optional[redis.Redis] = None
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
if self._client is None:
try:
self._client = redis.from_url(
settings.REDIS_URL,
decode_responses=True
)
logger.info("Redis连接成功")
except Exception as e:
logger.warning(f"Redis连接失败: {str(e)}")
self._client = None
def is_connected(self) -> bool:
"""检查Redis是否连接"""
if self._client is None:
return False
try:
self._client.ping()
return True
except:
return False
def get(self, key: str) -> Optional[Any]:
"""获取缓存"""
if not self.is_connected():
return None
try:
value = self._client.get(key)
if value:
return json.loads(value)
return None
except Exception as e:
logger.error(f"Redis获取失败: {str(e)}")
return None
def set(self, key: str, value: Any, expire: int = None) -> bool:
"""设置缓存"""
if not self.is_connected():
return False
try:
self._client.set(key, json.dumps(value), ex=expire)
return True
except Exception as e:
logger.error(f"Redis设置失败: {str(e)}")
return False
def delete(self, key: str) -> bool:
"""删除缓存"""
if not self.is_connected():
return False
try:
self._client.delete(key)
return True
except Exception as e:
logger.error(f"Redis删除失败: {str(e)}")
return False
def exists(self, key: str) -> bool:
"""检查键是否存在"""
if not self.is_connected():
return False
try:
return self._client.exists(key) > 0
except Exception as e:
logger.error(f"Redis检查失败: {str(e)}")
return False
redis_client = RedisClient()

@ -16,6 +16,7 @@ from app.services.future_service import FutureService
from app.services.sdk_manager import sdk_manager from app.services.sdk_manager import sdk_manager
from app.utils.date_utils import parse_date, format_date, get_market_from_code from app.utils.date_utils import parse_date, format_date, get_market_from_code
from app.config import settings from app.config import settings
from app.core.redis_client import redis_client
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -23,6 +24,8 @@ logger = logging.getLogger(__name__)
class CacheService: class CacheService:
"""缓存服务""" """缓存服务"""
CODE_LIST_CACHE_EXPIRE = 12 * 60 * 60
def __init__(self, db: Session): def __init__(self, db: Session):
self.db = db self.db = db
self.base_service = BaseDataService(db) self.base_service = BaseDataService(db)
@ -31,7 +34,7 @@ class CacheService:
def get_all_codes(self, security_type: str, contract_type: str = "all") -> List[str]: def get_all_codes(self, security_type: str, contract_type: str = "all") -> List[str]:
""" """
获取所有代码列表 获取所有代码列表带Redis缓存
Args: Args:
security_type: 证券类型 (stock, future) security_type: 证券类型 (stock, future)
@ -40,21 +43,32 @@ class CacheService:
Returns: Returns:
代码列表 代码列表
""" """
cache_key = f"code_list:{security_type}:{contract_type}"
cached_codes = redis_client.get(cache_key)
if cached_codes is not None:
logger.info(f"从Redis缓存获取代码列表: {security_type}/{contract_type}, 共{len(cached_codes)}")
return cached_codes
adapter = sdk_manager.get_default_connection() adapter = sdk_manager.get_default_connection()
if not adapter: if not adapter:
raise RuntimeError("SDK连接失败") raise RuntimeError("SDK连接失败")
codes = []
if security_type == "stock": if security_type == "stock":
return adapter.get_code_list("EXTRA_STOCK_A") codes = adapter.get_code_list("EXTRA_STOCK_A")
elif security_type == "future": elif security_type == "future":
if contract_type == "main": if contract_type == "main":
# 只获取主力合约
main_contracts = adapter.get_all_main_contracts() main_contracts = adapter.get_all_main_contracts()
return list(main_contracts.values()) codes = list(main_contracts.values())
else: else:
return adapter.get_code_list("EXTRA_FUTURE") codes = adapter.get_code_list("EXTRA_FUTURE")
else:
return [] if codes:
redis_client.set(cache_key, codes, expire=self.CODE_LIST_CACHE_EXPIRE)
logger.info(f"代码列表已缓存到Redis: {security_type}/{contract_type}, 共{len(codes)}个, 有效期12小时")
return codes
def get_future_varieties(self) -> List[str]: def get_future_varieties(self) -> List[str]:
"""获取期货品种列表""" """获取期货品种列表"""
@ -214,6 +228,10 @@ class CacheService:
for date_key in daily_stats: for date_key in daily_stats:
daily_stats[date_key]["missing"] = daily_stats[date_key]["expected"] - daily_stats[date_key]["actual"] daily_stats[date_key]["missing"] = daily_stats[date_key]["expected"] - daily_stats[date_key]["actual"]
# 保存缺失代码列表到任务记录
missing_code_list = [m["code"] for m in missing_codes]
task.code_list = ",".join(missing_code_list[:500]) if missing_code_list else ""
task.status = "completed" task.status = "completed"
task.success_count = len(complete_codes) task.success_count = len(complete_codes)
task.error_count = error_count task.error_count = error_count
@ -237,7 +255,8 @@ class CacheService:
"security_type": security_type, "security_type": security_type,
"period_type": period_type, "period_type": period_type,
"daily_stats": daily_stats, "daily_stats": daily_stats,
"missing_codes": missing_codes[:100] # 只返回前100个缺失代码 "missing_codes": missing_codes[:100],
"missing_code_list": missing_code_list
} }
except Exception as e: except Exception as e:
@ -254,6 +273,202 @@ class CacheService:
"error_message": str(e) "error_message": str(e)
} }
def _create_cache_task(
self,
security_type: str,
period_type: str,
start_date: date,
end_date: date,
contract_type: str = "all"
) -> CacheTask:
"""创建缓存任务记录"""
code_list = self.get_all_codes(security_type, contract_type)
if not code_list:
raise ValueError(f"无法获取{security_type}代码列表")
task = CacheTask(
task_name=f"一键缓存所有数据 - {security_type} - {contract_type} - {len(code_list)}个代码",
task_type="cache_all_data",
security_type=security_type,
period_type=period_type,
start_date=start_date,
end_date=end_date,
code_list=",".join(code_list[:100]) + "...",
status="pending",
total_count=len(code_list),
started_at=datetime.utcnow()
)
self.db.add(task)
self.db.commit()
self.db.refresh(task)
return task
def _create_fill_missing_task(
self,
security_type: str,
period_type: str,
start_date: date,
end_date: date,
missing_codes: List[str]
) -> CacheTask:
"""创建补齐缺失数据任务记录"""
if not missing_codes:
raise ValueError("没有缺失代码需要补齐")
task = CacheTask(
task_name=f"一键补齐缺失数据 - {security_type} - {len(missing_codes)}个代码",
task_type="fill_missing_data",
security_type=security_type,
period_type=period_type,
start_date=start_date,
end_date=end_date,
code_list=",".join(missing_codes[:500]),
status="pending",
total_count=len(missing_codes),
started_at=datetime.utcnow()
)
self.db.add(task)
self.db.commit()
self.db.refresh(task)
return task
def _execute_fill_missing_task(
self,
task_id: int,
security_type: str,
period_type: str,
start_date: date,
end_date: date,
missing_codes: List[str]
):
"""执行补齐缺失数据任务"""
task = self.db.query(CacheTask).filter(CacheTask.id == task_id).first()
if not task:
return
task.status = "running"
self.db.commit()
try:
success_count = 0
error_count = 0
for i, code in enumerate(missing_codes):
try:
if security_type == "stock":
self.stock_service.get_kline([code], start_date, end_date, period_type)
elif security_type == "future":
self.future_service.get_kline([code], start_date, end_date, period_type)
success_count += 1
except Exception as e:
logger.error(f"补齐{code}数据失败: {str(e)}")
error_count += 1
if (i + 1) % 10 == 0 or i == len(missing_codes) - 1:
task.success_count = success_count
task.error_count = error_count
task.progress = min(100, int((i + 1) / len(missing_codes) * 100))
self.db.commit()
task.status = "completed"
task.success_count = success_count
task.error_count = error_count
task.completed_at = datetime.utcnow()
self.db.commit()
except Exception as e:
task.status = "failed"
task.error_message = str(e)
task.completed_at = datetime.utcnow()
self.db.commit()
def _execute_cache_task(
self,
task_id: int,
security_type: str,
period_type: str,
start_date: date,
end_date: date,
contract_type: str = "all"
):
"""执行缓存任务"""
task = self.db.query(CacheTask).filter(CacheTask.id == task_id).first()
if not task:
return
task.status = "running"
self.db.commit()
code_list = self.get_all_codes(security_type, contract_type)
try:
market = "CFE" if security_type == "future" else "SH"
trading_days = self.base_service.get_trading_calendar(market, start_date, end_date)
expected_count = len(trading_days)
success_count = 0
skipped_count = 0
error_count = 0
for i, code in enumerate(code_list):
try:
if security_type == "stock" and period_type == "daily":
actual_count = self.db.query(StockKlineDaily).filter(
and_(
StockKlineDaily.code == code,
StockKlineDaily.trade_date >= start_date,
StockKlineDaily.trade_date <= end_date
)
).count()
elif security_type == "future" and period_type == "daily":
actual_count = self.db.query(FutureKlineDaily).filter(
and_(
FutureKlineDaily.code == code,
FutureKlineDaily.trade_date >= start_date,
FutureKlineDaily.trade_date <= end_date
)
).count()
else:
actual_count = 0
if actual_count >= expected_count:
skipped_count += 1
continue
if security_type == "stock":
self.stock_service.get_kline([code], start_date, end_date, period_type)
elif security_type == "future":
self.future_service.get_kline([code], start_date, end_date, period_type)
success_count += 1
except Exception as e:
logger.error(f"缓存{code}数据失败: {str(e)}")
error_count += 1
if (i + 1) % 10 == 0 or i == len(code_list) - 1:
task.success_count = success_count
task.error_count = error_count
task.progress = min(100, int((i + 1) / len(code_list) * 100))
self.db.commit()
task.status = "completed"
task.success_count = success_count
task.error_count = error_count
task.completed_at = datetime.utcnow()
self.db.commit()
except Exception as e:
task.status = "failed"
task.error_message = str(e)
task.completed_at = datetime.utcnow()
self.db.commit()
def cache_all_missing_data( def cache_all_missing_data(
self, self,
security_type: str, security_type: str,

@ -66,3 +66,13 @@ export const getFutureVarieties = () => {
export const getMainContracts = () => { export const getMainContracts = () => {
return request.get('/cache/main-contracts') return request.get('/cache/main-contracts')
} }
export const fillMissingData = (data: {
security_type: string
period_type: string
contract_type?: string
start_date: string
end_date: string
}) => {
return request.post('/cache/fill-missing', data)
}

@ -72,7 +72,17 @@
<!-- 检测结果汇总 --> <!-- 检测结果汇总 -->
<el-card class="summary-card" v-if="detectResult"> <el-card class="summary-card" v-if="detectResult">
<template #header> <template #header>
<div style="display: flex; justify-content: space-between; align-items: center;">
<span>检测结果汇总</span> <span>检测结果汇总</span>
<el-button
type="success"
@click="handleFillMissing"
:loading="fillingMissing"
:disabled="!detectResult.missing_count || detectResult.missing_count === 0"
>
<el-icon><Download /></el-icon>
</el-button>
</div>
</template> </template>
<el-row :gutter="20"> <el-row :gutter="20">
<el-col :span="6"> <el-col :span="6">
@ -120,6 +130,22 @@
</el-row> </el-row>
</el-card> </el-card>
<!-- 缓存进度 -->
<el-card class="progress-card" v-if="cacheTask">
<template #header>
<span>缓存进度</span>
</template>
<el-progress :percentage="Math.round(cacheTask.progress || 0)" :status="getProgressStatus(cacheTask.status)" />
<el-descriptions :column="4" border style="margin-top: 15px;">
<el-descriptions-item label="任务名称">{{ cacheTask.task_name }}</el-descriptions-item>
<el-descriptions-item label="状态">
<el-tag :type="getStatusType(cacheTask.status)">{{ cacheTask.status }}</el-tag>
</el-descriptions-item>
<el-descriptions-item label="总数">{{ cacheTask.total_count }}</el-descriptions-item>
<el-descriptions-item label="已处理">{{ cacheTask.success_count || 0 }}</el-descriptions-item>
</el-descriptions>
</el-card>
<!-- 每日数据统计 --> <!-- 每日数据统计 -->
<el-card class="daily-card" v-if="detectResult && detectResult.daily_stats"> <el-card class="daily-card" v-if="detectResult && detectResult.daily_stats">
<template #header> <template #header>
@ -203,15 +229,17 @@
<script setup lang="ts"> <script setup lang="ts">
import { ref, reactive, computed } from 'vue' import { ref, reactive, computed } from 'vue'
import { ElMessage } from 'element-plus' import { ElMessage } from 'element-plus'
import { detectMissingData, batchCacheData, detectAllMissingData, cacheAllMissingData } from '@/api/cache' import { detectMissingData, batchCacheData, detectAllMissingData, cacheAllMissingData, getCacheTask, fillMissingData } from '@/api/cache'
const detecting = ref(false) const detecting = ref(false)
const caching = ref(false) const caching = ref(false)
const detectingAll = ref(false) const detectingAll = ref(false)
const cachingAll = ref(false) const cachingAll = ref(false)
const fillingMissing = ref(false)
const codeInput = ref('000001.SZ\n600000.SH') const codeInput = ref('000001.SZ\n600000.SH')
const detectResult = ref<any>(null) const detectResult = ref<any>(null)
const batchDetectResult = ref<any[]>([]) const batchDetectResult = ref<any[]>([])
const cacheTask = ref<any>(null)
const hasMissing = computed(() => batchDetectResult.value.some(r => r.missingCount > 0)) const hasMissing = computed(() => batchDetectResult.value.some(r => r.missingCount > 0))
@ -253,6 +281,12 @@ function getStatusType(status: string) {
return 'danger' return 'danger'
} }
function getProgressStatus(status: string) {
if (status === 'completed') return 'success'
if (status === 'failed') return 'exception'
return undefined
}
const parseCodes = () => { const parseCodes = () => {
return codeInput.value return codeInput.value
.split(/[\n,]/) .split(/[\n,]/)
@ -260,6 +294,25 @@ const parseCodes = () => {
.filter(c => c.length > 0) .filter(c => c.length > 0)
} }
const pollTaskProgress = async (taskId: number) => {
const poll = async () => {
if (!cacheTask.value || cacheTask.value.status === 'running' || cacheTask.value.status === 'pending') {
const res: any = await getCacheTask(taskId)
if (res.data && res.data.task) {
cacheTask.value = res.data.task
}
if (cacheTask.value && (cacheTask.value.status === 'running' || cacheTask.value.status === 'pending')) {
setTimeout(poll, 2000)
} else if (cacheTask.value && cacheTask.value.status === 'completed') {
ElMessage.success(`缓存完成:成功${cacheTask.value.success_count}个,错误${cacheTask.value.error_count}`)
} else if (cacheTask.value && cacheTask.value.status === 'failed') {
ElMessage.error(`缓存失败:${cacheTask.value.error_message}`)
}
}
}
setTimeout(poll, 1000)
}
const handleDetectAll = async () => { const handleDetectAll = async () => {
detectingAll.value = true detectingAll.value = true
detectResult.value = null detectResult.value = null
@ -290,6 +343,7 @@ const handleDetectAll = async () => {
const handleCacheAll = async () => { const handleCacheAll = async () => {
cachingAll.value = true cachingAll.value = true
cacheTask.value = null
try { try {
const res: any = await cacheAllMissingData({ const res: any = await cacheAllMissingData({
security_type: form.securityType, security_type: form.securityType,
@ -300,7 +354,17 @@ const handleCacheAll = async () => {
}) })
if (res.data) { if (res.data) {
cacheTask.value = {
task_id: res.data.task_id,
task_name: res.data.task_name,
status: res.data.status,
total_count: res.data.total_count,
progress: res.data.progress,
success_count: 0,
error_count: 0
}
ElMessage.success(`缓存任务已启动,共${res.data.total_count}个代码`) ElMessage.success(`缓存任务已启动,共${res.data.total_count}个代码`)
pollTaskProgress(res.data.task_id)
} }
} catch (error) { } catch (error) {
console.error(error) console.error(error)
@ -310,6 +374,45 @@ const handleCacheAll = async () => {
} }
} }
const handleFillMissing = async () => {
fillingMissing.value = true
cacheTask.value = null
try {
const res: any = await fillMissingData({
security_type: form.securityType,
period_type: form.periodType,
contract_type: form.contractType,
start_date: form.startDate,
end_date: form.endDate
})
if (res.data) {
if (res.data.message) {
ElMessage.info(res.data.message)
} else {
cacheTask.value = {
task_id: res.data.task_id,
task_name: res.data.task_name,
status: res.data.status,
total_count: res.data.total_count,
progress: res.data.progress,
success_count: 0,
error_count: 0
}
ElMessage.success(`补齐任务已启动,共${res.data.missing_count}个缺失代码`)
if (res.data.task_id) {
pollTaskProgress(res.data.task_id)
}
}
}
} catch (error) {
console.error(error)
ElMessage.error('补齐失败')
} finally {
fillingMissing.value = false
}
}
const handleDetect = async () => { const handleDetect = async () => {
const codes = parseCodes() const codes = parseCodes()
if (codes.length === 0) { if (codes.length === 0) {
@ -379,6 +482,10 @@ const showDetail = (row: any) => {
margin-top: 20px; margin-top: 20px;
} }
.progress-card {
margin-top: 20px;
}
.daily-card { .daily-card {
margin-top: 20px; margin-top: 20px;
} }

Loading…
Cancel
Save