feat: 实现用户注册功能,包括数据库表结构、gRPC 服务和业务逻辑

This commit is contained in:
fish
2026-03-28 20:11:54 +08:00
parent e4bb25d1ac
commit 4ff974439f
9 changed files with 529 additions and 0 deletions

View File

@@ -0,0 +1,49 @@
package main
import (
"fmt"
"log"
"backend/services/user-svc/internal/config"
"backend/services/user-svc/internal/grpcserver"
"backend/services/user-svc/internal/repository"
"backend/services/user-svc/internal/service"
"backend/shared/pkg/database"
"backend/shared/pkg/logger"
)
func main() {
// 加载配置
cfg, err := config.Load()
if err != nil {
log.Fatalf("Failed to load config: %v", err)
}
// 连接数据库
db, err := database.NewPostgresDB(database.PostgresConfig{
Host: cfg.Database.Host,
Port: cfg.Database.Port,
User: cfg.Database.User,
Password: cfg.Database.Password,
DBName: cfg.Database.DBName,
SSLMode: cfg.Database.SSLMode,
})
if err != nil {
log.Fatalf("Failed to connect to database: %v", err)
}
defer db.Close()
// 初始化仓库
repo := repository.NewUserRepository(db.DB)
// 初始化服务
userService := service.NewUserService(repo)
// 启动 gRPC 服务器
serverAddr := fmt.Sprintf(":%d", cfg.Server.Port)
logger.Info("User service starting on %s", serverAddr)
if err := grpcserver.Start(userService, cfg.Server.Port); err != nil {
log.Fatalf("Failed to start server: %v", err)
}
}

View File

@@ -0,0 +1,11 @@
module backend/services/user-svc
go 1.26.1
require (
github.com/google/uuid v1.6.0
github.com/spf13/viper v1.19.0
golang.org/x/crypto v0.20.0
google.golang.org/grpc v1.64.0
google.golang.org/protobuf v1.33.0
)

View File

@@ -0,0 +1,52 @@
package config
import (
"github.com/spf13/viper"
)
type Config struct {
Server ServerConfig
Database DatabaseConfig
}
type ServerConfig struct {
Port int
}
type DatabaseConfig struct {
Host string
Port int
User string
Password string
DBName string
SSLMode string
}
func Load() (*Config, error) {
viper.SetConfigName("config")
viper.SetConfigType("yaml")
viper.AddConfigPath("./config")
viper.AddConfigPath("../config")
viper.AddConfigPath("../../config")
viper.SetDefault("server.port", 9000)
viper.SetDefault("database.host", "postgres")
viper.SetDefault("database.port", 5432)
viper.SetDefault("database.user", "admin")
viper.SetDefault("database.password", "password")
viper.SetDefault("database.dbname", "backend")
viper.SetDefault("database.sslmode", "disable")
if err := viper.ReadInConfig(); err != nil {
if _, ok := err.(viper.ConfigFileNotFoundError); !ok {
return nil, err
}
}
var config Config
if err := viper.Unmarshal(&config); err != nil {
return nil, err
}
return &config, nil
}

View File

@@ -0,0 +1,42 @@
package domain
import (
"time"
"github.com/google/uuid"
)
type User struct {
ID uuid.UUID `json:"id"`
Deleted bool `json:"deleted"`
CreateTime time.Time `json:"create_time"`
UpdateTime time.Time `json:"update_time"`
}
type UserLoginAccount struct {
ID uuid.UUID `json:"id"`
UserID uuid.UUID `json:"user_id"`
Account string `json:"account"`
Deleted bool `json:"deleted"`
CreateTime time.Time `json:"create_time"`
UpdateTime time.Time `json:"update_time"`
}
type UserLoginPassword struct {
ID uuid.UUID `json:"id"`
UserID uuid.UUID `json:"user_id"`
Password string `json:"password"`
Deleted bool `json:"deleted"`
CreateTime time.Time `json:"create_time"`
UpdateTime time.Time `json:"update_time"`
}
type RegisterRequest struct {
Account string `json:"account" binding:"required"`
Password string `json:"password" binding:"required,min=6"`
}
type RegisterResponse struct {
UserID uuid.UUID `json:"user_id"`
Account string `json:"account"`
}

