feat: 实现用户注册功能,包括数据库表结构、gRPC 服务和业务逻辑
This commit is contained in:
49
backend/services/user-svc/cmd/main.go
Normal file
49
backend/services/user-svc/cmd/main.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"backend/services/user-svc/internal/config"
|
||||
"backend/services/user-svc/internal/grpcserver"
|
||||
"backend/services/user-svc/internal/repository"
|
||||
"backend/services/user-svc/internal/service"
|
||||
"backend/shared/pkg/database"
|
||||
"backend/shared/pkg/logger"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// 加载配置
|
||||
cfg, err := config.Load()
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to load config: %v", err)
|
||||
}
|
||||
|
||||
// 连接数据库
|
||||
db, err := database.NewPostgresDB(database.PostgresConfig{
|
||||
Host: cfg.Database.Host,
|
||||
Port: cfg.Database.Port,
|
||||
User: cfg.Database.User,
|
||||
Password: cfg.Database.Password,
|
||||
DBName: cfg.Database.DBName,
|
||||
SSLMode: cfg.Database.SSLMode,
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to connect to database: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// 初始化仓库
|
||||
repo := repository.NewUserRepository(db.DB)
|
||||
|
||||
// 初始化服务
|
||||
userService := service.NewUserService(repo)
|
||||
|
||||
// 启动 gRPC 服务器
|
||||
serverAddr := fmt.Sprintf(":%d", cfg.Server.Port)
|
||||
logger.Info("User service starting on %s", serverAddr)
|
||||
|
||||
if err := grpcserver.Start(userService, cfg.Server.Port); err != nil {
|
||||
log.Fatalf("Failed to start server: %v", err)
|
||||
}
|
||||
}
|
||||
11
backend/services/user-svc/go.mod
Normal file
11
backend/services/user-svc/go.mod
Normal file
@@ -0,0 +1,11 @@
|
||||
module backend/services/user-svc
|
||||
|
||||
go 1.26.1
|
||||
|
||||
require (
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/spf13/viper v1.19.0
|
||||
golang.org/x/crypto v0.20.0
|
||||
google.golang.org/grpc v1.64.0
|
||||
google.golang.org/protobuf v1.33.0
|
||||
)
|
||||
52
backend/services/user-svc/internal/config/config.go
Normal file
52
backend/services/user-svc/internal/config/config.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Server ServerConfig
|
||||
Database DatabaseConfig
|
||||
}
|
||||
|
||||
type ServerConfig struct {
|
||||
Port int
|
||||
}
|
||||
|
||||
type DatabaseConfig struct {
|
||||
Host string
|
||||
Port int
|
||||
User string
|
||||
Password string
|
||||
DBName string
|
||||
SSLMode string
|
||||
}
|
||||
|
||||
func Load() (*Config, error) {
|
||||
viper.SetConfigName("config")
|
||||
viper.SetConfigType("yaml")
|
||||
viper.AddConfigPath("./config")
|
||||
viper.AddConfigPath("../config")
|
||||
viper.AddConfigPath("../../config")
|
||||
|
||||
viper.SetDefault("server.port", 9000)
|
||||
viper.SetDefault("database.host", "postgres")
|
||||
viper.SetDefault("database.port", 5432)
|
||||
viper.SetDefault("database.user", "admin")
|
||||
viper.SetDefault("database.password", "password")
|
||||
viper.SetDefault("database.dbname", "backend")
|
||||
viper.SetDefault("database.sslmode", "disable")
|
||||
|
||||
if err := viper.ReadInConfig(); err != nil {
|
||||
if _, ok := err.(viper.ConfigFileNotFoundError); !ok {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
var config Config
|
||||
if err := viper.Unmarshal(&config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &config, nil
|
||||
}
|
||||
42
backend/services/user-svc/internal/domain/user.go
Normal file
42
backend/services/user-svc/internal/domain/user.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type User struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
Deleted bool `json:"deleted"`
|
||||
CreateTime time.Time `json:"create_time"`
|
||||
UpdateTime time.Time `json:"update_time"`
|
||||
}
|
||||
|
||||
type UserLoginAccount struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
UserID uuid.UUID `json:"user_id"`
|
||||
Account string `json:"account"`
|
||||
Deleted bool `json:"deleted"`
|
||||
CreateTime time.Time `json:"create_time"`
|
||||
UpdateTime time.Time `json:"update_time"`
|
||||
}
|
||||
|
||||
type UserLoginPassword struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
UserID uuid.UUID `json:"user_id"`
|
||||
Password string `json:"password"`
|
||||
Deleted bool `json:"deleted"`
|
||||
CreateTime time.Time `json:"create_time"`
|
||||
UpdateTime time.Time `json:"update_time"`
|
||||
}
|
||||
|
||||
type RegisterRequest struct {
|
||||
Account string `json:"account" binding:"required"`
|
||||
Password string `json:"password" binding:"required,min=6"`
|
||||
}
|
||||
|
||||
type RegisterResponse struct {
|
||||
UserID uuid.UUID `json:"user_id"`
|
||||
Account string `json:"account"`
|
||||
}
|
||||
139
backend/services/user-svc/internal/grpcserver/server.go
Normal file
139
backend/services/user-svc/internal/grpcserver/server.go
Normal file
@@ -0,0 +1,139 @@
|
||||
package grpcserver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"backend/services/user-svc/internal/domain"
|
||||
"backend/services/user-svc/internal/service"
|
||||
"backend/shared/pkg/errors"
|
||||
"backend/shared/pkg/logger"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
// 导入生成的 proto 代码
|
||||
userpb "backend/services/user-svc/proto"
|
||||
)
|
||||
|
||||
type UserServer struct {
|
||||
userpb.UnimplementedUserServiceServer
|
||||
service *service.UserService
|
||||
}
|
||||
|
||||
func NewUserServer(service *service.UserService) *UserServer {
|
||||
return &UserServer{service: service}
|
||||
}
|
||||
|
||||
// Register 实现用户注册接口
|
||||
func (s *UserServer) Register(ctx context.Context, req *userpb.RegisterRequest) (*userpb.RegisterResponse, error) {
|
||||
// 转换请求参数
|
||||
domainReq := &domain.RegisterRequest{
|
||||
Account: req.GetAccount(),
|
||||
Password: req.GetPassword(),
|
||||
}
|
||||
|
||||
// 调用服务层注册用户
|
||||
resp, err := s.service.Register(domainReq)
|
||||
if err != nil {
|
||||
logger.Error("Register failed: %v", err)
|
||||
|
||||
// 转换错误类型
|
||||
switch {
|
||||
case errors.IsConflict(err):
|
||||
return &userpb.RegisterResponse{
|
||||
Response: &userpb.Response{
|
||||
Code: 409,
|
||||
Message: "账号已存在",
|
||||
},
|
||||
}, status.Errorf(codes.AlreadyExists, "账号已存在")
|
||||
case errors.IsInvalidInput(err):
|
||||
return &userpb.RegisterResponse{
|
||||
Response: &userpb.Response{
|
||||
Code: 400,
|
||||
Message: "无效的输入参数",
|
||||
},
|
||||
}, status.Errorf(codes.InvalidArgument, "无效的输入参数")
|
||||
default:
|
||||
return &userpb.RegisterResponse{
|
||||
Response: &userpb.Response{
|
||||
Code: 500,
|
||||
Message: "内部服务器错误",
|
||||
},
|
||||
}, status.Errorf(codes.Internal, "内部服务器错误")
|
||||
}
|
||||
}
|
||||
|
||||
// 构造响应
|
||||
return &userpb.RegisterResponse{
|
||||
UserId: resp.UserID.String(),
|
||||
Account: resp.Account,
|
||||
Response: &userpb.Response{
|
||||
Code: 200,
|
||||
Message: "注册成功",
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetUserByAccount 实现根据账号获取用户信息接口
|
||||
func (s *UserServer) GetUserByAccount(ctx context.Context, req *userpb.GetUserByAccountRequest) (*userpb.GetUserByAccountResponse, error) {
|
||||
// 调用服务层获取用户信息
|
||||
user, account, _, err := s.service.GetUserByAccount(req.GetAccount())
|
||||
if err != nil {
|
||||
logger.Error("GetUserByAccount failed: %v", err)
|
||||
|
||||
// 转换错误类型
|
||||
switch {
|
||||
case errors.IsNotFound(err):
|
||||
return &userpb.GetUserByAccountResponse{
|
||||
Response: &userpb.Response{
|
||||
Code: 404,
|
||||
Message: "用户不存在",
|
||||
},
|
||||
}, status.Errorf(codes.NotFound, "用户不存在")
|
||||
case errors.IsInvalidInput(err):
|
||||
return &userpb.GetUserByAccountResponse{
|
||||
Response: &userpb.Response{
|
||||
Code: 400,
|
||||
Message: "无效的输入参数",
|
||||
},
|
||||
}, status.Errorf(codes.InvalidArgument, "无效的输入参数")
|
||||
default:
|
||||
return &userpb.GetUserByAccountResponse{
|
||||
Response: &userpb.Response{
|
||||
Code: 500,
|
||||
Message: "内部服务器错误",
|
||||
},
|
||||
}, status.Errorf(codes.Internal, "内部服务器错误")
|
||||
}
|
||||
}
|
||||
|
||||
// 构造响应
|
||||
return &userpb.GetUserByAccountResponse{
|
||||
UserId: user.ID.String(),
|
||||
Account: account.Account,
|
||||
Response: &userpb.Response{
|
||||
Code: 200,
|
||||
Message: "获取成功",
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Start 启动 gRPC 服务器
|
||||
func Start(service *service.UserService, port int) error {
|
||||
server := grpc.NewServer()
|
||||
userServer := NewUserServer(service)
|
||||
|
||||
userpb.RegisterUserServiceServer(server, userServer)
|
||||
|
||||
addr := fmt.Sprintf(":%d", port)
|
||||
listener, err := net.Listen("tcp", addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
logger.Info("User service gRPC server starting on %s", addr)
|
||||
return server.Serve(listener)
|
||||
}
|
||||
126
backend/services/user-svc/internal/repository/repo.go
Normal file
126
backend/services/user-svc/internal/repository/repo.go
Normal file
@@ -0,0 +1,126 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"backend/services/user-svc/internal/domain"
|
||||
"backend/shared/pkg/errors"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
type UserRepository struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
func NewUserRepository(db *sql.DB) *UserRepository {
|
||||
return &UserRepository{db: db}
|
||||
}
|
||||
|
||||
// Register 用户注册
|
||||
func (r *UserRepository) Register(req *domain.RegisterRequest) (*domain.RegisterResponse, error) {
|
||||
// 开启事务
|
||||
tx, err := r.db.Begin()
|
||||
if err != nil {
|
||||
return nil, errors.WrapError(err, "failed to begin transaction")
|
||||
}
|
||||
|
||||
// 检查账号是否已存在
|
||||
var exists bool
|
||||
query := "SELECT EXISTS(SELECT 1 FROM user_login_account WHERE account = $1 AND deleted = false)"
|
||||
if err := tx.QueryRow(query, req.Account).Scan(&exists); err != nil {
|
||||
tx.Rollback()
|
||||
return nil, errors.WrapError(err, "failed to check account existence")
|
||||
}
|
||||
|
||||
if exists {
|
||||
tx.Rollback()
|
||||
return nil, errors.ErrConflict
|
||||
}
|
||||
|
||||
// 创建用户
|
||||
userID := uuid.New()
|
||||
now := time.Now()
|
||||
|
||||
userQuery := "INSERT INTO user_main (id, deleted, create_time, update_time) VALUES ($1, $2, $3, $4)"
|
||||
if _, err := tx.Exec(userQuery, userID, false, now, now); err != nil {
|
||||
tx.Rollback()
|
||||
return nil, errors.WrapError(err, "failed to create user")
|
||||
}
|
||||
|
||||
// 创建登录账号
|
||||
accountID := uuid.New()
|
||||
accountQuery := "INSERT INTO user_login_account (id, user_id, account, deleted, create_time, update_time) VALUES ($1, $2, $3, $4, $5, $6)"
|
||||
if _, err := tx.Exec(accountQuery, accountID, userID, req.Account, false, now, now); err != nil {
|
||||
tx.Rollback()
|
||||
return nil, errors.WrapError(err, "failed to create login account")
|
||||
}
|
||||
|
||||
// 加密密码
|
||||
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(req.Password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return nil, errors.WrapError(err, "failed to hash password")
|
||||
}
|
||||
|
||||
// 创建密码记录
|
||||
passwordID := uuid.New()
|
||||
passwordQuery := "INSERT INTO user_login_password (id, user_id, password, deleted, create_time, update_time) VALUES ($1, $2, $3, $4, $5, $6)"
|
||||
if _, err := tx.Exec(passwordQuery, passwordID, userID, string(hashedPassword), false, now, now); err != nil {
|
||||
tx.Rollback()
|
||||
return nil, errors.WrapError(err, "failed to create password record")
|
||||
}
|
||||
|
||||
// 提交事务
|
||||
if err := tx.Commit(); err != nil {
|
||||
return nil, errors.WrapError(err, "failed to commit transaction")
|
||||
}
|
||||
|
||||
return &domain.RegisterResponse{
|
||||
UserID: userID,
|
||||
Account: req.Account,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetUserByAccount 根据账号获取用户信息
|
||||
func (r *UserRepository) GetUserByAccount(account string) (*domain.User, *domain.UserLoginAccount, *domain.UserLoginPassword, error) {
|
||||
// 获取登录账号
|
||||
var accountRecord domain.UserLoginAccount
|
||||
accountQuery := "SELECT id, user_id, account, deleted, create_time, update_time FROM user_login_account WHERE account = $1 AND deleted = false"
|
||||
if err := r.db.QueryRow(accountQuery, account).Scan(
|
||||
&accountRecord.ID, &accountRecord.UserID, &accountRecord.Account, &accountRecord.Deleted, &accountRecord.CreateTime, &accountRecord.UpdateTime,
|
||||
); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil, nil, errors.ErrNotFound
|
||||
}
|
||||
return nil, nil, nil, errors.WrapError(err, "failed to get login account")
|
||||
}
|
||||
|
||||
// 获取用户信息
|
||||
var user domain.User
|
||||
userQuery := "SELECT id, deleted, create_time, update_time FROM user_main WHERE id = $1 AND deleted = false"
|
||||
if err := r.db.QueryRow(userQuery, accountRecord.UserID).Scan(
|
||||
&user.ID, &user.Deleted, &user.CreateTime, &user.UpdateTime,
|
||||
); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil, nil, errors.ErrNotFound
|
||||
}
|
||||
return nil, nil, nil, errors.WrapError(err, "failed to get user")
|
||||
}
|
||||
|
||||
// 获取密码记录
|
||||
var passwordRecord domain.UserLoginPassword
|
||||
passwordQuery := "SELECT id, user_id, password, deleted, create_time, update_time FROM user_login_password WHERE user_id = $1 AND deleted = false"
|
||||
if err := r.db.QueryRow(passwordQuery, accountRecord.UserID).Scan(
|
||||
&passwordRecord.ID, &passwordRecord.UserID, &passwordRecord.Password, &passwordRecord.Deleted, &passwordRecord.CreateTime, &passwordRecord.UpdateTime,
|
||||
); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil, nil, errors.ErrNotFound
|
||||
}
|
||||
return nil, nil, nil, errors.WrapError(err, "failed to get password record")
|
||||
}
|
||||
|
||||
return &user, &accountRecord, &passwordRecord, nil
|
||||
}
|
||||
39
backend/services/user-svc/internal/service/service.go
Normal file
39
backend/services/user-svc/internal/service/service.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"backend/services/user-svc/internal/domain"
|
||||
"backend/services/user-svc/internal/repository"
|
||||
"backend/shared/pkg/errors"
|
||||
)
|
||||
|
||||
type UserService struct {
|
||||
repo *repository.UserRepository
|
||||
}
|
||||
|
||||
func NewUserService(repo *repository.UserRepository) *UserService {
|
||||
return &UserService{repo: repo}
|
||||
}
|
||||
|
||||
// Register 用户注册
|
||||
func (s *UserService) Register(req *domain.RegisterRequest) (*domain.RegisterResponse, error) {
|
||||
// 验证请求参数
|
||||
if req.Account == "" {
|
||||
return nil, errors.ErrInvalidInput
|
||||
}
|
||||
|
||||
if len(req.Password) < 6 {
|
||||
return nil, errors.ErrInvalidInput
|
||||
}
|
||||
|
||||
// 调用仓库层注册用户
|
||||
return s.repo.Register(req)
|
||||
}
|
||||
|
||||
// GetUserByAccount 根据账号获取用户信息
|
||||
func (s *UserService) GetUserByAccount(account string) (*domain.User, *domain.UserLoginAccount, *domain.UserLoginPassword, error) {
|
||||
if account == "" {
|
||||
return nil, nil, nil, errors.ErrInvalidInput
|
||||
}
|
||||
|
||||
return s.repo.GetUserByAccount(account)
|
||||
}
|
||||
33
backend/services/user-svc/migrations/001_init.sql
Normal file
33
backend/services/user-svc/migrations/001_init.sql
Normal file
@@ -0,0 +1,33 @@
|
||||
-- 创建 user_main 表
|
||||
CREATE TABLE IF NOT EXISTS user_main (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
deleted BOOLEAN DEFAULT false,
|
||||
create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
update_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
-- 创建 user_login_account 表
|
||||
CREATE TABLE IF NOT EXISTS user_login_account (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id UUID NOT NULL REFERENCES user_main(id),
|
||||
account VARCHAR(255) NOT NULL,
|
||||
deleted BOOLEAN DEFAULT false,
|
||||
create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
update_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
UNIQUE(account)
|
||||
);
|
||||
|
||||
-- 创建 user_login_password 表
|
||||
CREATE TABLE IF NOT EXISTS user_login_password (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id UUID NOT NULL REFERENCES user_main(id),
|
||||
password VARCHAR(255) NOT NULL,
|
||||
deleted BOOLEAN DEFAULT false,
|
||||
create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
update_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
-- 创建索引
|
||||
CREATE INDEX IF NOT EXISTS idx_user_login_account_user_id ON user_login_account(user_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_user_login_account_account ON user_login_account(account);
|
||||
CREATE INDEX IF NOT EXISTS idx_user_login_password_user_id ON user_login_password(user_id);
|
||||
38
backend/services/user-svc/proto/user.proto
Normal file
38
backend/services/user-svc/proto/user.proto
Normal file
@@ -0,0 +1,38 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package user;
|
||||
|
||||
import "shared/proto/common/common.proto";
|
||||
|
||||
// 用户服务
|
||||
service UserService {
|
||||
// 注册用户
|
||||
rpc Register(RegisterRequest) returns (RegisterResponse);
|
||||
// 获取用户信息
|
||||
rpc GetUserByAccount(GetUserByAccountRequest) returns (GetUserByAccountResponse);
|
||||
}
|
||||
|
||||
// 注册请求
|
||||
message RegisterRequest {
|
||||
string account = 1;
|
||||
string password = 2;
|
||||
}
|
||||
|
||||
// 注册响应
|
||||
message RegisterResponse {
|
||||
string user_id = 1;
|
||||
string account = 2;
|
||||
common.Response response = 3;
|
||||
}
|
||||
|
||||
// 获取用户信息请求
|
||||
message GetUserByAccountRequest {
|
||||
string account = 1;
|
||||
}
|
||||
|
||||
// 获取用户信息响应
|
||||
message GetUserByAccountResponse {
|
||||
string user_id = 1;
|
||||
string account = 2;
|
||||
common.Response response = 3;
|
||||
}
|
||||
Reference in New Issue
Block a user