add
This commit is contained in:
@@ -0,0 +1,152 @@
|
||||
// common/middleware/cors.go
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// CORSOptions 跨域配置项
|
||||
// 支持自定义允许的源、方法、头、凭证、缓存时间,按需扩展
|
||||
type CORSOptions struct {
|
||||
AllowOrigins []string // 允许的跨域源,如["http://localhost:8080", "https://xxx.com"],*表示允许所有
|
||||
AllowMethods []string // 允许的HTTP方法,默认GET/POST/PUT/DELETE/PATCH/OPTIONS
|
||||
AllowHeaders []string // 允许的请求头,*表示允许所有
|
||||
AllowCredentials bool // 是否允许携带凭证(Cookie/Token),前后端联调必备
|
||||
ExposeHeaders []string // 允许前端获取的响应头
|
||||
MaxAge time.Duration // 预检请求(OPTIONS)的缓存时间,避免重复预检
|
||||
}
|
||||
|
||||
// CORSOption 选项模式函数类型,用于灵活配置跨域参数
|
||||
type CORSOption func(*CORSOptions)
|
||||
|
||||
// defaultCORSOptions 初始化默认跨域配置
|
||||
// 开发环境默认允许所有源、常用方法,生产环境可通过配置覆盖
|
||||
func defaultCORSOptions() *CORSOptions {
|
||||
return &CORSOptions{
|
||||
AllowOrigins: []string{"*"}, // 开发环境默认允许所有源
|
||||
AllowMethods: []string{http.MethodGet, http.MethodPost, http.MethodPut, http.MethodDelete, http.MethodPatch, http.MethodOptions},
|
||||
AllowHeaders: []string{"*"}, // 允许所有请求头
|
||||
AllowCredentials: true, // 允许携带凭证
|
||||
ExposeHeaders: []string{"Content-Length", "Content-Type", "X-Token"},
|
||||
MaxAge: 12 * time.Hour, // 预检请求缓存12小时
|
||||
}
|
||||
}
|
||||
|
||||
// 以下为配置项设置函数,支持链式调用
|
||||
// WithAllowOrigins 设置允许的跨域源,示例:WithAllowOrigins("http://localhost:3000", "https://app.com")
|
||||
func WithAllowOrigins(origins ...string) CORSOption {
|
||||
return func(o *CORSOptions) { o.AllowOrigins = origins }
|
||||
}
|
||||
|
||||
// WithAllowCredentials 设置是否允许携带凭证(Cookie/Token)
|
||||
func WithAllowCredentials(allow bool) CORSOption {
|
||||
return func(o *CORSOptions) { o.AllowCredentials = allow }
|
||||
}
|
||||
|
||||
// WithMaxAge 设置预检请求缓存时间
|
||||
func WithMaxAge(age time.Duration) CORSOption {
|
||||
return func(o *CORSOptions) { o.MaxAge = age }
|
||||
}
|
||||
|
||||
// WithAllowHeaders 设置允许的请求头
|
||||
func WithAllowHeaders(headers ...string) CORSOption {
|
||||
return func(o *CORSOptions) { o.AllowHeaders = headers }
|
||||
}
|
||||
|
||||
// CORS 跨域中间件核心方法
|
||||
// 适配Go原生http.Handler,可直接用于Gin/Echo等框架(兼容框架中间件规范)
|
||||
func CORS(opts ...CORSOption) func(http.Handler) http.Handler {
|
||||
// 加载默认配置 + 覆盖用户自定义配置
|
||||
options := defaultCORSOptions()
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
}
|
||||
|
||||
// 处理允许的源:拼接为字符串
|
||||
allowOrigins := strings.Join(options.AllowOrigins, ", ")
|
||||
// 处理允许的方法:拼接为字符串
|
||||
allowMethods := strings.Join(options.AllowMethods, ", ")
|
||||
// 处理允许的请求头:拼接为字符串
|
||||
allowHeaders := strings.Join(options.AllowHeaders, ", ")
|
||||
// 处理允许暴露的响应头:拼接为字符串
|
||||
exposeHeaders := strings.Join(options.ExposeHeaders, ", ")
|
||||
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// 1. 获取前端请求的Origin(跨域核心)
|
||||
origin := r.Header.Get("Origin")
|
||||
// 若配置了*,则直接使用请求的Origin;否则使用配置的源(生产环境建议精准配置)
|
||||
if len(options.AllowOrigins) > 0 && options.AllowOrigins[0] == "*" {
|
||||
w.Header().Set("Access-Control-Allow-Origin", origin)
|
||||
} else {
|
||||
w.Header().Set("Access-Control-Allow-Origin", allowOrigins)
|
||||
}
|
||||
|
||||
// 2. 设置跨域核心响应头
|
||||
w.Header().Set("Access-Control-Allow-Methods", allowMethods)
|
||||
w.Header().Set("Access-Control-Allow-Headers", allowHeaders)
|
||||
w.Header().Set("Access-Control-Expose-Headers", exposeHeaders)
|
||||
w.Header().Set("Access-Control-Max-Age", string(rune(options.MaxAge.Seconds())))
|
||||
// 允许携带凭证时,不能将Allow-Origin设为*,需动态匹配请求Origin(已做处理)
|
||||
if options.AllowCredentials {
|
||||
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
}
|
||||
|
||||
// 3. 处理预检请求(OPTIONS):直接返回204,无需执行后续业务逻辑
|
||||
if r.Method == http.MethodOptions {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
|
||||
// 4. 非预检请求,执行后续业务逻辑
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ---- 兼容Gin框架的快捷中间件(可选)----
|
||||
// 若团队使用Gin框架开发,可直接使用此方法,无需额外转换,提升开发效率
|
||||
// 需提前安装Gin:go get github.com/gin-gonic/gin
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// CorsGin 适配Gin框架的跨域中间件
|
||||
func CorsGin(opts ...CORSOption) gin.HandlerFunc {
|
||||
// 复用原生CORS配置逻辑
|
||||
options := defaultCORSOptions()
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
}
|
||||
allowOrigins := strings.Join(options.AllowOrigins, ", ")
|
||||
allowMethods := strings.Join(options.Methods, ", ")
|
||||
allowHeaders := strings.Join(options.AllowHeaders, ", ")
|
||||
exposeHeaders := strings.Join(options.ExposeHeaders, ", ")
|
||||
|
||||
return func(c *gin.Context) {
|
||||
origin := c.Request.Header.Get("Origin")
|
||||
if options.AllowOrigins[0] == "*" {
|
||||
c.Writer.Header().Set("Access-Control-Allow-Origin", origin)
|
||||
} else {
|
||||
c.Writer.Header().Set("Access-Control-Allow-Origin", allowOrigins)
|
||||
}
|
||||
c.Writer.Header().Set("Access-Control-Allow-Methods", allowMethods)
|
||||
c.Writer.Header().Set("Access-Control-Allow-Headers", allowHeaders)
|
||||
c.Writer.Header().Set("Access-Control-Expose-Headers", exposeHeaders)
|
||||
c.Writer.Header().Set("Access-Control-Max-Age", string(rune(options.MaxAge.Seconds())))
|
||||
if options.AllowCredentials {
|
||||
c.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
}
|
||||
|
||||
// 处理预检请求
|
||||
if c.Request.Method == http.MethodOptions {
|
||||
c.AbortWithStatus(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
|
||||
// 继续执行后续中间件/业务逻辑
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user