fix: 一键缓存缺失数据修改方法(未测试);增加get_all_code redis缓存

master
Lxy 2 months ago
parent b54a0cee58
commit b0d0ef298b

@ -2,7 +2,7 @@
缓存管理路由
"""
from typing import List
from fastapi import APIRouter, Depends, Query
from fastapi import APIRouter, Depends, Query, BackgroundTasks
from sqlalchemy.orm import Session
from app.db.session import get_db
@ -20,17 +20,85 @@ from app.utils.date_utils import parse_date, format_date
router = APIRouter()
@router.post("/detect-all-missing", response_model=ResponseModel)
async def detect_all_missing_data(
def run_cache_task(
security_type: str,
period_type: str,
start_date_str: str,
end_date_str: str,
contract_type: str,
task_id: int
):
"""后台执行缓存任务"""
from app.db.session import SessionLocal
from app.utils.date_utils import parse_date
db = SessionLocal()
try:
service = CacheService(db)
start_date = parse_date(start_date_str)
end_date = parse_date(end_date_str)
service._execute_cache_task(
task_id,
security_type,
period_type,
start_date,
end_date,
contract_type
)
except Exception as e:
import logging
logger = logging.getLogger(__name__)
logger.error(f"后台缓存任务失败: {str(e)}")
finally:
db.close()
def run_fill_missing_task(
security_type: str,
period_type: str,
start_date_str: str,
end_date_str: str,
missing_codes: List[str],
task_id: int
):
"""后台执行补齐缺失数据任务"""
from app.db.session import SessionLocal
from app.utils.date_utils import parse_date
db = SessionLocal()
try:
service = CacheService(db)
start_date = parse_date(start_date_str)
end_date = parse_date(end_date_str)
service._execute_fill_missing_task(
task_id,
security_type,
period_type,
start_date,
end_date,
missing_codes
)
except Exception as e:
import logging
logger = logging.getLogger(__name__)
logger.error(f"后台补齐任务失败: {str(e)}")
finally:
db.close()
@router.post("/fill-missing", response_model=ResponseModel)
async def fill_missing_data(
request: AllDataRequest,
background_tasks: BackgroundTasks,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""一键检测所有数据的缺失情况"""
"""一键补齐缺失数据(异步执行,只处理缺失代码)"""
service = CacheService(db)
start = parse_date(request.start_date)
end = parse_date(request.end_date)
# 先执行检测获取缺失代码列表
result = service.detect_all_missing_data(
request.security_type,
request.period_type,
@ -39,21 +107,59 @@ async def detect_all_missing_data(
request.contract_type
)
return ResponseModel(data=result)
missing_codes = result.get("missing_code_list", [])
if not missing_codes:
return ResponseModel(data={
"task_id": None,
"message": "没有缺失数据需要补齐",
"missing_count": 0
})
# 创建补齐任务记录
task = service._create_fill_missing_task(
request.security_type,
request.period_type,
start,
end,
missing_codes
)
# 在后台执行补齐任务
background_tasks.add_task(
run_fill_missing_task,
request.security_type,
request.period_type,
request.start_date,
request.end_date,
missing_codes,
task.id
)
return ResponseModel(data={
"task_id": task.id,
"task_name": task.task_name,
"status": task.status,
"total_count": task.total_count,
"progress": float(task.progress),
"missing_count": len(missing_codes)
})
@router.post("/cache-all-missing", response_model=ResponseModel)
async def cache_all_missing_data(
request: AllDataRequest,
background_tasks: BackgroundTasks,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""一键缓存所有缺失数据"""
"""一键缓存所有缺失数据(异步执行)"""
service = CacheService(db)
start = parse_date(request.start_date)
end = parse_date(request.end_date)
task = service.cache_all_missing_data(
# 创建任务记录
task = service._create_cache_task(
request.security_type,
request.period_type,
start,
@ -61,15 +167,48 @@ async def cache_all_missing_data(
request.contract_type
)
# 在后台执行缓存任务
background_tasks.add_task(
run_cache_task,
request.security_type,
request.period_type,
request.start_date,
request.end_date,
request.contract_type,
task.id
)
return ResponseModel(data={
"task_id": task.id,
"task_name": task.task_name,
"status": task.status,
"total_count": task.total_count,
"progress": task.progress
"progress": float(task.progress)
})
@router.post("/detect-all-missing", response_model=ResponseModel)
async def detect_all_missing_data(
request: AllDataRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""一键检测所有数据的缺失情况"""
service = CacheService(db)
start = parse_date(request.start_date)
end = parse_date(request.end_date)
result = service.detect_all_missing_data(
request.security_type,
request.period_type,
start,
end,
request.contract_type
)
return ResponseModel(data=result)
@router.post("/detect-missing", response_model=ResponseModel)
async def detect_missing_data(
request: DetectMissingRequest,

@ -0,0 +1,92 @@
"""
Redis客户端模块
"""
import redis
import json
import logging
from typing import Optional, List, Any
from app.config import settings
logger = logging.getLogger(__name__)
class RedisClient:
"""Redis客户端"""
_instance: Optional['RedisClient'] = None
_client: Optional[redis.Redis] = None
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
if self._client is None:
try:
self._client = redis.from_url(
settings.REDIS_URL,
decode_responses=True
)
logger.info("Redis连接成功")
except Exception as e:
logger.warning(f"Redis连接失败: {str(e)}")
self._client = None
def is_connected(self) -> bool:
"""检查Redis是否连接"""
if self._client is None:
return False
try:
self._client.ping()
return True
except:
return False
def get(self, key: str) -> Optional[Any]:
"""获取缓存"""
if not self.is_connected():
return None
try:
value = self._client.get(key)
if value:
return json.loads(value)
return None
except Exception as e:
logger.error(f"Redis获取失败: {str(e)}")
return None
def set(self, key: str, value: Any, expire: int = None) -> bool:
"""设置缓存"""
if not self.is_connected():
return False
try:
self._client.set(key, json.dumps(value), ex=expire)
return True
except Exception as e:
logger.error(f"Redis设置失败: {str(e)}")
return False
def delete(self, key: str) -> bool:
"""删除缓存"""
if not self.is_connected():
return False
try:
self._client.delete(key)
return True
except Exception as e:
logger.error(f"Redis删除失败: {str(e)}")
return False
def exists(self, key: str) -> bool:
"""检查键是否存在"""
if not self.is_connected():
return False
try:
return self._client.exists(key) > 0
except Exception as e:
logger.error(f"Redis检查失败: {str(e)}")
return False
redis_client = RedisClient()

@ -16,6 +16,7 @@ from app.services.future_service import FutureService
from app.services.sdk_manager import sdk_manager
from app.utils.date_utils import parse_date, format_date, get_market_from_code
from app.config import settings
from app.core.redis_client import redis_client
logger = logging.getLogger(__name__)
@ -23,6 +24,8 @@ logger = logging.getLogger(__name__)
class CacheService:
"""缓存服务"""
CODE_LIST_CACHE_EXPIRE = 12 * 60 * 60
def __init__(self, db: Session):
self.db = db
self.base_service = BaseDataService(db)
@ -31,7 +34,7 @@ class CacheService:
def get_all_codes(self, security_type: str, contract_type: str = "all") -> List[str]:
"""
获取所有代码列表
获取所有代码列表带Redis缓存
Args:
security_type: 证券类型 (stock, future)
@ -40,21 +43,32 @@ class CacheService:
Returns:
代码列表
"""
cache_key = f"code_list:{security_type}:{contract_type}"
cached_codes = redis_client.get(cache_key)
if cached_codes is not None:
logger.info(f"从Redis缓存获取代码列表: {security_type}/{contract_type}, 共{len(cached_codes)}")
return cached_codes
adapter = sdk_manager.get_default_connection()
if not adapter:
raise RuntimeError("SDK连接失败")
codes = []
if security_type == "stock":
return adapter.get_code_list("EXTRA_STOCK_A")
codes = adapter.get_code_list("EXTRA_STOCK_A")
elif security_type == "future":
if contract_type == "main":
# 只获取主力合约
main_contracts = adapter.get_all_main_contracts()
return list(main_contracts.values())
codes = list(main_contracts.values())
else:
return adapter.get_code_list("EXTRA_FUTURE")
else:
return []
codes = adapter.get_code_list("EXTRA_FUTURE")
if codes:
redis_client.set(cache_key, codes, expire=self.CODE_LIST_CACHE_EXPIRE)
logger.info(f"代码列表已缓存到Redis: {security_type}/{contract_type}, 共{len(codes)}个, 有效期12小时")
return codes
def get_future_varieties(self) -> List[str]:
"""获取期货品种列表"""
@ -214,6 +228,10 @@ class CacheService:
for date_key in daily_stats:
daily_stats[date_key]["missing"] = daily_stats[date_key]["expected"] - daily_stats[date_key]["actual"]
# 保存缺失代码列表到任务记录
missing_code_list = [m["code"] for m in missing_codes]
task.code_list = ",".join(missing_code_list[:500]) if missing_code_list else ""
task.status = "completed"
task.success_count = len(complete_codes)
task.error_count = error_count
@ -237,7 +255,8 @@ class CacheService:
"security_type": security_type,
"period_type": period_type,
"daily_stats": daily_stats,
"missing_codes": missing_codes[:100] # 只返回前100个缺失代码
"missing_codes": missing_codes[:100],
"missing_code_list": missing_code_list
}
except Exception as e:
@ -254,6 +273,202 @@ class CacheService:
"error_message": str(e)
}
def _create_cache_task(
self,
security_type: str,
period_type: str,
start_date: date,
end_date: date,
contract_type: str = "all"
) -> CacheTask:
"""创建缓存任务记录"""
code_list = self.get_all_codes(security_type, contract_type)
if not code_list:
raise ValueError(f"无法获取{security_type}代码列表")
task = CacheTask(
task_name=f"一键缓存所有数据 - {security_type} - {contract_type} - {len(code_list)}个代码",
task_type="cache_all_data",
security_type=security_type,
period_type=period_type,
start_date=start_date,
end_date=end_date,
code_list=",".join(code_list[:100]) + "...",
status="pending",
total_count=len(code_list),
started_at=datetime.utcnow()
)
self.db.add(task)
self.db.commit()
self.db.refresh(task)
return task
def _create_fill_missing_task(
self,
security_type: str,
period_type: str,
start_date: date,
end_date: date,
missing_codes: List[str]
) -> CacheTask:
"""创建补齐缺失数据任务记录"""
if not missing_codes:
raise ValueError("没有缺失代码需要补齐")
task = CacheTask(
task_name=f"一键补齐缺失数据 - {security_type} - {len(missing_codes)}个代码",
task_type="fill_missing_data",
security_type=security_type,
period_type=period_type,
start_date=start_date,
end_date=end_date,
code_list=",".join(missing_codes[:500]),
status="pending",
total_count=len(missing_codes),
started_at=datetime.utcnow()
)
self.db.add(task)
self.db.commit()
self.db.refresh(task)
return task
def _execute_fill_missing_task(
self,
task_id: int,
security_type: str,
period_type: str,
start_date: date,
end_date: date,
missing_codes: List[str]
):
"""执行补齐缺失数据任务"""
task = self.db.query(CacheTask).filter(CacheTask.id == task_id).first()
if not task:
return
task.status = "running"
self.db.commit()
try:
success_count = 0
error_count = 0
for i, code in enumerate(missing_codes):
try:
if security_type == "stock":
self.stock_service.get_kline([code], start_date, end_date, period_type)
elif security_type == "future":
self.future_service.get_kline([code], start_date, end_date, period_type)
success_count += 1
except Exception as e:
logger.error(f"补齐{code}数据失败: {str(e)}")
error_count += 1
if (i + 1) % 10 == 0 or i == len(missing_codes) - 1:
task.success_count = success_count
task.error_count = error_count
task.progress = min(100, int((i + 1) / len(missing_codes) * 100))
self.db.commit()
task.status = "completed"
task.success_count = success_count
task.error_count = error_count
task.completed_at = datetime.utcnow()
self.db.commit()
except Exception as e:
task.status = "failed"
task.error_message = str(e)
task.completed_at = datetime.utcnow()
self.db.commit()
def _execute_cache_task(
self,
task_id: int,
security_type: str,
period_type: str,
start_date: date,
end_date: date,
contract_type: str = "all"
):
"""执行缓存任务"""
task = self.db.query(CacheTask).filter(CacheTask.id == task_id).first()
if not task:
return
task.status = "running"
self.db.commit()
code_list = self.get_all_codes(security_type, contract_type)
try:
market = "CFE" if security_type == "future" else "SH"
trading_days = self.base_service.get_trading_calendar(market, start_date, end_date)
expected_count = len(trading_days)
success_count = 0
skipped_count = 0
error_count = 0
for i, code in enumerate(code_list):
try:
if security_type == "stock" and period_type == "daily":
actual_count = self.db.query(StockKlineDaily).filter(
and_(
StockKlineDaily.code == code,
StockKlineDaily.trade_date >= start_date,
StockKlineDaily.trade_date <= end_date
)
).count()
elif security_type == "future" and period_type == "daily":
actual_count = self.db.query(FutureKlineDaily).filter(
and_(
FutureKlineDaily.code == code,
FutureKlineDaily.trade_date >= start_date,
FutureKlineDaily.trade_date <= end_date
)
).count()
else:
actual_count = 0
if actual_count >= expected_count:
skipped_count += 1
continue
if security_type == "stock":
self.stock_service.get_kline([code], start_date, end_date, period_type)
elif security_type == "future":
self.future_service.get_kline([code], start_date, end_date, period_type)
success_count += 1
except Exception as e:
logger.error(f"缓存{code}数据失败: {str(e)}")
error_count += 1
if (i + 1) % 10 == 0 or i == len(code_list) - 1:
task.success_count = success_count
task.error_count = error_count
task.progress = min(100, int((i + 1) / len(code_list) * 100))
self.db.commit()
task.status = "completed"
task.success_count = success_count
task.error_count = error_count
task.completed_at = datetime.utcnow()
self.db.commit()
except Exception as e:
task.status = "failed"
task.error_message = str(e)
task.completed_at = datetime.utcnow()
self.db.commit()
def cache_all_missing_data(
self,
security_type: str,

@ -66,3 +66,13 @@ export const getFutureVarieties = () => {
export const getMainContracts = () => {
return request.get('/cache/main-contracts')
}
export const fillMissingData = (data: {
security_type: string
period_type: string
contract_type?: string
start_date: string
end_date: string
}) => {
return request.post('/cache/fill-missing', data)
}

@ -72,7 +72,17 @@
<!-- 检测结果汇总 -->
<el-card class="summary-card" v-if="detectResult">
<template #header>
<span>检测结果汇总</span>
<div style="display: flex; justify-content: space-between; align-items: center;">
<span>检测结果汇总</span>
<el-button
type="success"
@click="handleFillMissing"
:loading="fillingMissing"
:disabled="!detectResult.missing_count || detectResult.missing_count === 0"
>
<el-icon><Download /></el-icon>
</el-button>
</div>
</template>
<el-row :gutter="20">
<el-col :span="6">
@ -120,6 +130,22 @@
</el-row>
</el-card>
<!-- 缓存进度 -->
<el-card class="progress-card" v-if="cacheTask">
<template #header>
<span>缓存进度</span>
</template>
<el-progress :percentage="Math.round(cacheTask.progress || 0)" :status="getProgressStatus(cacheTask.status)" />
<el-descriptions :column="4" border style="margin-top: 15px;">
<el-descriptions-item label="任务名称">{{ cacheTask.task_name }}</el-descriptions-item>
<el-descriptions-item label="状态">
<el-tag :type="getStatusType(cacheTask.status)">{{ cacheTask.status }}</el-tag>
</el-descriptions-item>
<el-descriptions-item label="总数">{{ cacheTask.total_count }}</el-descriptions-item>
<el-descriptions-item label="已处理">{{ cacheTask.success_count || 0 }}</el-descriptions-item>
</el-descriptions>
</el-card>
<!-- 每日数据统计 -->
<el-card class="daily-card" v-if="detectResult && detectResult.daily_stats">
<template #header>
@ -203,15 +229,17 @@
<script setup lang="ts">
import { ref, reactive, computed } from 'vue'
import { ElMessage } from 'element-plus'
import { detectMissingData, batchCacheData, detectAllMissingData, cacheAllMissingData } from '@/api/cache'
import { detectMissingData, batchCacheData, detectAllMissingData, cacheAllMissingData, getCacheTask, fillMissingData } from '@/api/cache'
const detecting = ref(false)
const caching = ref(false)
const detectingAll = ref(false)
const cachingAll = ref(false)
const fillingMissing = ref(false)
const codeInput = ref('000001.SZ\n600000.SH')
const detectResult = ref<any>(null)
const batchDetectResult = ref<any[]>([])
const cacheTask = ref<any>(null)
const hasMissing = computed(() => batchDetectResult.value.some(r => r.missingCount > 0))
@ -253,6 +281,12 @@ function getStatusType(status: string) {
return 'danger'
}
function getProgressStatus(status: string) {
if (status === 'completed') return 'success'
if (status === 'failed') return 'exception'
return undefined
}
const parseCodes = () => {
return codeInput.value
.split(/[\n,]/)
@ -260,6 +294,25 @@ const parseCodes = () => {
.filter(c => c.length > 0)
}
const pollTaskProgress = async (taskId: number) => {
const poll = async () => {
if (!cacheTask.value || cacheTask.value.status === 'running' || cacheTask.value.status === 'pending') {
const res: any = await getCacheTask(taskId)
if (res.data && res.data.task) {
cacheTask.value = res.data.task
}
if (cacheTask.value && (cacheTask.value.status === 'running' || cacheTask.value.status === 'pending')) {
setTimeout(poll, 2000)
} else if (cacheTask.value && cacheTask.value.status === 'completed') {
ElMessage.success(`缓存完成:成功${cacheTask.value.success_count}个,错误${cacheTask.value.error_count}`)
} else if (cacheTask.value && cacheTask.value.status === 'failed') {
ElMessage.error(`缓存失败:${cacheTask.value.error_message}`)
}
}
}
setTimeout(poll, 1000)
}
const handleDetectAll = async () => {
detectingAll.value = true
detectResult.value = null
@ -290,6 +343,7 @@ const handleDetectAll = async () => {
const handleCacheAll = async () => {
cachingAll.value = true
cacheTask.value = null
try {
const res: any = await cacheAllMissingData({
security_type: form.securityType,
@ -300,7 +354,17 @@ const handleCacheAll = async () => {
})
if (res.data) {
cacheTask.value = {
task_id: res.data.task_id,
task_name: res.data.task_name,
status: res.data.status,
total_count: res.data.total_count,
progress: res.data.progress,
success_count: 0,
error_count: 0
}
ElMessage.success(`缓存任务已启动,共${res.data.total_count}个代码`)
pollTaskProgress(res.data.task_id)
}
} catch (error) {
console.error(error)
@ -310,6 +374,45 @@ const handleCacheAll = async () => {
}
}
const handleFillMissing = async () => {
fillingMissing.value = true
cacheTask.value = null
try {
const res: any = await fillMissingData({
security_type: form.securityType,
period_type: form.periodType,
contract_type: form.contractType,
start_date: form.startDate,
end_date: form.endDate
})
if (res.data) {
if (res.data.message) {
ElMessage.info(res.data.message)
} else {
cacheTask.value = {
task_id: res.data.task_id,
task_name: res.data.task_name,
status: res.data.status,
total_count: res.data.total_count,
progress: res.data.progress,
success_count: 0,
error_count: 0
}
ElMessage.success(`补齐任务已启动,共${res.data.missing_count}个缺失代码`)
if (res.data.task_id) {
pollTaskProgress(res.data.task_id)
}
}
}
} catch (error) {
console.error(error)
ElMessage.error('补齐失败')
} finally {
fillingMissing.value = false
}
}
const handleDetect = async () => {
const codes = parseCodes()
if (codes.length === 0) {
@ -379,6 +482,10 @@ const showDetail = (row: any) => {
margin-top: 20px;
}
.progress-card {
margin-top: 20px;
}
.daily-card {
margin-top: 20px;
}

Loading…
Cancel
Save