342 lines
7.8 KiB
Go
342 lines
7.8 KiB
Go
package ws
|
|
|
|
import (
|
|
"encoding/json"
|
|
"log"
|
|
"time"
|
|
|
|
"github.com/gorilla/websocket"
|
|
)
|
|
|
|
const (
|
|
// 写入超时时间
|
|
writeWait = 10 * time.Second
|
|
|
|
// 读取超时时间
|
|
pongWait = 60 * time.Second
|
|
|
|
// 发送 ping 的时间间隔,必须小于 pongWait
|
|
pingPeriod = (pongWait * 9) / 10
|
|
|
|
// 最大消息大小
|
|
maxMessageSize = 4096
|
|
)
|
|
|
|
type Client struct {
|
|
Hub *Hub
|
|
Conn *websocket.Conn
|
|
Send chan *Message
|
|
}
|
|
|
|
func NewClient(hub *Hub, conn *websocket.Conn) *Client {
|
|
return &Client{
|
|
Hub: hub,
|
|
Conn: conn,
|
|
Send: make(chan *Message, 256),
|
|
}
|
|
}
|
|
|
|
// SendWs 发送 WebSocket 消息,自动生成 seq、cmd、timestamp
|
|
func (c *Client) SendWs(cmd string, data interface{}) error {
|
|
// 生成唯一请求ID
|
|
seq := "req_" + time.Now().Format("20060102150405") + "_" + generateRandomString(8)
|
|
|
|
// 构建消息
|
|
message := map[string]interface{}{
|
|
"seq": seq,
|
|
"cmd": cmd,
|
|
"data": data,
|
|
"timestamp": time.Now().UnixMilli(),
|
|
}
|
|
|
|
// 将消息转换为 JSON
|
|
msgBytes, err := json.Marshal(message)
|
|
if err != nil {
|
|
log.Printf("error marshaling message: %v", err)
|
|
return err
|
|
}
|
|
|
|
// 写入 WebSocket 连接
|
|
c.Conn.SetWriteDeadline(time.Now().Add(writeWait))
|
|
if err := c.Conn.WriteMessage(websocket.TextMessage, msgBytes); err != nil {
|
|
log.Printf("error writing message: %v", err)
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// generateRandomString 生成随机字符串
|
|
func generateRandomString(length int) string {
|
|
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
|
result := make([]byte, length)
|
|
for i := range result {
|
|
result[i] = charset[time.Now().UnixNano()%int64(len(charset))]
|
|
time.Sleep(1 * time.Nanosecond)
|
|
}
|
|
return string(result)
|
|
}
|
|
|
|
func (c *Client) ReadPump() {
|
|
defer func() {
|
|
c.Hub.Unregister <- c
|
|
c.Conn.Close()
|
|
}()
|
|
|
|
c.Conn.SetReadLimit(maxMessageSize)
|
|
c.Conn.SetReadDeadline(time.Now().Add(pongWait))
|
|
c.Conn.SetPongHandler(func(string) error {
|
|
c.Conn.SetReadDeadline(time.Now().Add(pongWait))
|
|
return nil
|
|
})
|
|
|
|
for {
|
|
_, message, err := c.Conn.ReadMessage()
|
|
if err != nil {
|
|
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
|
|
log.Printf("error: %v", err)
|
|
}
|
|
break
|
|
}
|
|
|
|
// 解析为通用的消息结构
|
|
var msgMap map[string]interface{}
|
|
if err := json.Unmarshal(message, &msgMap); err != nil {
|
|
log.Printf("error unmarshaling message: %v", err)
|
|
continue
|
|
}
|
|
|
|
// 处理消息
|
|
c.HandleMessage(msgMap)
|
|
}
|
|
}
|
|
|
|
func (c *Client) HandleMessage(msgMap map[string]interface{}) {
|
|
// 检查是否是 ping 消息(保持向后兼容)
|
|
if msgTypeStr, ok := msgMap["type"].(string); ok {
|
|
msgType := MessageType(msgTypeStr)
|
|
switch msgType {
|
|
case MessageTypePing:
|
|
// 回复 pong
|
|
pongMsg := &Message{Type: MessageTypePong}
|
|
c.Send <- pongMsg
|
|
return
|
|
case MessageTypeText:
|
|
// 广播文本消息
|
|
msg := &Message{
|
|
Type: MessageTypeText,
|
|
Content: msgMap["content"].(string),
|
|
Data: msgMap["data"],
|
|
}
|
|
c.Hub.Broadcast <- msg
|
|
return
|
|
case MessageTypeCommand:
|
|
// 处理命令(保持向后兼容)
|
|
msg := &Message{
|
|
Type: MessageTypeCommand,
|
|
Data: msgMap["data"],
|
|
}
|
|
c.HandleCommand(msg)
|
|
return
|
|
}
|
|
}
|
|
|
|
// 处理新的消息结构
|
|
if cmd, ok := msgMap["cmd"].(string); ok {
|
|
// 提取数据
|
|
data := msgMap["data"]
|
|
seq := msgMap["seq"].(string)
|
|
|
|
// 处理命令
|
|
c.HandleNewCommand(seq, cmd, data)
|
|
}
|
|
}
|
|
|
|
func (c *Client) WritePump() {
|
|
ticker := time.NewTicker(pingPeriod)
|
|
defer func() {
|
|
ticker.Stop()
|
|
c.Conn.Close()
|
|
}()
|
|
|
|
for {
|
|
select {
|
|
case message, ok := <-c.Send:
|
|
c.Conn.SetWriteDeadline(time.Now().Add(writeWait))
|
|
if !ok {
|
|
// Hub 关闭了通道
|
|
c.Conn.WriteMessage(websocket.CloseMessage, []byte{})
|
|
return
|
|
}
|
|
|
|
w, err := c.Conn.NextWriter(websocket.TextMessage)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
msgBytes, err := json.Marshal(message)
|
|
if err != nil {
|
|
log.Printf("error marshaling message: %v", err)
|
|
return
|
|
}
|
|
|
|
w.Write(msgBytes)
|
|
|
|
if err := w.Close(); err != nil {
|
|
return
|
|
}
|
|
|
|
case <-ticker.C:
|
|
c.Conn.SetWriteDeadline(time.Now().Add(writeWait))
|
|
pingMsg := &Message{Type: MessageTypePing}
|
|
pingBytes, err := json.Marshal(pingMsg)
|
|
if err != nil {
|
|
return
|
|
}
|
|
if err := c.Conn.WriteMessage(websocket.TextMessage, pingBytes); err != nil {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (c *Client) HandleCommand(msg *Message) {
|
|
// 处理命令逻辑
|
|
// 这里可以根据命令类型执行不同的操作
|
|
log.Printf("Received command: %v", msg.Data)
|
|
|
|
// 检查是否是注册命令
|
|
if cmdData, ok := msg.Data.(map[string]interface{}); ok {
|
|
if cmdType, ok := cmdData["type"].(string); ok && cmdType == "register" {
|
|
// 提取注册信息
|
|
account, accountOk := cmdData["account"].(string)
|
|
password, passwordOk := cmdData["password"].(string)
|
|
|
|
if !accountOk || !passwordOk {
|
|
// 回复错误信息
|
|
errorResponse := &Message{
|
|
Type: MessageTypeError,
|
|
Content: "Invalid register command: missing account or password",
|
|
}
|
|
c.Send <- errorResponse
|
|
return
|
|
}
|
|
|
|
// 调用用户服务注册
|
|
if c.Hub.UserService != nil {
|
|
// 异步调用用户服务注册
|
|
go func() {
|
|
resp, err := c.Hub.UserService.Register(nil, account, password)
|
|
if err != nil {
|
|
// 回复错误信息
|
|
errorResponse := &Message{
|
|
Type: MessageTypeError,
|
|
Content: "Register failed: " + err.Error(),
|
|
}
|
|
c.Send <- errorResponse
|
|
return
|
|
}
|
|
|
|
// 回复成功信息
|
|
successResponse := &Message{
|
|
Type: MessageTypeText,
|
|
Content: "Register successful",
|
|
Data: map[string]interface{}{
|
|
"user_id": resp.UserId,
|
|
"account": resp.Account,
|
|
"message": resp.Response.Message,
|
|
"code": resp.Response.Code,
|
|
},
|
|
}
|
|
c.Send <- successResponse
|
|
}()
|
|
} else {
|
|
// 回复错误信息
|
|
errorResponse := &Message{
|
|
Type: MessageTypeError,
|
|
Content: "User service not available",
|
|
}
|
|
c.Send <- errorResponse
|
|
}
|
|
return
|
|
}
|
|
}
|
|
|
|
// 其他命令处理
|
|
response := &Message{
|
|
Type: MessageTypeText,
|
|
Content: "Command executed successfully",
|
|
Data: msg.Data,
|
|
}
|
|
|
|
c.Send <- response
|
|
}
|
|
|
|
func (c *Client) HandleNewCommand(seq string, cmd string, data interface{}) {
|
|
// 处理新的命令结构
|
|
log.Printf("Received new command: %s, seq: %s, data: %v", cmd, seq, data)
|
|
|
|
// 根据 cmd 字段处理不同的命令
|
|
switch cmd {
|
|
case "user.register":
|
|
// 提取注册信息
|
|
if registerData, ok := data.(map[string]interface{}); ok {
|
|
account, accountOk := registerData["account"].(string)
|
|
password, passwordOk := registerData["password"].(string)
|
|
|
|
if !accountOk || !passwordOk {
|
|
// 回复错误信息
|
|
c.SendWs(cmd, map[string]interface{}{
|
|
"type": "error",
|
|
"content": "Invalid register command: missing account or password",
|
|
})
|
|
return
|
|
}
|
|
|
|
// 调用用户服务注册
|
|
if c.Hub.UserService != nil {
|
|
// 异步调用用户服务注册
|
|
go func() {
|
|
resp, err := c.Hub.UserService.Register(nil, account, password)
|
|
if err != nil {
|
|
// 回复错误信息
|
|
c.SendWs(cmd, map[string]interface{}{
|
|
"type": "error",
|
|
"content": "Register failed: " + err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
// 回复成功信息
|
|
c.SendWs(cmd, map[string]interface{}{
|
|
"type": "text",
|
|
"content": "Register successful",
|
|
"data": map[string]interface{}{
|
|
"user_id": resp.UserId,
|
|
"account": resp.Account,
|
|
"message": resp.Response.Message,
|
|
"code": resp.Response.Code,
|
|
},
|
|
})
|
|
}()
|
|
} else {
|
|
// 回复错误信息
|
|
c.SendWs(cmd, map[string]interface{}{
|
|
"type": "error",
|
|
"content": "User service not available",
|
|
})
|
|
}
|
|
}
|
|
return
|
|
default:
|
|
// 其他命令处理
|
|
c.SendWs(cmd, map[string]interface{}{
|
|
"type": "text",
|
|
"content": "Command executed successfully",
|
|
"data": data,
|
|
})
|
|
}
|
|
}
|
|
|
|
|