You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

291 lines
7.0 KiB

package repository
import (
"context"
"database/sql"
"fmt"
"strings"
"time"
"market-data-service/api"
)
// DB PostgreSQL连接
type DB struct {
*sql.DB
}
// NewDB 创建数据库连接
func NewDB(connStr string) (*DB, error) {
db, err := sql.Open("postgres", connStr)
if err != nil {
return nil, err
}
if err := db.Ping(); err != nil {
return nil, err
}
return &DB{db}, nil
}
// ============================================
// 股票Repository
// ============================================
// StockRepository 股票数据仓库
type StockRepository struct {
db *DB
}
// NewStockRepository 创建股票Repository
func NewStockRepository(db *DB) *StockRepository {
return &StockRepository{db: db}
}
// GetKLines 获取K线数据
func (r *StockRepository) GetKLines(ctx context.Context, symbol string, freq api.Frequency, start, end time.Time, adjust api.AdjustType) ([]api.KLineItem, error) {
tableName := fmt.Sprintf("stock.klines_%s", freq)
query := fmt.Sprintf(`
SELECT ts, open, high, low, close, volume, amount
FROM %s
WHERE symbol_id = $1 AND ts >= $2 AND ts <= $3
ORDER BY ts ASC
`, tableName)
rows, err := r.db.QueryContext(ctx, query, symbol, start, end)
if err != nil {
return nil, err
}
defer rows.Close()
var items []api.KLineItem
for rows.Next() {
var item api.KLineItem
if err := rows.Scan(&item.Time, &item.Open, &item.High, &item.Low, &item.Close, &item.Volume, &item.Amount); err != nil {
return nil, err
}
items = append(items, item)
}
return items, rows.Err()
}
// SaveKLines 保存K线数据
func (r *StockRepository) SaveKLines(ctx context.Context, freq api.Frequency, items []api.KLineItem) error {
if len(items) == 0 {
return nil
}
tableName := fmt.Sprintf("stock.klines_%s", freq)
// 使用批量插入
valueStrs := make([]string, 0, len(items))
args := make([]interface{}, 0, len(items)*7)
argIdx := 1
for _, item := range items {
valueStrs = append(valueStrs, fmt.Sprintf("($%d, $%d, $%d, $%d, $%d, $%d, $%d)",
argIdx, argIdx+1, argIdx+2, argIdx+3, argIdx+4, argIdx+5, argIdx+6))
args = append(args, item.Symbol, item.Time, item.Open, item.High, item.Low, item.Close, item.Volume, item.Amount)
argIdx += 7
}
query := fmt.Sprintf(`
INSERT INTO %s (symbol_id, ts, open, high, low, close, volume, amount)
VALUES %s
ON CONFLICT (symbol_id, ts) DO UPDATE SET
open = EXCLUDED.open,
high = EXCLUDED.high,
low = EXCLUDED.low,
close = EXCLUDED.close,
volume = EXCLUDED.volume,
amount = EXCLUDED.amount
`, tableName, strings.Join(valueStrs, ","))
_, err := r.db.ExecContext(ctx, query, args...)
return err
}
// ListSymbols 查询标的列表
func (r *StockRepository) ListSymbols(ctx context.Context, req *api.SymbolListRequest) ([]api.Symbol, int, error) {
whereClause := "WHERE 1=1"
args := []interface{}{}
argIdx := 1
if req.Exchange != "" {
whereClause += fmt.Sprintf(" AND exchange = $%d", argIdx)
args = append(args, req.Exchange)
argIdx++
}
if req.Keyword != "" {
whereClause += fmt.Sprintf(" AND (symbol_id ILIKE $%d OR name ILIKE $%d)", argIdx, argIdx)
args = append(args, "%"+req.Keyword+"%")
argIdx++
}
// 查询总数
countQuery := "SELECT COUNT(*) FROM stock.symbols " + whereClause
var total int
if err := r.db.QueryRowContext(ctx, countQuery, args...).Scan(&total); err != nil {
return nil, 0, err
}
// 查询数据
query := fmt.Sprintf(`
SELECT symbol_id, symbol_type, exchange, name, name_en, list_date, delist_date, industry, status
FROM stock.symbols
%s
ORDER BY symbol_id
LIMIT $%d OFFSET $%d
`, whereClause, argIdx, argIdx+1)
args = append(args, req.Size, (req.Page-1)*req.Size)
rows, err := r.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, 0, err
}
defer rows.Close()
var symbols []api.Symbol
for rows.Next() {
var s api.Symbol
var listDate, delistDate sql.NullTime
if err := rows.Scan(&s.SymbolID, &s.SymbolType, &s.Exchange, &s.Name, &s.NameEN,
&listDate, &delistDate, &s.Industry, &s.Status); err != nil {
return nil, 0, err
}
if listDate.Valid {
s.ListDate = &listDate.Time
}
if delistDate.Valid {
s.DelistDate = &delistDate.Time
}
symbols = append(symbols, s)
}
return symbols, total, rows.Err()
}
// GetTradingDates 获取交易日历
func (r *StockRepository) GetTradingDates(ctx context.Context, start, end string) (*api.TradingDatesData, error) {
query := `
SELECT trade_date
FROM stock.trading_calendar
WHERE trade_date >= $1 AND trade_date <= $2 AND is_trading_day = true
ORDER BY trade_date ASC
`
rows, err := r.db.QueryContext(ctx, query, start, end)
if err != nil {
return nil, err
}
defer rows.Close()
var dates []string
for rows.Next() {
var date string
if err := rows.Scan(&date); err != nil {
return nil, err
}
dates = append(dates, date)
}
// 计算总天数
startDate, _ := time.Parse("20060102", start)
endDate, _ := time.Parse("20060102", end)
totalDays := int(endDate.Sub(startDate).Hours()/24) + 1
return &api.TradingDatesData{
Start: start,
End: end,
TotalDays: totalDays,
TradingDays: len(dates),
TradingDates: dates,
}, rows.Err()
}
// SaveSymbols 保存标的列表
func (r *StockRepository) SaveSymbols(ctx context.Context, symbols []api.Symbol) error {
if len(symbols) == 0 {
return nil
}
tx, err := r.db.BeginTx(ctx, nil)
if err != nil {
return err
}
defer tx.Rollback()
stmt, err := tx.PrepareContext(ctx, `
INSERT INTO stock.symbols (symbol_id, symbol_type, exchange, name, name_en, list_date, delist_date, industry, status)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
ON CONFLICT (symbol_id) DO UPDATE SET
name = EXCLUDED.name,
name_en = EXCLUDED.name_en,
list_date = EXCLUDED.list_date,
delist_date = EXCLUDED.delist_date,
industry = EXCLUDED.industry,
status = EXCLUDED.status,
updated_at = NOW()
`)
if err != nil {
return err
}
defer stmt.Close()
for _, s := range symbols {
var listDate, delistDate interface{}
if s.ListDate != nil {
listDate = *s.ListDate
}
if s.DelistDate != nil {
delistDate = *s.DelistDate
}
_, err := stmt.ExecContext(ctx, s.SymbolID, s.SymbolType, s.Exchange, s.Name, s.NameEN,
listDate, delistDate, s.Industry, s.Status)
if err != nil {
return err
}
}
return tx.Commit()
}
// SaveTradingCalendar 保存交易日历
func (r *StockRepository) SaveTradingCalendar(ctx context.Context, dates []api.TradeCalData) error {
if len(dates) == 0 {
return nil
}
tx, err := r.db.BeginTx(ctx, nil)
if err != nil {
return err
}
defer tx.Rollback()
stmt, err := tx.PrepareContext(ctx, `
INSERT INTO stock.trading_calendar (trade_date, is_trading_day, week_day)
VALUES ($1, $2, $3)
ON CONFLICT (trade_date) DO UPDATE SET
is_trading_day = EXCLUDED.is_trading_day,
week_day = EXCLUDED.week_day,
updated_at = NOW()
`)
if err != nil {
return err
}
defer stmt.Close()
for _, d := range dates {
_, err := stmt.ExecContext(ctx, d.Date.Format("2006-01-02"), d.IsTradingDay, int(d.Date.Weekday())+1)
if err != nil {
return err
}
}
return tx.Commit()
}