Files

137 lines
4.4 KiB
Go

package repository
import (
"database/sql"
"user-svc/internal/domain"
"shared/pkg/errors"
"github.com/google/uuid"
"golang.org/x/crypto/bcrypt"
)
type UserRepository struct {
db *sql.DB
}
func NewUserRepository(db *sql.DB) *UserRepository {
return &UserRepository{db: db}
}
// Register 用户注册
func (r *UserRepository) Register(req *domain.RegisterRequest) (*domain.RegisterResponse, error) {
// 开启事务
tx, err := r.db.Begin()
if err != nil {
return nil, errors.WrapError(err, "failed to begin transaction")
}
// 检查账号是否已存在
var exists bool
query := "SELECT EXISTS(SELECT 1 FROM user_login_account WHERE account = $1 AND deleted = false)"
if err := tx.QueryRow(query, req.Account).Scan(&exists); err != nil {
tx.Rollback()
return nil, errors.WrapError(err, "failed to check account existence")
}
if exists {
tx.Rollback()
return nil, errors.ErrConflict
}
// 创建用户
userID, err := uuid.NewV7()
if err != nil {
tx.Rollback()
return nil, errors.WrapError(err, "failed to generate user ID")
}
userQuery := "INSERT INTO user_main (id, deleted) VALUES ($1, $2)"
if _, err := tx.Exec(userQuery, userID, false); err != nil {
tx.Rollback()
return nil, errors.WrapError(err, "failed to create user")
}
// 创建登录账号
accountID, err := uuid.NewV7()
if err != nil {
tx.Rollback()
return nil, errors.WrapError(err, "failed to generate account ID")
}
accountQuery := "INSERT INTO user_login_account (id, user_id, account, deleted) VALUES ($1, $2, $3, $4)"
if _, err := tx.Exec(accountQuery, accountID, userID, req.Account, false); err != nil {
tx.Rollback()
return nil, errors.WrapError(err, "failed to create login account")
}
// 加密密码
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(req.Password), bcrypt.DefaultCost)
if err != nil {
tx.Rollback()
return nil, errors.WrapError(err, "failed to hash password")
}
// 创建密码记录
passwordID, err := uuid.NewV7()
if err != nil {
tx.Rollback()
return nil, errors.WrapError(err, "failed to generate password ID")
}
passwordQuery := "INSERT INTO user_login_password (id, user_id, password, deleted) VALUES ($1, $2, $3, $4)"
if _, err := tx.Exec(passwordQuery, passwordID, userID, string(hashedPassword), false); err != nil {
tx.Rollback()
return nil, errors.WrapError(err, "failed to create password record")
}
// 提交事务
if err := tx.Commit(); err != nil {
return nil, errors.WrapError(err, "failed to commit transaction")
}
return &domain.RegisterResponse{
UserID: userID,
Account: req.Account,
}, nil
}
// GetUserByAccount 根据账号获取用户信息
func (r *UserRepository) GetUserByAccount(account string) (*domain.User, *domain.UserLoginAccount, *domain.UserLoginPassword, error) {
// 获取登录账号
var accountRecord domain.UserLoginAccount
accountQuery := "SELECT id, user_id, account, deleted, create_time, update_time FROM user_login_account WHERE account = $1 AND deleted = false"
if err := r.db.QueryRow(accountQuery, account).Scan(
&accountRecord.ID, &accountRecord.UserID, &accountRecord.Account, &accountRecord.Deleted, &accountRecord.CreateTime, &accountRecord.UpdateTime,
); err != nil {
if err == sql.ErrNoRows {
return nil, nil, nil, errors.ErrNotFound
}
return nil, nil, nil, errors.WrapError(err, "failed to get login account")
}
// 获取用户信息
var user domain.User
userQuery := "SELECT id, deleted, create_time, update_time FROM user_main WHERE id = $1 AND deleted = false"
if err := r.db.QueryRow(userQuery, accountRecord.UserID).Scan(
&user.ID, &user.Deleted, &user.CreateTime, &user.UpdateTime,
); err != nil {
if err == sql.ErrNoRows {
return nil, nil, nil, errors.ErrNotFound
}
return nil, nil, nil, errors.WrapError(err, "failed to get user")
}
// 获取密码记录
var passwordRecord domain.UserLoginPassword
passwordQuery := "SELECT id, user_id, password, deleted, create_time, update_time FROM user_login_password WHERE user_id = $1 AND deleted = false"
if err := r.db.QueryRow(passwordQuery, accountRecord.UserID).Scan(
&passwordRecord.ID, &passwordRecord.UserID, &passwordRecord.Password, &passwordRecord.Deleted, &passwordRecord.CreateTime, &passwordRecord.UpdateTime,
); err != nil {
if err == sql.ErrNoRows {
return nil, nil, nil, errors.ErrNotFound
}
return nil, nil, nil, errors.WrapError(err, "failed to get password record")
}
return &user, &accountRecord, &passwordRecord, nil
}