feat: 实现网关服务的长连接功能

This commit is contained in:
fish
2026-03-28 19:57:20 +08:00
parent 03728d743e
commit be24b465b1
7 changed files with 379 additions and 0 deletions

View File

@@ -0,0 +1,36 @@
package main
import (
"fmt"
"log"
"net/http"
"backend/gateway/internal/config"
"backend/gateway/internal/router"
"backend/gateway/internal/ws"
)
func main() {
// 加载配置
cfg, err := config.Load()
if err != nil {
log.Fatalf("Failed to load config: %v", err)
}
// 创建 WebSocket Hub
hub := ws.NewHub()
go hub.Run()
// 创建路由器
r := router.NewRouter(hub)
// 启动服务器
serverAddr := fmt.Sprintf(":%d", cfg.Server.Port)
log.Printf("Gateway service starting on %s", serverAddr)
log.Printf("WebSocket endpoint: ws://localhost%s/ws", serverAddr)
log.Printf("Health check: http://localhost%s/health", serverAddr)
if err := http.ListenAndServe(serverAddr, r.SetupRoutes()); err != nil {
log.Fatalf("Failed to start server: %v", err)
}
}

8
backend/gateway/go.mod Normal file
View File

@@ -0,0 +1,8 @@
module backend/gateway
go 1.26.1
require (
github.com/gorilla/websocket v1.5.1
github.com/spf13/viper v1.19.0
)

View File

@@ -0,0 +1,56 @@
package config
import (
"github.com/spf13/viper"
)
type Config struct {
Server ServerConfig
Redis RedisConfig
Services ServicesConfig
}
type ServerConfig struct {
Port int
}
type RedisConfig struct {
Addr string
Password string
DB int
}
type ServicesConfig struct {
UserService UserServiceConfig
}
type UserServiceConfig struct {
Addr string
}
func Load() (*Config, error) {
viper.SetConfigName("config")
viper.SetConfigType("yaml")
viper.AddConfigPath("./config")
viper.AddConfigPath("../config")
viper.AddConfigPath("../../config")
viper.SetDefault("server.port", 8000)
viper.SetDefault("redis.addr", "redis:6379")
viper.SetDefault("redis.password", "")
viper.SetDefault("redis.db", 0)
viper.SetDefault("services.userService.addr", "user-svc:9000")
if err := viper.ReadInConfig(); err != nil {
if _, ok := err.(viper.ConfigFileNotFoundError); !ok {
return nil, err
}
}
var config Config
if err := viper.Unmarshal(&config); err != nil {
return nil, err
}
return &config, nil
}

View File

