You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

102 lines
3.4 KiB

"""
基础数据路由
"""
from typing import List
from fastapi import APIRouter, Depends, Query
from sqlalchemy.orm import Session
from app.db.session import get_db
from app.schemas.base import ResponseModel
from app.services.base_data_service import BaseDataService
from app.core.security import get_current_user
from app.models.user import User
from app.utils.date_utils import parse_date
router = APIRouter()
@router.get("/codes", response_model=ResponseModel)
async def get_code_list(
security_type: str = Query(..., description="证券类型: EXTRA_STOCK_A, EXTRA_FUTURE, EXTRA_ETF, EXTRA_INDEX_A"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""获取代码列表"""
service = BaseDataService(db)
codes = service.get_code_list(security_type)
return ResponseModel(data={"codes": codes[:100]}) # 限制返回数量
@router.get("/codes/{code}/info", response_model=ResponseModel)
async def get_code_info(
code: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""获取证券信息"""
service = BaseDataService(db)
# 根据代码判断证券类型
security_type = service.get_security_type(code)
# 获取对应类型的代码信息
if security_type == "stock":
info = service.get_code_info("EXTRA_STOCK_A")
elif security_type == "future":
info = service.get_code_info("EXTRA_FUTURE")
else:
info = None
if info is not None and not info.empty:
code_info = info[info.get("code") == code]
if not code_info.empty:
return ResponseModel(data=code_info.to_dict("records")[0])
return ResponseModel(data={"code": code, "security_type": security_type})
@router.get("/calendar", response_model=ResponseModel)
async def get_trading_calendar(
market: str = Query("SH", description="市场: SH, SZ, CFE"),
start_date: str = Query(..., description="开始日期(YYYYMMDD)"),
end_date: str = Query(..., description="结束日期(YYYYMMDD)"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""获取交易日历"""
service = BaseDataService(db)
start = parse_date(start_date)
end = parse_date(end_date)
calendar = service.get_trading_calendar(market, start, end)
return ResponseModel(data={
"market": market,
"start_date": start_date,
"end_date": end_date,
"trading_days": [d.isoformat() for d in calendar],
"count": len(calendar)
})
@router.get("/calendar/trading-days", response_model=ResponseModel)
async def get_trading_days(
market: str = Query("SH", description="市场: SH, SZ, CFE"),
days: int = Query(30, description="最近交易日数量"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""获取最近交易日列表"""
from datetime import date, timedelta
service = BaseDataService(db)
end_date = date.today()
start_date = end_date - timedelta(days=days * 2) # 获取更多天数以确保有足够交易日
calendar = service.get_trading_calendar(market, start_date, end_date)
recent_days = calendar[-days:] if len(calendar) > days else calendar
return ResponseModel(data={
"market": market,
"trading_days": [d.isoformat() for d in recent_days],
"count": len(recent_days)
})