feat: 实现网关服务的长连接功能

This commit is contained in:
fish
2026-03-28 19:57:20 +08:00
parent 03728d743e
commit be24b465b1
7 changed files with 379 additions and 0 deletions

View File

@@ -0,0 +1,144 @@
package ws
import (
"encoding/json"
"log"
"time"
"github.com/gorilla/websocket"
)
const (
// 写入超时时间
writeWait = 10 * time.Second
// 读取超时时间
pongWait = 60 * time.Second
// 发送 ping 的时间间隔,必须小于 pongWait
pingPeriod = (pongWait * 9) / 10
// 最大消息大小
maxMessageSize = 4096
)
type Client struct {
hub *Hub
conn *websocket.Conn
send chan *Message
}
func NewClient(hub *Hub, conn *websocket.Conn) *Client {
return &Client{
hub: hub,
conn: conn,
send: make(chan *Message, 256),
}
}
func (c *Client) readPump() {
defer func() {
c.hub.unregister <- c
c.conn.Close()
}()
c.conn.SetReadLimit(maxMessageSize)
c.conn.SetReadDeadline(time.Now().Add(pongWait))
c.conn.SetPongHandler(func(string) error {
c.conn.SetReadDeadline(time.Now().Add(pongWait))
return nil
})
for {
_, message, err := c.conn.ReadMessage()
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
log.Printf("error: %v", err)
}
break
}
var msg Message
if err := json.Unmarshal(message, &msg); err != nil {
log.Printf("error unmarshaling message: %v", err)
continue
}
// 处理消息
switch msg.Type {
case MessageTypePing:
// 回复 pong
pongMsg := &Message{Type: MessageTypePong}
c.send <- pongMsg
case MessageTypeText:
// 广播文本消息
c.hub.broadcast <- &msg
case MessageTypeCommand:
// 处理命令
c.handleCommand(&msg)
}
}
}
func (c *Client) writePump() {
ticker := time.NewTicker(pingPeriod)
defer func() {
ticker.Stop()
c.conn.Close()
}()
for {
select {
case message, ok := <-c.send:
c.conn.SetWriteDeadline(time.Now().Add(writeWait))
if !ok {
// Hub 关闭了通道
c.conn.WriteMessage(websocket.CloseMessage, []byte{})
return
}
w, err := c.conn.NextWriter(websocket.TextMessage)
if err != nil {
return
}
msgBytes, err := json.Marshal(message)
if err != nil {
log.Printf("error marshaling message: %v", err)
return
}
w.Write(msgBytes)
if err := w.Close(); err != nil {
return
}
case <-ticker.C:
c.conn.SetWriteDeadline(time.Now().Add(writeWait))
pingMsg := &Message{Type: MessageTypePing}
pingBytes, err := json.Marshal(pingMsg)
if err != nil {
return
}
if err := c.conn.WriteMessage(websocket.TextMessage, pingBytes); err != nil {
return
}
}
}
}
func (c *Client) handleCommand(msg *Message) {
// 处理命令逻辑
// 这里可以根据命令类型执行不同的操作
log.Printf("Received command: %v", msg.Data)
// 示例:回复命令执行结果
response := &Message{
Type: MessageTypeText,
Content: "Command executed successfully",
Data: msg.Data,
}
c.send <- response
}

View File

@@ -0,0 +1,55 @@
package ws
import (
"log"
)
type Hub struct {
// 注册的客户端
clients map[*Client]bool
// 从客户端接收的消息
broadcast chan *Message
// 注册请求
register chan *Client
// 注销请求
unregister chan *Client
}
func NewHub() *Hub {
return &Hub{
broadcast: make(chan *Message),
register: make(chan *Client),
unregister: make(chan *Client),
clients: make(map[*Client]bool),
}
}
func (h *Hub) Run() {
for {
select {
case client := <-h.register:
h.clients[client] = true
log.Printf("Client connected. Total clients: %d", len(h.clients))
case client := <-h.unregister:
if _, ok := h.clients[client]; ok {
delete(h.clients, client)
close(client.send)
log.Printf("Client disconnected. Total clients: %d", len(h.clients))
}
case message := <-h.broadcast:
for client := range h.clients {
select {
case client.send <- message:
default:
close(client.send)
delete(h.clients, client)
}
}
}
}
}

View File

@@ -0,0 +1,18 @@
package ws
type MessageType string
const (
MessageTypePing MessageType = "ping"
MessageTypePong MessageType = "pong"
MessageTypeText MessageType = "text"
MessageTypeError MessageType = "error"
MessageTypeCommand MessageType = "command"
)
type Message struct {
Type MessageType `json:"type"`
Content string `json:"content,omitempty"`
Data interface{} `json:"data,omitempty"`
Error string `json:"error,omitempty"`
}