From c7b11dca359fba7328e2d73286bbfdfa3d16c6dd Mon Sep 17 00:00:00 2001 From: vipg Date: Mon, 9 Feb 2026 17:43:42 +0800 Subject: [PATCH] add --- trading_assistant_api/common/auth/jwt.go | 49 +++++++++++++++++++ trading_assistant_api/common/go.mod | 1 + .../services/user/docker-compose.yml | 2 + .../user/internal/handler/user_handler.go | 8 +-- .../user/internal/service/user_service.go | 45 ++++++++++------- 5 files changed, 83 insertions(+), 22 deletions(-) create mode 100644 trading_assistant_api/common/auth/jwt.go diff --git a/trading_assistant_api/common/auth/jwt.go b/trading_assistant_api/common/auth/jwt.go new file mode 100644 index 0000000..85c398b --- /dev/null +++ b/trading_assistant_api/common/auth/jwt.go @@ -0,0 +1,49 @@ +package auth + +import ( + "errors" + "time" + + "common/utils" + "github.com/golang-jwt/jwt/v5" +) + +func GenerateToken(userID string) (string, error) { + secret := utils.GetEnv("JWT_SECRET", "") + if secret == "" { + return "", errors.New("jwt_secret_missing") + } + ttlStr := utils.GetEnv("JWT_TTL", "24h") + ttl, err := time.ParseDuration(ttlStr) + if err != nil { + return "", err + } + issuer := utils.GetEnv("JWT_ISSUER", "trading-assistant") + now := time.Now() + claims := jwt.RegisteredClaims{ + Subject: userID, + Issuer: issuer, + IssuedAt: jwt.NewNumericDate(now), + ExpiresAt: jwt.NewNumericDate(now.Add(ttl)), + } + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + return token.SignedString([]byte(secret)) +} + +func ParseToken(tokenStr string) (string, error) { + secret := utils.GetEnv("JWT_SECRET", "") + if secret == "" { + return "", errors.New("jwt_secret_missing") + } + var claims jwt.RegisteredClaims + tkn, err := jwt.ParseWithClaims(tokenStr, &claims, func(token *jwt.Token) (interface{}, error) { + return []byte(secret), nil + }) + if err != nil { + return "", err + } + if !tkn.Valid { + return "", errors.New("token_invalid") + } + return claims.Subject, nil +} diff --git a/trading_assistant_api/common/go.mod b/trading_assistant_api/common/go.mod index d9b9569..3f41a04 100644 --- a/trading_assistant_api/common/go.mod +++ b/trading_assistant_api/common/go.mod @@ -3,3 +3,4 @@ module common go 1.25.7 require github.com/jackc/pgx/v5 v5.6.0 +require github.com/golang-jwt/jwt/v5 v5.2.1 diff --git a/trading_assistant_api/services/user/docker-compose.yml b/trading_assistant_api/services/user/docker-compose.yml index 075bd30..3f036aa 100644 --- a/trading_assistant_api/services/user/docker-compose.yml +++ b/trading_assistant_api/services/user/docker-compose.yml @@ -31,6 +31,8 @@ services: PG_SSLMODE: disable MIGRATE_ON_START: "0" PORT: "8080" + JWT_SECRET: "change_me_dev_secret" + JWT_TTL: "24h" ports: - "8080:8080" depends_on: diff --git a/trading_assistant_api/services/user/internal/handler/user_handler.go b/trading_assistant_api/services/user/internal/handler/user_handler.go index 11bc1a0..d5f7ba9 100644 --- a/trading_assistant_api/services/user/internal/handler/user_handler.go +++ b/trading_assistant_api/services/user/internal/handler/user_handler.go @@ -28,7 +28,7 @@ func (h *Handler) Register(w http.ResponseWriter, r *http.Request) { httpx.BadRequest(w, "invalid json") return } - userID, err := h.S.Register(req.Account, req.Password) + userID, token, err := h.S.Register(req.Account, req.Password) if err != nil { switch err { case service.ErrInvalidInput: @@ -40,7 +40,7 @@ func (h *Handler) Register(w http.ResponseWriter, r *http.Request) { } return } - httpx.Created(w, map[string]string{"user_id": userID}) + httpx.Created(w, map[string]string{"user_id": userID, "token": token}) } func (h *Handler) Login(w http.ResponseWriter, r *http.Request) { @@ -53,7 +53,7 @@ func (h *Handler) Login(w http.ResponseWriter, r *http.Request) { httpx.BadRequest(w, "invalid json") return } - userID, err := h.S.Login(req.Account, req.Password) + userID, token, err := h.S.Login(req.Account, req.Password) if err != nil { switch err { case service.ErrInvalidInput: @@ -65,7 +65,7 @@ func (h *Handler) Login(w http.ResponseWriter, r *http.Request) { } return } - httpx.OK(w, map[string]string{"user_id": userID}) + httpx.OK(w, map[string]string{"user_id": userID, "token": token}) } func (h *Handler) Healthz(w http.ResponseWriter, r *http.Request) { diff --git a/trading_assistant_api/services/user/internal/service/user_service.go b/trading_assistant_api/services/user/internal/service/user_service.go index 6001da1..ddf8617 100644 --- a/trading_assistant_api/services/user/internal/service/user_service.go +++ b/trading_assistant_api/services/user/internal/service/user_service.go @@ -4,6 +4,7 @@ import ( "database/sql" "errors" + "common/auth" "github.com/jackc/pgconn" "golang.org/x/crypto/bcrypt" "user/internal/repository" @@ -23,64 +24,72 @@ func New(repo *repository.Repo) *Service { return &Service{Repo: repo} } -func (s *Service) Register(account, password string) (string, error) { +func (s *Service) Register(account, password string) (string, string, error) { if !validAccount(account) || !validPassword(password) { - return "", ErrInvalidInput + return "", "", ErrInvalidInput } hashed, err := bcrypt.GenerateFromPassword([]byte(password), 12) if err != nil { - return "", err + return "", "", err } tx, err := s.Repo.DB.Begin() if err != nil { - return "", err + return "", "", err } defer func() { _ = tx.Rollback() }() userID, err := s.Repo.CreateUser(tx) if err != nil { - return "", err + return "", "", err } if err := s.Repo.CreateLoginAccount(tx, userID, account); err != nil { if isUniqueViolation(err) { - return "", ErrAccountExists + return "", "", ErrAccountExists } - return "", err + return "", "", err } if err := s.Repo.CreateLoginPassword(tx, userID, string(hashed)); err != nil { - return "", err + return "", "", err } if err := tx.Commit(); err != nil { - return "", err + return "", "", err } - return userID, nil + tkn, err := auth.GenerateToken(userID) + if err != nil { + return "", "", err + } + return userID, tkn, nil } -func (s *Service) Login(account, password string) (string, error) { +func (s *Service) Login(account, password string) (string, string, error) { if !validAccount(account) || !validPassword(password) { - return "", ErrInvalidInput + return "", "", ErrInvalidInput } userID, err := s.Repo.GetUserIDByAccount(account) if err != nil { - return "", ErrUnauthorized + return "", "", ErrUnauthorized } hashed, err := s.Repo.GetHashedPassword(userID) if err != nil { - return "", ErrUnauthorized + return "", "", ErrUnauthorized } if bcrypt.CompareHashAndPassword([]byte(hashed), []byte(password)) != nil { - return "", ErrUnauthorized + return "", "", ErrUnauthorized } - return userID, nil + tkn, err := auth.GenerateToken(userID) + if err != nil { + return "", "", err + } + return userID, tkn, nil } func validAccount(a string) bool { n := len(a) - return n >= 3 && n <= 20 + return n >= 3 && n <= 100 } func validPassword(p string) bool { n := len(p) - return n >= 8 && n <= 20 + return n >= 8 && n <= 128 } func isUniqueViolation(err error) bool {