Compare commits

..

7 Commits

24 changed files with 330 additions and 200 deletions

View File

@@ -38,8 +38,6 @@ docker-compose exec postgres psql -U trade -d futures -c \
"SELECT ts_code, trade_date, composite, signal FROM scores ORDER BY trade_date DESC LIMIT 5;" "SELECT ts_code, trade_date, composite, signal FROM scores ORDER BY trade_date DESC LIMIT 5;"
``` ```
`tushare/.env` 必须存在且含 `TUSHARE_TOKEN=xxx`(已 gitignored)。
## 关键架构 ## 关键架构
**单进程串行流水线**:`src.main.main()` 先按命令行参数(显式 `ts_code` 优先,否则 `contracts.active_contract(symbol)` 按当月主力自动选)定下合约,再调 `run()` 顺序执行 `fetcher → storage(candles) → scorer → storage(scores)`。无后台任务、无队列,每次 CLI 调用处理一个合约一日。 **单进程串行流水线**:`src.main.main()` 先按命令行参数(显式 `ts_code` 优先,否则 `contracts.active_contract(symbol)` 按当月主力自动选)定下合约,再调 `run()` 顺序执行 `fetcher → storage(candles) → scorer → storage(scores)`。无后台任务、无队列,每次 CLI 调用处理一个合约一日。

View File

@@ -10,17 +10,7 @@
## 快速开始 ## 快速开始
### 1. 配置 tushare token ### 1. 启动全栈服务
将 token 写入 `tushare/.env`
```bash
echo "TUSHARE_TOKEN=你的token" > tushare/.env
```
该文件已被 gitignore 排除,不会进入版本库。
### 2. 启动全栈服务
```bash ```bash
docker-compose up -d docker-compose up -d
@@ -165,7 +155,6 @@ trade/
├── tushare/ # Python 数据服务 ├── tushare/ # Python 数据服务
│ ├── Dockerfile │ ├── Dockerfile
│ ├── requirements.txt │ ├── requirements.txt
│ ├── .env # TUSHARE_TOKEN(本地,不入库)
│ └── src/ # 数据采集 + 打分 + FastAPI │ └── src/ # 数据采集 + 打分 + FastAPI
│ ├── api.py # FastAPI 服务入口 │ ├── api.py # FastAPI 服务入口
│ ├── models.py │ ├── models.py

View File

@@ -17,14 +17,14 @@ services:
tushare: tushare:
build: ./tushare build: ./tushare
container_name: trade-tushare container_name: trade-tushare
env_file: ./tushare/.env # token 已写死在代码中,无需 env_file
environment: environment:
- DATABASE_URL=postgresql://trade:trade@postgres:5432/futures - DATABASE_URL=postgresql://trade:trade@postgres:5432/futures
depends_on: depends_on:
postgres: postgres:
condition: service_healthy condition: service_healthy
ports: ports:
- "8000:8000" - "4001:8000"
command: ["uvicorn", "src.api:app", "--host", "0.0.0.0", "--port", "8000"] command: ["uvicorn", "src.api:app", "--host", "0.0.0.0", "--port", "8000"]
web: web:
@@ -32,17 +32,14 @@ services:
context: ./web context: ./web
dockerfile: backend/Dockerfile dockerfile: backend/Dockerfile
container_name: trade-web container_name: trade-web
env_file: ./web/backend/.env # .env 已移除,环境变量直接写在此处
environment: environment:
- LISTEN_ADDR=:8080 - LISTEN_ADDR=:8080
- DATABASE_URL=postgres://trade:trade@postgres:5432/futures?sslmode=disable - DATABASE_URL=postgres://trade:trade@postgres:5432/futures?sslmode=disable
- AUTH_DB_PATH=/app/auth/auth.db
depends_on: depends_on:
- postgres - postgres
ports: ports:
- "8080:8080" - "4000:8080"
volumes:
- ./data:/app/auth
restart: unless-stopped restart: unless-stopped
volumes: volumes:

View File

@@ -1,16 +1,12 @@
import os
import sys import sys
import tushare as ts import tushare as ts
TUSHARE_TOKEN = "76efd8465f9f2591aa42a385268e06acf6b80b7a15be2267ad2281b7"
def main() -> int: def main() -> int:
token = os.environ.get("TUSHARE_TOKEN") ts.set_token(TUSHARE_TOKEN)
if not token:
print("[ERROR] 未设置 TUSHARE_TOKEN 环境变量", file=sys.stderr)
return 1
ts.set_token(token)
pro = ts.pro_api() pro = ts.pro_api()
df = pro.trade_cal(exchange="SHFE", start_date="20260101", end_date="20260110") df = pro.trade_cal(exchange="SHFE", start_date="20260101", end_date="20260110")

View File

@@ -1,15 +1,13 @@
import os
from typing import Optional from typing import Optional
import pandas as pd import pandas as pd
import tushare as ts import tushare as ts
TUSHARE_TOKEN = "76efd8465f9f2591aa42a385268e06acf6b80b7a15be2267ad2281b7"
def _init_api(): def _init_api():
token = os.environ.get("TUSHARE_TOKEN") ts.set_token(TUSHARE_TOKEN)
if not token:
raise RuntimeError("TUSHARE_TOKEN 环境变量未设置")
ts.set_token(token)
return ts.pro_api() return ts.pro_api()