View File

@@ -0,0 +1,139 @@
package grpcserver
import (
"context"
"fmt"
"net"
"backend/services/user-svc/internal/domain"
"backend/services/user-svc/internal/service"
"backend/shared/pkg/errors"
"backend/shared/pkg/logger"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
// 导入生成的 proto 代码
userpb "backend/services/user-svc/proto"
)
type UserServer struct {
userpb.UnimplementedUserServiceServer
service *service.UserService
}
func NewUserServer(service *service.UserService) *UserServer {
return &UserServer{service: service}
}
// Register 实现用户注册接口
func (s *UserServer) Register(ctx context.Context, req *userpb.RegisterRequest) (*userpb.RegisterResponse, error) {
// 转换请求参数
domainReq := &domain.RegisterRequest{
Account: req.GetAccount(),
Password: req.GetPassword(),
}
// 调用服务层注册用户
resp, err := s.service.Register(domainReq)
if err != nil {
logger.Error("Register failed: %v", err)
// 转换错误类型
switch {
case errors.IsConflict(err):
return &userpb.RegisterResponse{
Response: &userpb.Response{
Code: 409,
Message: "账号已存在",
},
}, status.Errorf(codes.AlreadyExists, "账号已存在")
case errors.IsInvalidInput(err):
return &userpb.RegisterResponse{
Response: &userpb.Response{
Code: 400,
Message: "无效的输入参数",
},
}, status.Errorf(codes.InvalidArgument, "无效的输入参数")
default:
return &userpb.RegisterResponse{
Response: &userpb.Response{
Code: 500,
Message: "内部服务器错误",
},
}, status.Errorf(codes.Internal, "内部服务器错误")
}
}
// 构造响应
return &userpb.RegisterResponse{
UserId: resp.UserID.String(),
Account: resp.Account,
Response: &userpb.Response{
Code: 200,
Message: "注册成功",
},
}, nil
}
// GetUserByAccount 实现根据账号获取用户信息接口
func (s *UserServer) GetUserByAccount(ctx context.Context, req *userpb.GetUserByAccountRequest) (*userpb.GetUserByAccountResponse, error) {
// 调用服务层获取用户信息
user, account, _, err := s.service.GetUserByAccount(req.GetAccount())
if err != nil {
logger.Error("GetUserByAccount failed: %v", err)
// 转换错误类型
switch {
case errors.IsNotFound(err):
return &userpb.GetUserByAccountResponse{
Response: &userpb.Response{
Code: 404,
Message: "用户不存在",
},
}, status.Errorf(codes.NotFound, "用户不存在")
case errors.IsInvalidInput(err):
return &userpb.GetUserByAccountResponse{
Response: &userpb.Response{
Code: 400,
Message: "无效的输入参数",
},
}, status.Errorf(codes.InvalidArgument, "无效的输入参数")
default:
return &userpb.GetUserByAccountResponse{
Response: &userpb.Response{
Code: 500,
Message: "内部服务器错误",
},
}, status.Errorf(codes.Internal, "内部服务器错误")
}
}
// 构造响应
return &userpb.GetUserByAccountResponse{
UserId: user.ID.String(),
Account: account.Account,
Response: &userpb.Response{
Code: 200,
Message: "获取成功",
},
}, nil
}
// Start 启动 gRPC 服务器
func Start(service *service.UserService, port int) error {
server := grpc.NewServer()
userServer := NewUserServer(service)
userpb.RegisterUserServiceServer(server, userServer)
addr := fmt.Sprintf(":%d", port)
listener, err := net.Listen("tcp", addr)
if err != nil {
return err
}
logger.Info("User service gRPC server starting on %s", addr)
return server.Serve(listener)
}

