package store import ( "database/sql" "errors" "fmt" "time" _ "github.com/lib/pq" ) type AuthStore struct{ db *sql.DB } type User struct { ID int64 `json:"id"` Username string `json:"username"` PasswordHash string `json:"-"` Role string `json:"role"` Disabled bool `json:"disabled"` ForcePasswordChange bool `json:"force_password_change"` CreatedAt string `json:"created_at"` UpdatedAt string `json:"updated_at"` } const ( RoleAdmin = "admin" RoleUser = "user" ) var ErrNotFound = errors.New("user not found") func OpenAuth(databaseURL string) (*AuthStore, error) { db, err := sql.Open("postgres", databaseURL) if err != nil { return nil, fmt.Errorf("open auth db: %w", err) } db.SetMaxOpenConns(8) if err := db.Ping(); err != nil { return nil, fmt.Errorf("ping auth db: %w", err) } s := &AuthStore{db: db} if err := s.init(); err != nil { return nil, err } return s, nil } func (s *AuthStore) Close() error { return s.db.Close() } func (s *AuthStore) init() error { _, err := s.db.Exec(` CREATE TABLE IF NOT EXISTS users ( id SERIAL PRIMARY KEY, username TEXT NOT NULL UNIQUE, password_hash TEXT NOT NULL, role TEXT NOT NULL CHECK(role IN ('admin','user')), disabled BOOLEAN NOT NULL DEFAULT FALSE, created_at TEXT NOT NULL, updated_at TEXT NOT NULL ); CREATE INDEX IF NOT EXISTS idx_users_username ON users(username); `) if err != nil { return err } // 兼容旧表:添加 force_password_change 列(已存在则忽略错误) _, _ = s.db.Exec(`ALTER TABLE users ADD COLUMN IF NOT EXISTS force_password_change BOOLEAN NOT NULL DEFAULT FALSE`) return nil } func (s *AuthStore) CountAdmins() (int, error) { var n int err := s.db.QueryRow(`SELECT COUNT(*) FROM users WHERE role = 'admin'`).Scan(&n) return n, err } func (s *AuthStore) CreateUser(username, passwordHash, role string) (*User, error) { now := time.Now().Format("2006-01-02 15:04:05") var id int64 err := s.db.QueryRow( `INSERT INTO users(username, password_hash, role, disabled, created_at, updated_at) VALUES ($1, $2, $3, FALSE, $4, $5) RETURNING id`, username, passwordHash, role, now, now, ).Scan(&id) if err != nil { return nil, err } return &User{ID: id, Username: username, PasswordHash: passwordHash, Role: role, CreatedAt: now, UpdatedAt: now}, nil } func (s *AuthStore) GetByUsername(username string) (*User, error) { row := s.db.QueryRow(`SELECT id, username, password_hash, role, disabled, force_password_change, created_at, updated_at FROM users WHERE username = $1`, username) return scanUser(row) } func (s *AuthStore) GetByID(id int64) (*User, error) { row := s.db.QueryRow(`SELECT id, username, password_hash, role, disabled, force_password_change, created_at, updated_at FROM users WHERE id = $1`, id) return scanUser(row) } func (s *AuthStore) ListUsers() ([]User, error) { rows, err := s.db.Query(`SELECT id, username, password_hash, role, disabled, force_password_change, created_at, updated_at FROM users ORDER BY id ASC`) if err != nil { return nil, err } defer rows.Close() out := []User{} for rows.Next() { u, err := scanUserRows(rows) if err != nil { return nil, err } out = append(out, *u) } return out, rows.Err() } func (s *AuthStore) UpdatePassword(id int64, hash string) error { now := time.Now().Format("2006-01-02 15:04:05") res, err := s.db.Exec(`UPDATE users SET password_hash = $1, updated_at = $2 WHERE id = $3`, hash, now, id) if err != nil { return err } n, _ := res.RowsAffected() if n == 0 { return ErrNotFound } return nil } func (s *AuthStore) SetForcePasswordChange(id int64, v bool) error { now := time.Now().Format("2006-01-02 15:04:05") res, err := s.db.Exec(`UPDATE users SET force_password_change = $1, updated_at = $2 WHERE id = $3`, v, now, id) if err != nil { return err } n, _ := res.RowsAffected() if n == 0 { return ErrNotFound } return nil } func (s *AuthStore) SetDisabled(id int64, disabled bool) error { now := time.Now().Format("2006-01-02 15:04:05") res, err := s.db.Exec(`UPDATE users SET disabled = $1, updated_at = $2 WHERE id = $3`, disabled, now, id) if err != nil { return err } n, _ := res.RowsAffected() if n == 0 { return ErrNotFound } return nil } func (s *AuthStore) DeleteUser(id int64) error { res, err := s.db.Exec(`DELETE FROM users WHERE id = $1`, id) if err != nil { return err } n, _ := res.RowsAffected() if n == 0 { return ErrNotFound } return nil } type rowScanner interface { Scan(dest ...any) error } func scanUser(r rowScanner) (*User, error) { var u User if err := r.Scan(&u.ID, &u.Username, &u.PasswordHash, &u.Role, &u.Disabled, &u.ForcePasswordChange, &u.CreatedAt, &u.UpdatedAt); err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, ErrNotFound } return nil, err } return &u, nil } func scanUserRows(rows *sql.Rows) (*User, error) { return scanUser(rows) }