You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

375 lines
8.0 KiB

package websocket
import (
"context"
"encoding/json"
"log"
"net/http"
"sync"
"time"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"market-data-service/api"
)
var upgrader = websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
return true // 允许所有来源,生产环境需要限制
},
ReadBufferSize: 1024,
WriteBufferSize: 1024,
}
// Hub WebSocket连接管理中心
type Hub struct {
// 已注册的客户端
clients map[*Client]bool
// 广播消息通道
broadcast chan []byte
// 注册请求通道
register chan *Client
// 注销请求通道
unregister chan *Client
// 标的订阅映射: symbol -> clients
subscriptions map[string]map[*Client]bool
// 保护subscriptions的锁
subMu sync.RWMutex
// 最大订阅标的数
maxSymbolsPerClient int
}
// NewHub 创建Hub
func NewHub() *Hub {
return &Hub{
clients: make(map[*Client]bool),
broadcast: make(chan []byte),
register: make(chan *Client),
unregister: make(chan *Client),
subscriptions: make(map[string]map[*Client]bool),
maxSymbolsPerClient: 100,
}
}
// Run 启动Hub
func (h *Hub) Run() {
for {
select {
case client := <-h.register:
h.clients[client] = true
log.Printf("Client registered, total: %d", len(h.clients))
case client := <-h.unregister:
if _, ok := h.clients[client]; ok {
delete(h.clients, client)
close(client.send)
// 清理订阅
h.removeAllSubscriptions(client)
log.Printf("Client unregistered, total: %d", len(h.clients))
}
case message := <-h.broadcast:
for client := range h.clients {
select {
case client.send <- message:
default:
// 发送缓冲满,关闭连接
close(client.send)
delete(h.clients, client)
}
}
}
}
}
// Subscribe 客户端订阅标的
func (h *Hub) Subscribe(client *Client, symbols []string) error {
if len(client.subscriptions)+len(symbols) > h.maxSymbolsPerClient {
return api.ErrRateLimit
}
h.subMu.Lock()
defer h.subMu.Unlock()
for _, symbol := range symbols {
if _, ok := h.subscriptions[symbol]; !ok {
h.subscriptions[symbol] = make(map[*Client]bool)
}
h.subscriptions[symbol][client] = true
client.subscriptions[symbol] = true
}
return nil
}
// Unsubscribe 客户端取消订阅
func (h *Hub) Unsubscribe(client *Client, symbols []string) {
h.subMu.Lock()
defer h.subMu.Unlock()
for _, symbol := range symbols {
if clients, ok := h.subscriptions[symbol]; ok {
delete(clients, client)
if len(clients) == 0 {
delete(h.subscriptions, symbol)
}
}
delete(client.subscriptions, symbol)
}
}
// removeAllSubscriptions 移除客户端所有订阅
func (h *Hub) removeAllSubscriptions(client *Client) {
h.subMu.Lock()
defer h.subMu.Unlock()
for symbol := range client.subscriptions {
if clients, ok := h.subscriptions[symbol]; ok {
delete(clients, client)
if len(clients) == 0 {
delete(h.subscriptions, symbol)
}
}
}
}
// BroadcastToSymbol 向订阅了某标的的所有客户端广播
func (h *Hub) BroadcastToSymbol(symbol string, data []byte) {
h.subMu.RLock()
clients := h.subscriptions[symbol]
h.subMu.RUnlock()
for client := range clients {
select {
case client.send <- data:
default:
// 发送缓冲满,稍后处理
}
}
}
// GetSubscriptionStats 获取订阅统计
func (h *Hub) GetSubscriptionStats() map[string]interface{} {
h.subMu.RLock()
defer h.subMu.RUnlock()
return map[string]interface{}{
"total_clients": len(h.clients),
"total_subscriptions": len(h.subscriptions),
}
}
// Client WebSocket客户端
type Client struct {
hub *Hub
conn *websocket.Conn
send chan []byte
// 已订阅的标的
subscriptions map[string]bool
subMu sync.RWMutex
}
// NewClient 创建客户端
func NewClient(hub *Hub, conn *websocket.Conn) *Client {
return &Client{
hub: hub,
conn: conn,
send: make(chan []byte, 256),
subscriptions: make(map[string]bool),
}
}
// ReadPump 读取客户端消息
func (c *Client) ReadPump() {
defer func() {
c.hub.unregister <- c
c.conn.Close()
}()
c.conn.SetReadLimit(512 * 1024) // 512KB
c.conn.SetReadDeadline(time.Now().Add(60 * time.Second))
c.conn.SetPongHandler(func(string) error {
c.conn.SetReadDeadline(time.Now().Add(60 * time.Second))
return nil
})
for {
_, message, err := c.conn.ReadMessage()
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
log.Printf("WebSocket error: %v", err)
}
break
}
// 处理客户端消息
c.handleMessage(message)
}
}
// WritePump 向客户端写入消息
func (c *Client) WritePump() {
ticker := time.NewTicker(30 * time.Second)
defer func() {
ticker.Stop()
c.conn.Close()
}()
for {
select {
case message, ok := <-c.send:
c.conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
if !ok {
c.conn.WriteMessage(websocket.CloseMessage, []byte{})
return
}
c.conn.WriteMessage(websocket.TextMessage, message)
case <-ticker.C:
c.conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil {
return
}
}
}
}
// handleMessage 处理客户端消息
func (c *Client) handleMessage(data []byte) {
var msg ClientMessage
if err := json.Unmarshal(data, &msg); err != nil {
c.sendError(1000, "Invalid message format")
return
}
switch msg.Action {
case "subscribe":
c.handleSubscribe(msg.Symbols)
case "unsubscribe":
c.handleUnsubscribe(msg.Symbols)
default:
c.sendError(1001, "Unknown action")
}
}
// handleSubscribe 处理订阅请求
func (c *Client) handleSubscribe(symbols []string) {
if len(symbols) == 0 {
c.sendError(1002, "Symbols cannot be empty")
return
}
if err := c.hub.Subscribe(c, symbols); err != nil {
c.sendError(1003, err.Error())
return
}
// 发送确认
ack := map[string]interface{}{
"type": "ack",
"action": "subscribe",
"symbols": symbols,
"ts": time.Now().Format(time.RFC3339),
}
data, _ := json.Marshal(ack)
c.send <- data
}
// handleUnsubscribe 处理取消订阅请求
func (c *Client) handleUnsubscribe(symbols []string) {
c.hub.Unsubscribe(c, symbols)
ack := map[string]interface{}{
"type": "ack",
"action": "unsubscribe",
"symbols": symbols,
"ts": time.Now().Format(time.RFC3339),
}
data, _ := json.Marshal(ack)
c.send <- data
}
// sendError 发送错误消息
func (c *Client) sendError(code int, message string) {
err := map[string]interface{}{
"type": "error",
"code": code,
"message": message,
"ts": time.Now().Format(time.RFC3339),
}
data, _ := json.Marshal(err)
c.send <- data
}
// ClientMessage 客户端消息结构
type ClientMessage struct {
Action string `json:"action"`
Symbols []string `json:"symbols"`
}
// Server WebSocket服务器
type Server struct {
hub *Hub
}
// NewServer 创建WebSocket服务器
func NewServer(hub *Hub) *Server {
return &Server{hub: hub}
}
// HandleWebSocket 处理WebSocket连接
func (s *Server) HandleWebSocket(c *gin.Context) {
// 认证检查
apiKey := c.GetHeader("X-API-Key")
if apiKey == "" {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Missing API Key"})
return
}
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
if err != nil {
log.Printf("WebSocket upgrade failed: %v", err)
return
}
client := NewClient(s.hub, conn)
s.hub.register <- client
go client.WritePump()
go client.ReadPump()
}
// BroadcastTick 广播Tick数据
func (s *Server) BroadcastTick(symbol string, tick map[string]interface{}) {
data, err := json.Marshal(tick)
if err != nil {
return
}
s.hub.BroadcastToSymbol(symbol, data)
}
// BroadcastKLine 广播K线闭合数据
func (s *Server) BroadcastKLine(symbol string, freq string, kline map[string]interface{}) {
msg := map[string]interface{}{
"type": "klines",
"symbol": symbol,
"freq": freq,
"data": kline,
"ts": time.Now().Format(time.RFC3339),
}
data, err := json.Marshal(msg)
if err != nil {
return
}
s.hub.BroadcastToSymbol(symbol, data)
}