189 lines
4.9 KiB
Go
189 lines
4.9 KiB
Go
package store
|
|
|
|
import (
|
|
"database/sql"
|
|
"errors"
|
|
"fmt"
|
|
"time"
|
|
|
|
_ "github.com/lib/pq"
|
|
)
|
|
|
|
type AuthStore struct{ db *sql.DB }
|
|
|
|
type User struct {
|
|
ID int64 `json:"id"`
|
|
Username string `json:"username"`
|
|
PasswordHash string `json:"-"`
|
|
Role string `json:"role"`
|
|
Disabled bool `json:"disabled"`
|
|
ForcePasswordChange bool `json:"force_password_change"`
|
|
CreatedAt string `json:"created_at"`
|
|
UpdatedAt string `json:"updated_at"`
|
|
}
|
|
|
|
const (
|
|
RoleAdmin = "admin"
|
|
RoleUser = "user"
|
|
)
|
|
|
|
var ErrNotFound = errors.New("user not found")
|
|
|
|
func OpenAuth(databaseURL string) (*AuthStore, error) {
|
|
db, err := sql.Open("postgres", databaseURL)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("open auth db: %w", err)
|
|
}
|
|
db.SetMaxOpenConns(8)
|
|
if err := db.Ping(); err != nil {
|
|
return nil, fmt.Errorf("ping auth db: %w", err)
|
|
}
|
|
s := &AuthStore{db: db}
|
|
if err := s.init(); err != nil {
|
|
return nil, err
|
|
}
|
|
return s, nil
|
|
}
|
|
|
|
func (s *AuthStore) Close() error { return s.db.Close() }
|
|
|
|
func (s *AuthStore) init() error {
|
|
_, err := s.db.Exec(`
|
|
CREATE TABLE IF NOT EXISTS users (
|
|
id SERIAL PRIMARY KEY,
|
|
username TEXT NOT NULL UNIQUE,
|
|
password_hash TEXT NOT NULL,
|
|
role TEXT NOT NULL CHECK(role IN ('admin','user')),
|
|
disabled BOOLEAN NOT NULL DEFAULT FALSE,
|
|
created_at TEXT NOT NULL,
|
|
updated_at TEXT NOT NULL
|
|
);
|
|
CREATE INDEX IF NOT EXISTS idx_users_username ON users(username);
|
|
`)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
// 兼容旧表:添加 force_password_change 列(已存在则忽略错误)
|
|
_, _ = s.db.Exec(`ALTER TABLE users ADD COLUMN IF NOT EXISTS force_password_change BOOLEAN NOT NULL DEFAULT FALSE`)
|
|
return nil
|
|
}
|
|
|
|
func (s *AuthStore) CountAdmins() (int, error) {
|
|
var n int
|
|
err := s.db.QueryRow(`SELECT COUNT(*) FROM users WHERE role = 'admin'`).Scan(&n)
|
|
return n, err
|
|
}
|
|
|
|
func (s *AuthStore) CreateUser(username, passwordHash, role string) (*User, error) {
|
|
now := time.Now().Format("2006-01-02 15:04:05")
|
|
var id int64
|
|
err := s.db.QueryRow(
|
|
`INSERT INTO users(username, password_hash, role, disabled, created_at, updated_at)
|
|
VALUES ($1, $2, $3, FALSE, $4, $5) RETURNING id`,
|
|
username, passwordHash, role, now, now,
|
|
).Scan(&id)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &User{ID: id, Username: username, PasswordHash: passwordHash, Role: role,
|
|
CreatedAt: now, UpdatedAt: now}, nil
|
|
}
|
|
|
|
func (s *AuthStore) GetByUsername(username string) (*User, error) {
|
|
row := s.db.QueryRow(`SELECT id, username, password_hash, role, disabled, force_password_change, created_at, updated_at
|
|
FROM users WHERE username = $1`, username)
|
|
return scanUser(row)
|
|
}
|
|
|
|
func (s *AuthStore) GetByID(id int64) (*User, error) {
|
|
row := s.db.QueryRow(`SELECT id, username, password_hash, role, disabled, force_password_change, created_at, updated_at
|
|
FROM users WHERE id = $1`, id)
|
|
return scanUser(row)
|
|
}
|
|
|
|
func (s *AuthStore) ListUsers() ([]User, error) {
|
|
rows, err := s.db.Query(`SELECT id, username, password_hash, role, disabled, force_password_change, created_at, updated_at
|
|
FROM users ORDER BY id ASC`)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
out := []User{}
|
|
for rows.Next() {
|
|
u, err := scanUserRows(rows)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
out = append(out, *u)
|
|
}
|
|
return out, rows.Err()
|
|
}
|
|
|
|
func (s *AuthStore) UpdatePassword(id int64, hash string) error {
|
|
now := time.Now().Format("2006-01-02 15:04:05")
|
|
res, err := s.db.Exec(`UPDATE users SET password_hash = $1, updated_at = $2 WHERE id = $3`, hash, now, id)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
n, _ := res.RowsAffected()
|
|
if n == 0 {
|
|
return ErrNotFound
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *AuthStore) SetForcePasswordChange(id int64, v bool) error {
|
|
now := time.Now().Format("2006-01-02 15:04:05")
|
|
res, err := s.db.Exec(`UPDATE users SET force_password_change = $1, updated_at = $2 WHERE id = $3`, v, now, id)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
n, _ := res.RowsAffected()
|
|
if n == 0 {
|
|
return ErrNotFound
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *AuthStore) SetDisabled(id int64, disabled bool) error {
|
|
now := time.Now().Format("2006-01-02 15:04:05")
|
|
res, err := s.db.Exec(`UPDATE users SET disabled = $1, updated_at = $2 WHERE id = $3`, disabled, now, id)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
n, _ := res.RowsAffected()
|
|
if n == 0 {
|
|
return ErrNotFound
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *AuthStore) DeleteUser(id int64) error {
|
|
res, err := s.db.Exec(`DELETE FROM users WHERE id = $1`, id)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
n, _ := res.RowsAffected()
|
|
if n == 0 {
|
|
return ErrNotFound
|
|
}
|
|
return nil
|
|
}
|
|
|
|
type rowScanner interface {
|
|
Scan(dest ...any) error
|
|
}
|
|
|
|
func scanUser(r rowScanner) (*User, error) {
|
|
var u User
|
|
if err := r.Scan(&u.ID, &u.Username, &u.PasswordHash, &u.Role, &u.Disabled, &u.ForcePasswordChange, &u.CreatedAt, &u.UpdatedAt); err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return nil, ErrNotFound
|
|
}
|
|
return nil, err
|
|
}
|
|
return &u, nil
|
|
}
|
|
|
|
func scanUserRows(rows *sql.Rows) (*User, error) { return scanUser(rows) }
|