145 lines
2.8 KiB
Go
145 lines
2.8 KiB
Go
package ws
|
|
|
|
import (
|
|
"encoding/json"
|
|
"log"
|
|
"time"
|
|
|
|
"github.com/gorilla/websocket"
|
|
)
|
|
|
|
const (
|
|
// 写入超时时间
|
|
writeWait = 10 * time.Second
|
|
|
|
// 读取超时时间
|
|
pongWait = 60 * time.Second
|
|
|
|
// 发送 ping 的时间间隔,必须小于 pongWait
|
|
pingPeriod = (pongWait * 9) / 10
|
|
|
|
// 最大消息大小
|
|
maxMessageSize = 4096
|
|
)
|
|
|
|
type Client struct {
|
|
hub *Hub
|
|
conn *websocket.Conn
|
|
send chan *Message
|
|
}
|
|
|
|
func NewClient(hub *Hub, conn *websocket.Conn) *Client {
|
|
return &Client{
|
|
hub: hub,
|
|
conn: conn,
|
|
send: make(chan *Message, 256),
|
|
}
|
|
}
|
|
|
|
func (c *Client) readPump() {
|
|
defer func() {
|
|
c.hub.unregister <- c
|
|
c.conn.Close()
|
|
}()
|
|
|
|
c.conn.SetReadLimit(maxMessageSize)
|
|
c.conn.SetReadDeadline(time.Now().Add(pongWait))
|
|
c.conn.SetPongHandler(func(string) error {
|
|
c.conn.SetReadDeadline(time.Now().Add(pongWait))
|
|
return nil
|
|
})
|
|
|
|
for {
|
|
_, message, err := c.conn.ReadMessage()
|
|
if err != nil {
|
|
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
|
|
log.Printf("error: %v", err)
|
|
}
|
|
break
|
|
}
|
|
|
|
var msg Message
|
|
if err := json.Unmarshal(message, &msg); err != nil {
|
|
log.Printf("error unmarshaling message: %v", err)
|
|
continue
|
|
}
|
|
|
|
// 处理消息
|
|
switch msg.Type {
|
|
case MessageTypePing:
|
|
// 回复 pong
|
|
pongMsg := &Message{Type: MessageTypePong}
|
|
c.send <- pongMsg
|
|
case MessageTypeText:
|
|
// 广播文本消息
|
|
c.hub.broadcast <- &msg
|
|
case MessageTypeCommand:
|
|
// 处理命令
|
|
c.handleCommand(&msg)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (c *Client) writePump() {
|
|
ticker := time.NewTicker(pingPeriod)
|
|
defer func() {
|
|
ticker.Stop()
|
|
c.conn.Close()
|
|
}()
|
|
|
|
for {
|
|
select {
|
|
case message, ok := <-c.send:
|
|
c.conn.SetWriteDeadline(time.Now().Add(writeWait))
|
|
if !ok {
|
|
// Hub 关闭了通道
|
|
c.conn.WriteMessage(websocket.CloseMessage, []byte{})
|
|
return
|
|
}
|
|
|
|
w, err := c.conn.NextWriter(websocket.TextMessage)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
msgBytes, err := json.Marshal(message)
|
|
if err != nil {
|
|
log.Printf("error marshaling message: %v", err)
|
|
return
|
|
}
|
|
|
|
w.Write(msgBytes)
|
|
|
|
if err := w.Close(); err != nil {
|
|
return
|
|
}
|
|
|
|
case <-ticker.C:
|
|
c.conn.SetWriteDeadline(time.Now().Add(writeWait))
|
|
pingMsg := &Message{Type: MessageTypePing}
|
|
pingBytes, err := json.Marshal(pingMsg)
|
|
if err != nil {
|
|
return
|
|
}
|
|
if err := c.conn.WriteMessage(websocket.TextMessage, pingBytes); err != nil {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (c *Client) handleCommand(msg *Message) {
|
|
// 处理命令逻辑
|
|
// 这里可以根据命令类型执行不同的操作
|
|
log.Printf("Received command: %v", msg.Data)
|
|
|
|
// 示例:回复命令执行结果
|
|
response := &Message{
|
|
Type: MessageTypeText,
|
|
Content: "Command executed successfully",
|
|
Data: msg.Data,
|
|
}
|
|
|
|
c.send <- response
|
|
}
|