将 auth 数据库从 SQLite 迁移到 PostgreSQL

This commit is contained in:
fish
2026-05-03 20:50:28 +08:00
parent fbcde3cc71
commit 79d2f292f1
8 changed files with 23 additions and 59 deletions

View File

@@ -4,10 +4,9 @@ import (
"database/sql"
"errors"
"fmt"
"path/filepath"
"time"
_ "modernc.org/sqlite"
_ "github.com/lib/pq"
)
type AuthStore struct{ db *sql.DB }
@@ -30,18 +29,14 @@ const (
var ErrNotFound = errors.New("user not found")
func OpenAuth(path string) (*AuthStore, error) {
if dir := filepath.Dir(path); dir != "" {
_ = ensureDir(dir)
}
dsn := fmt.Sprintf("file:%s?_pragma=journal_mode(WAL)&_pragma=foreign_keys(1)&_pragma=busy_timeout(5000)", path)
db, err := sql.Open("sqlite", dsn)
func OpenAuth(databaseURL string) (*AuthStore, error) {
db, err := sql.Open("postgres", databaseURL)
if err != nil {
return nil, fmt.Errorf("open auth.db: %w", err)
return nil, fmt.Errorf("open auth db: %w", err)
}
db.SetMaxOpenConns(1) // sqlite write 单连接更稳
db.SetMaxOpenConns(8)
if err := db.Ping(); err != nil {
return nil, fmt.Errorf("ping auth.db: %w", err)
return nil, fmt.Errorf("ping auth db: %w", err)
}
s := &AuthStore{db: db}
if err := s.init(); err != nil {
@@ -55,11 +50,11 @@ func (s *AuthStore) Close() error { return s.db.Close() }
func (s *AuthStore) init() error {
_, err := s.db.Exec(`
CREATE TABLE IF NOT EXISTS users (
id INTEGER PRIMARY KEY AUTOINCREMENT,
id SERIAL PRIMARY KEY,
username TEXT NOT NULL UNIQUE,
password_hash TEXT NOT NULL,
role TEXT NOT NULL CHECK(role IN ('admin','user')),
disabled INTEGER NOT NULL DEFAULT 0,
disabled BOOLEAN NOT NULL DEFAULT FALSE,
created_at TEXT NOT NULL,
updated_at TEXT NOT NULL
);
@@ -69,7 +64,7 @@ func (s *AuthStore) init() error {
return err
}
// 兼容旧表:添加 force_password_change 列(已存在则忽略错误)
_, _ = s.db.Exec(`ALTER TABLE users ADD COLUMN force_password_change INTEGER NOT NULL DEFAULT 0`)
_, _ = s.db.Exec(`ALTER TABLE users ADD COLUMN IF NOT EXISTS force_password_change BOOLEAN NOT NULL DEFAULT FALSE`)
return nil
}
@@ -81,28 +76,28 @@ func (s *AuthStore) CountAdmins() (int, error) {
func (s *AuthStore) CreateUser(username, passwordHash, role string) (*User, error) {
now := time.Now().Format("2006-01-02 15:04:05")
res, err := s.db.Exec(
var id int64
err := s.db.QueryRow(
`INSERT INTO users(username, password_hash, role, disabled, created_at, updated_at)
VALUES (?, ?, ?, 0, ?, ?)`,
VALUES ($1, $2, $3, FALSE, $4, $5) RETURNING id`,
username, passwordHash, role, now, now,
)
).Scan(&id)
if err != nil {
return nil, err
}
id, _ := res.LastInsertId()
return &User{ID: id, Username: username, PasswordHash: passwordHash, Role: role,
CreatedAt: now, UpdatedAt: now}, nil
}
func (s *AuthStore) GetByUsername(username string) (*User, error) {
row := s.db.QueryRow(`SELECT id, username, password_hash, role, disabled, force_password_change, created_at, updated_at
FROM users WHERE username = ?`, username)
FROM users WHERE username = $1`, username)
return scanUser(row)
}
func (s *AuthStore) GetByID(id int64) (*User, error) {
row := s.db.QueryRow(`SELECT id, username, password_hash, role, disabled, force_password_change, created_at, updated_at
FROM users WHERE id = ?`, id)
FROM users WHERE id = $1`, id)
return scanUser(row)
}
@@ -126,7 +121,7 @@ func (s *AuthStore) ListUsers() ([]User, error) {
func (s *AuthStore) UpdatePassword(id int64, hash string) error {
now := time.Now().Format("2006-01-02 15:04:05")
res, err := s.db.Exec(`UPDATE users SET password_hash = ?, updated_at = ? WHERE id = ?`, hash, now, id)
res, err := s.db.Exec(`UPDATE users SET password_hash = $1, updated_at = $2 WHERE id = $3`, hash, now, id)
if err != nil {
return err
}
@@ -139,11 +134,7 @@ func (s *AuthStore) UpdatePassword(id int64, hash string) error {
func (s *AuthStore) SetForcePasswordChange(id int64, v bool) error {
now := time.Now().Format("2006-01-02 15:04:05")
val := 0
if v {
val = 1
}
res, err := s.db.Exec(`UPDATE users SET force_password_change = ?, updated_at = ? WHERE id = ?`, val, now, id)
res, err := s.db.Exec(`UPDATE users SET force_password_change = $1, updated_at = $2 WHERE id = $3`, v, now, id)
if err != nil {
return err
}
@@ -156,11 +147,7 @@ func (s *AuthStore) SetForcePasswordChange(id int64, v bool) error {
func (s *AuthStore) SetDisabled(id int64, disabled bool) error {
now := time.Now().Format("2006-01-02 15:04:05")
v := 0
if disabled {
v = 1
}
res, err := s.db.Exec(`UPDATE users SET disabled = ?, updated_at = ? WHERE id = ?`, v, now, id)
res, err := s.db.Exec(`UPDATE users SET disabled = $1, updated_at = $2 WHERE id = $3`, disabled, now, id)
if err != nil {
return err
}
@@ -172,7 +159,7 @@ func (s *AuthStore) SetDisabled(id int64, disabled bool) error {
}
func (s *AuthStore) DeleteUser(id int64) error {
res, err := s.db.Exec(`DELETE FROM users WHERE id = ?`, id)
res, err := s.db.Exec(`DELETE FROM users WHERE id = $1`, id)
if err != nil {
return err
}
@@ -189,16 +176,12 @@ type rowScanner interface {
func scanUser(r rowScanner) (*User, error) {
var u User
var disabled int
var forceChange int
if err := r.Scan(&u.ID, &u.Username, &u.PasswordHash, &u.Role, &disabled, &forceChange, &u.CreatedAt, &u.UpdatedAt); err != nil {
if err := r.Scan(&u.ID, &u.Username, &u.PasswordHash, &u.Role, &u.Disabled, &u.ForcePasswordChange, &u.CreatedAt, &u.UpdatedAt); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrNotFound
}
return nil, err
}
u.Disabled = disabled != 0
u.ForcePasswordChange = forceChange != 0
return &u, nil
}

View File

@@ -1,10 +0,0 @@
package store
import "os"
func ensureDir(dir string) error {
if _, err := os.Stat(dir); err == nil {
return nil
}
return os.MkdirAll(dir, 0o755)
}