package websocket import ( "context" "encoding/json" "log" "net/http" "sync" "time" "github.com/gin-gonic/gin" "github.com/gorilla/websocket" "market-data-service/api" ) var upgrader = websocket.Upgrader{ CheckOrigin: func(r *http.Request) bool { return true // 允许所有来源,生产环境需要限制 }, ReadBufferSize: 1024, WriteBufferSize: 1024, } // Hub WebSocket连接管理中心 type Hub struct { // 已注册的客户端 clients map[*Client]bool // 广播消息通道 broadcast chan []byte // 注册请求通道 register chan *Client // 注销请求通道 unregister chan *Client // 标的订阅映射: symbol -> clients subscriptions map[string]map[*Client]bool // 保护subscriptions的锁 subMu sync.RWMutex // 最大订阅标的数 maxSymbolsPerClient int } // NewHub 创建Hub func NewHub() *Hub { return &Hub{ clients: make(map[*Client]bool), broadcast: make(chan []byte), register: make(chan *Client), unregister: make(chan *Client), subscriptions: make(map[string]map[*Client]bool), maxSymbolsPerClient: 100, } } // Run 启动Hub func (h *Hub) Run() { for { select { case client := <-h.register: h.clients[client] = true log.Printf("Client registered, total: %d", len(h.clients)) case client := <-h.unregister: if _, ok := h.clients[client]; ok { delete(h.clients, client) close(client.send) // 清理订阅 h.removeAllSubscriptions(client) log.Printf("Client unregistered, total: %d", len(h.clients)) } case message := <-h.broadcast: for client := range h.clients { select { case client.send <- message: default: // 发送缓冲满,关闭连接 close(client.send) delete(h.clients, client) } } } } } // Subscribe 客户端订阅标的 func (h *Hub) Subscribe(client *Client, symbols []string) error { if len(client.subscriptions)+len(symbols) > h.maxSymbolsPerClient { return api.ErrRateLimit } h.subMu.Lock() defer h.subMu.Unlock() for _, symbol := range symbols { if _, ok := h.subscriptions[symbol]; !ok { h.subscriptions[symbol] = make(map[*Client]bool) } h.subscriptions[symbol][client] = true client.subscriptions[symbol] = true } return nil } // Unsubscribe 客户端取消订阅 func (h *Hub) Unsubscribe(client *Client, symbols []string) { h.subMu.Lock() defer h.subMu.Unlock() for _, symbol := range symbols { if clients, ok := h.subscriptions[symbol]; ok { delete(clients, client) if len(clients) == 0 { delete(h.subscriptions, symbol) } } delete(client.subscriptions, symbol) } } // removeAllSubscriptions 移除客户端所有订阅 func (h *Hub) removeAllSubscriptions(client *Client) { h.subMu.Lock() defer h.subMu.Unlock() for symbol := range client.subscriptions { if clients, ok := h.subscriptions[symbol]; ok { delete(clients, client) if len(clients) == 0 { delete(h.subscriptions, symbol) } } } } // BroadcastToSymbol 向订阅了某标的的所有客户端广播 func (h *Hub) BroadcastToSymbol(symbol string, data []byte) { h.subMu.RLock() clients := h.subscriptions[symbol] h.subMu.RUnlock() for client := range clients { select { case client.send <- data: default: // 发送缓冲满,稍后处理 } } } // GetSubscriptionStats 获取订阅统计 func (h *Hub) GetSubscriptionStats() map[string]interface{} { h.subMu.RLock() defer h.subMu.RUnlock() return map[string]interface{}{ "total_clients": len(h.clients), "total_subscriptions": len(h.subscriptions), } } // Client WebSocket客户端 type Client struct { hub *Hub conn *websocket.Conn send chan []byte // 已订阅的标的 subscriptions map[string]bool subMu sync.RWMutex } // NewClient 创建客户端 func NewClient(hub *Hub, conn *websocket.Conn) *Client { return &Client{ hub: hub, conn: conn, send: make(chan []byte, 256), subscriptions: make(map[string]bool), } } // ReadPump 读取客户端消息 func (c *Client) ReadPump() { defer func() { c.hub.unregister <- c c.conn.Close() }() c.conn.SetReadLimit(512 * 1024) // 512KB c.conn.SetReadDeadline(time.Now().Add(60 * time.Second)) c.conn.SetPongHandler(func(string) error { c.conn.SetReadDeadline(time.Now().Add(60 * time.Second)) return nil }) for { _, message, err := c.conn.ReadMessage() if err != nil { if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { log.Printf("WebSocket error: %v", err) } break } // 处理客户端消息 c.handleMessage(message) } } // WritePump 向客户端写入消息 func (c *Client) WritePump() { ticker := time.NewTicker(30 * time.Second) defer func() { ticker.Stop() c.conn.Close() }() for { select { case message, ok := <-c.send: c.conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) if !ok { c.conn.WriteMessage(websocket.CloseMessage, []byte{}) return } c.conn.WriteMessage(websocket.TextMessage, message) case <-ticker.C: c.conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil { return } } } } // handleMessage 处理客户端消息 func (c *Client) handleMessage(data []byte) { var msg ClientMessage if err := json.Unmarshal(data, &msg); err != nil { c.sendError(1000, "Invalid message format") return } switch msg.Action { case "subscribe": c.handleSubscribe(msg.Symbols) case "unsubscribe": c.handleUnsubscribe(msg.Symbols) default: c.sendError(1001, "Unknown action") } } // handleSubscribe 处理订阅请求 func (c *Client) handleSubscribe(symbols []string) { if len(symbols) == 0 { c.sendError(1002, "Symbols cannot be empty") return } if err := c.hub.Subscribe(c, symbols); err != nil { c.sendError(1003, err.Error()) return } // 发送确认 ack := map[string]interface{}{ "type": "ack", "action": "subscribe", "symbols": symbols, "ts": time.Now().Format(time.RFC3339), } data, _ := json.Marshal(ack) c.send <- data } // handleUnsubscribe 处理取消订阅请求 func (c *Client) handleUnsubscribe(symbols []string) { c.hub.Unsubscribe(c, symbols) ack := map[string]interface{}{ "type": "ack", "action": "unsubscribe", "symbols": symbols, "ts": time.Now().Format(time.RFC3339), } data, _ := json.Marshal(ack) c.send <- data } // sendError 发送错误消息 func (c *Client) sendError(code int, message string) { err := map[string]interface{}{ "type": "error", "code": code, "message": message, "ts": time.Now().Format(time.RFC3339), } data, _ := json.Marshal(err) c.send <- data } // ClientMessage 客户端消息结构 type ClientMessage struct { Action string `json:"action"` Symbols []string `json:"symbols"` } // Server WebSocket服务器 type Server struct { hub *Hub } // NewServer 创建WebSocket服务器 func NewServer(hub *Hub) *Server { return &Server{hub: hub} } // HandleWebSocket 处理WebSocket连接 func (s *Server) HandleWebSocket(c *gin.Context) { // 认证检查 apiKey := c.GetHeader("X-API-Key") if apiKey == "" { c.JSON(http.StatusUnauthorized, gin.H{"error": "Missing API Key"}) return } conn, err := upgrader.Upgrade(c.Writer, c.Request, nil) if err != nil { log.Printf("WebSocket upgrade failed: %v", err) return } client := NewClient(s.hub, conn) s.hub.register <- client go client.WritePump() go client.ReadPump() } // BroadcastTick 广播Tick数据 func (s *Server) BroadcastTick(symbol string, tick map[string]interface{}) { data, err := json.Marshal(tick) if err != nil { return } s.hub.BroadcastToSymbol(symbol, data) } // BroadcastKLine 广播K线闭合数据 func (s *Server) BroadcastKLine(symbol string, freq string, kline map[string]interface{}) { msg := map[string]interface{}{ "type": "klines", "symbol": symbol, "freq": freq, "data": kline, "ts": time.Now().Format(time.RFC3339), } data, err := json.Marshal(msg) if err != nil { return } s.hub.BroadcastToSymbol(symbol, data) }