You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
102 lines
3.2 KiB
102 lines
3.2 KiB
# -*- coding: utf-8 -*-
|
|
"""
|
|
===================================
|
|
get_latest_data 测试
|
|
===================================
|
|
|
|
职责:
|
|
1. 验证 get_latest_data 方法
|
|
2. 测试返回数据按日期降序排列
|
|
3. 测试 days 参数限制
|
|
"""
|
|
|
|
import os
|
|
import tempfile
|
|
import unittest
|
|
from datetime import date, timedelta
|
|
|
|
import pandas as pd
|
|
|
|
from src.config import Config
|
|
from src.storage import DatabaseManager, StockDaily
|
|
|
|
|
|
class GetLatestDataTestCase(unittest.TestCase):
|
|
"""get_latest_data 方法测试"""
|
|
#
|
|
# def setUp(self) -> None:
|
|
# """为每个用例初始化独立数据库"""
|
|
# self._temp_dir = tempfile.TemporaryDirectory()
|
|
# self._db_path = os.path.join(self._temp_dir.name, "test_get_latest_data.db")
|
|
# os.environ["DATABASE_PATH"] = self._db_path
|
|
#
|
|
# Config._instance = None
|
|
# DatabaseManager.reset_instance()
|
|
# self.db = DatabaseManager.get_instance()
|
|
#
|
|
# def tearDown(self) -> None:
|
|
# """清理资源"""
|
|
# DatabaseManager.reset_instance()
|
|
# self._temp_dir.cleanup()
|
|
|
|
def _insert_stock_data(self, code: str, days_ago: int, close: float) -> None:
|
|
"""插入测试用股票数据"""
|
|
target_date = date.today() - timedelta(days=days_ago)
|
|
df = pd.DataFrame([{
|
|
'date': target_date,
|
|
'open': close - 1,
|
|
'high': close + 1,
|
|
'low': close - 2,
|
|
'close': close,
|
|
'volume': 1000000,
|
|
'amount': 10000000,
|
|
'pct_chg': 1.5,
|
|
}])
|
|
self.db.save_daily_data(df, code, data_source="TestData")
|
|
|
|
def test_get_latest_data_returns_empty_when_no_data(self) -> None:
|
|
"""无数据时返回空列表"""
|
|
result = self.db.get_latest_data("999999", days=2)
|
|
self.assertEqual(result, [])
|
|
|
|
def test_get_latest_data_returns_correct_count(self) -> None:
|
|
"""返回正确数量的数据"""
|
|
# 插入5天数据
|
|
for i in range(5):
|
|
self._insert_stock_data("600519", days_ago=i, close=100.0 + i)
|
|
|
|
# 请求2天数据
|
|
result = self.db.get_latest_data("600519", days=2)
|
|
self.assertEqual(len(result), 2)
|
|
|
|
# 请求5天数据
|
|
result = self.db.get_latest_data("600519", days=5)
|
|
self.assertEqual(len(result), 5)
|
|
|
|
def test_get_latest_data_ordered_by_date_desc(self) -> None:
|
|
"""验证数据按日期降序排列"""
|
|
# 插入3天数据
|
|
for i in range(3):
|
|
self._insert_stock_data("600519", days_ago=i, close=100.0 + i)
|
|
|
|
result = self.db.get_latest_data("600519", days=3)
|
|
|
|
# 验证日期降序(最新日期在前)
|
|
self.assertEqual(len(result), 3)
|
|
self.assertGreater(result[0].date, result[1].date)
|
|
self.assertGreater(result[1].date, result[2].date)
|
|
|
|
def test_get_latest_data_filters_by_code(self) -> None:
|
|
"""验证按股票代码过滤"""
|
|
# 插入不同股票的数据
|
|
self._insert_stock_data("600519", days_ago=0, close=100.0)
|
|
self._insert_stock_data("000001", days_ago=0, close=50.0)
|
|
|
|
result = self.db.get_latest_data("600519", days=5)
|
|
self.assertEqual(len(result), 1)
|
|
self.assertEqual(result[0].code, "600519")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|