package store import ( "database/sql" "errors" "fmt" "path/filepath" "time" _ "modernc.org/sqlite" ) type AuthStore struct{ db *sql.DB } type User struct { ID int64 `json:"id"` Username string `json:"username"` PasswordHash string `json:"-"` Role string `json:"role"` Disabled bool `json:"disabled"` ForcePasswordChange bool `json:"force_password_change"` CreatedAt string `json:"created_at"` UpdatedAt string `json:"updated_at"` } const ( RoleAdmin = "admin" RoleUser = "user" ) var ErrNotFound = errors.New("user not found") func OpenAuth(path string) (*AuthStore, error) { if dir := filepath.Dir(path); dir != "" { _ = ensureDir(dir) } dsn := fmt.Sprintf("file:%s?_pragma=journal_mode(WAL)&_pragma=foreign_keys(1)&_pragma=busy_timeout(5000)", path) db, err := sql.Open("sqlite", dsn) if err != nil { return nil, fmt.Errorf("open auth.db: %w", err) } db.SetMaxOpenConns(1) // sqlite write 单连接更稳 if err := db.Ping(); err != nil { return nil, fmt.Errorf("ping auth.db: %w", err) } s := &AuthStore{db: db} if err := s.init(); err != nil { return nil, err } return s, nil } func (s *AuthStore) Close() error { return s.db.Close() } func (s *AuthStore) init() error { _, err := s.db.Exec(` CREATE TABLE IF NOT EXISTS users ( id INTEGER PRIMARY KEY AUTOINCREMENT, username TEXT NOT NULL UNIQUE, password_hash TEXT NOT NULL, role TEXT NOT NULL CHECK(role IN ('admin','user')), disabled INTEGER NOT NULL DEFAULT 0, created_at TEXT NOT NULL, updated_at TEXT NOT NULL ); CREATE INDEX IF NOT EXISTS idx_users_username ON users(username); `) if err != nil { return err } // 兼容旧表:添加 force_password_change 列(已存在则忽略错误) _, _ = s.db.Exec(`ALTER TABLE users ADD COLUMN force_password_change INTEGER NOT NULL DEFAULT 0`) return nil } func (s *AuthStore) CountAdmins() (int, error) { var n int err := s.db.QueryRow(`SELECT COUNT(*) FROM users WHERE role = 'admin'`).Scan(&n) return n, err } func (s *AuthStore) CreateUser(username, passwordHash, role string) (*User, error) { now := time.Now().Format("2006-01-02 15:04:05") res, err := s.db.Exec( `INSERT INTO users(username, password_hash, role, disabled, created_at, updated_at) VALUES (?, ?, ?, 0, ?, ?)`, username, passwordHash, role, now, now, ) if err != nil { return nil, err } id, _ := res.LastInsertId() return &User{ID: id, Username: username, PasswordHash: passwordHash, Role: role, CreatedAt: now, UpdatedAt: now}, nil } func (s *AuthStore) GetByUsername(username string) (*User, error) { row := s.db.QueryRow(`SELECT id, username, password_hash, role, disabled, force_password_change, created_at, updated_at FROM users WHERE username = ?`, username) return scanUser(row) } func (s *AuthStore) GetByID(id int64) (*User, error) { row := s.db.QueryRow(`SELECT id, username, password_hash, role, disabled, force_password_change, created_at, updated_at FROM users WHERE id = ?`, id) return scanUser(row) } func (s *AuthStore) ListUsers() ([]User, error) { rows, err := s.db.Query(`SELECT id, username, password_hash, role, disabled, force_password_change, created_at, updated_at FROM users ORDER BY id ASC`) if err != nil { return nil, err } defer rows.Close() out := []User{} for rows.Next() { u, err := scanUserRows(rows) if err != nil { return nil, err } out = append(out, *u) } return out, rows.Err() } func (s *AuthStore) UpdatePassword(id int64, hash string) error { now := time.Now().Format("2006-01-02 15:04:05") res, err := s.db.Exec(`UPDATE users SET password_hash = ?, updated_at = ? WHERE id = ?`, hash, now, id) if err != nil { return err } n, _ := res.RowsAffected() if n == 0 { return ErrNotFound } return nil } func (s *AuthStore) SetForcePasswordChange(id int64, v bool) error { now := time.Now().Format("2006-01-02 15:04:05") val := 0 if v { val = 1 } res, err := s.db.Exec(`UPDATE users SET force_password_change = ?, updated_at = ? WHERE id = ?`, val, now, id) if err != nil { return err } n, _ := res.RowsAffected() if n == 0 { return ErrNotFound } return nil } func (s *AuthStore) SetDisabled(id int64, disabled bool) error { now := time.Now().Format("2006-01-02 15:04:05") v := 0 if disabled { v = 1 } res, err := s.db.Exec(`UPDATE users SET disabled = ?, updated_at = ? WHERE id = ?`, v, now, id) if err != nil { return err } n, _ := res.RowsAffected() if n == 0 { return ErrNotFound } return nil } func (s *AuthStore) DeleteUser(id int64) error { res, err := s.db.Exec(`DELETE FROM users WHERE id = ?`, id) if err != nil { return err } n, _ := res.RowsAffected() if n == 0 { return ErrNotFound } return nil } type rowScanner interface { Scan(dest ...any) error } func scanUser(r rowScanner) (*User, error) { var u User var disabled int var forceChange int if err := r.Scan(&u.ID, &u.Username, &u.PasswordHash, &u.Role, &disabled, &forceChange, &u.CreatedAt, &u.UpdatedAt); err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, ErrNotFound } return nil, err } u.Disabled = disabled != 0 u.ForcePasswordChange = forceChange != 0 return &u, nil } func scanUserRows(rows *sql.Rows) (*User, error) { return scanUser(rows) }