diff --git a/trading_assistant_api/common/db/options.go b/trading_assistant_api/common/db/options.go index e69de29..b3cd39d 100644 --- a/trading_assistant_api/common/db/options.go +++ b/trading_assistant_api/common/db/options.go @@ -0,0 +1,112 @@ +// common/db/options.go +package db + +import ( + "time" + + "gorm.io/gorm" +) + +// PostgresOptions PostgreSQL连接配置项 +// 包含基础连接信息+连接池配置,按需暴露,无冗余 +type PostgresOptions struct { + Host string // 数据库地址 + Port string // 数据库端口 + User string // 数据库账号 + Password string // 数据库密码 + DBName string // 数据库名 + SSLMode string // SSL模式(开发环境一般disable,生产可enable) + TimeZone string // 时区(如Asia/Shanghai) + // 连接池配置 + MaxOpenConns int // 最大打开连接数 + MaxIdleConns int // 最大空闲连接数 + ConnMaxLifetime time.Duration // 连接最大生命周期 + ConnMaxIdleTime time.Duration // 连接最大空闲时间 + // GORM配置 + LogLevel gorm.LogLevel // GORM日志级别(开发:Info,生产:Error) +} + +// PostgresOption 选项模式函数类型 +type PostgresOption func(*PostgresOptions) + +// 初始化默认配置,避免用户传参不全导致连接失败 +// 开发环境常用默认值,生产环境通过业务层传参覆盖 +func defaultPostgresOptions() *PostgresOptions { + return &PostgresOptions{ + Host: "postgres", // 默认匹配根目录Compose的服务名,开发环境直接用 + Port: "5432", + SSLMode: "disable", + TimeZone: "Asia/Shanghai", + MaxOpenConns: 20, + MaxIdleConns: 10, + ConnMaxLifetime: 30 * time.Minute, + ConnMaxIdleTime: 10 * time.Minute, + LogLevel: gorm.LogLevelInfo, // 开发环境默认打印Info日志 + } +} + +// 以下为配置项的设置函数,用户可通过链式调用配置 +// 示例:db.WithHost("192.168.1.100").WithUser("prod_user") + +// WithHost 设置数据库地址 +func WithHost(host string) PostgresOption { + return func(o *PostgresOptions) { + o.Host = host + } +} + +// WithPort 设置数据库端口 +func WithPort(port string) PostgresOption { + return func(o *PostgresOptions) { + o.Port = port + } +} + +// WithUser 设置数据库账号 +func WithUser(user string) PostgresOption { + return func(o *PostgresOptions) { + o.User = user + } +} + +// WithPassword 设置数据库密码 +func WithPassword(pwd string) PostgresOption { + return func(o *PostgresOptions) { + o.Password = pwd + } +} + +// WithDBName 设置数据库名 +func WithDBName(dbName string) PostgresOption { + return func(o *PostgresOptions) { + o.DBName = dbName + } +} + +// WithSSLMode 设置SSL模式 +func WithSSLMode(mode string) PostgresOption { + return func(o *PostgresOptions) { + o.SSLMode = mode + } +} + +// WithTimeZone 设置时区 +func WithTimeZone(tz string) PostgresOption { + return func(o *PostgresOptions) { + o.TimeZone = tz + } +} + +// WithMaxOpenConns 设置最大打开连接数 +func WithMaxOpenConns(num int) PostgresOption { + return func(o *PostgresOptions) { + o.MaxOpenConns = num + } +} + +// WithLogLevel 设置GORM日志级别 +func WithLogLevel(level gorm.LogLevel) PostgresOption { + return func(o *PostgresOptions) { + o.LogLevel = level + } +} \ No newline at end of file diff --git a/trading_assistant_api/common/db/postgres.go b/trading_assistant_api/common/db/postgres.go index e69de29..b6ca194 100644 --- a/trading_assistant_api/common/db/postgres.go +++ b/trading_assistant_api/common/db/postgres.go @@ -0,0 +1,136 @@ +// common/db/postgres.go +package db + +import ( + "context" + "database/sql" + "fmt" + "sync" + + "gorm.io/driver/postgres" + "gorm.io/gorm" + "gorm.io/gorm/logger" +) + +// 全局单例连接,避免重复创建连接(连接池全局复用) +var ( + pgInstance *gorm.DB + pgOnce sync.Once + pgErr error +) + +// InitPostgres 初始化PostgreSQL连接(单例模式,仅执行一次) +// 入参:选项模式的配置函数,用户可灵活配置 +// 返回:原生*gorm.DB + 错误,业务层可直接使用GORM所有方法 +func InitPostgres(opts ...PostgresOption) (*gorm.DB, error) { + pgOnce.Do(func() { + // 加载默认配置 + 覆盖用户自定义配置 + options := defaultPostgresOptions() + for _, opt := range opts { + opt(options) + } + + // 校验必传配置(账号、密码、数据库名不能为空) + if options.User == "" || options.Password == "" || options.DBName == "" { + pgErr = fmt.Errorf("postgres config error: user/password/dbname can not be empty") + return + } + + // 拼接PostgreSQL DSN(数据源名称),GORM官方规范 + dsn := fmt.Sprintf( + "host=%s port=%s user=%s password=%s dbname=%s sslmode=%s TimeZone=%s", + options.Host, + options.Port, + options.User, + options.Password, + options.DBName, + options.SSLMode, + options.TimeZone, + ) + + // 配置GORM日志(按传入的日志级别,生产环境可关闭) + gormConfig := &gorm.Config{ + Logger: logger.Default.LogMode(options.LogLevel), + } + + // 建立GORM连接 + pgInstance, pgErr = gorm.Open(postgres.Open(dsn), gormConfig) + if pgErr != nil { + pgErr = fmt.Errorf("postgres connect failed: %w", pgErr) + return + } + + // 获取底层*sql.DB,配置连接池(关键:避免连接泄漏,提升性能) + sqlDB, err := pgInstance.DB() + if err != nil { + pgErr = fmt.Errorf("get postgres sql.DB failed: %w", err) + return + } + + // 设置连接池参数 + sqlDB.SetMaxOpenConns(options.MaxOpenConns) + sqlDB.SetMaxIdleConns(options.MaxIdleConns) + sqlDB.SetConnMaxLifetime(options.ConnMaxLifetime) + sqlDB.SetConnMaxIdleTime(options.ConnMaxIdleTime) + + // 测试连接(ping一下,确保连接有效) + if err := sqlDB.PingContext(context.Background()); err != nil { + pgErr = fmt.Errorf("postgres ping failed: %w", err) + pgInstance = nil // 连接失败,置空实例 + return + } + }) + + // 单例执行完成后,返回实例和错误 + return pgInstance, pgErr +} + +// GetPostgres 获取全局PostgreSQL单例连接 +// 业务层初始化后,可通过此方法直接获取连接,无需重复调用InitPostgres +func GetPostgres() (*gorm.DB, error) { + if pgInstance == nil { + return nil, fmt.Errorf("postgres not initialized, please call InitPostgres first") + } + return pgInstance, nil +} + +// ClosePostgres 关闭PostgreSQL连接(应用退出时调用,释放资源) +// 一般在main函数的defer中调用,如:defer db.ClosePostgres() +func ClosePostgres() error { + if pgInstance == nil { + return fmt.Errorf("postgres not initialized") + } + + // 获取底层*sql.DB,执行关闭 + sqlDB, err := pgInstance.DB() + if err != nil { + return fmt.Errorf("get postgres sql.DB failed: %w", err) + } + + if err := sqlDB.Close(); err != nil { + return fmt.Errorf("close postgres failed: %w", err) + } + + // 关闭后重置单例,避免重复关闭 + pgOnce = sync.Once{} + pgInstance = nil + pgErr = nil + + return nil +} + +// PingPostgres 测试PostgreSQL连接是否有效 +// 业务层可定时调用(如健康检查),确保连接未断开 +func PingPostgres() error { + db, err := GetPostgres() + if err != nil { + return err + } + + sqlDB, err := db.DB() + if err != nil { + return fmt.Errorf("get postgres sql.DB failed: %w", err) + } + + return sqlDB.PingContext(context.Background()) +} \ No newline at end of file