package ws import ( "encoding/json" "log" "time" "github.com/gorilla/websocket" ) const ( // 写入超时时间 writeWait = 10 * time.Second // 读取超时时间 pongWait = 60 * time.Second // 发送 ping 的时间间隔,必须小于 pongWait pingPeriod = (pongWait * 9) / 10 // 最大消息大小 maxMessageSize = 4096 ) type Client struct { hub *Hub conn *websocket.Conn send chan *Message } func NewClient(hub *Hub, conn *websocket.Conn) *Client { return &Client{ hub: hub, conn: conn, send: make(chan *Message, 256), } } func (c *Client) readPump() { defer func() { c.hub.unregister <- c c.conn.Close() }() c.conn.SetReadLimit(maxMessageSize) c.conn.SetReadDeadline(time.Now().Add(pongWait)) c.conn.SetPongHandler(func(string) error { c.conn.SetReadDeadline(time.Now().Add(pongWait)) return nil }) for { _, message, err := c.conn.ReadMessage() if err != nil { if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { log.Printf("error: %v", err) } break } // 解析为通用的消息结构 var msgMap map[string]interface{} if err := json.Unmarshal(message, &msgMap); err != nil { log.Printf("error unmarshaling message: %v", err) continue } // 处理消息 c.handleMessage(msgMap) } } func (c *Client) handleMessage(msgMap map[string]interface{}) { // 检查是否是 ping 消息(保持向后兼容) if msgType, ok := msgMap["type"].(string); ok { switch msgType { case MessageTypePing: // 回复 pong pongMsg := &Message{Type: MessageTypePong} c.send <- pongMsg return case MessageTypeText: // 广播文本消息 msg := &Message{ Type: MessageTypeText, Content: msgMap["content"].(string), Data: msgMap["data"], } c.hub.broadcast <- msg return case MessageTypeCommand: // 处理命令(保持向后兼容) msg := &Message{ Type: MessageTypeCommand, Data: msgMap["data"], } c.handleCommand(msg) return } } // 处理新的消息结构 if cmd, ok := msgMap["cmd"].(string); ok { // 提取数据 data := msgMap["data"] seq := msgMap["seq"].(string) // 处理命令 c.handleNewCommand(seq, cmd, data) } } func (c *Client) writePump() { ticker := time.NewTicker(pingPeriod) defer func() { ticker.Stop() c.conn.Close() }() for { select { case message, ok := <-c.send: c.conn.SetWriteDeadline(time.Now().Add(writeWait)) if !ok { // Hub 关闭了通道 c.conn.WriteMessage(websocket.CloseMessage, []byte{}) return } w, err := c.conn.NextWriter(websocket.TextMessage) if err != nil { return } msgBytes, err := json.Marshal(message) if err != nil { log.Printf("error marshaling message: %v", err) return } w.Write(msgBytes) if err := w.Close(); err != nil { return } case <-ticker.C: c.conn.SetWriteDeadline(time.Now().Add(writeWait)) pingMsg := &Message{Type: MessageTypePing} pingBytes, err := json.Marshal(pingMsg) if err != nil { return } if err := c.conn.WriteMessage(websocket.TextMessage, pingBytes); err != nil { return } } } } func (c *Client) handleCommand(msg *Message) { // 处理命令逻辑 // 这里可以根据命令类型执行不同的操作 log.Printf("Received command: %v", msg.Data) // 检查是否是注册命令 if cmdData, ok := msg.Data.(map[string]interface{}); ok { if cmdType, ok := cmdData["type"].(string); ok && cmdType == "register" { // 提取注册信息 account, accountOk := cmdData["account"].(string) password, passwordOk := cmdData["password"].(string) if !accountOk || !passwordOk { // 回复错误信息 errorResponse := &Message{ Type: MessageTypeError, Content: "Invalid register command: missing account or password", } c.send <- errorResponse return } // 调用用户服务注册 if c.hub.userService != nil { // 异步调用用户服务注册 go func() { resp, err := c.hub.userService.Register(nil, account, password) if err != nil { // 回复错误信息 errorResponse := &Message{ Type: MessageTypeError, Content: "Register failed: " + err.Error(), } c.send <- errorResponse return } // 回复成功信息 successResponse := &Message{ Type: MessageTypeText, Content: "Register successful", Data: map[string]interface{}{ "user_id": resp.UserId, "account": resp.Account, "message": resp.Response.Message, "code": resp.Response.Code, }, } c.send <- successResponse }() } else { // 回复错误信息 errorResponse := &Message{ Type: MessageTypeError, Content: "User service not available", } c.send <- errorResponse } return } } // 其他命令处理 response := &Message{ Type: MessageTypeText, Content: "Command executed successfully", Data: msg.Data, } c.send <- response } func (c *Client) handleNewCommand(seq string, cmd string, data interface{}) { // 处理新的命令结构 log.Printf("Received new command: %s, seq: %s, data: %v", cmd, seq, data) // 根据 cmd 字段处理不同的命令 switch cmd { case "user.register": // 提取注册信息 if registerData, ok := data.(map[string]interface{}); ok { account, accountOk := registerData["account"].(string) password, passwordOk := registerData["password"].(string) if !accountOk || !passwordOk { // 回复错误信息 errorResponse := map[string]interface{}{ "seq": seq, "cmd": cmd, "type": "error", "content": "Invalid register command: missing account or password", "timestamp": time.Now().UnixMilli(), } c.sendJSON(errorResponse) return } // 调用用户服务注册 if c.hub.userService != nil { // 异步调用用户服务注册 go func() { resp, err := c.hub.userService.Register(nil, account, password) if err != nil { // 回复错误信息 errorResponse := map[string]interface{}{ "seq": seq, "cmd": cmd, "type": "error", "content": "Register failed: " + err.Error(), "timestamp": time.Now().UnixMilli(), } c.sendJSON(errorResponse) return } // 回复成功信息 successResponse := map[string]interface{}{ "seq": seq, "cmd": cmd, "type": "text", "content": "Register successful", "data": map[string]interface{}{ "user_id": resp.UserId, "account": resp.Account, "message": resp.Response.Message, "code": resp.Response.Code, }, "timestamp": time.Now().UnixMilli(), } c.sendJSON(successResponse) }() } else { // 回复错误信息 errorResponse := map[string]interface{}{ "seq": seq, "cmd": cmd, "type": "error", "content": "User service not available", "timestamp": time.Now().UnixMilli(), } c.sendJSON(errorResponse) } } return default: // 其他命令处理 response := map[string]interface{}{ "seq": seq, "cmd": cmd, "type": "text", "content": "Command executed successfully", "data": data, "timestamp": time.Now().UnixMilli(), } c.sendJSON(response) } } func (c *Client) sendJSON(data interface{}) { // 将数据转换为 JSON 并发送 msgBytes, err := json.Marshal(data) if err != nil { log.Printf("error marshaling message: %v", err) return } // 写入 WebSocket 连接 c.conn.SetWriteDeadline(time.Now().Add(writeWait)) if err := c.conn.WriteMessage(websocket.TextMessage, msgBytes); err != nil { log.Printf("error writing message: %v", err) } }