diff --git a/trading_assistant_api/common/db/postgres.go b/trading_assistant_api/common/db/postgres.go new file mode 100644 index 0000000..32104b6 --- /dev/null +++ b/trading_assistant_api/common/db/postgres.go @@ -0,0 +1,146 @@ +package db + +import ( + "context" + "database/sql" + "fmt" + "sync" + "time" + + "common/utils" + _ "github.com/jackc/pgx/v5/stdlib" +) + +type Options struct { + Host string + Port string + User string + Password string + DBName string + SSLMode string + MaxOpenConns int + MaxIdleConns int + ConnMaxLifetime time.Duration + ConnMaxIdleTime time.Duration +} + +type Option func(*Options) + +func defaultOptions() *Options { + return &Options{ + Host: "localhost", + Port: "5432", + SSLMode: "disable", + MaxOpenConns: 20, + MaxIdleConns: 10, + ConnMaxLifetime: 30 * time.Minute, + ConnMaxIdleTime: 10 * time.Minute, + } +} + +func FromEnv() *Options { + return &Options{ + Host: utils.GetEnv("PG_HOST", ""), + Port: utils.GetEnv("PG_PORT", ""), + User: utils.GetEnv("PG_USER", ""), + Password: utils.GetEnv("PG_PASSWORD", ""), + DBName: utils.GetEnv("PG_DBNAME", ""), + SSLMode: utils.GetEnv("PG_SSLMODE", ""), + } +} + +func WithHost(v string) Option { return func(o *Options) { o.Host = v } } +func WithPort(v string) Option { return func(o *Options) { o.Port = v } } +func WithUser(v string) Option { return func(o *Options) { o.User = v } } +func WithPassword(v string) Option { return func(o *Options) { o.Password = v } } +func WithDBName(v string) Option { return func(o *Options) { o.DBName = v } } +func WithSSLMode(v string) Option { return func(o *Options) { o.SSLMode = v } } +func WithMaxOpenConns(v int) Option { return func(o *Options) { o.MaxOpenConns = v } } +func WithMaxIdleConns(v int) Option { return func(o *Options) { o.MaxIdleConns = v } } +func WithConnMaxLifetime(v time.Duration) Option { + return func(o *Options) { o.ConnMaxLifetime = v } +} +func WithConnMaxIdleTime(v time.Duration) Option { + return func(o *Options) { o.ConnMaxIdleTime = v } +} + +var ( + pg *sql.DB + pgOnce sync.Once + pgErr error +) + +func InitPostgres(opts ...Option) (*sql.DB, error) { + pgOnce.Do(func() { + o := defaultOptions() + env := FromEnv() + merge := func(dst *Options, src *Options) { + if src.Host != "" { + dst.Host = src.Host + } + if src.Port != "" { + dst.Port = src.Port + } + if src.User != "" { + dst.User = src.User + } + if src.Password != "" { + dst.Password = src.Password + } + if src.DBName != "" { + dst.DBName = src.DBName + } + if src.SSLMode != "" { + dst.SSLMode = src.SSLMode + } + } + merge(o, env) + for _, f := range opts { + f(o) + } + if o.User == "" || o.Password == "" || o.DBName == "" { + pgErr = fmt.Errorf("postgres config missing") + return + } + dsn := fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=%s", o.Host, o.Port, o.User, o.Password, o.DBName, o.SSLMode) + db, err := sql.Open("pgx", dsn) + if err != nil { + pgErr = err + return + } + db.SetMaxOpenConns(o.MaxOpenConns) + db.SetMaxIdleConns(o.MaxIdleConns) + db.SetConnMaxLifetime(o.ConnMaxLifetime) + db.SetConnMaxIdleTime(o.ConnMaxIdleTime) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := db.PingContext(ctx); err != nil { + _ = db.Close() + pgErr = err + return + } + pg = db + pgErr = nil + }) + return pg, pgErr +} + +func GetPostgres() (*sql.DB, error) { + if pg == nil { + return nil, fmt.Errorf("postgres not initialized") + } + return pg, nil +} + +func ClosePostgres() error { + if pg == nil { + return fmt.Errorf("postgres not initialized") + } + err := pg.Close() + pg = nil + pgOnce = sync.Once{} + if err != nil { + return err + } + return nil +} diff --git a/trading_assistant_api/common/go.mod b/trading_assistant_api/common/go.mod index b93cdd1..d9b9569 100644 --- a/trading_assistant_api/common/go.mod +++ b/trading_assistant_api/common/go.mod @@ -1,3 +1,5 @@ module common go 1.25.7 + +require github.com/jackc/pgx/v5 v5.6.0