迁移 PostgreSQL 并新增 Python API 服务

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
fish
2026-05-03 14:58:01 +08:00
parent 750584e619
commit 220f4acc45
10 changed files with 333 additions and 123 deletions

View File

@@ -1,12 +1,29 @@
services: services:
postgres:
image: postgres:18.3-alpine3.23
environment:
POSTGRES_USER: trade
POSTGRES_PASSWORD: trade
POSTGRES_DB: futures
volumes:
- pgdata:/var/lib/postgresql/data
healthcheck:
test: ["CMD-SHELL", "pg_isready -U trade -d futures"]
interval: 5s
timeout: 5s
retries: 5
tushare: tushare:
build: ./tushare build: ./tushare
env_file: ./tushare/.env env_file: ./tushare/.env
environment: environment:
- DB_PATH=/app/data/futures.db - DATABASE_URL=postgresql://trade:trade@postgres:5432/futures
volumes: depends_on:
- ./data:/app/data postgres:
command: ["python", "-m", "src.main"] condition: service_healthy
ports:
- "8000:8000"
command: ["uvicorn", "src.api:app", "--host", "0.0.0.0", "--port", "8000"]
web: web:
build: build:
@@ -15,14 +32,15 @@ services:
env_file: ./web/backend/.env env_file: ./web/backend/.env
environment: environment:
- LISTEN_ADDR=:8080 - LISTEN_ADDR=:8080
- FUTURES_DB_PATH=/app/data/futures.db - DATABASE_URL=postgres://trade:trade@postgres:5432/futures?sslmode=disable
- AUTH_DB_PATH=/app/auth/auth.db - AUTH_DB_PATH=/app/auth/auth.db
volumes: depends_on:
# futures.db 由 tushare 写入,web 端通过 DSN mode=ro&query_only 只读访问; - postgres
# 不在容器层加 :ro,因为 WAL 模式下读访问也需要写 -shm 同步文件
- ./data:/app/data
# auth.db 由 web 自己写,落在 ./data/auth.db (已被 .gitignore)
- ./data:/app/auth
ports: ports:
- "8080:8080" - "8080:8080"
volumes:
- ./data:/app/auth
restart: unless-stopped restart: unless-stopped
volumes:
pgdata:

View File

@@ -5,12 +5,12 @@ ENV PYTHONDONTWRITEBYTECODE=1 \
PIP_NO_CACHE_DIR=1 \ PIP_NO_CACHE_DIR=1 \
PIP_DISABLE_PIP_VERSION_CHECK=1 \ PIP_DISABLE_PIP_VERSION_CHECK=1 \
TZ=Asia/Shanghai \ TZ=Asia/Shanghai \
DB_PATH=/app/data/futures.db DATABASE_URL=postgresql://trade:trade@postgres:5432/futures
WORKDIR /app WORKDIR /app
# 运行时依赖 + 时区 # 运行时依赖 + 时区 + libpq(psycopg2)
RUN apk add --no-cache tzdata \ RUN apk add --no-cache tzdata libpq \
&& cp /usr/share/zoneinfo/Asia/Shanghai /etc/localtime \ && cp /usr/share/zoneinfo/Asia/Shanghai /etc/localtime \
&& echo "Asia/Shanghai" > /etc/timezone && echo "Asia/Shanghai" > /etc/timezone
@@ -25,4 +25,4 @@ RUN adduser -D -u 1000 app \
COPY --chown=app:app src ./src COPY --chown=app:app src ./src
USER app USER app
CMD ["python", "-m", "src.main"] CMD ["uvicorn", "src.api:app", "--host", "0.0.0.0", "--port", "8000"]

View File

@@ -1,3 +1,6 @@
tushare>=1.4.0 tushare>=1.4.0
pandas>=2.2.0 pandas>=2.2.0
requests>=2.31.0 requests>=2.31.0
fastapi>=0.115.0
uvicorn[standard]>=0.34.0
psycopg2-binary>=2.9.10

162
tushare/src/api.py Normal file
View File

