Files

342 lines
7.8 KiB
Go

package ws
import (
"encoding/json"
"log"
"time"
"github.com/gorilla/websocket"
)
const (
// 写入超时时间
writeWait = 10 * time.Second
// 读取超时时间
pongWait = 60 * time.Second
// 发送 ping 的时间间隔,必须小于 pongWait
pingPeriod = (pongWait * 9) / 10
// 最大消息大小
maxMessageSize = 4096
)
type Client struct {
Hub *Hub
Conn *websocket.Conn
Send chan *Message
}
func NewClient(hub *Hub, conn *websocket.Conn) *Client {
return &Client{
Hub: hub,
Conn: conn,
Send: make(chan *Message, 256),
}
}
// SendWs 发送 WebSocket 消息,自动生成 seq、cmd、timestamp
func (c *Client) SendWs(cmd string, data interface{}) error {
// 生成唯一请求ID
seq := "req_" + time.Now().Format("20060102150405") + "_" + generateRandomString(8)
// 构建消息
message := map[string]interface{}{
"seq": seq,
"cmd": cmd,
"data": data,
"timestamp": time.Now().UnixMilli(),
}
// 将消息转换为 JSON
msgBytes, err := json.Marshal(message)
if err != nil {
log.Printf("error marshaling message: %v", err)
return err
}
// 写入 WebSocket 连接
c.Conn.SetWriteDeadline(time.Now().Add(writeWait))
if err := c.Conn.WriteMessage(websocket.TextMessage, msgBytes); err != nil {
log.Printf("error writing message: %v", err)
return err
}
return nil
}
// generateRandomString 生成随机字符串
func generateRandomString(length int) string {
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
result := make([]byte, length)
for i := range result {
result[i] = charset[time.Now().UnixNano()%int64(len(charset))]
time.Sleep(1 * time.Nanosecond)
}
return string(result)
}
func (c *Client) ReadPump() {
defer func() {
c.Hub.Unregister <- c
c.Conn.Close()
}()
c.Conn.SetReadLimit(maxMessageSize)
c.Conn.SetReadDeadline(time.Now().Add(pongWait))
c.Conn.SetPongHandler(func(string) error {
c.Conn.SetReadDeadline(time.Now().Add(pongWait))
return nil
})
for {
_, message, err := c.Conn.ReadMessage()
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
log.Printf("error: %v", err)
}
break
}
// 解析为通用的消息结构
var msgMap map[string]interface{}
if err := json.Unmarshal(message, &msgMap); err != nil {
log.Printf("error unmarshaling message: %v", err)
continue
}
// 处理消息
c.HandleMessage(msgMap)
}
}
func (c *Client) HandleMessage(msgMap map[string]interface{}) {
// 检查是否是 ping 消息(保持向后兼容)
if msgTypeStr, ok := msgMap["type"].(string); ok {
msgType := MessageType(msgTypeStr)
switch msgType {
case MessageTypePing:
// 回复 pong
pongMsg := &Message{Type: MessageTypePong}
c.Send <- pongMsg
return
case MessageTypeText:
// 广播文本消息
msg := &Message{
Type: MessageTypeText,
Content: msgMap["content"].(string),
Data: msgMap["data"],
}
c.Hub.Broadcast <- msg
return
case MessageTypeCommand:
// 处理命令(保持向后兼容)
msg := &Message{
Type: MessageTypeCommand,
Data: msgMap["data"],
}
c.HandleCommand(msg)
return
}
}
// 处理新的消息结构
if cmd, ok := msgMap["cmd"].(string); ok {
// 提取数据
data := msgMap["data"]
seq := msgMap["seq"].(string)
// 处理命令
c.HandleNewCommand(seq, cmd, data)
}
}
func (c *Client) WritePump() {
ticker := time.NewTicker(pingPeriod)
defer func() {
ticker.Stop()
c.Conn.Close()
}()
for {
select {
case message, ok := <-c.Send:
c.Conn.SetWriteDeadline(time.Now().Add(writeWait))
if !ok {
// Hub 关闭了通道
c.Conn.WriteMessage(websocket.CloseMessage, []byte{})
return
}
w, err := c.Conn.NextWriter(websocket.TextMessage)
if err != nil {
return
}
msgBytes, err := json.Marshal(message)
if err != nil {
log.Printf("error marshaling message: %v", err)
return
}
w.Write(msgBytes)
if err := w.Close(); err != nil {
return
}
case <-ticker.C:
c.Conn.SetWriteDeadline(time.Now().Add(writeWait))
pingMsg := &Message{Type: MessageTypePing}
pingBytes, err := json.Marshal(pingMsg)
if err != nil {
return
}
if err := c.Conn.WriteMessage(websocket.TextMessage, pingBytes); err != nil {
return
}
}
}
}
func (c *Client) HandleCommand(msg *Message) {
// 处理命令逻辑
// 这里可以根据命令类型执行不同的操作
log.Printf("Received command: %v", msg.Data)
// 检查是否是注册命令
if cmdData, ok := msg.Data.(map[string]interface{}); ok {
if cmdType, ok := cmdData["type"].(string); ok && cmdType == "register" {
// 提取注册信息
account, accountOk := cmdData["account"].(string)
password, passwordOk := cmdData["password"].(string)
if !accountOk || !passwordOk {
// 回复错误信息
errorResponse := &Message{
Type: MessageTypeError,
Content: "Invalid register command: missing account or password",
}
c.Send <- errorResponse
return
}
// 调用用户服务注册
if c.Hub.UserService != nil {
// 异步调用用户服务注册
go func() {
resp, err := c.Hub.UserService.Register(nil, account, password)
if err != nil {
// 回复错误信息
errorResponse := &Message{
Type: MessageTypeError,
Content: "Register failed: " + err.Error(),
}
c.Send <- errorResponse
return
}
// 回复成功信息
successResponse := &Message{
Type: MessageTypeText,
Content: "Register successful",
Data: map[string]interface{}{
"user_id": resp.UserId,
"account": resp.Account,
"message": resp.Response.Message,
"code": resp.Response.Code,
},
}
c.Send <- successResponse
}()
} else {
// 回复错误信息
errorResponse := &Message{
Type: MessageTypeError,
Content: "User service not available",
}
c.Send <- errorResponse
}
return
}
}
// 其他命令处理
response := &Message{
Type: MessageTypeText,
Content: "Command executed successfully",
Data: msg.Data,
}
c.Send <- response
}
func (c *Client) HandleNewCommand(seq string, cmd string, data interface{}) {
// 处理新的命令结构
log.Printf("Received new command: %s, seq: %s, data: %v", cmd, seq, data)
// 根据 cmd 字段处理不同的命令
switch cmd {
case "user.register":
// 提取注册信息
if registerData, ok := data.(map[string]interface{}); ok {
account, accountOk := registerData["account"].(string)
password, passwordOk := registerData["password"].(string)
if !accountOk || !passwordOk {
// 回复错误信息
c.SendWs(cmd, map[string]interface{}{
"type": "error",
"content": "Invalid register command: missing account or password",
})
return
}
// 调用用户服务注册
if c.Hub.UserService != nil {
// 异步调用用户服务注册
go func() {
resp, err := c.Hub.UserService.Register(nil, account, password)
if err != nil {
// 回复错误信息
c.SendWs(cmd, map[string]interface{}{
"type": "error",
"content": "Register failed: " + err.Error(),
})
return
}
// 回复成功信息
c.SendWs(cmd, map[string]interface{}{
"type": "text",
"content": "Register successful",
"data": map[string]interface{}{
"user_id": resp.UserId,
"account": resp.Account,
"message": resp.Response.Message,
"code": resp.Response.Code,
},
})
}()
} else {
// 回复错误信息
c.SendWs(cmd, map[string]interface{}{
"type": "error",
"content": "User service not available",
})
}
}
return
default:
// 其他命令处理
c.SendWs(cmd, map[string]interface{}{
"type": "text",
"content": "Command executed successfully",
"data": data,
})
}
}