|
|
|
|
@@ -4,10 +4,9 @@ import (
|
|
|
|
|
"database/sql"
|
|
|
|
|
"errors"
|
|
|
|
|
"fmt"
|
|
|
|
|
"path/filepath"
|
|
|
|
|
"time"
|
|
|
|
|
|
|
|
|
|
_ "modernc.org/sqlite"
|
|
|
|
|
_ "github.com/lib/pq"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
type AuthStore struct{ db *sql.DB }
|
|
|
|
|
@@ -30,18 +29,14 @@ const (
|
|
|
|
|
|
|
|
|
|
var ErrNotFound = errors.New("user not found")
|
|
|
|
|
|
|
|
|
|
func OpenAuth(path string) (*AuthStore, error) {
|
|
|
|
|
if dir := filepath.Dir(path); dir != "" {
|
|
|
|
|
_ = ensureDir(dir)
|
|
|
|
|
}
|
|
|
|
|
dsn := fmt.Sprintf("file:%s?_pragma=journal_mode(WAL)&_pragma=foreign_keys(1)&_pragma=busy_timeout(5000)", path)
|
|
|
|
|
db, err := sql.Open("sqlite", dsn)
|
|
|
|
|
func OpenAuth(databaseURL string) (*AuthStore, error) {
|
|
|
|
|
db, err := sql.Open("postgres", databaseURL)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return nil, fmt.Errorf("open auth.db: %w", err)
|
|
|
|
|
return nil, fmt.Errorf("open auth db: %w", err)
|
|
|
|
|
}
|
|
|
|
|
db.SetMaxOpenConns(1) // sqlite write 单连接更稳
|
|
|
|
|
db.SetMaxOpenConns(8)
|
|
|
|
|
if err := db.Ping(); err != nil {
|
|
|
|
|
return nil, fmt.Errorf("ping auth.db: %w", err)
|
|
|
|
|
return nil, fmt.Errorf("ping auth db: %w", err)
|
|
|
|
|
}
|
|
|
|
|
s := &AuthStore{db: db}
|
|
|
|
|
if err := s.init(); err != nil {
|
|
|
|
|
@@ -55,11 +50,11 @@ func (s *AuthStore) Close() error { return s.db.Close() }
|
|
|
|
|
func (s *AuthStore) init() error {
|
|
|
|
|
_, err := s.db.Exec(`
|
|
|
|
|
CREATE TABLE IF NOT EXISTS users (
|
|
|
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
|
|
|
id SERIAL PRIMARY KEY,
|
|
|
|
|
username TEXT NOT NULL UNIQUE,
|
|
|
|
|
password_hash TEXT NOT NULL,
|
|
|
|
|
role TEXT NOT NULL CHECK(role IN ('admin','user')),
|
|
|
|
|
disabled INTEGER NOT NULL DEFAULT 0,
|
|
|
|
|
disabled BOOLEAN NOT NULL DEFAULT FALSE,
|
|
|
|
|
created_at TEXT NOT NULL,
|
|
|
|
|
updated_at TEXT NOT NULL
|
|
|
|
|
);
|
|
|
|
|
@@ -69,7 +64,7 @@ func (s *AuthStore) init() error {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
// 兼容旧表:添加 force_password_change 列(已存在则忽略错误)
|
|
|
|
|
_, _ = s.db.Exec(`ALTER TABLE users ADD COLUMN force_password_change INTEGER NOT NULL DEFAULT 0`)
|
|
|
|
|
_, _ = s.db.Exec(`ALTER TABLE users ADD COLUMN IF NOT EXISTS force_password_change BOOLEAN NOT NULL DEFAULT FALSE`)
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@@ -81,28 +76,28 @@ func (s *AuthStore) CountAdmins() (int, error) {
|
|
|
|
|
|
|
|
|
|
func (s *AuthStore) CreateUser(username, passwordHash, role string) (*User, error) {
|
|
|
|
|
now := time.Now().Format("2006-01-02 15:04:05")
|
|
|
|
|
res, err := s.db.Exec(
|
|
|
|
|
var id int64
|
|
|
|
|
err := s.db.QueryRow(
|
|
|
|
|
`INSERT INTO users(username, password_hash, role, disabled, created_at, updated_at)
|
|
|
|
|
VALUES (?, ?, ?, 0, ?, ?)`,
|
|
|
|
|
VALUES ($1, $2, $3, FALSE, $4, $5) RETURNING id`,
|
|
|
|
|
username, passwordHash, role, now, now,
|
|
|
|
|
)
|
|
|
|
|
).Scan(&id)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
id, _ := res.LastInsertId()
|
|
|
|
|
return &User{ID: id, Username: username, PasswordHash: passwordHash, Role: role,
|
|
|
|
|
CreatedAt: now, UpdatedAt: now}, nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (s *AuthStore) GetByUsername(username string) (*User, error) {
|
|
|
|
|
row := s.db.QueryRow(`SELECT id, username, password_hash, role, disabled, force_password_change, created_at, updated_at
|
|
|
|
|
FROM users WHERE username = ?`, username)
|
|
|
|
|
FROM users WHERE username = $1`, username)
|
|
|
|
|
return scanUser(row)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (s *AuthStore) GetByID(id int64) (*User, error) {
|
|
|
|
|
row := s.db.QueryRow(`SELECT id, username, password_hash, role, disabled, force_password_change, created_at, updated_at
|
|
|
|
|
FROM users WHERE id = ?`, id)
|
|
|
|
|
FROM users WHERE id = $1`, id)
|
|
|
|
|
return scanUser(row)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@@ -126,7 +121,7 @@ func (s *AuthStore) ListUsers() ([]User, error) {
|
|
|
|
|
|
|
|
|
|
func (s *AuthStore) UpdatePassword(id int64, hash string) error {
|
|
|
|
|
now := time.Now().Format("2006-01-02 15:04:05")
|
|
|
|
|
res, err := s.db.Exec(`UPDATE users SET password_hash = ?, updated_at = ? WHERE id = ?`, hash, now, id)
|
|
|
|
|
res, err := s.db.Exec(`UPDATE users SET password_hash = $1, updated_at = $2 WHERE id = $3`, hash, now, id)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
@@ -139,11 +134,7 @@ func (s *AuthStore) UpdatePassword(id int64, hash string) error {
|
|
|
|
|
|
|
|
|
|
func (s *AuthStore) SetForcePasswordChange(id int64, v bool) error {
|
|
|
|
|
now := time.Now().Format("2006-01-02 15:04:05")
|
|
|
|
|
val := 0
|
|
|
|
|
if v {
|
|
|
|
|
val = 1
|
|
|
|
|
}
|
|
|
|
|
res, err := s.db.Exec(`UPDATE users SET force_password_change = ?, updated_at = ? WHERE id = ?`, val, now, id)
|
|
|
|
|
res, err := s.db.Exec(`UPDATE users SET force_password_change = $1, updated_at = $2 WHERE id = $3`, v, now, id)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
@@ -156,11 +147,7 @@ func (s *AuthStore) SetForcePasswordChange(id int64, v bool) error {
|
|
|
|
|
|
|
|
|
|
func (s *AuthStore) SetDisabled(id int64, disabled bool) error {
|
|
|
|
|
now := time.Now().Format("2006-01-02 15:04:05")
|
|
|
|
|
v := 0
|
|
|
|
|
if disabled {
|
|
|
|
|
v = 1
|
|
|
|
|
}
|
|
|
|
|
res, err := s.db.Exec(`UPDATE users SET disabled = ?, updated_at = ? WHERE id = ?`, v, now, id)
|
|
|
|
|
res, err := s.db.Exec(`UPDATE users SET disabled = $1, updated_at = $2 WHERE id = $3`, disabled, now, id)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
@@ -172,7 +159,7 @@ func (s *AuthStore) SetDisabled(id int64, disabled bool) error {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (s *AuthStore) DeleteUser(id int64) error {
|
|
|
|
|
res, err := s.db.Exec(`DELETE FROM users WHERE id = ?`, id)
|
|
|
|
|
res, err := s.db.Exec(`DELETE FROM users WHERE id = $1`, id)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
@@ -189,16 +176,12 @@ type rowScanner interface {
|
|
|
|
|
|
|
|
|
|
func scanUser(r rowScanner) (*User, error) {
|
|
|
|
|
var u User
|
|
|
|
|
var disabled int
|
|
|
|
|
var forceChange int
|
|
|
|
|
if err := r.Scan(&u.ID, &u.Username, &u.PasswordHash, &u.Role, &disabled, &forceChange, &u.CreatedAt, &u.UpdatedAt); err != nil {
|
|
|
|
|
if err := r.Scan(&u.ID, &u.Username, &u.PasswordHash, &u.Role, &u.Disabled, &u.ForcePasswordChange, &u.CreatedAt, &u.UpdatedAt); err != nil {
|
|
|
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
|
|
|
return nil, ErrNotFound
|
|
|
|
|
}
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
u.Disabled = disabled != 0
|
|
|
|
|
u.ForcePasswordChange = forceChange != 0
|
|
|
|
|
return &u, nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|