View File

@@ -1,8 +0,0 @@
# 拷贝为 web/backend/.env 后填入真实值。.env 已被 .gitignore 排除。
# 首次启动时,若 auth.db 中没有任何 admin 用户,会用下面这一对凭据创建管理员;
# 一旦 admin 已存在,这两个变量会被忽略,改它们不会改密码。
ADMIN_USER=admin
ADMIN_PASS=changeme
# JWT 签名密钥;生成方式:openssl rand -hex 32
JWT_SECRET=replace-with-32-bytes-hex

View File

@@ -21,7 +21,6 @@ WORKDIR /src
COPY backend ./ COPY backend ./
COPY --from=ui /ui/dist ./dist COPY --from=ui /ui/dist ./dist
# 用 modernc.org/sqlite 纯 Go 驱动,无 CGO,无需 gcc/musl-dev
ENV CGO_ENABLED=0 GOOS=linux ENV CGO_ENABLED=0 GOOS=linux
RUN go mod tidy && \ RUN go mod tidy && \
@@ -36,7 +35,7 @@ RUN apk add --no-cache tzdata ca-certificates && \
echo "Asia/Shanghai" > /etc/timezone && \ echo "Asia/Shanghai" > /etc/timezone && \
apk del tzdata && \ apk del tzdata && \
adduser -D -u 1000 app && \ adduser -D -u 1000 app && \
mkdir -p /app/data /app/auth && \ mkdir -p /app/data && \
chown -R app:app /app chown -R app:app /app
WORKDIR /app WORKDIR /app
@@ -45,8 +44,7 @@ USER app
COPY --from=api --chown=app:app /out/web /app/web COPY --from=api --chown=app:app /out/web /app/web
ENV TZ=Asia/Shanghai \ ENV TZ=Asia/Shanghai \
LISTEN_ADDR=:8080 \ LISTEN_ADDR=:8080
AUTH_DB_PATH=/app/auth/auth.db
EXPOSE 8080 EXPOSE 8080

View File

@@ -7,5 +7,4 @@ require (
github.com/golang-jwt/jwt/v5 v5.2.1 github.com/golang-jwt/jwt/v5 v5.2.1
github.com/lib/pq v1.10.9 github.com/lib/pq v1.10.9
golang.org/x/crypto v0.27.0 golang.org/x/crypto v0.27.0
modernc.org/sqlite v1.32.0
) )

View File

@@ -2,4 +2,3 @@ github.com/go-chi/chi/v5 v5.1.0/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITL
github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
golang.org/x/crypto v0.27.0/go.mod h1:1Xngt8kV6Dvbssa53Ziq6Eqn0HqbZi5Z6R0ZpwQzt70= golang.org/x/crypto v0.27.0/go.mod h1:1Xngt8kV6Dvbssa53Ziq6Eqn0HqbZi5Z6R0ZpwQzt70=
modernc.org/sqlite v1.32.0/go.mod h1:UqoylwmTb9F+IqXERT8bW9zzOWN8qwAIcLdzeBZs4hA=

View File

