You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

256 lines
6.7 KiB

// Package sync 数据同步工具
package main
import (
"context"
"flag"
"fmt"
"log"
"os"
"time"
"market-data-service/adapter/tushare"
"market-data-service/api"
"market-data-service/internal/repository"
)
func main() {
var (
syncType = flag.String("type", "", "同步类型: stocks, futures, calendar, klines")
startDate = flag.String("start", "", "开始日期 YYYYMMDD")
endDate = flag.String("end", "", "结束日期 YYYYMMDD")
symbol = flag.String("symbol", "", "标的代码")
underlying = flag.String("underlying", "", "期货品种代码")
freq = flag.String("freq", "1d", "K线周期: 1m/5m/15m/30m/60m/1d")
)
flag.Parse()
if *syncType == "" {
flag.Usage()
os.Exit(1)
}
// 配置
tushareToken := os.Getenv("TUSHARE_TOKEN")
if tushareToken == "" {
log.Fatal("TUSHARE_TOKEN environment variable is required")
}
dbURL := os.Getenv("DATABASE_URL")
if dbURL == "" {
dbURL = "postgres://user:password@localhost:5432/marketdata?sslmode=disable"
}
// 连接数据库
db, err := repository.NewDB(dbURL)
if err != nil {
log.Fatalf("Failed to connect to database: %v", err)
}
defer db.Close()
// 初始化Tushare客户端
client := tushare.NewClient(tushareToken)
ctx := context.Background()
switch *syncType {
case "stocks":
syncStocks(ctx, client, db)
case "futures":
syncFutures(ctx, client, db)
case "calendar":
syncCalendar(ctx, client, db, *startDate, *endDate)
case "klines":
syncKLines(ctx, client, db, *symbol, *startDate, *endDate, *freq)
default:
log.Fatalf("Unknown sync type: %s", *syncType)
}
}
// syncStocks 同步股票基础信息
func syncStocks(ctx context.Context, client *tushare.Client, db *repository.DB) {
log.Println("Syncing stock basic info...")
data, err := client.GetStockBasic()
if err != nil {
log.Fatalf("Failed to get stock basic: %v", err)
}
repo := repository.NewStockRepository(db)
symbols := make([]api.Symbol, 0, len(data))
for _, d := range data {
if d.ListStatus != "L" {
continue // 只同步上市状态的
}
listDate, _ := time.Parse("20060102", d.ListDate)
symbols = append(symbols, api.Symbol{
SymbolID: d.TSCode,
SymbolType: api.SymbolTypeStock,
Exchange: api.Exchange(d.Exchange),
Name: d.Name,
NameEN: d.EnName,
Industry: d.Industry,
ListDate: &listDate,
Status: "active",
})
}
if err := repo.SaveSymbols(ctx, symbols); err != nil {
log.Fatalf("Failed to save symbols: %v", err)
}
log.Printf("Synced %d stocks", len(symbols))
}
// syncFutures 同步期货基础信息
func syncFutures(ctx context.Context, client *tushare.Client, db *repository.DB) {
log.Println("Syncing futures basic info...")
data, err := client.GetFuturesBasic("")
if err != nil {
log.Fatalf("Failed to get futures basic: %v", err)
}
repo := repository.NewFuturesRepository(db)
symbols := make([]api.Symbol, 0, len(data))
for _, d := range data {
listDate, _ := time.Parse("20060102", d.ListDate)
delistDate, _ := time.Parse("20060102", d.DelistDate)
status := "active"
if time.Now().After(delistDate) {
status = "expired"
}
symbols = append(symbols, api.Symbol{
SymbolID: d.TSCode,
SymbolType: api.SymbolTypeFutures,
Exchange: api.Exchange(d.Exchange),
Name: d.Name,
Underlying: d.FutCode,
ContractMonth: d.Symbol[len(d.FutCode):],
ListDate: &listDate,
DelistDate: &delistDate,
Status: status,
})
}
if err := repo.SaveSymbols(ctx, symbols); err != nil {
log.Fatalf("Failed to save symbols: %v", err)
}
log.Printf("Synced %d futures", len(symbols))
}
// syncCalendar 同步交易日历
func syncCalendar(ctx context.Context, client *tushare.Client, db *repository.DB, start, end string) {
if start == "" {
start = time.Now().AddDate(0, 0, -30).Format("20060102")
}
if end == "" {
end = time.Now().AddDate(0, 6, 0).Format("20060102")
}
log.Printf("Syncing trading calendar from %s to %s...", start, end)
// 同步股票交易日历(上交所)
stockData, err := client.GetTradeCal("SSE", start, end, -1)
if err != nil {
log.Fatalf("Failed to get stock calendar: %v", err)
}
stockRepo := repository.NewStockRepository(db)
stockDates := make([]api.TradeCalData, len(stockData))
for i, d := range stockData {
calDate, _ := time.Parse("20060102", d.CalDate)
stockDates[i] = api.TradeCalData{
Date: calDate,
IsTradingDay: d.IsOpen == 1,
}
}
if err := stockRepo.SaveTradingCalendar(ctx, stockDates); err != nil {
log.Fatalf("Failed to save stock calendar: %v", err)
}
// 同步期货交易日历(使用相同的,实际可能需要单独配置)
futuresRepo := repository.NewFuturesRepository(db)
if err := futuresRepo.SaveTradingCalendar(ctx, stockDates); err != nil {
log.Fatalf("Failed to save futures calendar: %v", err)
}
log.Printf("Synced %d calendar days", len(stockDates))
}
// syncKLines 同步K线数据
func syncKLines(ctx context.Context, client *tushare.Client, db *repository.DB, symbol, start, end, freq string) {
if symbol == "" {
log.Fatal("symbol is required for klines sync")
}
if start == "" {
start = time.Now().AddDate(0, 0, -7).Format("20060102")
}
if end == "" {
end = time.Now().Format("20060102")
}
log.Printf("Syncing %s klines for %s from %s to %s...", freq, symbol, start, end)
adapter := tushare.NewAdapter()
if err := adapter.Connect(map[string]string{"token": os.Getenv("TUSHARE_TOKEN")}); err != nil {
log.Fatalf("Failed to connect adapter: %v", err)
}
data, err := adapter.FetchKLines(symbol, start, end, freq)
if err != nil {
log.Fatalf("Failed to fetch klines: %v", err)
}
// 转换为api.KLineItem并保存
items := make([]api.KLineItem, len(data))
for i, d := range data {
ts := time.Unix(d.Time, 0)
items[i] = api.KLineItem{
Time: ts,
Open: d.Open,
High: d.High,
Low: d.Low,
Close: d.Close,
Volume: d.Volume,
Amount: d.Amount,
}
if d.OpenInterest > 0 {
oi := d.OpenInterest
items[i].OpenInterest = &oi
}
}
// 判断股票还是期货并保存
if isStock(symbol) {
repo := repository.NewStockRepository(db)
if err := repo.SaveKLines(ctx, api.Frequency(freq), items); err != nil {
log.Fatalf("Failed to save stock klines: %v", err)
}
} else {
repo := repository.NewFuturesRepository(db)
if err := repo.SaveKLines(ctx, api.Frequency(freq), symbol, items); err != nil {
log.Fatalf("Failed to save futures klines: %v", err)
}
}
log.Printf("Synced %d klines", len(items))
}
// isStock 判断是否为股票代码
func isStock(symbol string) bool {
return contains(symbol, ".SH") || contains(symbol, ".SZ") || contains(symbol, ".BJ")
}
func contains(s, substr string) bool {
return len(s) >= len(substr) && s[len(s)-len(substr):] == substr
}