View File

@@ -0,0 +1,126 @@
package repository
import (
"database/sql"
"time"
"backend/services/user-svc/internal/domain"
"backend/shared/pkg/errors"
"github.com/google/uuid"
"golang.org/x/crypto/bcrypt"
)
type UserRepository struct {
db *sql.DB
}
func NewUserRepository(db *sql.DB) *UserRepository {
return &UserRepository{db: db}
}
// Register 用户注册
func (r *UserRepository) Register(req *domain.RegisterRequest) (*domain.RegisterResponse, error) {
// 开启事务
tx, err := r.db.Begin()
if err != nil {
return nil, errors.WrapError(err, "failed to begin transaction")
}
// 检查账号是否已存在
var exists bool
query := "SELECT EXISTS(SELECT 1 FROM user_login_account WHERE account = $1 AND deleted = false)"
if err := tx.QueryRow(query, req.Account).Scan(&exists); err != nil {
tx.Rollback()
return nil, errors.WrapError(err, "failed to check account existence")
}
if exists {
tx.Rollback()
return nil, errors.ErrConflict
}
// 创建用户
userID := uuid.New()
now := time.Now()
userQuery := "INSERT INTO user_main (id, deleted, create_time, update_time) VALUES ($1, $2, $3, $4)"
if _, err := tx.Exec(userQuery, userID, false, now, now); err != nil {
tx.Rollback()
return nil, errors.WrapError(err, "failed to create user")
}
// 创建登录账号
accountID := uuid.New()
accountQuery := "INSERT INTO user_login_account (id, user_id, account, deleted, create_time, update_time) VALUES ($1, $2, $3, $4, $5, $6)"
if _, err := tx.Exec(accountQuery, accountID, userID, req.Account, false, now, now); err != nil {
tx.Rollback()
return nil, errors.WrapError(err, "failed to create login account")
}
// 加密密码
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(req.Password), bcrypt.DefaultCost)
if err != nil {
tx.Rollback()
return nil, errors.WrapError(err, "failed to hash password")
}
// 创建密码记录
passwordID := uuid.New()
passwordQuery := "INSERT INTO user_login_password (id, user_id, password, deleted, create_time, update_time) VALUES ($1, $2, $3, $4, $5, $6)"
if _, err := tx.Exec(passwordQuery, passwordID, userID, string(hashedPassword), false, now, now); err != nil {
tx.Rollback()
return nil, errors.WrapError(err, "failed to create password record")
}
// 提交事务
if err := tx.Commit(); err != nil {
return nil, errors.WrapError(err, "failed to commit transaction")
}
return &domain.RegisterResponse{
UserID: userID,
Account: req.Account,
}, nil
}
// GetUserByAccount 根据账号获取用户信息
func (r *UserRepository) GetUserByAccount(account string) (*domain.User, *domain.UserLoginAccount, *domain.UserLoginPassword, error) {
// 获取登录账号
var accountRecord domain.UserLoginAccount
accountQuery := "SELECT id, user_id, account, deleted, create_time, update_time FROM user_login_account WHERE account = $1 AND deleted = false"
if err := r.db.QueryRow(accountQuery, account).Scan(
&accountRecord.ID, &accountRecord.UserID, &accountRecord.Account, &accountRecord.Deleted, &accountRecord.CreateTime, &accountRecord.UpdateTime,
); err != nil {
if err == sql.ErrNoRows {
return nil, nil, nil, errors.ErrNotFound
}
return nil, nil, nil, errors.WrapError(err, "failed to get login account")
}
// 获取用户信息
var user domain.User
userQuery := "SELECT id, deleted, create_time, update_time FROM user_main WHERE id = $1 AND deleted = false"
if err := r.db.QueryRow(userQuery, accountRecord.UserID).Scan(
&user.ID, &user.Deleted, &user.CreateTime, &user.UpdateTime,
); err != nil {
if err == sql.ErrNoRows {
return nil, nil, nil, errors.ErrNotFound
}
return nil, nil, nil, errors.WrapError(err, "failed to get user")
}
// 获取密码记录
var passwordRecord domain.UserLoginPassword
passwordQuery := "SELECT id, user_id, password, deleted, create_time, update_time FROM user_login_password WHERE user_id = $1 AND deleted = false"
if err := r.db.QueryRow(passwordQuery, accountRecord.UserID).Scan(
&passwordRecord.ID, &passwordRecord.UserID, &passwordRecord.Password, &passwordRecord.Deleted, &passwordRecord.CreateTime, &passwordRecord.UpdateTime,
); err != nil {
if err == sql.ErrNoRows {
return nil, nil, nil, errors.ErrNotFound
}
return nil, nil, nil, errors.WrapError(err, "failed to get password record")
}
return &user, &accountRecord, &passwordRecord, nil
}