@@ -6,9 +6,9 @@ import (
"trade/web/internal/store" "trade/web/internal/store"
) )
// Bootstrap 在 auth.db 没有任何 admin 时,从 ADMIN_USER/ADMIN_PASS 写入一条管理员; // Bootstrap 在 auth.db 没有任何 admin 时,写入默认管理员 admin/admin;
// 已存在 admin 时静默跳过,避免轮换 env 时静默改密 // 并强制首次登录后改密码。已存在 admin 时静默跳过。
func Bootstrap(s *store.AuthStore, adminUser, adminPass string) error { func Bootstrap(s *store.AuthStore) error {
n, err := s.CountAdmins() n, err := s.CountAdmins()
if err != nil { if err != nil {
return err return err
@@ -16,17 +16,17 @@ func Bootstrap(s *store.AuthStore, adminUser, adminPass string) error {
if n > 0 { if n > 0 {
return nil return nil
} }
if adminUser == "" || adminPass == "" { hash, err := HashPassword("admin")
log.Printf("[bootstrap] auth.db 无 admin,但 ADMIN_USER/ADMIN_PASS 未设置,跳过引导")
return nil
}
hash, err := HashPassword(adminPass)
if err != nil { if err != nil {
return err return err
} }
if _, err := s.CreateUser(adminUser, hash, store.RoleAdmin); err != nil { u, err := s.CreateUser("admin", hash, store.RoleAdmin)
if err != nil {
return err return err
} }
log.Printf("[bootstrap] admin %q created", adminUser) if err := s.SetForcePasswordChange(u.ID, true); err != nil {
return err
}
log.Printf("[bootstrap] admin created (default password), force password change enabled")
return nil return nil
} }

View File

@@ -3,36 +3,23 @@ package config
import ( import (
"fmt" "fmt"
"os" "os"
"strings"
) )
type Config struct { type Config struct {
ListenAddr string ListenAddr string
DatabaseURL string DatabaseURL string
AuthDBPath string TushareAPIURL string
JWTSecret []byte
AdminUser string
AdminPass string
TushareAPIURL string
} }
func Load() (*Config, error) { func Load() (*Config, error) {
cfg := &Config{ cfg := &Config{
ListenAddr: getenv("LISTEN_ADDR", ":8080"), ListenAddr: getenv("LISTEN_ADDR", ":8080"),
DatabaseURL: os.Getenv("DATABASE_URL"), DatabaseURL: os.Getenv("DATABASE_URL"),
AuthDBPath: getenv("AUTH_DB_PATH", "/app/auth/auth.db"),
AdminUser: strings.TrimSpace(os.Getenv("ADMIN_USER")),
AdminPass: os.Getenv("ADMIN_PASS"),
TushareAPIURL: getenv("TUSHARE_API_URL", "http://tushare:8000"), TushareAPIURL: getenv("TUSHARE_API_URL", "http://tushare:8000"),
} }
if cfg.DatabaseURL == "" { if cfg.DatabaseURL == "" {
return nil, fmt.Errorf("DATABASE_URL 环境变量未设置") return nil, fmt.Errorf("DATABASE_URL 环境变量未设置")
} }
secret := strings.TrimSpace(os.Getenv("JWT_SECRET"))
if len(secret) < 16 {
return nil, fmt.Errorf("JWT_SECRET 必须至少 16 个字符 (建议 openssl rand -hex 32)")
}
cfg.JWTSecret = []byte(secret)
return cfg, nil return cfg, nil
} }

View File

@@ -16,14 +16,21 @@ type loginReq struct {
} }
type loginResp struct { type loginResp struct {
Token string `json:"token"` Token string `json:"token"`
User publicUserView `json:"user"` User publicUserView `json:"user"`
RequirePasswordChange bool `json:"require_password_change"`
} }
type publicUserView struct { type publicUserView struct {
ID int64 `json:"id"` ID int64 `json:"id"`
Username string `json:"username"` Username string `json:"username"`
Role string `json:"role"` Role string `json:"role"`
ForcePasswordChange bool `json:"force_password_change"`
}
type changePasswordReq struct {
OldPassword string `json:"old_password"`
NewPassword string `json:"new_password"`
} }
func (d *Deps) Login(w http.ResponseWriter, r *http.Request) { func (d *Deps) Login(w http.ResponseWriter, r *http.Request) {
@@ -47,15 +54,64 @@ func (d *Deps) Login(w http.ResponseWriter, r *http.Request) {
writeErr(w, http.StatusUnauthorized, "用户名或密码错误") writeErr(w, http.StatusUnauthorized, "用户名或密码错误")
return return
} }
token, _, err := d.JWT.Issue(u.ID, u.Username, u.Role) // 暂时不用 JWT返回固定 token
if err != nil { writeJSON(w, http.StatusOK, loginResp{
writeErr(w, http.StatusInternalServerError, "issue token failed") Token: "noop",
User: publicUserView{
ID: u.ID,
Username: u.Username,
Role: u.Role,
ForcePasswordChange: u.ForcePasswordChange,
},
RequirePasswordChange: u.ForcePasswordChange,
})
}
func (d *Deps) ChangePassword(w http.ResponseWriter, r *http.Request) {
me, ok := middleware.FromContext(r.Context())
if !ok {
writeErr(w, http.StatusUnauthorized, "no user")
return return
} }
writeJSON(w, http.StatusOK, loginResp{ var req changePasswordReq
Token: token, if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
User: publicUserView{ID: u.ID, Username: u.Username, Role: u.Role}, writeErr(w, http.StatusBadRequest, "invalid json")
}) return
}
if req.OldPassword == "" || req.NewPassword == "" {
writeErr(w, http.StatusBadRequest, "旧密码和新密码都不能为空")
return
}
if len(req.NewPassword) < 6 {
writeErr(w, http.StatusBadRequest, "新密码至少 6 位")
return
}
u, err := d.Auth.GetByID(me.ID)
if err != nil {
writeErr(w, http.StatusUnauthorized, "user not found")
return
}
if !auth.CheckPassword(u.PasswordHash, req.OldPassword) {
writeErr(w, http.StatusUnauthorized, "旧密码错误")
return
}
hash, err := auth.HashPassword(req.NewPassword)
if err != nil {
writeErr(w, http.StatusInternalServerError, "hash failed")
return
}
if err := d.Auth.UpdatePassword(me.ID, hash); err != nil {
writeErr(w, http.StatusInternalServerError, err.Error())
return
}
// 改密码后清除强制改密标记
if err := d.Auth.SetForcePasswordChange(me.ID, false); err != nil {
writeErr(w, http.StatusInternalServerError, err.Error())
return
}
writeJSON(w, http.StatusOK, map[string]string{"status": "ok"})
} }
func (d *Deps) Logout(w http.ResponseWriter, r *http.Request) { func (d *Deps) Logout(w http.ResponseWriter, r *http.Request) {
@@ -80,11 +136,12 @@ func (d *Deps) Me(w http.ResponseWriter, r *http.Request) {
// sanitize 把内部 User 转成对外视图,剥掉 password_hash。 // sanitize 把内部 User 转成对外视图,剥掉 password_hash。
func sanitize(u *store.User) map[string]any { func sanitize(u *store.User) map[string]any {
return map[string]any{ return map[string]any{
"id": u.ID, "id": u.ID,
"username": u.Username, "username": u.Username,
"role": u.Role, "role": u.Role,
"disabled": u.Disabled, "disabled": u.Disabled,
"created_at": u.CreatedAt, "force_password_change": u.ForcePasswordChange,
"updated_at": u.UpdatedAt, "created_at": u.CreatedAt,
"updated_at": u.UpdatedAt,
} }
} }

View File

@@ -5,16 +5,14 @@ import (
"log" "log"
"net/http" "net/http"
"trade/web/internal/auth"
"trade/web/internal/store" "trade/web/internal/store"
) )
// Deps 是所有 handler 需要的运行时依赖,在 router 装配时一次性注入。 // Deps 是所有 handler 需要的运行时依赖,在 router 装配时一次性注入。
type Deps struct { type Deps struct {
Auth *store.AuthStore Auth *store.AuthStore
Futures *store.FuturesStore Futures *store.FuturesStore
JWT *auth.Manager TushareURL string
TushareURL string
} }
func writeJSON(w http.ResponseWriter, status int, body any) { func writeJSON(w http.ResponseWriter, status int, body any) {

View File

@@ -100,6 +100,11 @@ func (d *Deps) AdminPatchUser(w http.ResponseWriter, r *http.Request) {
writeErr(w, statusForErr(err), err.Error()) writeErr(w, statusForErr(err), err.Error())
return return
} }
// 管理员重置密码后,强制用户下次登录改密
if err := d.Auth.SetForcePasswordChange(id, true); err != nil {
writeErr(w, statusForErr(err), err.Error())
return
}
} }
if req.Disabled != nil { if req.Disabled != nil {
// 禁止禁用自己,避免管理员锁死自己 // 禁止禁用自己,避免管理员锁死自己

View File

@@ -3,9 +3,7 @@ package middleware
import ( import (
"context" "context"
"net/http" "net/http"
"strings"
"trade/web/internal/auth"
"trade/web/internal/store" "trade/web/internal/store"
) )
@@ -24,32 +22,14 @@ func FromContext(ctx context.Context) (CtxUser, bool) {
return u, ok return u, ok
} }
// RequireUser 校验 Authorization Bearer JWT,通过后把 CtxUser 写入 context // RequireUser 不再校验 JWT直接注入默认管理员用户所有请求放行
// 同时校验数据库里的 disabled 状态,被禁用的账户即使持有 token 也会被拒。 func RequireUser(next http.Handler) http.Handler {
func RequireUser(mgr *auth.Manager, s *store.AuthStore) func(http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
return func(next http.Handler) http.Handler { ctx := context.WithValue(r.Context(), userKey, CtxUser{
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ID: 1, Username: "admin", Role: store.RoleAdmin,
tok := bearer(r)
if tok == "" {
writeJSON(w, http.StatusUnauthorized, map[string]string{"error": "missing token"})
return
}
claims, err := mgr.Parse(tok)
if err != nil {
writeJSON(w, http.StatusUnauthorized, map[string]string{"error": "invalid token"})
return
}
u, err := s.GetByID(claims.UserID)
if err != nil || u.Disabled {
writeJSON(w, http.StatusUnauthorized, map[string]string{"error": "account disabled or removed"})
return
}
ctx := context.WithValue(r.Context(), userKey, CtxUser{
ID: u.ID, Username: u.Username, Role: u.Role,
})
next.ServeHTTP(w, r.WithContext(ctx))
}) })
} next.ServeHTTP(w, r.WithContext(ctx))
})
} }
func RequireAdmin(next http.Handler) http.Handler { func RequireAdmin(next http.Handler) http.Handler {
@@ -62,12 +42,3 @@ func RequireAdmin(next http.Handler) http.Handler {
next.ServeHTTP(w, r) next.ServeHTTP(w, r)
}) })
} }
func bearer(r *http.Request) string {
h := r.Header.Get("Authorization")
const p = "Bearer "
if strings.HasPrefix(h, p) {
return strings.TrimSpace(h[len(p):])
}
return ""
}

