package repository import ( "context" "database/sql" "fmt" "strings" "time" "market-data-service/api" ) // FuturesRepository 期货数据仓库 type FuturesRepository struct { db *DB } // NewFuturesRepository 创建期货Repository func NewFuturesRepository(db *DB) *FuturesRepository { return &FuturesRepository{db: db} } // GetKLines 获取K线数据 func (r *FuturesRepository) GetKLines(ctx context.Context, symbol string, freq api.Frequency, start, end time.Time) ([]api.KLineItem, error) { tableName := fmt.Sprintf("futures.klines_%s", freq) query := fmt.Sprintf(` SELECT ts, open, high, low, close, volume, amount, open_interest FROM %s WHERE symbol_id = $1 AND ts >= $2 AND ts <= $3 ORDER BY ts ASC `, tableName) rows, err := r.db.QueryContext(ctx, query, symbol, start, end) if err != nil { return nil, err } defer rows.Close() var items []api.KLineItem for rows.Next() { var item api.KLineItem var oi sql.NullInt64 if err := rows.Scan( &item.Time, &item.Open, &item.High, &item.Low, &item.Close, &item.Volume, &item.Amount, &oi); err != nil { return nil, err } if oi.Valid { item.OpenInterest = &oi.Int64 } items = append(items, item) } return items, rows.Err() } // SaveKLines 保存K线数据 func (r *FuturesRepository) SaveKLines(ctx context.Context, freq api.Frequency, symbol string, items []api.KLineItem) error { if len(items) == 0 { return nil } tableName := fmt.Sprintf("futures.klines_%s", freq) // 使用批量插入 valueStrs := make([]string, 0, len(items)) args := make([]interface{}, 0, len(items)*8) argIdx := 1 for _, item := range items { valueStrs = append(valueStrs, fmt.Sprintf("($%d, $%d, $%d, $%d, $%d, $%d, $%d, $%d)", argIdx, argIdx+1, argIdx+2, argIdx+3, argIdx+4, argIdx+5, argIdx+6, argIdx+7)) args = append(args, symbol, item.Time, item.Open, item.High, item.Low, item.Close, item.Volume, item.Amount) if item.OpenInterest != nil { args = append(args, *item.OpenInterest) } else { args = append(args, nil) } argIdx += 8 } query := fmt.Sprintf(` INSERT INTO %s (symbol_id, ts, open, high, low, close, volume, amount, open_interest) VALUES %s ON CONFLICT (symbol_id, ts) DO UPDATE SET open = EXCLUDED.open, high = EXCLUDED.high, low = EXCLUDED.low, close = EXCLUDED.close, volume = EXCLUDED.volume, amount = EXCLUDED.amount, open_interest = EXCLUDED.open_interest `, tableName, strings.Join(valueStrs, ",")) _, err := r.db.ExecContext(ctx, query, args...) return err } // ListSymbols 查询标的列表 func (r *FuturesRepository) ListSymbols(ctx context.Context, req *api.SymbolListRequest) ([]api.Symbol, int, error) { whereClause := "WHERE 1=1" args := []interface{}{} argIdx := 1 if req.Exchange != "" { whereClause += fmt.Sprintf(" AND exchange = $%d", argIdx) args = append(args, req.Exchange) argIdx++ } if req.Underlying != "" { whereClause += fmt.Sprintf(" AND underlying = $%d", argIdx) args = append(args, req.Underlying) argIdx++ } if req.Keyword != "" { whereClause += fmt.Sprintf(" AND (symbol_id ILIKE $%d OR name ILIKE $%d)", argIdx, argIdx) args = append(args, "%"+req.Keyword+"%") argIdx++ } // 查询总数 countQuery := "SELECT COUNT(*) FROM futures.symbols " + whereClause var total int if err := r.db.QueryRowContext(ctx, countQuery, args...).Scan(&total); err != nil { return nil, 0, err } // 查询数据 query := fmt.Sprintf(` SELECT symbol_id, symbol_type, exchange, name, underlying, contract_month, list_date, delist_date, status FROM futures.symbols %s ORDER BY symbol_id LIMIT $%d OFFSET $%d `, whereClause, argIdx, argIdx+1) args = append(args, req.Size, (req.Page-1)*req.Size) rows, err := r.db.QueryContext(ctx, query, args...) if err != nil { return nil, 0, err } defer rows.Close() var symbols []api.Symbol for rows.Next() { var s api.Symbol var listDate, delistDate sql.NullTime if err := rows.Scan( &s.SymbolID, &s.SymbolType, &s.Exchange, &s.Name, &s.Underlying, &s.ContractMonth, &listDate, &delistDate, &s.Status); err != nil { return nil, 0, err } if listDate.Valid { s.ListDate = &listDate.Time } if delistDate.Valid { s.DelistDate = &delistDate.Time } symbols = append(symbols, s) } return symbols, total, rows.Err() } // GetContractsByUnderlying 根据品种获取合约 func (r *FuturesRepository) GetContractsByUnderlying(ctx context.Context, underlying string, exchange string) (*api.FuturesContractsData, error) { query := ` SELECT symbol_id, symbol_type, exchange, name, underlying, contract_month, list_date, delist_date, status FROM futures.symbols WHERE underlying = $1 AND status = 'active' ` args := []interface{}{underlying} if exchange != "" { query += " AND exchange = $2" args = append(args, exchange) } query += " ORDER BY contract_month ASC" rows, err := r.db.QueryContext(ctx, query, args...) if err != nil { return nil, err } defer rows.Close() var contracts []api.FuturesContractInfo for rows.Next() { var c api.FuturesContractInfo var listDate, delistDate sql.NullTime if err := rows.Scan( &c.SymbolID, &c.SymbolType, &c.Exchange, &c.Name, &c.Underlying, &c.ContractMonth, &listDate, &delistDate, &c.Status); err != nil { return nil, err } if listDate.Valid { c.ListDate = &listDate.Time } if delistDate.Valid { c.DelistDate = &delistDate.Time } contracts = append(contracts, c) } return &api.FuturesContractsData{ Underlying: underlying, Count: len(contracts), Items: contracts, }, rows.Err() } // GetTradingDates 获取交易日历 func (r *FuturesRepository) GetTradingDates(ctx context.Context, start, end string) (*api.TradingDatesData, error) { query := ` SELECT trade_date FROM futures.trading_calendar WHERE trade_date >= $1 AND trade_date <= $2 AND is_trading_day = true ORDER BY trade_date ASC ` rows, err := r.db.QueryContext(ctx, query, start, end) if err != nil { return nil, err } defer rows.Close() var dates []string for rows.Next() { var date string if err := rows.Scan(&date); err != nil { return nil, err } dates = append(dates, date) } // 计算总天数 startDate, _ := time.Parse("20060102", start) endDate, _ := time.Parse("20060102", end) totalDays := int(endDate.Sub(startDate).Hours()/24) + 1 return &api.TradingDatesData{ Start: start, End: end, TotalDays: totalDays, TradingDays: len(dates), TradingDates: dates, }, rows.Err() } // SaveSymbols 保存标的列表 func (r *FuturesRepository) SaveSymbols(ctx context.Context, symbols []api.Symbol) error { if len(symbols) == 0 { return nil } tx, err := r.db.BeginTx(ctx, nil) if err != nil { return err } defer tx.Rollback() stmt, err := tx.PrepareContext(ctx, ` INSERT INTO futures.symbols (symbol_id, symbol_type, exchange, name, underlying, contract_month, list_date, delist_date, status) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) ON CONFLICT (symbol_id) DO UPDATE SET name = EXCLUDED.name, underlying = EXCLUDED.underlying, contract_month = EXCLUDED.contract_month, list_date = EXCLUDED.list_date, delist_date = EXCLUDED.delist_date, status = EXCLUDED.status, updated_at = NOW() `) if err != nil { return err } defer stmt.Close() for _, s := range symbols { var listDate, delistDate interface{} if s.ListDate != nil { listDate = *s.ListDate } if s.DelistDate != nil { delistDate = *s.DelistDate } _, err := stmt.ExecContext(ctx, s.SymbolID, s.SymbolType, s.Exchange, s.Name, s.Underlying, s.ContractMonth, listDate, delistDate, s.Status) if err != nil { return err } } return tx.Commit() } // SaveTradingCalendar 保存交易日历 func (r *FuturesRepository) SaveTradingCalendar(ctx context.Context, dates []api.TradeCalData) error { if len(dates) == 0 { return nil } tx, err := r.db.BeginTx(ctx, nil) if err != nil { return err } defer tx.Rollback() stmt, err := tx.PrepareContext(ctx, ` INSERT INTO futures.trading_calendar (trade_date, is_trading_day, has_night_session, week_day) VALUES ($1, $2, $3, $4) ON CONFLICT (trade_date) DO UPDATE SET is_trading_day = EXCLUDED.is_trading_day, has_night_session = EXCLUDED.has_night_session, week_day = EXCLUDED.week_day, updated_at = NOW() `) if err != nil { return err } defer stmt.Close() for _, d := range dates { _, err := stmt.ExecContext(ctx, d.Date.Format("2006-01-02"), d.IsTradingDay, d.HasNightSession, int(d.Date.Weekday())+1) if err != nil { return err } } return tx.Commit() }