package db import ( "context" "database/sql" "fmt" "sync" "time" "common/utils" _ "github.com/jackc/pgx/v5/stdlib" ) type Options struct { Host string Port string User string Password string DBName string SSLMode string MaxOpenConns int MaxIdleConns int ConnMaxLifetime time.Duration ConnMaxIdleTime time.Duration } type Option func(*Options) func defaultOptions() *Options { return &Options{ Host: "localhost", Port: "5432", SSLMode: "disable", MaxOpenConns: 20, MaxIdleConns: 10, ConnMaxLifetime: 30 * time.Minute, ConnMaxIdleTime: 10 * time.Minute, } } func FromEnv() *Options { return &Options{ Host: utils.GetEnv("PG_HOST", ""), Port: utils.GetEnv("PG_PORT", ""), User: utils.GetEnv("PG_USER", ""), Password: utils.GetEnv("PG_PASSWORD", ""), DBName: utils.GetEnv("PG_DBNAME", ""), SSLMode: utils.GetEnv("PG_SSLMODE", ""), } } func WithHost(v string) Option { return func(o *Options) { o.Host = v } } func WithPort(v string) Option { return func(o *Options) { o.Port = v } } func WithUser(v string) Option { return func(o *Options) { o.User = v } } func WithPassword(v string) Option { return func(o *Options) { o.Password = v } } func WithDBName(v string) Option { return func(o *Options) { o.DBName = v } } func WithSSLMode(v string) Option { return func(o *Options) { o.SSLMode = v } } func WithMaxOpenConns(v int) Option { return func(o *Options) { o.MaxOpenConns = v } } func WithMaxIdleConns(v int) Option { return func(o *Options) { o.MaxIdleConns = v } } func WithConnMaxLifetime(v time.Duration) Option { return func(o *Options) { o.ConnMaxLifetime = v } } func WithConnMaxIdleTime(v time.Duration) Option { return func(o *Options) { o.ConnMaxIdleTime = v } } var ( pg *sql.DB pgOnce sync.Once pgErr error ) func InitPostgres(opts ...Option) (*sql.DB, error) { pgOnce.Do(func() { o := defaultOptions() env := FromEnv() merge := func(dst *Options, src *Options) { if src.Host != "" { dst.Host = src.Host } if src.Port != "" { dst.Port = src.Port } if src.User != "" { dst.User = src.User } if src.Password != "" { dst.Password = src.Password } if src.DBName != "" { dst.DBName = src.DBName } if src.SSLMode != "" { dst.SSLMode = src.SSLMode } } merge(o, env) for _, f := range opts { f(o) } if o.User == "" || o.Password == "" || o.DBName == "" { pgErr = fmt.Errorf("postgres config missing") return } dsn := fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=%s", o.Host, o.Port, o.User, o.Password, o.DBName, o.SSLMode) db, err := sql.Open("pgx", dsn) if err != nil { pgErr = err return } db.SetMaxOpenConns(o.MaxOpenConns) db.SetMaxIdleConns(o.MaxIdleConns) db.SetConnMaxLifetime(o.ConnMaxLifetime) db.SetConnMaxIdleTime(o.ConnMaxIdleTime) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() if err := db.PingContext(ctx); err != nil { _ = db.Close() pgErr = err return } pg = db pgErr = nil }) return pg, pgErr } func GetPostgres() (*sql.DB, error) { if pg == nil { return nil, fmt.Errorf("postgres not initialized") } return pg, nil } func ClosePostgres() error { if pg == nil { return fmt.Errorf("postgres not initialized") } err := pg.Close() pg = nil pgOnce = sync.Once{} if err != nil { return err } return nil }