@@ -0,0 +1,162 @@
from typing import Optional
from fastapi import FastAPI, HTTPException, Query
from pydantic import BaseModel
from . import contracts, fetcher, notifier, scorer, storage
app = FastAPI(title="期货数据采集与打分服务")
class RunRequest(BaseModel):
ts_code: Optional[str] = None
symbol: str = "FG"
class RunResponse(BaseModel):
ts_code: str
trade_date: str
close: float
oi: float
oi_chg: float
short_term: float
medium_term: float
long_term: float
composite: float
signal: str
@app.on_event("startup")
def startup():
storage.init_db()
@app.get("/health")
def health():
return {"status": "ok"}
@app.post("/api/v1/run", response_model=RunResponse)
def run_pipeline(req: RunRequest):
ts_code = req.ts_code or contracts.active_contract(req.symbol)
if not req.ts_code:
print(f"[AUTO] {req.symbol} 当月主力 -> {ts_code}")
df = fetcher.fetch_contract(ts_code)
storage.save_candles(df)
result = scorer.score_daily(df)
storage.save_score(result)
push_title = f"{result.ts_code.split('.')[0]} {result.trade_date}"
push_body = (
f"综合 {result.composite:.1f}\n"
f"短期 {result.short_term:.1f} | 中期 {result.medium_term:.1f} | 长期 {result.long_term:.1f}\n"
f"{result.signal}"
)
notifier.push_bark(push_title, push_body)
return RunResponse(
ts_code=result.ts_code,
trade_date=result.trade_date,
close=result.close,
oi=result.oi,
oi_chg=result.oi_chg,
short_term=result.short_term,
medium_term=result.medium_term,
long_term=result.long_term,
composite=result.composite,
signal=result.signal,
)
@app.get("/api/v1/scores")
def list_scores(
ts_code: Optional[str] = Query(None),
start: Optional[str] = Query(None),
end: Optional[str] = Query(None),
limit: int = Query(200, ge=1, le=500),
):
conn = storage._get_conn()
try:
with conn.cursor() as cur:
q = """SELECT id, ts_code, trade_date, close, oi, oi_chg, short_term, medium_term,
long_term, composite, signal, created_at FROM scores WHERE 1=1"""
args = []
if ts_code:
q += " AND ts_code = %s"
args.append(ts_code)
if start:
q += " AND trade_date >= %s"
args.append(start)
if end:
q += " AND trade_date <= %s"
args.append(end)
q += " ORDER BY trade_date DESC, id DESC LIMIT %s"
args.append(limit)
cur.execute(q, args)
cols = [d[0] for d in cur.description]
rows = [dict(zip(cols, row)) for row in cur.fetchall()]
return rows
finally:
conn.close()
@app.get("/api/v1/scores/{score_id}")
def get_score(score_id: int):
conn = storage._get_conn()
try:
with conn.cursor() as cur:
cur.execute(
"""SELECT id, ts_code, trade_date, close, oi, oi_chg, short_term, medium_term,
long_term, composite, signal, detail_json, created_at
FROM scores WHERE id = %s""",
(score_id,),
)
row = cur.fetchone()
if not row:
raise HTTPException(status_code=404, detail="not found")
cols = [d[0] for d in cur.description]
return dict(zip(cols, row))
finally:
conn.close()
@app.get("/api/v1/contracts")
def list_contracts():
conn = storage._get_conn()
try:
with conn.cursor() as cur:
cur.execute("SELECT DISTINCT ts_code FROM scores ORDER BY ts_code ASC")
return [r[0] for r in cur.fetchall()]
finally:
conn.close()
@app.get("/api/v1/candles")
def list_candles(
ts_code: str = Query(...),
start: Optional[str] = Query(None),
end: Optional[str] = Query(None),
):
conn = storage._get_conn()
try:
with conn.cursor() as cur:
q = """SELECT ts_code, trade_date,
COALESCE(open, 0), COALESCE(high, 0), COALESCE(low, 0), COALESCE(close, 0),
COALESCE(vol, 0), COALESCE(amount, 0),
COALESCE(oi, 0), COALESCE(oi_chg, 0), COALESCE(pre_close, 0)
FROM candles WHERE ts_code = %s"""
args = [ts_code]
if start:
q += " AND trade_date >= %s"
args.append(start)
if end:
q += " AND trade_date <= %s"
args.append(end)
q += " ORDER BY trade_date ASC LIMIT 1000"
cur.execute(q, args)
cols = ["ts_code", "trade_date", "open", "high", "low", "close",
"vol", "amount", "oi", "oi_chg", "pre_close"]
return [dict(zip(cols, row)) for row in cur.fetchall()]
finally:
conn.close()

