147 lines
3.5 KiB
Go
147 lines
3.5 KiB
Go
package db
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"fmt"
|
|
"sync"
|
|
"time"
|
|
|
|
"common/utils"
|
|
_ "github.com/jackc/pgx/v5/stdlib"
|
|
)
|
|
|
|
type Options struct {
|
|
Host string
|
|
Port string
|
|
User string
|
|
Password string
|
|
DBName string
|
|
SSLMode string
|
|
MaxOpenConns int
|
|
MaxIdleConns int
|
|
ConnMaxLifetime time.Duration
|
|
ConnMaxIdleTime time.Duration
|
|
}
|
|
|
|
type Option func(*Options)
|
|
|
|
func defaultOptions() *Options {
|
|
return &Options{
|
|
Host: "localhost",
|
|
Port: "5432",
|
|
SSLMode: "disable",
|
|
MaxOpenConns: 20,
|
|
MaxIdleConns: 10,
|
|
ConnMaxLifetime: 30 * time.Minute,
|
|
ConnMaxIdleTime: 10 * time.Minute,
|
|
}
|
|
}
|
|
|
|
func FromEnv() *Options {
|
|
return &Options{
|
|
Host: utils.GetEnv("PG_HOST", ""),
|
|
Port: utils.GetEnv("PG_PORT", ""),
|
|
User: utils.GetEnv("PG_USER", ""),
|
|
Password: utils.GetEnv("PG_PASSWORD", ""),
|
|
DBName: utils.GetEnv("PG_DBNAME", ""),
|
|
SSLMode: utils.GetEnv("PG_SSLMODE", ""),
|
|
}
|
|
}
|
|
|
|
func WithHost(v string) Option { return func(o *Options) { o.Host = v } }
|
|
func WithPort(v string) Option { return func(o *Options) { o.Port = v } }
|
|
func WithUser(v string) Option { return func(o *Options) { o.User = v } }
|
|
func WithPassword(v string) Option { return func(o *Options) { o.Password = v } }
|
|
func WithDBName(v string) Option { return func(o *Options) { o.DBName = v } }
|
|
func WithSSLMode(v string) Option { return func(o *Options) { o.SSLMode = v } }
|
|
func WithMaxOpenConns(v int) Option { return func(o *Options) { o.MaxOpenConns = v } }
|
|
func WithMaxIdleConns(v int) Option { return func(o *Options) { o.MaxIdleConns = v } }
|
|
func WithConnMaxLifetime(v time.Duration) Option {
|
|
return func(o *Options) { o.ConnMaxLifetime = v }
|
|
}
|
|
func WithConnMaxIdleTime(v time.Duration) Option {
|
|
return func(o *Options) { o.ConnMaxIdleTime = v }
|
|
}
|
|
|
|
var (
|
|
pg *sql.DB
|
|
pgOnce sync.Once
|
|
pgErr error
|
|
)
|
|
|
|
func InitPostgres(opts ...Option) (*sql.DB, error) {
|
|
pgOnce.Do(func() {
|
|
o := defaultOptions()
|
|
env := FromEnv()
|
|
merge := func(dst *Options, src *Options) {
|
|
if src.Host != "" {
|
|
dst.Host = src.Host
|
|
}
|
|
if src.Port != "" {
|
|
dst.Port = src.Port
|
|
}
|
|
if src.User != "" {
|
|
dst.User = src.User
|
|
}
|
|
if src.Password != "" {
|
|
dst.Password = src.Password
|
|
}
|
|
if src.DBName != "" {
|
|
dst.DBName = src.DBName
|
|
}
|
|
if src.SSLMode != "" {
|
|
dst.SSLMode = src.SSLMode
|
|
}
|
|
}
|
|
merge(o, env)
|
|
for _, f := range opts {
|
|
f(o)
|
|
}
|
|
if o.User == "" || o.Password == "" || o.DBName == "" {
|
|
pgErr = fmt.Errorf("postgres config missing")
|
|
return
|
|
}
|
|
dsn := fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=%s", o.Host, o.Port, o.User, o.Password, o.DBName, o.SSLMode)
|
|
db, err := sql.Open("pgx", dsn)
|
|
if err != nil {
|
|
pgErr = err
|
|
return
|
|
}
|
|
db.SetMaxOpenConns(o.MaxOpenConns)
|
|
db.SetMaxIdleConns(o.MaxIdleConns)
|
|
db.SetConnMaxLifetime(o.ConnMaxLifetime)
|
|
db.SetConnMaxIdleTime(o.ConnMaxIdleTime)
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancel()
|
|
if err := db.PingContext(ctx); err != nil {
|
|
_ = db.Close()
|
|
pgErr = err
|
|
return
|
|
}
|
|
pg = db
|
|
pgErr = nil
|
|
})
|
|
return pg, pgErr
|
|
}
|
|
|
|
func GetPostgres() (*sql.DB, error) {
|
|
if pg == nil {
|
|
return nil, fmt.Errorf("postgres not initialized")
|
|
}
|
|
return pg, nil
|
|
}
|
|
|
|
func ClosePostgres() error {
|
|
if pg == nil {
|
|
return fmt.Errorf("postgres not initialized")
|
|
}
|
|
err := pg.Close()
|
|
pg = nil
|
|
pgOnce = sync.Once{}
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|