74 lines
1.9 KiB
Go
74 lines
1.9 KiB
Go
package middleware
|
|
|
|
import (
|
|
"context"
|
|
"net/http"
|
|
"strings"
|
|
|
|
"trade/web/internal/auth"
|
|
"trade/web/internal/store"
|
|
)
|
|
|
|
type ctxKey string
|
|
|
|
const userKey ctxKey = "user"
|
|
|
|
type CtxUser struct {
|
|
ID int64
|
|
Username string
|
|
Role string
|
|
}
|
|
|
|
func FromContext(ctx context.Context) (CtxUser, bool) {
|
|
u, ok := ctx.Value(userKey).(CtxUser)
|
|
return u, ok
|
|
}
|
|
|
|
// RequireUser 校验 Authorization Bearer JWT,通过后把 CtxUser 写入 context。
|
|
// 同时校验数据库里的 disabled 状态,被禁用的账户即使持有 token 也会被拒。
|
|
func RequireUser(mgr *auth.Manager, s *store.AuthStore) func(http.Handler) http.Handler {
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
tok := bearer(r)
|
|
if tok == "" {
|
|
writeJSON(w, http.StatusUnauthorized, map[string]string{"error": "missing token"})
|
|
return
|
|
}
|
|
claims, err := mgr.Parse(tok)
|
|
if err != nil {
|
|
writeJSON(w, http.StatusUnauthorized, map[string]string{"error": "invalid token"})
|
|
return
|
|
}
|
|
u, err := s.GetByID(claims.UserID)
|
|
if err != nil || u.Disabled {
|
|
writeJSON(w, http.StatusUnauthorized, map[string]string{"error": "account disabled or removed"})
|
|
return
|
|
}
|
|
ctx := context.WithValue(r.Context(), userKey, CtxUser{
|
|
ID: u.ID, Username: u.Username, Role: u.Role,
|
|
})
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
})
|
|
}
|
|
}
|
|
|
|
func RequireAdmin(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
u, ok := FromContext(r.Context())
|
|
if !ok || u.Role != store.RoleAdmin {
|
|
writeJSON(w, http.StatusForbidden, map[string]string{"error": "admin only"})
|
|
return
|
|
}
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
func bearer(r *http.Request) string {
|
|
h := r.Header.Get("Authorization")
|
|
const p = "Bearer "
|
|
if strings.HasPrefix(h, p) {
|
|
return strings.TrimSpace(h[len(p):])
|
|
}
|
|
return ""
|
|
}
|