@@ -0,0 +1,62 @@
package router
import (
"log"
"net/http"
"backend/gateway/internal/ws"
"github.com/gorilla/websocket"
)
var upgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
// 允许所有来源的请求
CheckOrigin: func(r *http.Request) bool {
return true
},
}
type Router struct {
hub *ws.Hub
}
func NewRouter(hub *ws.Hub) *Router {
return &Router{
hub: hub,
}
}
func (r *Router) SetupRoutes() http.Handler {
mux := http.NewServeMux()
// WebSocket 连接处理
mux.HandleFunc("/ws", r.handleWebSocket)
// 健康检查
mux.HandleFunc("/health", r.handleHealth)
return mux
}
func (r *Router) handleWebSocket(w http.ResponseWriter, req *http.Request) {
conn, err := upgrader.Upgrade(w, req, nil)
if err != nil {
log.Println(err)
return
}
client := ws.NewClient(r.hub, conn)
r.hub.register <- client
// 启动客户端的读写协程
go client.writePump()
go client.readPump()
}
func (r *Router) handleHealth(w http.ResponseWriter, req *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"status": "ok"}`))
}

View File

@@ -0,0 +1,144 @@
package ws
import (
"encoding/json"
"log"
"time"
"github.com/gorilla/websocket"
)
const (
// 写入超时时间
writeWait = 10 * time.Second
// 读取超时时间
pongWait = 60 * time.Second
// 发送 ping 的时间间隔,必须小于 pongWait
pingPeriod = (pongWait * 9) / 10
// 最大消息大小
maxMessageSize = 4096
)
type Client struct {
hub *Hub
conn *websocket.Conn
send chan *Message
}
func NewClient(hub *Hub, conn *websocket.Conn) *Client {
return &Client{
hub: hub,
conn: conn,
send: make(chan *Message, 256),
}
}
func (c *Client) readPump() {
defer func() {
c.hub.unregister <- c
c.conn.Close()
}()
c.conn.SetReadLimit(maxMessageSize)
c.conn.SetReadDeadline(time.Now().Add(pongWait))
c.conn.SetPongHandler(func(string) error {
c.conn.SetReadDeadline(time.Now().Add(pongWait))
return nil
})
for {
_, message, err := c.conn.ReadMessage()
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
log.Printf("error: %v", err)
}
break
}
var msg Message
if err := json.Unmarshal(message, &msg); err != nil {
log.Printf("error unmarshaling message: %v", err)
continue
}
// 处理消息
switch msg.Type {
case MessageTypePing:
// 回复 pong
pongMsg := &Message{Type: MessageTypePong}
c.send <- pongMsg
case MessageTypeText:
// 广播文本消息
c.hub.broadcast <- &msg
case MessageTypeCommand:
// 处理命令
c.handleCommand(&msg)
}
}
}
func (c *Client) writePump() {
ticker := time.NewTicker(pingPeriod)
defer func() {
ticker.Stop()
c.conn.Close()
}()
for {
select {
case message, ok := <-c.send:
c.conn.SetWriteDeadline(time.Now().Add(writeWait))
if !ok {
// Hub 关闭了通道
c.conn.WriteMessage(websocket.CloseMessage, []byte{})
return
}
w, err := c.conn.NextWriter(websocket.TextMessage)
if err != nil {
return
}
msgBytes, err := json.Marshal(message)
if err != nil {
log.Printf("error marshaling message: %v", err)
return
}
w.Write(msgBytes)
if err := w.Close(); err != nil {
return
}
case <-ticker.C:
c.conn.SetWriteDeadline(time.Now().Add(writeWait))
pingMsg := &Message{Type: MessageTypePing}
pingBytes, err := json.Marshal(pingMsg)
if err != nil {
return
}
if err := c.conn.WriteMessage(websocket.TextMessage, pingBytes); err != nil {
return
}
}
}
}
func (c *Client) handleCommand(msg *Message) {
// 处理命令逻辑
// 这里可以根据命令类型执行不同的操作
log.Printf("Received command: %v", msg.Data)
// 示例:回复命令执行结果
response := &Message{
Type: MessageTypeText,
Content: "Command executed successfully",
Data: msg.Data,
}
c.send <- response
}

View File

@@ -0,0 +1,55 @@
package ws
import (
"log"
)
type Hub struct {
// 注册的客户端
clients map[*Client]bool
// 从客户端接收的消息
broadcast chan *Message
// 注册请求
register chan *Client
// 注销请求
unregister chan *Client
}
func NewHub() *Hub {
return &Hub{
broadcast: make(chan *Message),
register: make(chan *Client),
unregister: make(chan *Client),
clients: make(map[*Client]bool),
}
}
func (h *Hub) Run() {
for {
select {
case client := <-h.register:
h.clients[client] = true
log.Printf("Client connected. Total clients: %d", len(h.clients))
case client := <-h.unregister:
if _, ok := h.clients[client]; ok {
delete(h.clients, client)
close(client.send)
log.Printf("Client disconnected. Total clients: %d", len(h.clients))
}
case message := <-h.broadcast:
for client := range h.clients {
select {
case client.send <- message:
default:
close(client.send)
delete(h.clients, client)
}
}
}
}
}

View File

@@ -0,0 +1,18 @@
package ws
type MessageType string
const (
MessageTypePing MessageType = "ping"
MessageTypePong MessageType = "pong"
MessageTypeText MessageType = "text"
MessageTypeError MessageType = "error"
MessageTypeCommand MessageType = "command"
)
type Message struct {
Type MessageType `json:"type"`
Content string `json:"content,omitempty"`
Data interface{} `json:"data,omitempty"`
Error string `json:"error,omitempty"`
}