You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
102 lines
3.4 KiB
102 lines
3.4 KiB
"""
|
|
基础数据路由
|
|
"""
|
|
from typing import List
|
|
from fastapi import APIRouter, Depends, Query
|
|
from sqlalchemy.orm import Session
|
|
|
|
from app.db.session import get_db
|
|
from app.schemas.base import ResponseModel
|
|
from app.services.base_data_service import BaseDataService
|
|
from app.core.security import get_current_user
|
|
from app.models.user import User
|
|
from app.utils.date_utils import parse_date
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
@router.get("/codes", response_model=ResponseModel)
|
|
async def get_code_list(
|
|
security_type: str = Query(..., description="证券类型: EXTRA_STOCK_A, EXTRA_FUTURE, EXTRA_ETF, EXTRA_INDEX_A"),
|
|
db: Session = Depends(get_db),
|
|
current_user: User = Depends(get_current_user)
|
|
):
|
|
"""获取代码列表"""
|
|
service = BaseDataService(db)
|
|
codes = service.get_code_list(security_type)
|
|
return ResponseModel(data={"codes": codes[:100]}) # 限制返回数量
|
|
|
|
|
|
@router.get("/codes/{code}/info", response_model=ResponseModel)
|
|
async def get_code_info(
|
|
code: str,
|
|
db: Session = Depends(get_db),
|
|
current_user: User = Depends(get_current_user)
|
|
):
|
|
"""获取证券信息"""
|
|
service = BaseDataService(db)
|
|
# 根据代码判断证券类型
|
|
security_type = service.get_security_type(code)
|
|
|
|
# 获取对应类型的代码信息
|
|
if security_type == "stock":
|
|
info = service.get_code_info("EXTRA_STOCK_A")
|
|
elif security_type == "future":
|
|
info = service.get_code_info("EXTRA_FUTURE")
|
|
else:
|
|
info = None
|
|
|
|
if info is not None and not info.empty:
|
|
code_info = info[info.get("code") == code]
|
|
if not code_info.empty:
|
|
return ResponseModel(data=code_info.to_dict("records")[0])
|
|
|
|
return ResponseModel(data={"code": code, "security_type": security_type})
|
|
|
|
|
|
@router.get("/calendar", response_model=ResponseModel)
|
|
async def get_trading_calendar(
|
|
market: str = Query("SH", description="市场: SH, SZ, CFE"),
|
|
start_date: str = Query(..., description="开始日期(YYYYMMDD)"),
|
|
end_date: str = Query(..., description="结束日期(YYYYMMDD)"),
|
|
db: Session = Depends(get_db),
|
|
current_user: User = Depends(get_current_user)
|
|
):
|
|
"""获取交易日历"""
|
|
service = BaseDataService(db)
|
|
start = parse_date(start_date)
|
|
end = parse_date(end_date)
|
|
calendar = service.get_trading_calendar(market, start, end)
|
|
|
|
return ResponseModel(data={
|
|
"market": market,
|
|
"start_date": start_date,
|
|
"end_date": end_date,
|
|
"trading_days": [d.isoformat() for d in calendar],
|
|
"count": len(calendar)
|
|
})
|
|
|
|
|
|
@router.get("/calendar/trading-days", response_model=ResponseModel)
|
|
async def get_trading_days(
|
|
market: str = Query("SH", description="市场: SH, SZ, CFE"),
|
|
days: int = Query(30, description="最近交易日数量"),
|
|
db: Session = Depends(get_db),
|
|
current_user: User = Depends(get_current_user)
|
|
):
|
|
"""获取最近交易日列表"""
|
|
from datetime import date, timedelta
|
|
|
|
service = BaseDataService(db)
|
|
end_date = date.today()
|
|
start_date = end_date - timedelta(days=days * 2) # 获取更多天数以确保有足够交易日
|
|
|
|
calendar = service.get_trading_calendar(market, start_date, end_date)
|
|
recent_days = calendar[-days:] if len(calendar) > days else calendar
|
|
|
|
return ResponseModel(data={
|
|
"market": market,
|
|
"trading_days": [d.isoformat() for d in recent_days],
|
|
"count": len(recent_days)
|
|
})
|