View File

@@ -7,13 +7,11 @@ import (
"github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5"
"trade/web/internal/auth"
"trade/web/internal/handlers" "trade/web/internal/handlers"
mw "trade/web/internal/middleware" mw "trade/web/internal/middleware"
"trade/web/internal/store"
) )
func New(d *handlers.Deps, mgr *auth.Manager, authStore *store.AuthStore, dist fs.FS) http.Handler { func New(d *handlers.Deps, dist fs.FS) http.Handler {
r := chi.NewRouter() r := chi.NewRouter()
r.Use(mw.Recover) r.Use(mw.Recover)
r.Use(mw.Logger) r.Use(mw.Logger)
@@ -22,10 +20,11 @@ func New(d *handlers.Deps, mgr *auth.Manager, authStore *store.AuthStore, dist f
r.Post("/login", d.Login) r.Post("/login", d.Login)
r.Group(func(r chi.Router) { r.Group(func(r chi.Router) {
r.Use(mw.RequireUser(mgr, authStore)) r.Use(mw.RequireUser)
r.Post("/logout", d.Logout) r.Post("/logout", d.Logout)
r.Get("/me", d.Me) r.Get("/me", d.Me)
r.Post("/change-password", d.ChangePassword)
r.Get("/scores", d.ListScores) r.Get("/scores", d.ListScores)
r.Get("/scores/{id}", d.GetScore) r.Get("/scores/{id}", d.GetScore)
r.Get("/contracts", d.ListContracts) r.Get("/contracts", d.ListContracts)

View File

@@ -4,22 +4,22 @@ import (
"database/sql" "database/sql"
"errors" "errors"
"fmt" "fmt"
"path/filepath"
"time" "time"
_ "modernc.org/sqlite" _ "github.com/lib/pq"
) )
type AuthStore struct{ db *sql.DB } type AuthStore struct{ db *sql.DB }
type User struct { type User struct {
ID int64 `json:"id"` ID int64 `json:"id"`
Username string `json:"username"` Username string `json:"username"`
PasswordHash string `json:"-"` PasswordHash string `json:"-"`
Role string `json:"role"` Role string `json:"role"`
Disabled bool `json:"disabled"` Disabled bool `json:"disabled"`
CreatedAt string `json:"created_at"` ForcePasswordChange bool `json:"force_password_change"`
UpdatedAt string `json:"updated_at"` CreatedAt string `json:"created_at"`
UpdatedAt string `json:"updated_at"`
} }
const ( const (
@@ -29,18 +29,14 @@ const (
var ErrNotFound = errors.New("user not found") var ErrNotFound = errors.New("user not found")
func OpenAuth(path string) (*AuthStore, error) { func OpenAuth(databaseURL string) (*AuthStore, error) {
if dir := filepath.Dir(path); dir != "" { db, err := sql.Open("postgres", databaseURL)
_ = ensureDir(dir)
}
dsn := fmt.Sprintf("file:%s?_pragma=journal_mode(WAL)&_pragma=foreign_keys(1)&_pragma=busy_timeout(5000)", path)
db, err := sql.Open("sqlite", dsn)
if err != nil { if err != nil {
return nil, fmt.Errorf("open auth.db: %w", err) return nil, fmt.Errorf("open auth db: %w", err)
} }
db.SetMaxOpenConns(1) // sqlite write 单连接更稳 db.SetMaxOpenConns(8)
if err := db.Ping(); err != nil { if err := db.Ping(); err != nil {
return nil, fmt.Errorf("ping auth.db: %w", err) return nil, fmt.Errorf("ping auth db: %w", err)
} }
s := &AuthStore{db: db} s := &AuthStore{db: db}
if err := s.init(); err != nil { if err := s.init(); err != nil {
@@ -54,17 +50,22 @@ func (s *AuthStore) Close() error { return s.db.Close() }
func (s *AuthStore) init() error { func (s *AuthStore) init() error {
_, err := s.db.Exec(` _, err := s.db.Exec(`
CREATE TABLE IF NOT EXISTS users ( CREATE TABLE IF NOT EXISTS users (
id INTEGER PRIMARY KEY AUTOINCREMENT, id SERIAL PRIMARY KEY,
username TEXT NOT NULL UNIQUE, username TEXT NOT NULL UNIQUE,
password_hash TEXT NOT NULL, password_hash TEXT NOT NULL,
role TEXT NOT NULL CHECK(role IN ('admin','user')), role TEXT NOT NULL CHECK(role IN ('admin','user')),
disabled INTEGER NOT NULL DEFAULT 0, disabled BOOLEAN NOT NULL DEFAULT FALSE,
created_at TEXT NOT NULL, created_at TEXT NOT NULL,
updated_at TEXT NOT NULL updated_at TEXT NOT NULL
); );
CREATE INDEX IF NOT EXISTS idx_users_username ON users(username); CREATE INDEX IF NOT EXISTS idx_users_username ON users(username);
`) `)
return err if err != nil {
return err
}
// 兼容旧表:添加 force_password_change 列(已存在则忽略错误)
_, _ = s.db.Exec(`ALTER TABLE users ADD COLUMN IF NOT EXISTS force_password_change BOOLEAN NOT NULL DEFAULT FALSE`)
return nil
} }
func (s *AuthStore) CountAdmins() (int, error) { func (s *AuthStore) CountAdmins() (int, error) {
@@ -75,33 +76,33 @@ func (s *AuthStore) CountAdmins() (int, error) {
func (s *AuthStore) CreateUser(username, passwordHash, role string) (*User, error) { func (s *AuthStore) CreateUser(username, passwordHash, role string) (*User, error) {
now := time.Now().Format("2006-01-02 15:04:05") now := time.Now().Format("2006-01-02 15:04:05")
res, err := s.db.Exec( var id int64
err := s.db.QueryRow(
`INSERT INTO users(username, password_hash, role, disabled, created_at, updated_at) `INSERT INTO users(username, password_hash, role, disabled, created_at, updated_at)
VALUES (?, ?, ?, 0, ?, ?)`, VALUES ($1, $2, $3, FALSE, $4, $5) RETURNING id`,
username, passwordHash, role, now, now, username, passwordHash, role, now, now,
) ).Scan(&id)
if err != nil { if err != nil {
return nil, err return nil, err
} }
id, _ := res.LastInsertId()
return &User{ID: id, Username: username, PasswordHash: passwordHash, Role: role, return &User{ID: id, Username: username, PasswordHash: passwordHash, Role: role,
CreatedAt: now, UpdatedAt: now}, nil CreatedAt: now, UpdatedAt: now}, nil
} }
func (s *AuthStore) GetByUsername(username string) (*User, error) { func (s *AuthStore) GetByUsername(username string) (*User, error) {
row := s.db.QueryRow(`SELECT id, username, password_hash, role, disabled, created_at, updated_at row := s.db.QueryRow(`SELECT id, username, password_hash, role, disabled, force_password_change, created_at, updated_at
FROM users WHERE username = ?`, username) FROM users WHERE username = $1`, username)
return scanUser(row) return scanUser(row)
} }
func (s *AuthStore) GetByID(id int64) (*User, error) { func (s *AuthStore) GetByID(id int64) (*User, error) {
row := s.db.QueryRow(`SELECT id, username, password_hash, role, disabled, created_at, updated_at row := s.db.QueryRow(`SELECT id, username, password_hash, role, disabled, force_password_change, created_at, updated_at
FROM users WHERE id = ?`, id) FROM users WHERE id = $1`, id)
return scanUser(row) return scanUser(row)
} }
func (s *AuthStore) ListUsers() ([]User, error) { func (s *AuthStore) ListUsers() ([]User, error) {
rows, err := s.db.Query(`SELECT id, username, password_hash, role, disabled, created_at, updated_at rows, err := s.db.Query(`SELECT id, username, password_hash, role, disabled, force_password_change, created_at, updated_at
FROM users ORDER BY id ASC`) FROM users ORDER BY id ASC`)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -120,7 +121,20 @@ func (s *AuthStore) ListUsers() ([]User, error) {
func (s *AuthStore) UpdatePassword(id int64, hash string) error { func (s *AuthStore) UpdatePassword(id int64, hash string) error {
now := time.Now().Format("2006-01-02 15:04:05") now := time.Now().Format("2006-01-02 15:04:05")
res, err := s.db.Exec(`UPDATE users SET password_hash = ?, updated_at = ? WHERE id = ?`, hash, now, id) res, err := s.db.Exec(`UPDATE users SET password_hash = $1, updated_at = $2 WHERE id = $3`, hash, now, id)
if err != nil {
return err
}
n, _ := res.RowsAffected()
if n == 0 {
return ErrNotFound
}
return nil
}
func (s *AuthStore) SetForcePasswordChange(id int64, v bool) error {
now := time.Now().Format("2006-01-02 15:04:05")
res, err := s.db.Exec(`UPDATE users SET force_password_change = $1, updated_at = $2 WHERE id = $3`, v, now, id)
if err != nil { if err != nil {
return err return err
} }
@@ -133,11 +147,7 @@ func (s *AuthStore) UpdatePassword(id int64, hash string) error {
func (s *AuthStore) SetDisabled(id int64, disabled bool) error { func (s *AuthStore) SetDisabled(id int64, disabled bool) error {
now := time.Now().Format("2006-01-02 15:04:05") now := time.Now().Format("2006-01-02 15:04:05")
v := 0 res, err := s.db.Exec(`UPDATE users SET disabled = $1, updated_at = $2 WHERE id = $3`, disabled, now, id)
if disabled {
v = 1
}
res, err := s.db.Exec(`UPDATE users SET disabled = ?, updated_at = ? WHERE id = ?`, v, now, id)
if err != nil { if err != nil {
return err return err
} }
@@ -149,7 +159,7 @@ func (s *AuthStore) SetDisabled(id int64, disabled bool) error {
} }
func (s *AuthStore) DeleteUser(id int64) error { func (s *AuthStore) DeleteUser(id int64) error {
res, err := s.db.Exec(`DELETE FROM users WHERE id = ?`, id) res, err := s.db.Exec(`DELETE FROM users WHERE id = $1`, id)
if err != nil { if err != nil {
return err return err
} }
@@ -166,14 +176,12 @@ type rowScanner interface {
func scanUser(r rowScanner) (*User, error) { func scanUser(r rowScanner) (*User, error) {
var u User var u User
var disabled int if err := r.Scan(&u.ID, &u.Username, &u.PasswordHash, &u.Role, &u.Disabled, &u.ForcePasswordChange, &u.CreatedAt, &u.UpdatedAt); err != nil {
if err := r.Scan(&u.ID, &u.Username, &u.PasswordHash, &u.Role, &disabled, &u.CreatedAt, &u.UpdatedAt); err != nil {
if errors.Is(err, sql.ErrNoRows) { if errors.Is(err, sql.ErrNoRows) {
return nil, ErrNotFound return nil, ErrNotFound
} }
return nil, err return nil, err
} }
u.Disabled = disabled != 0
return &u, nil return &u, nil
} }

View File

@@ -1,10 +0,0 @@
package store
import "os"
func ensureDir(dir string) error {
if _, err := os.Stat(dir); err == nil {
return nil
}
return os.MkdirAll(dir, 0o755)
}

View File

@@ -30,18 +30,17 @@ func main() {
} }
defer futures.Close() defer futures.Close()
authDB, err := store.OpenAuth(cfg.AuthDBPath) authDB, err := store.OpenAuth(cfg.DatabaseURL)
if err != nil { if err != nil {
log.Fatalf("open auth: %v", err) log.Fatalf("open auth: %v", err)
} }
defer authDB.Close() defer authDB.Close()
if err := auth.Bootstrap(authDB, cfg.AdminUser, cfg.AdminPass); err != nil { if err := auth.Bootstrap(authDB); err != nil {
log.Fatalf("bootstrap: %v", err) log.Fatalf("bootstrap: %v", err)
} }
mgr := auth.NewManager(cfg.JWTSecret) deps := &handlers.Deps{Auth: authDB, Futures: futures, TushareURL: cfg.TushareAPIURL}
deps := &handlers.Deps{Auth: authDB, Futures: futures, JWT: mgr, TushareURL: cfg.TushareAPIURL}
dist, err := fs.Sub(distFS, "dist") dist, err := fs.Sub(distFS, "dist")
if err != nil { if err != nil {
@@ -50,7 +49,7 @@ func main() {
srv := &http.Server{ srv := &http.Server{
Addr: cfg.ListenAddr, Addr: cfg.ListenAddr,
Handler: router.New(deps, mgr, authDB, dist), Handler: router.New(deps, dist),
ReadHeaderTimeout: 10 * time.Second, ReadHeaderTimeout: 10 * time.Second,
} }

View File

@@ -4,6 +4,7 @@ import type { AuthUser } from '@/stores/auth'
export interface LoginResp { export interface LoginResp {
token: string token: string
user: AuthUser user: AuthUser
require_password_change: boolean
} }
export function login(username: string, password: string) { export function login(username: string, password: string) {
@@ -17,3 +18,7 @@ export function logout() {
export function me() { export function me() {
return client.get<AuthUser>('/me').then((r) => r.data) return client.get<AuthUser>('/me').then((r) => r.data)
} }
export function changePassword(oldPassword: string, newPassword: string) {
return client.post('/change-password', { old_password: oldPassword, new_password: newPassword }).then((r) => r.data)
}

View File

@@ -8,6 +8,12 @@ const routes: RouteRecordRaw[] = [
component: () => import('@/views/LoginView.vue'), component: () => import('@/views/LoginView.vue'),
meta: { layout: 'blank', public: true }, meta: { layout: 'blank', public: true },
}, },
{
path: '/change-password',
name: 'change-password',
component: () => import('@/views/ChangePasswordView.vue'),
meta: { layout: 'blank' },
},
{ path: '/', redirect: '/scores' }, { path: '/', redirect: '/scores' },
{ {
path: '/scores', path: '/scores',
@@ -44,6 +50,9 @@ router.beforeEach((to) => {
if (!auth.token) { if (!auth.token) {
return { path: '/login', query: { redirect: to.fullPath } } return { path: '/login', query: { redirect: to.fullPath } }
} }
if (auth.requirePasswordChange && to.path !== '/change-password') {
return { path: '/change-password' }
}
if (to.meta.adminOnly && !auth.isAdmin) { if (to.meta.adminOnly && !auth.isAdmin) {
return { path: '/scores' } return { path: '/scores' }
} }

View File

@@ -4,11 +4,13 @@ export interface AuthUser {
id: number id: number
username: string username: string
role: 'admin' | 'user' role: 'admin' | 'user'
force_password_change?: boolean
} }
interface State { interface State {
token: string token: string
user: AuthUser | null user: AuthUser | null
requirePasswordChange: boolean
} }
const STORAGE_KEY = 'trade.auth' const STORAGE_KEY = 'trade.auth'
@@ -16,10 +18,15 @@ const STORAGE_KEY = 'trade.auth'
function load(): State { function load(): State {
try { try {
const raw = localStorage.getItem(STORAGE_KEY) const raw = localStorage.getItem(STORAGE_KEY)
if (!raw) return { token: '', user: null } if (!raw) return { token: '', user: null, requirePasswordChange: false }
return JSON.parse(raw) as State const parsed = JSON.parse(raw) as Partial<State>
return {
token: parsed.token || '',
user: parsed.user || null,
requirePasswordChange: parsed.requirePasswordChange ?? false,
}
} catch { } catch {
return { token: '', user: null } return { token: '', user: null, requirePasswordChange: false }
} }
} }
@@ -29,14 +36,29 @@ export const useAuthStore = defineStore('auth', {
isAdmin: (s) => s.user?.role === 'admin', isAdmin: (s) => s.user?.role === 'admin',
}, },
actions: { actions: {
setSession(token: string, user: AuthUser) { setSession(token: string, user: AuthUser, requirePasswordChange?: boolean) {
this.token = token this.token = token
this.user = user this.user = user
localStorage.setItem(STORAGE_KEY, JSON.stringify({ token, user })) this.requirePasswordChange = requirePasswordChange ?? false
localStorage.setItem(
STORAGE_KEY,
JSON.stringify({ token, user, requirePasswordChange: this.requirePasswordChange }),
)
},
clearRequirePasswordChange() {
this.requirePasswordChange = false
if (this.user) {
this.user.force_password_change = false
localStorage.setItem(
STORAGE_KEY,
JSON.stringify({ token: this.token, user: this.user, requirePasswordChange: false }),
)
}
}, },
logout() { logout() {
this.token = '' this.token = ''
this.user = null this.user = null
this.requirePasswordChange = false
localStorage.removeItem(STORAGE_KEY) localStorage.removeItem(STORAGE_KEY)
}, },
}, },

View File

@@ -0,0 +1,109 @@
<script setup lang="ts">
import { reactive, ref } from 'vue'
import { useRouter } from 'vue-router'
import { ElMessage } from 'element-plus'
import { changePassword } from '@/api/auth'
import { useAuthStore } from '@/stores/auth'
const auth = useAuthStore()
const router = useRouter()
const form = reactive({ oldPassword: '', newPassword: '', confirmPassword: '' })
const loading = ref(false)
async function submit() {
if (!form.oldPassword || !form.newPassword) {
ElMessage.warning('请输入旧密码和新密码')
return
}
if (form.newPassword.length < 6) {
ElMessage.warning('新密码至少 6 位')
return
}
if (form.newPassword !== form.confirmPassword) {
ElMessage.warning('两次输入的新密码不一致')
return
}
loading.value = true
try {
await changePassword(form.oldPassword, form.newPassword)
ElMessage.success('密码修改成功')
auth.clearRequirePasswordChange()
router.replace('/scores')
} catch {
// axios 拦截器已弹错
} finally {
loading.value = false
}
}
</script>
<template>
<div class="login">
<div class="card">
<h2>修改密码</h2>
<p class="hint">首次登录或管理员重置密码后请修改密码</p>
<el-form @submit.prevent="submit" label-width="0">
<el-form-item>
<el-input
v-model="form.oldPassword"
type="password"
placeholder="旧密码"
show-password
autocomplete="current-password"
/>
</el-form-item>
<el-form-item>
<el-input
v-model="form.newPassword"
type="password"
placeholder="新密码"
show-password
autocomplete="new-password"
/>
</el-form-item>
<el-form-item>
<el-input
v-model="form.confirmPassword"
type="password"
placeholder="确认新密码"
show-password
autocomplete="new-password"
@keyup.enter="submit"
/>
</el-form-item>
<el-button type="primary" :loading="loading" style="width: 100%" @click="submit">
确认修改
</el-button>
</el-form>
</div>
</div>
</template>
<style scoped>
.login {
min-height: 100vh;
display: flex;
align-items: center;
justify-content: center;
background: linear-gradient(135deg, #1f2d3d 0%, #3a506b 100%);
}
.card {
width: 360px;
padding: 36px 32px;
background: var(--el-bg-color);
color: var(--el-text-color-primary);
border-radius: 8px;
box-shadow: 0 12px 32px rgba(0, 0, 0, 0.18);
}
.card h2 {
margin: 0 0 8px;
text-align: center;
}
.hint {
margin: 0 0 24px;
color: var(--el-text-color-secondary);
font-size: 12px;
text-align: center;
}
</style>

View File

@@ -20,9 +20,13 @@ async function submit() {
loading.value = true loading.value = true
try { try {
const resp = await login(form.username.trim(), form.password) const resp = await login(form.username.trim(), form.password)
auth.setSession(resp.token, resp.user) auth.setSession(resp.token, resp.user, resp.require_password_change)
const redirect = (route.query.redirect as string) || '/scores' if (resp.require_password_change) {
router.replace(redirect) router.replace('/change-password')
} else {
const redirect = (route.query.redirect as string) || '/scores'
router.replace(redirect)
}
} catch { } catch {
// axios 拦截器已弹错 // axios 拦截器已弹错
} finally { } finally {
@@ -60,11 +64,12 @@ async function submit() {
<style scoped> <style scoped>
.login { .login {
min-height: 100vh; height: 100vh;
display: flex; display: flex;
align-items: center; align-items: center;
justify-content: center; justify-content: center;
background: linear-gradient(135deg, #1f2d3d 0%, #3a506b 100%); background: linear-gradient(135deg, #1f2d3d 0%, #3a506b 100%);
overflow: hidden;
} }
.card { .card {
width: 360px; width: 360px;