View File

@@ -1,28 +1,26 @@
import json import json
import os import os
import sqlite3
from typing import Optional from typing import Optional
import pandas as pd import pandas as pd
import psycopg2
from psycopg2.extras import RealDictCursor
from .models import ScoreResult from .models import ScoreResult
DEFAULT_DB_PATH = os.environ.get("DB_PATH", "/app/data/futures.db") DEFAULT_DB_URL = os.environ.get("DATABASE_URL", "postgresql://trade:trade@postgres:5432/futures")
def _get_conn(db_path: str = DEFAULT_DB_PATH) -> sqlite3.Connection: def _get_conn(db_url: str = DEFAULT_DB_URL):
conn = sqlite3.connect(db_path) return psycopg2.connect(db_url)
conn.row_factory = sqlite3.Row
return conn
def init_db(db_path: str = DEFAULT_DB_PATH): def init_db(db_url: str = DEFAULT_DB_URL):
"""初始化数据库,创建 candles 和 scores 表。""" """初始化数据库,创建 candles 和 scores 表。"""
os.makedirs(os.path.dirname(db_path), exist_ok=True) conn = _get_conn(db_url)
conn = _get_conn(db_path)
try: try:
conn.execute("PRAGMA journal_mode=WAL") with conn.cursor() as cur:
conn.execute(""" cur.execute("""
CREATE TABLE IF NOT EXISTS candles ( CREATE TABLE IF NOT EXISTS candles (
ts_code TEXT NOT NULL, ts_code TEXT NOT NULL,
trade_date TEXT NOT NULL, trade_date TEXT NOT NULL,
@@ -38,9 +36,9 @@ def init_db(db_path: str = DEFAULT_DB_PATH):
PRIMARY KEY (ts_code, trade_date) PRIMARY KEY (ts_code, trade_date)
) )
""") """)
conn.execute(""" cur.execute("""
CREATE TABLE IF NOT EXISTS scores ( CREATE TABLE IF NOT EXISTS scores (
id INTEGER PRIMARY KEY AUTOINCREMENT, id INTEGER GENERATED ALWAYS AS IDENTITY PRIMARY KEY,
ts_code TEXT NOT NULL, ts_code TEXT NOT NULL,
trade_date TEXT NOT NULL, trade_date TEXT NOT NULL,
close REAL, close REAL,
@@ -52,7 +50,7 @@ def init_db(db_path: str = DEFAULT_DB_PATH):
composite REAL, composite REAL,
signal TEXT, signal TEXT,
detail_json TEXT, detail_json TEXT,
created_at TEXT DEFAULT (datetime('now', 'localtime')), created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
UNIQUE (ts_code, trade_date) UNIQUE (ts_code, trade_date)
) )
""") """)
@@ -61,21 +59,32 @@ def init_db(db_path: str = DEFAULT_DB_PATH):
conn.close() conn.close()
def save_candles(df: pd.DataFrame, db_path: str = DEFAULT_DB_PATH): def save_candles(df: pd.DataFrame, db_url: str = DEFAULT_DB_URL):
"""批量写入/更新日线数据。""" """批量写入/更新日线数据。"""
if df.empty: if df.empty:
return return
conn = _get_conn(db_path) conn = _get_conn(db_url)
try: try:
df = df.copy() df = df.copy()
df = df.where(pd.notna(df), None) df = df.where(pd.notna(df), None)
records = df.to_dict(orient="records") records = df.to_dict(orient="records")
conn.executemany( with conn.cursor() as cur:
cur.executemany(
""" """
INSERT OR REPLACE INTO candles INSERT INTO candles
(ts_code, trade_date, open, high, low, close, vol, amount, oi, oi_chg, pre_close) (ts_code, trade_date, open, high, low, close, vol, amount, oi, oi_chg, pre_close)
VALUES (:ts_code, :trade_date, :open, :high, :low, :close, VALUES (%(ts_code)s, %(trade_date)s, %(open)s, %(high)s, %(low)s, %(close)s,
:vol, :amount, :oi, :oi_chg, :pre_close) %(vol)s, %(amount)s, %(oi)s, %(oi_chg)s, %(pre_close)s)
ON CONFLICT (ts_code, trade_date) DO UPDATE SET
open = EXCLUDED.open,
high = EXCLUDED.high,
low = EXCLUDED.low,
close = EXCLUDED.close,
vol = EXCLUDED.vol,
amount = EXCLUDED.amount,
oi = EXCLUDED.oi,
oi_chg = EXCLUDED.oi_chg,
pre_close = EXCLUDED.pre_close
""", """,
records, records,
) )
@@ -84,16 +93,28 @@ def save_candles(df: pd.DataFrame, db_path: str = DEFAULT_DB_PATH):
conn.close() conn.close()
def save_score(score: ScoreResult, db_path: str = DEFAULT_DB_PATH): def save_score(score: ScoreResult, db_url: str = DEFAULT_DB_URL):
"""写入打分结果。""" """写入打分结果。"""
conn = _get_conn(db_path) conn = _get_conn(db_url)
try: try:
conn.execute( with conn.cursor() as cur:
cur.execute(
""" """
INSERT OR REPLACE INTO scores INSERT INTO scores
(ts_code, trade_date, close, oi, oi_chg, (ts_code, trade_date, close, oi, oi_chg,
short_term, medium_term, long_term, composite, signal, detail_json) short_term, medium_term, long_term, composite, signal, detail_json)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
ON CONFLICT (ts_code, trade_date) DO UPDATE SET
close = EXCLUDED.close,
oi = EXCLUDED.oi,
oi_chg = EXCLUDED.oi_chg,
short_term = EXCLUDED.short_term,
medium_term = EXCLUDED.medium_term,
long_term = EXCLUDED.long_term,
composite = EXCLUDED.composite,
signal = EXCLUDED.signal,
detail_json = EXCLUDED.detail_json,
created_at = CURRENT_TIMESTAMP
""", """,
( (
score.ts_code, score.ts_code,
@@ -118,14 +139,16 @@ def save_score(score: ScoreResult, db_path: str = DEFAULT_DB_PATH):
conn.close() conn.close()
def get_latest_score(ts_code: str, db_path: str = DEFAULT_DB_PATH) -> Optional[dict]: def get_latest_score(ts_code: str, db_url: str = DEFAULT_DB_URL) -> Optional[dict]:
"""查询最新打分记录。""" """查询最新打分记录。"""
conn = _get_conn(db_path) conn = _get_conn(db_url)
try: try:
row = conn.execute( with conn.cursor(cursor_factory=RealDictCursor) as cur:
"SELECT * FROM scores WHERE ts_code = ? ORDER BY trade_date DESC LIMIT 1", cur.execute(
"SELECT * FROM scores WHERE ts_code = %s ORDER BY trade_date DESC LIMIT 1",
(ts_code,), (ts_code,),
).fetchone() )
row = cur.fetchone()
return dict(row) if row else None return dict(row) if row else None
finally: finally:
conn.close() conn.close()

View File

@@ -46,7 +46,6 @@ COPY --from=api --chown=app:app /out/web /app/web
ENV TZ=Asia/Shanghai \ ENV TZ=Asia/Shanghai \
LISTEN_ADDR=:8080 \ LISTEN_ADDR=:8080 \
FUTURES_DB_PATH=/app/data/futures.db \
AUTH_DB_PATH=/app/auth/auth.db AUTH_DB_PATH=/app/auth/auth.db
EXPOSE 8080 EXPOSE 8080

View File

@@ -5,6 +5,7 @@ go 1.25.8
require ( require (
github.com/go-chi/chi/v5 v5.1.0 github.com/go-chi/chi/v5 v5.1.0
github.com/golang-jwt/jwt/v5 v5.2.1 github.com/golang-jwt/jwt/v5 v5.2.1
github.com/lib/pq v1.10.9
golang.org/x/crypto v0.27.0 golang.org/x/crypto v0.27.0
modernc.org/sqlite v1.32.0 modernc.org/sqlite v1.32.0
) )

View File

@@ -8,7 +8,7 @@ import (
type Config struct { type Config struct {
ListenAddr string ListenAddr string
FuturesDBPath string DatabaseURL string
AuthDBPath string AuthDBPath string
JWTSecret []byte JWTSecret []byte
AdminUser string AdminUser string
@@ -18,11 +18,14 @@ type Config struct {
func Load() (*Config, error) { func Load() (*Config, error) {
cfg := &Config{ cfg := &Config{
ListenAddr: getenv("LISTEN_ADDR", ":8080"), ListenAddr: getenv("LISTEN_ADDR", ":8080"),
FuturesDBPath: getenv("FUTURES_DB_PATH", "/app/data/futures.db"), DatabaseURL: os.Getenv("DATABASE_URL"),
AuthDBPath: getenv("AUTH_DB_PATH", "/app/auth/auth.db"), AuthDBPath: getenv("AUTH_DB_PATH", "/app/auth/auth.db"),
AdminUser: strings.TrimSpace(os.Getenv("ADMIN_USER")), AdminUser: strings.TrimSpace(os.Getenv("ADMIN_USER")),
AdminPass: os.Getenv("ADMIN_PASS"), AdminPass: os.Getenv("ADMIN_PASS"),
} }
if cfg.DatabaseURL == "" {
return nil, fmt.Errorf("DATABASE_URL 环境变量未设置")
}
secret := strings.TrimSpace(os.Getenv("JWT_SECRET")) secret := strings.TrimSpace(os.Getenv("JWT_SECRET"))
if len(secret) < 16 { if len(secret) < 16 {
return nil, fmt.Errorf("JWT_SECRET 必须至少 16 个字符 (建议 openssl rand -hex 32)") return nil, fmt.Errorf("JWT_SECRET 必须至少 16 个字符 (建议 openssl rand -hex 32)")

View File

@@ -6,21 +6,22 @@ import (
"errors" "errors"
"fmt" "fmt"
"strings" "strings"
_ "github.com/lib/pq"
) )
var ErrMissingTsCode = errors.New("ts_code 必填") var ErrMissingTsCode = errors.New("ts_code 必填")
type FuturesStore struct{ db *sql.DB } type FuturesStore struct{ db *sql.DB }
func OpenFutures(path string) (*FuturesStore, error) { func OpenFutures(databaseURL string) (*FuturesStore, error) {
dsn := fmt.Sprintf("file:%s?mode=ro&_pragma=query_only(true)", path) db, err := sql.Open("postgres", databaseURL)
db, err := sql.Open("sqlite", dsn)
if err != nil { if err != nil {
return nil, fmt.Errorf("open futures.db: %w", err) return nil, fmt.Errorf("open futures db: %w", err)
} }
db.SetMaxOpenConns(4) db.SetMaxOpenConns(8)
if err := db.Ping(); err != nil { if err := db.Ping(); err != nil {
return nil, fmt.Errorf("ping futures.db: %w", err) return nil, fmt.Errorf("ping futures db: %w", err)
} }
return &FuturesStore{db: db}, nil return &FuturesStore{db: db}, nil
} }

View File

@@ -24,7 +24,7 @@ func main() {
log.Fatalf("config: %v", err) log.Fatalf("config: %v", err)
} }
futures, err := store.OpenFutures(cfg.FuturesDBPath) futures, err := store.OpenFutures(cfg.DatabaseURL)
if err != nil { if err != nil {
log.Fatalf("open futures: %v", err) log.Fatalf("open futures: %v", err)
} }