diff --git a/trading_assistant_api/common/auth/jwt.go b/trading_assistant_api/common/auth/jwt.go index 85c398b..34fc4a6 100644 --- a/trading_assistant_api/common/auth/jwt.go +++ b/trading_assistant_api/common/auth/jwt.go @@ -5,6 +5,7 @@ import ( "time" "common/utils" + "common/logger" "github.com/golang-jwt/jwt/v5" ) @@ -27,7 +28,11 @@ func GenerateToken(userID string) (string, error) { ExpiresAt: jwt.NewNumericDate(now.Add(ttl)), } token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) - return token.SignedString([]byte(secret)) + signed, err := token.SignedString([]byte(secret)) + if err == nil { + logger.L().Printf("jwt generate success user=%s exp=%s", userID, claims.ExpiresAt.Time.Format(time.RFC3339)) + } + return signed, err } func ParseToken(tokenStr string) (string, error) { @@ -40,10 +45,13 @@ func ParseToken(tokenStr string) (string, error) { return []byte(secret), nil }) if err != nil { + logger.L().Printf("jwt parse error: %v", err) return "", err } if !tkn.Valid { + logger.L().Printf("jwt invalid") return "", errors.New("token_invalid") } + logger.L().Printf("jwt parse success user=%s exp=%s", claims.Subject, claims.ExpiresAt.Time.Format(time.RFC3339)) return claims.Subject, nil } diff --git a/trading_assistant_api/common/httpx/auth.go b/trading_assistant_api/common/httpx/auth.go new file mode 100644 index 0000000..784f464 --- /dev/null +++ b/trading_assistant_api/common/httpx/auth.go @@ -0,0 +1,39 @@ +package httpx + +import ( + "context" + "net/http" + "strings" + + "common/auth" +) + +type userIDKey struct{} + +func AuthRequired() func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ah := r.Header.Get("Authorization") + if ah == "" || !strings.HasPrefix(ah, "Bearer ") { + Unauthorized(w, "unauthorized") + return + } + token := strings.TrimSpace(strings.TrimPrefix(ah, "Bearer ")) + sub, err := auth.ParseToken(token) + if err != nil || sub == "" { + Unauthorized(w, "unauthorized") + return + } + ctx := context.WithValue(r.Context(), userIDKey{}, sub) + next.ServeHTTP(w, r.WithContext(ctx)) + }) + } +} + +func UserID(r *http.Request) string { + v := r.Context().Value(userIDKey{}) + if id, ok := v.(string); ok { + return id + } + return "" +} diff --git a/trading_assistant_api/services/user/internal/handler/user_handler.go b/trading_assistant_api/services/user/internal/handler/user_handler.go index d5f7ba9..71f4222 100644 --- a/trading_assistant_api/services/user/internal/handler/user_handler.go +++ b/trading_assistant_api/services/user/internal/handler/user_handler.go @@ -89,5 +89,5 @@ func (h *Handler) Root(w http.ResponseWriter, r *http.Request) { httpx.MethodNotAllowed(w, string(codes.MethodNotAllowed)) return } - httpx.OK(w, map[string]string{"service": "user"}) + httpx.OK(w, map[string]string{"service": "user", "user_id": httpx.UserID(r)}) } diff --git a/trading_assistant_api/services/user/internal/router/router.go b/trading_assistant_api/services/user/internal/router/router.go index 55d9005..607aed2 100644 --- a/trading_assistant_api/services/user/internal/router/router.go +++ b/trading_assistant_api/services/user/internal/router/router.go @@ -3,6 +3,7 @@ package router import ( "net/http" + "common/httpx" "user/internal/handler" ) @@ -10,7 +11,7 @@ func New(h *handler.Handler) http.Handler { mux := http.NewServeMux() mux.HandleFunc("/healthz", h.Healthz) mux.HandleFunc("/version", h.Version) - mux.HandleFunc("/", h.Root) + mux.Handle("/", httpx.AuthRequired()(http.HandlerFunc(h.Root))) mux.HandleFunc("/register", h.Register) mux.HandleFunc("/login", h.Login) return mux