127 lines
4.2 KiB
Go
127 lines
4.2 KiB
Go
package repository
|
|
|
|
import (
|
|
"database/sql"
|
|
"time"
|
|
|
|
"backend/services/user-svc/internal/domain"
|
|
"backend/shared/pkg/errors"
|
|
|
|
"github.com/google/uuid"
|
|
"golang.org/x/crypto/bcrypt"
|
|
)
|
|
|
|
type UserRepository struct {
|
|
db *sql.DB
|
|
}
|
|
|
|
func NewUserRepository(db *sql.DB) *UserRepository {
|
|
return &UserRepository{db: db}
|
|
}
|
|
|
|
// Register 用户注册
|
|
func (r *UserRepository) Register(req *domain.RegisterRequest) (*domain.RegisterResponse, error) {
|
|
// 开启事务
|
|
tx, err := r.db.Begin()
|
|
if err != nil {
|
|
return nil, errors.WrapError(err, "failed to begin transaction")
|
|
}
|
|
|
|
// 检查账号是否已存在
|
|
var exists bool
|
|
query := "SELECT EXISTS(SELECT 1 FROM user_login_account WHERE account = $1 AND deleted = false)"
|
|
if err := tx.QueryRow(query, req.Account).Scan(&exists); err != nil {
|
|
tx.Rollback()
|
|
return nil, errors.WrapError(err, "failed to check account existence")
|
|
}
|
|
|
|
if exists {
|
|
tx.Rollback()
|
|
return nil, errors.ErrConflict
|
|
}
|
|
|
|
// 创建用户
|
|
userID := uuid.New()
|
|
now := time.Now()
|
|
|
|
userQuery := "INSERT INTO user_main (id, deleted, create_time, update_time) VALUES ($1, $2, $3, $4)"
|
|
if _, err := tx.Exec(userQuery, userID, false, now, now); err != nil {
|
|
tx.Rollback()
|
|
return nil, errors.WrapError(err, "failed to create user")
|
|
}
|
|
|
|
// 创建登录账号
|
|
accountID := uuid.New()
|
|
accountQuery := "INSERT INTO user_login_account (id, user_id, account, deleted, create_time, update_time) VALUES ($1, $2, $3, $4, $5, $6)"
|
|
if _, err := tx.Exec(accountQuery, accountID, userID, req.Account, false, now, now); err != nil {
|
|
tx.Rollback()
|
|
return nil, errors.WrapError(err, "failed to create login account")
|
|
}
|
|
|
|
// 加密密码
|
|
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(req.Password), bcrypt.DefaultCost)
|
|
if err != nil {
|
|
tx.Rollback()
|
|
return nil, errors.WrapError(err, "failed to hash password")
|
|
}
|
|
|
|
// 创建密码记录
|
|
passwordID := uuid.New()
|
|
passwordQuery := "INSERT INTO user_login_password (id, user_id, password, deleted, create_time, update_time) VALUES ($1, $2, $3, $4, $5, $6)"
|
|
if _, err := tx.Exec(passwordQuery, passwordID, userID, string(hashedPassword), false, now, now); err != nil {
|
|
tx.Rollback()
|
|
return nil, errors.WrapError(err, "failed to create password record")
|
|
}
|
|
|
|
// 提交事务
|
|
if err := tx.Commit(); err != nil {
|
|
return nil, errors.WrapError(err, "failed to commit transaction")
|
|
}
|
|
|
|
return &domain.RegisterResponse{
|
|
UserID: userID,
|
|
Account: req.Account,
|
|
}, nil
|
|
}
|
|
|
|
// GetUserByAccount 根据账号获取用户信息
|
|
func (r *UserRepository) GetUserByAccount(account string) (*domain.User, *domain.UserLoginAccount, *domain.UserLoginPassword, error) {
|
|
// 获取登录账号
|
|
var accountRecord domain.UserLoginAccount
|
|
accountQuery := "SELECT id, user_id, account, deleted, create_time, update_time FROM user_login_account WHERE account = $1 AND deleted = false"
|
|
if err := r.db.QueryRow(accountQuery, account).Scan(
|
|
&accountRecord.ID, &accountRecord.UserID, &accountRecord.Account, &accountRecord.Deleted, &accountRecord.CreateTime, &accountRecord.UpdateTime,
|
|
); err != nil {
|
|
if err == sql.ErrNoRows {
|
|
return nil, nil, nil, errors.ErrNotFound
|
|
}
|
|
return nil, nil, nil, errors.WrapError(err, "failed to get login account")
|
|
}
|
|
|
|
// 获取用户信息
|
|
var user domain.User
|
|
userQuery := "SELECT id, deleted, create_time, update_time FROM user_main WHERE id = $1 AND deleted = false"
|
|
if err := r.db.QueryRow(userQuery, accountRecord.UserID).Scan(
|
|
&user.ID, &user.Deleted, &user.CreateTime, &user.UpdateTime,
|
|
); err != nil {
|
|
if err == sql.ErrNoRows {
|
|
return nil, nil, nil, errors.ErrNotFound
|
|
}
|
|
return nil, nil, nil, errors.WrapError(err, "failed to get user")
|
|
}
|
|
|
|
// 获取密码记录
|
|
var passwordRecord domain.UserLoginPassword
|
|
passwordQuery := "SELECT id, user_id, password, deleted, create_time, update_time FROM user_login_password WHERE user_id = $1 AND deleted = false"
|
|
if err := r.db.QueryRow(passwordQuery, accountRecord.UserID).Scan(
|
|
&passwordRecord.ID, &passwordRecord.UserID, &passwordRecord.Password, &passwordRecord.Deleted, &passwordRecord.CreateTime, &passwordRecord.UpdateTime,
|
|
); err != nil {
|
|
if err == sql.ErrNoRows {
|
|
return nil, nil, nil, errors.ErrNotFound
|
|
}
|
|
return nil, nil, nil, errors.WrapError(err, "failed to get password record")
|
|
}
|
|
|
|
return &user, &accountRecord, &passwordRecord, nil
|
|
}
|