diff --git a/backend/gateway/cmd/main.go b/backend/gateway/cmd/main.go new file mode 100644 index 0000000..f56cca3 --- /dev/null +++ b/backend/gateway/cmd/main.go @@ -0,0 +1,36 @@ +package main + +import ( + "fmt" + "log" + "net/http" + + "backend/gateway/internal/config" + "backend/gateway/internal/router" + "backend/gateway/internal/ws" +) + +func main() { + // 加载配置 + cfg, err := config.Load() + if err != nil { + log.Fatalf("Failed to load config: %v", err) + } + + // 创建 WebSocket Hub + hub := ws.NewHub() + go hub.Run() + + // 创建路由器 + r := router.NewRouter(hub) + + // 启动服务器 + serverAddr := fmt.Sprintf(":%d", cfg.Server.Port) + log.Printf("Gateway service starting on %s", serverAddr) + log.Printf("WebSocket endpoint: ws://localhost%s/ws", serverAddr) + log.Printf("Health check: http://localhost%s/health", serverAddr) + + if err := http.ListenAndServe(serverAddr, r.SetupRoutes()); err != nil { + log.Fatalf("Failed to start server: %v", err) + } +} diff --git a/backend/gateway/go.mod b/backend/gateway/go.mod new file mode 100644 index 0000000..43e3902 --- /dev/null +++ b/backend/gateway/go.mod @@ -0,0 +1,8 @@ +module backend/gateway + +go 1.26.1 + +require ( + github.com/gorilla/websocket v1.5.1 + github.com/spf13/viper v1.19.0 +) \ No newline at end of file diff --git a/backend/gateway/internal/config/config.go b/backend/gateway/internal/config/config.go new file mode 100644 index 0000000..605e455 --- /dev/null +++ b/backend/gateway/internal/config/config.go @@ -0,0 +1,56 @@ +package config + +import ( + "github.com/spf13/viper" +) + +type Config struct { + Server ServerConfig + Redis RedisConfig + Services ServicesConfig +} + +type ServerConfig struct { + Port int +} + +type RedisConfig struct { + Addr string + Password string + DB int +} + +type ServicesConfig struct { + UserService UserServiceConfig +} + +type UserServiceConfig struct { + Addr string +} + +func Load() (*Config, error) { + viper.SetConfigName("config") + viper.SetConfigType("yaml") + viper.AddConfigPath("./config") + viper.AddConfigPath("../config") + viper.AddConfigPath("../../config") + + viper.SetDefault("server.port", 8000) + viper.SetDefault("redis.addr", "redis:6379") + viper.SetDefault("redis.password", "") + viper.SetDefault("redis.db", 0) + viper.SetDefault("services.userService.addr", "user-svc:9000") + + if err := viper.ReadInConfig(); err != nil { + if _, ok := err.(viper.ConfigFileNotFoundError); !ok { + return nil, err + } + } + + var config Config + if err := viper.Unmarshal(&config); err != nil { + return nil, err + } + + return &config, nil +} \ No newline at end of file diff --git a/backend/gateway/internal/router/router.go b/backend/gateway/internal/router/router.go new file mode 100644 index 0000000..c1cabc1 --- /dev/null +++ b/backend/gateway/internal/router/router.go @@ -0,0 +1,62 @@ +package router + +import ( + "log" + "net/http" + + "backend/gateway/internal/ws" + + "github.com/gorilla/websocket" +) + +var upgrader = websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + // 允许所有来源的请求 + CheckOrigin: func(r *http.Request) bool { + return true + }, +} + +type Router struct { + hub *ws.Hub +} + +func NewRouter(hub *ws.Hub) *Router { + return &Router{ + hub: hub, + } +} + +func (r *Router) SetupRoutes() http.Handler { + mux := http.NewServeMux() + + // WebSocket 连接处理 + mux.HandleFunc("/ws", r.handleWebSocket) + + // 健康检查 + mux.HandleFunc("/health", r.handleHealth) + + return mux +} + +func (r *Router) handleWebSocket(w http.ResponseWriter, req *http.Request) { + conn, err := upgrader.Upgrade(w, req, nil) + if err != nil { + log.Println(err) + return + } + + client := ws.NewClient(r.hub, conn) + r.hub.register <- client + + // 启动客户端的读写协程 + go client.writePump() + go client.readPump() +} + +func (r *Router) handleHealth(w http.ResponseWriter, req *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"status": "ok"}`)) +} diff --git a/backend/gateway/internal/ws/client.go b/backend/gateway/internal/ws/client.go new file mode 100644 index 0000000..036d5dc --- /dev/null +++ b/backend/gateway/internal/ws/client.go @@ -0,0 +1,144 @@ +package ws + +import ( + "encoding/json" + "log" + "time" + + "github.com/gorilla/websocket" +) + +const ( + // 写入超时时间 + writeWait = 10 * time.Second + + // 读取超时时间 + pongWait = 60 * time.Second + + // 发送 ping 的时间间隔,必须小于 pongWait + pingPeriod = (pongWait * 9) / 10 + + // 最大消息大小 + maxMessageSize = 4096 +) + +type Client struct { + hub *Hub + conn *websocket.Conn + send chan *Message +} + +func NewClient(hub *Hub, conn *websocket.Conn) *Client { + return &Client{ + hub: hub, + conn: conn, + send: make(chan *Message, 256), + } +} + +func (c *Client) readPump() { + defer func() { + c.hub.unregister <- c + c.conn.Close() + }() + + c.conn.SetReadLimit(maxMessageSize) + c.conn.SetReadDeadline(time.Now().Add(pongWait)) + c.conn.SetPongHandler(func(string) error { + c.conn.SetReadDeadline(time.Now().Add(pongWait)) + return nil + }) + + for { + _, message, err := c.conn.ReadMessage() + if err != nil { + if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { + log.Printf("error: %v", err) + } + break + } + + var msg Message + if err := json.Unmarshal(message, &msg); err != nil { + log.Printf("error unmarshaling message: %v", err) + continue + } + + // 处理消息 + switch msg.Type { + case MessageTypePing: + // 回复 pong + pongMsg := &Message{Type: MessageTypePong} + c.send <- pongMsg + case MessageTypeText: + // 广播文本消息 + c.hub.broadcast <- &msg + case MessageTypeCommand: + // 处理命令 + c.handleCommand(&msg) + } + } +} + +func (c *Client) writePump() { + ticker := time.NewTicker(pingPeriod) + defer func() { + ticker.Stop() + c.conn.Close() + }() + + for { + select { + case message, ok := <-c.send: + c.conn.SetWriteDeadline(time.Now().Add(writeWait)) + if !ok { + // Hub 关闭了通道 + c.conn.WriteMessage(websocket.CloseMessage, []byte{}) + return + } + + w, err := c.conn.NextWriter(websocket.TextMessage) + if err != nil { + return + } + + msgBytes, err := json.Marshal(message) + if err != nil { + log.Printf("error marshaling message: %v", err) + return + } + + w.Write(msgBytes) + + if err := w.Close(); err != nil { + return + } + + case <-ticker.C: + c.conn.SetWriteDeadline(time.Now().Add(writeWait)) + pingMsg := &Message{Type: MessageTypePing} + pingBytes, err := json.Marshal(pingMsg) + if err != nil { + return + } + if err := c.conn.WriteMessage(websocket.TextMessage, pingBytes); err != nil { + return + } + } + } +} + +func (c *Client) handleCommand(msg *Message) { + // 处理命令逻辑 + // 这里可以根据命令类型执行不同的操作 + log.Printf("Received command: %v", msg.Data) + + // 示例:回复命令执行结果 + response := &Message{ + Type: MessageTypeText, + Content: "Command executed successfully", + Data: msg.Data, + } + + c.send <- response +} diff --git a/backend/gateway/internal/ws/hub.go b/backend/gateway/internal/ws/hub.go new file mode 100644 index 0000000..5a4ddd3 --- /dev/null +++ b/backend/gateway/internal/ws/hub.go @@ -0,0 +1,55 @@ +package ws + +import ( + "log" +) + +type Hub struct { + // 注册的客户端 + clients map[*Client]bool + + // 从客户端接收的消息 + broadcast chan *Message + + // 注册请求 + register chan *Client + + // 注销请求 + unregister chan *Client +} + +func NewHub() *Hub { + return &Hub{ + broadcast: make(chan *Message), + register: make(chan *Client), + unregister: make(chan *Client), + clients: make(map[*Client]bool), + } +} + +func (h *Hub) Run() { + for { + select { + case client := <-h.register: + h.clients[client] = true + log.Printf("Client connected. Total clients: %d", len(h.clients)) + + case client := <-h.unregister: + if _, ok := h.clients[client]; ok { + delete(h.clients, client) + close(client.send) + log.Printf("Client disconnected. Total clients: %d", len(h.clients)) + } + + case message := <-h.broadcast: + for client := range h.clients { + select { + case client.send <- message: + default: + close(client.send) + delete(h.clients, client) + } + } + } + } +} diff --git a/backend/gateway/internal/ws/message.go b/backend/gateway/internal/ws/message.go new file mode 100644 index 0000000..3a6da55 --- /dev/null +++ b/backend/gateway/internal/ws/message.go @@ -0,0 +1,18 @@ +package ws + +type MessageType string + +const ( + MessageTypePing MessageType = "ping" + MessageTypePong MessageType = "pong" + MessageTypeText MessageType = "text" + MessageTypeError MessageType = "error" + MessageTypeCommand MessageType = "command" +) + +type Message struct { + Type MessageType `json:"type"` + Content string `json:"content,omitempty"` + Data interface{} `json:"data,omitempty"` + Error string `json:"error,omitempty"` +}