This commit is contained in:
vipg
2026-02-09 17:43:42 +08:00
parent 25c4628b5f
commit c7b11dca35
5 changed files with 83 additions and 22 deletions

View File

@@ -0,0 +1,49 @@
package auth
import (
"errors"
"time"
"common/utils"
"github.com/golang-jwt/jwt/v5"
)
func GenerateToken(userID string) (string, error) {
secret := utils.GetEnv("JWT_SECRET", "")
if secret == "" {
return "", errors.New("jwt_secret_missing")
}
ttlStr := utils.GetEnv("JWT_TTL", "24h")
ttl, err := time.ParseDuration(ttlStr)
if err != nil {
return "", err
}
issuer := utils.GetEnv("JWT_ISSUER", "trading-assistant")
now := time.Now()
claims := jwt.RegisteredClaims{
Subject: userID,
Issuer: issuer,
IssuedAt: jwt.NewNumericDate(now),
ExpiresAt: jwt.NewNumericDate(now.Add(ttl)),
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
return token.SignedString([]byte(secret))
}
func ParseToken(tokenStr string) (string, error) {
secret := utils.GetEnv("JWT_SECRET", "")
if secret == "" {
return "", errors.New("jwt_secret_missing")
}
var claims jwt.RegisteredClaims
tkn, err := jwt.ParseWithClaims(tokenStr, &claims, func(token *jwt.Token) (interface{}, error) {
return []byte(secret), nil
})
if err != nil {
return "", err
}
if !tkn.Valid {
return "", errors.New("token_invalid")
}
return claims.Subject, nil
}

View File

@@ -3,3 +3,4 @@ module common
go 1.25.7
require github.com/jackc/pgx/v5 v5.6.0
require github.com/golang-jwt/jwt/v5 v5.2.1

View File

@@ -31,6 +31,8 @@ services:
PG_SSLMODE: disable
MIGRATE_ON_START: "0"
PORT: "8080"
JWT_SECRET: "change_me_dev_secret"
JWT_TTL: "24h"
ports:
- "8080:8080"
depends_on:

View File

@@ -28,7 +28,7 @@ func (h *Handler) Register(w http.ResponseWriter, r *http.Request) {
httpx.BadRequest(w, "invalid json")
return
}
userID, err := h.S.Register(req.Account, req.Password)
userID, token, err := h.S.Register(req.Account, req.Password)
if err != nil {
switch err {
case service.ErrInvalidInput:
@@ -40,7 +40,7 @@ func (h *Handler) Register(w http.ResponseWriter, r *http.Request) {
}
return
}
httpx.Created(w, map[string]string{"user_id": userID})
httpx.Created(w, map[string]string{"user_id": userID, "token": token})
}
func (h *Handler) Login(w http.ResponseWriter, r *http.Request) {
@@ -53,7 +53,7 @@ func (h *Handler) Login(w http.ResponseWriter, r *http.Request) {
httpx.BadRequest(w, "invalid json")
return
}
userID, err := h.S.Login(req.Account, req.Password)
userID, token, err := h.S.Login(req.Account, req.Password)
if err != nil {
switch err {
case service.ErrInvalidInput:
@@ -65,7 +65,7 @@ func (h *Handler) Login(w http.ResponseWriter, r *http.Request) {
}
return
}
httpx.OK(w, map[string]string{"user_id": userID})
httpx.OK(w, map[string]string{"user_id": userID, "token": token})
}
func (h *Handler) Healthz(w http.ResponseWriter, r *http.Request) {

View File

@@ -4,6 +4,7 @@ import (
"database/sql"
"errors"
"common/auth"
"github.com/jackc/pgconn"
"golang.org/x/crypto/bcrypt"
"user/internal/repository"
@@ -23,64 +24,72 @@ func New(repo *repository.Repo) *Service {
return &Service{Repo: repo}
}
func (s *Service) Register(account, password string) (string, error) {
func (s *Service) Register(account, password string) (string, string, error) {
if !validAccount(account) || !validPassword(password) {
return "", ErrInvalidInput
return "", "", ErrInvalidInput
}
hashed, err := bcrypt.GenerateFromPassword([]byte(password), 12)
if err != nil {
return "", err
return "", "", err
}
tx, err := s.Repo.DB.Begin()
if err != nil {
return "", err
return "", "", err
}
defer func() { _ = tx.Rollback() }()
userID, err := s.Repo.CreateUser(tx)
if err != nil {
return "", err
return "", "", err
}
if err := s.Repo.CreateLoginAccount(tx, userID, account); err != nil {
if isUniqueViolation(err) {
return "", ErrAccountExists
return "", "", ErrAccountExists
}
return "", err
return "", "", err
}
if err := s.Repo.CreateLoginPassword(tx, userID, string(hashed)); err != nil {
return "", err
return "", "", err
}
if err := tx.Commit(); err != nil {
return "", err
return "", "", err
}
return userID, nil
tkn, err := auth.GenerateToken(userID)
if err != nil {
return "", "", err
}
return userID, tkn, nil
}
func (s *Service) Login(account, password string) (string, error) {
func (s *Service) Login(account, password string) (string, string, error) {
if !validAccount(account) || !validPassword(password) {
return "", ErrInvalidInput
return "", "", ErrInvalidInput
}
userID, err := s.Repo.GetUserIDByAccount(account)
if err != nil {
return "", ErrUnauthorized
return "", "", ErrUnauthorized
}
hashed, err := s.Repo.GetHashedPassword(userID)
if err != nil {
return "", ErrUnauthorized
return "", "", ErrUnauthorized
}
if bcrypt.CompareHashAndPassword([]byte(hashed), []byte(password)) != nil {
return "", ErrUnauthorized
return "", "", ErrUnauthorized
}
return userID, nil
tkn, err := auth.GenerateToken(userID)
if err != nil {
return "", "", err
}
return userID, tkn, nil
}
func validAccount(a string) bool {
n := len(a)
return n >= 3 && n <= 20
return n >= 3 && n <= 100
}
func validPassword(p string) bool {
n := len(p)
return n >= 8 && n <= 20
return n >= 8 && n <= 128
}
func isUniqueViolation(err error) bool {