View File

@@ -0,0 +1,39 @@
package service
import (
"backend/services/user-svc/internal/domain"
"backend/services/user-svc/internal/repository"
"backend/shared/pkg/errors"
)
type UserService struct {
repo *repository.UserRepository
}
func NewUserService(repo *repository.UserRepository) *UserService {
return &UserService{repo: repo}
}
// Register 用户注册
func (s *UserService) Register(req *domain.RegisterRequest) (*domain.RegisterResponse, error) {
// 验证请求参数
if req.Account == "" {
return nil, errors.ErrInvalidInput
}
if len(req.Password) < 6 {
return nil, errors.ErrInvalidInput
}
// 调用仓库层注册用户
return s.repo.Register(req)
}
// GetUserByAccount 根据账号获取用户信息
func (s *UserService) GetUserByAccount(account string) (*domain.User, *domain.UserLoginAccount, *domain.UserLoginPassword, error) {
if account == "" {
return nil, nil, nil, errors.ErrInvalidInput
}
return s.repo.GetUserByAccount(account)
}

View File

@@ -0,0 +1,33 @@
-- 创建 user_main 表
CREATE TABLE IF NOT EXISTS user_main (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
deleted BOOLEAN DEFAULT false,
create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
update_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
-- 创建 user_login_account 表
CREATE TABLE IF NOT EXISTS user_login_account (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id UUID NOT NULL REFERENCES user_main(id),
account VARCHAR(255) NOT NULL,
deleted BOOLEAN DEFAULT false,
create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
update_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
UNIQUE(account)
);
-- 创建 user_login_password 表
CREATE TABLE IF NOT EXISTS user_login_password (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id UUID NOT NULL REFERENCES user_main(id),
password VARCHAR(255) NOT NULL,
deleted BOOLEAN DEFAULT false,
create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
update_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
-- 创建索引
CREATE INDEX IF NOT EXISTS idx_user_login_account_user_id ON user_login_account(user_id);
CREATE INDEX IF NOT EXISTS idx_user_login_account_account ON user_login_account(account);
CREATE INDEX IF NOT EXISTS idx_user_login_password_user_id ON user_login_password(user_id);

View File

@@ -0,0 +1,38 @@
syntax = "proto3";
package user;
import "shared/proto/common/common.proto";
// 用户服务
service UserService {
// 注册用户
rpc Register(RegisterRequest) returns (RegisterResponse);
// 获取用户信息
rpc GetUserByAccount(GetUserByAccountRequest) returns (GetUserByAccountResponse);
}
// 注册请求
message RegisterRequest {
string account = 1;
string password = 2;
}
// 注册响应
message RegisterResponse {
string user_id = 1;
string account = 2;
common.Response response = 3;
}
// 获取用户信息请求
message GetUserByAccountRequest {
string account = 1;
}
// 获取用户信息响应
message GetUserByAccountResponse {
string user_id = 1;
string account = 2;
common.Response response = 3;
}