From 220f4acc4547bfa8cd894894f7a79b9e6024b1ae Mon Sep 17 00:00:00 2001 From: fish Date: Sun, 3 May 2026 14:58:01 +0800 Subject: [PATCH] =?UTF-8?q?=E8=BF=81=E7=A7=BB=20PostgreSQL=20=E5=B9=B6?= =?UTF-8?q?=E6=96=B0=E5=A2=9E=20Python=20API=20=E6=9C=8D=E5=8A=A1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-Authored-By: Claude Opus 4.7 --- docker-compose.yml | 40 +++-- tushare/Dockerfile | 8 +- tushare/requirements.txt | 3 + tushare/src/api.py | 162 +++++++++++++++++++++ tushare/src/storage.py | 201 ++++++++++++++------------ web/backend/Dockerfile | 1 - web/backend/go.mod | 1 + web/backend/internal/config/config.go | 25 ++-- web/backend/internal/store/futures.go | 13 +- web/backend/main.go | 2 +- 10 files changed, 333 insertions(+), 123 deletions(-) create mode 100644 tushare/src/api.py diff --git a/docker-compose.yml b/docker-compose.yml index 6249e61..369973e 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,12 +1,29 @@ services: + postgres: + image: postgres:18.3-alpine3.23 + environment: + POSTGRES_USER: trade + POSTGRES_PASSWORD: trade + POSTGRES_DB: futures + volumes: + - pgdata:/var/lib/postgresql/data + healthcheck: + test: ["CMD-SHELL", "pg_isready -U trade -d futures"] + interval: 5s + timeout: 5s + retries: 5 + tushare: build: ./tushare env_file: ./tushare/.env environment: - - DB_PATH=/app/data/futures.db - volumes: - - ./data:/app/data - command: ["python", "-m", "src.main"] + - DATABASE_URL=postgresql://trade:trade@postgres:5432/futures + depends_on: + postgres: + condition: service_healthy + ports: + - "8000:8000" + command: ["uvicorn", "src.api:app", "--host", "0.0.0.0", "--port", "8000"] web: build: @@ -15,14 +32,15 @@ services: env_file: ./web/backend/.env environment: - LISTEN_ADDR=:8080 - - FUTURES_DB_PATH=/app/data/futures.db + - DATABASE_URL=postgres://trade:trade@postgres:5432/futures?sslmode=disable - AUTH_DB_PATH=/app/auth/auth.db - volumes: - # futures.db 由 tushare 写入,web 端通过 DSN mode=ro&query_only 只读访问; - # 不在容器层加 :ro,因为 WAL 模式下读访问也需要写 -shm 同步文件 - - ./data:/app/data - # auth.db 由 web 自己写,落在 ./data/auth.db (已被 .gitignore) - - ./data:/app/auth + depends_on: + - postgres ports: - "8080:8080" + volumes: + - ./data:/app/auth restart: unless-stopped + +volumes: + pgdata: diff --git a/tushare/Dockerfile b/tushare/Dockerfile index 3aae6ab..ad931dc 100644 --- a/tushare/Dockerfile +++ b/tushare/Dockerfile @@ -5,12 +5,12 @@ ENV PYTHONDONTWRITEBYTECODE=1 \ PIP_NO_CACHE_DIR=1 \ PIP_DISABLE_PIP_VERSION_CHECK=1 \ TZ=Asia/Shanghai \ - DB_PATH=/app/data/futures.db + DATABASE_URL=postgresql://trade:trade@postgres:5432/futures WORKDIR /app -# 运行时依赖 + 时区 -RUN apk add --no-cache tzdata \ +# 运行时依赖 + 时区 + libpq(psycopg2) +RUN apk add --no-cache tzdata libpq \ && cp /usr/share/zoneinfo/Asia/Shanghai /etc/localtime \ && echo "Asia/Shanghai" > /etc/timezone @@ -25,4 +25,4 @@ RUN adduser -D -u 1000 app \ COPY --chown=app:app src ./src USER app -CMD ["python", "-m", "src.main"] +CMD ["uvicorn", "src.api:app", "--host", "0.0.0.0", "--port", "8000"] diff --git a/tushare/requirements.txt b/tushare/requirements.txt index 8ce90c6..d54b57d 100644 --- a/tushare/requirements.txt +++ b/tushare/requirements.txt @@ -1,3 +1,6 @@ tushare>=1.4.0 pandas>=2.2.0 requests>=2.31.0 +fastapi>=0.115.0 +uvicorn[standard]>=0.34.0 +psycopg2-binary>=2.9.10 diff --git a/tushare/src/api.py b/tushare/src/api.py new file mode 100644 index 0000000..5a33c2c --- /dev/null +++ b/tushare/src/api.py @@ -0,0 +1,162 @@ +from typing import Optional + +from fastapi import FastAPI, HTTPException, Query +from pydantic import BaseModel + +from . import contracts, fetcher, notifier, scorer, storage + +app = FastAPI(title="期货数据采集与打分服务") + + +class RunRequest(BaseModel): + ts_code: Optional[str] = None + symbol: str = "FG" + + +class RunResponse(BaseModel): + ts_code: str + trade_date: str + close: float + oi: float + oi_chg: float + short_term: float + medium_term: float + long_term: float + composite: float + signal: str + + +@app.on_event("startup") +def startup(): + storage.init_db() + + +@app.get("/health") +def health(): + return {"status": "ok"} + + +@app.post("/api/v1/run", response_model=RunResponse) +def run_pipeline(req: RunRequest): + ts_code = req.ts_code or contracts.active_contract(req.symbol) + if not req.ts_code: + print(f"[AUTO] {req.symbol} 当月主力 -> {ts_code}") + + df = fetcher.fetch_contract(ts_code) + storage.save_candles(df) + result = scorer.score_daily(df) + storage.save_score(result) + + push_title = f"{result.ts_code.split('.')[0]} {result.trade_date}" + push_body = ( + f"综合 {result.composite:.1f}\n" + f"短期 {result.short_term:.1f} | 中期 {result.medium_term:.1f} | 长期 {result.long_term:.1f}\n" + f"{result.signal}" + ) + notifier.push_bark(push_title, push_body) + + return RunResponse( + ts_code=result.ts_code, + trade_date=result.trade_date, + close=result.close, + oi=result.oi, + oi_chg=result.oi_chg, + short_term=result.short_term, + medium_term=result.medium_term, + long_term=result.long_term, + composite=result.composite, + signal=result.signal, + ) + + +@app.get("/api/v1/scores") +def list_scores( + ts_code: Optional[str] = Query(None), + start: Optional[str] = Query(None), + end: Optional[str] = Query(None), + limit: int = Query(200, ge=1, le=500), +): + conn = storage._get_conn() + try: + with conn.cursor() as cur: + q = """SELECT id, ts_code, trade_date, close, oi, oi_chg, short_term, medium_term, + long_term, composite, signal, created_at FROM scores WHERE 1=1""" + args = [] + if ts_code: + q += " AND ts_code = %s" + args.append(ts_code) + if start: + q += " AND trade_date >= %s" + args.append(start) + if end: + q += " AND trade_date <= %s" + args.append(end) + q += " ORDER BY trade_date DESC, id DESC LIMIT %s" + args.append(limit) + cur.execute(q, args) + cols = [d[0] for d in cur.description] + rows = [dict(zip(cols, row)) for row in cur.fetchall()] + return rows + finally: + conn.close() + + +@app.get("/api/v1/scores/{score_id}") +def get_score(score_id: int): + conn = storage._get_conn() + try: + with conn.cursor() as cur: + cur.execute( + """SELECT id, ts_code, trade_date, close, oi, oi_chg, short_term, medium_term, + long_term, composite, signal, detail_json, created_at + FROM scores WHERE id = %s""", + (score_id,), + ) + row = cur.fetchone() + if not row: + raise HTTPException(status_code=404, detail="not found") + cols = [d[0] for d in cur.description] + return dict(zip(cols, row)) + finally: + conn.close() + + +@app.get("/api/v1/contracts") +def list_contracts(): + conn = storage._get_conn() + try: + with conn.cursor() as cur: + cur.execute("SELECT DISTINCT ts_code FROM scores ORDER BY ts_code ASC") + return [r[0] for r in cur.fetchall()] + finally: + conn.close() + + +@app.get("/api/v1/candles") +def list_candles( + ts_code: str = Query(...), + start: Optional[str] = Query(None), + end: Optional[str] = Query(None), +): + conn = storage._get_conn() + try: + with conn.cursor() as cur: + q = """SELECT ts_code, trade_date, + COALESCE(open, 0), COALESCE(high, 0), COALESCE(low, 0), COALESCE(close, 0), + COALESCE(vol, 0), COALESCE(amount, 0), + COALESCE(oi, 0), COALESCE(oi_chg, 0), COALESCE(pre_close, 0) + FROM candles WHERE ts_code = %s""" + args = [ts_code] + if start: + q += " AND trade_date >= %s" + args.append(start) + if end: + q += " AND trade_date <= %s" + args.append(end) + q += " ORDER BY trade_date ASC LIMIT 1000" + cur.execute(q, args) + cols = ["ts_code", "trade_date", "open", "high", "low", "close", + "vol", "amount", "oi", "oi_chg", "pre_close"] + return [dict(zip(cols, row)) for row in cur.fetchall()] + finally: + conn.close() diff --git a/tushare/src/storage.py b/tushare/src/storage.py index 42a634a..0b10e66 100644 --- a/tushare/src/storage.py +++ b/tushare/src/storage.py @@ -1,131 +1,154 @@ import json import os -import sqlite3 from typing import Optional import pandas as pd +import psycopg2 +from psycopg2.extras import RealDictCursor from .models import ScoreResult -DEFAULT_DB_PATH = os.environ.get("DB_PATH", "/app/data/futures.db") +DEFAULT_DB_URL = os.environ.get("DATABASE_URL", "postgresql://trade:trade@postgres:5432/futures") -def _get_conn(db_path: str = DEFAULT_DB_PATH) -> sqlite3.Connection: - conn = sqlite3.connect(db_path) - conn.row_factory = sqlite3.Row - return conn +def _get_conn(db_url: str = DEFAULT_DB_URL): + return psycopg2.connect(db_url) -def init_db(db_path: str = DEFAULT_DB_PATH): +def init_db(db_url: str = DEFAULT_DB_URL): """初始化数据库,创建 candles 和 scores 表。""" - os.makedirs(os.path.dirname(db_path), exist_ok=True) - conn = _get_conn(db_path) + conn = _get_conn(db_url) try: - conn.execute("PRAGMA journal_mode=WAL") - conn.execute(""" - CREATE TABLE IF NOT EXISTS candles ( - ts_code TEXT NOT NULL, - trade_date TEXT NOT NULL, - open REAL, - high REAL, - low REAL, - close REAL, - vol REAL, - amount REAL, - oi REAL, - oi_chg REAL, - pre_close REAL, - PRIMARY KEY (ts_code, trade_date) - ) - """) - conn.execute(""" - CREATE TABLE IF NOT EXISTS scores ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - ts_code TEXT NOT NULL, - trade_date TEXT NOT NULL, - close REAL, - oi REAL, - oi_chg REAL, - short_term REAL, - medium_term REAL, - long_term REAL, - composite REAL, - signal TEXT, - detail_json TEXT, - created_at TEXT DEFAULT (datetime('now', 'localtime')), - UNIQUE (ts_code, trade_date) - ) - """) + with conn.cursor() as cur: + cur.execute(""" + CREATE TABLE IF NOT EXISTS candles ( + ts_code TEXT NOT NULL, + trade_date TEXT NOT NULL, + open REAL, + high REAL, + low REAL, + close REAL, + vol REAL, + amount REAL, + oi REAL, + oi_chg REAL, + pre_close REAL, + PRIMARY KEY (ts_code, trade_date) + ) + """) + cur.execute(""" + CREATE TABLE IF NOT EXISTS scores ( + id INTEGER GENERATED ALWAYS AS IDENTITY PRIMARY KEY, + ts_code TEXT NOT NULL, + trade_date TEXT NOT NULL, + close REAL, + oi REAL, + oi_chg REAL, + short_term REAL, + medium_term REAL, + long_term REAL, + composite REAL, + signal TEXT, + detail_json TEXT, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + UNIQUE (ts_code, trade_date) + ) + """) conn.commit() finally: conn.close() -def save_candles(df: pd.DataFrame, db_path: str = DEFAULT_DB_PATH): +def save_candles(df: pd.DataFrame, db_url: str = DEFAULT_DB_URL): """批量写入/更新日线数据。""" if df.empty: return - conn = _get_conn(db_path) + conn = _get_conn(db_url) try: df = df.copy() df = df.where(pd.notna(df), None) records = df.to_dict(orient="records") - conn.executemany( - """ - INSERT OR REPLACE INTO candles - (ts_code, trade_date, open, high, low, close, vol, amount, oi, oi_chg, pre_close) - VALUES (:ts_code, :trade_date, :open, :high, :low, :close, - :vol, :amount, :oi, :oi_chg, :pre_close) - """, - records, - ) + with conn.cursor() as cur: + cur.executemany( + """ + INSERT INTO candles + (ts_code, trade_date, open, high, low, close, vol, amount, oi, oi_chg, pre_close) + VALUES (%(ts_code)s, %(trade_date)s, %(open)s, %(high)s, %(low)s, %(close)s, + %(vol)s, %(amount)s, %(oi)s, %(oi_chg)s, %(pre_close)s) + ON CONFLICT (ts_code, trade_date) DO UPDATE SET + open = EXCLUDED.open, + high = EXCLUDED.high, + low = EXCLUDED.low, + close = EXCLUDED.close, + vol = EXCLUDED.vol, + amount = EXCLUDED.amount, + oi = EXCLUDED.oi, + oi_chg = EXCLUDED.oi_chg, + pre_close = EXCLUDED.pre_close + """, + records, + ) conn.commit() finally: conn.close() -def save_score(score: ScoreResult, db_path: str = DEFAULT_DB_PATH): +def save_score(score: ScoreResult, db_url: str = DEFAULT_DB_URL): """写入打分结果。""" - conn = _get_conn(db_path) + conn = _get_conn(db_url) try: - conn.execute( - """ - INSERT OR REPLACE INTO scores - (ts_code, trade_date, close, oi, oi_chg, - short_term, medium_term, long_term, composite, signal, detail_json) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """, - ( - score.ts_code, - score.trade_date, - score.close, - score.oi, - score.oi_chg, - score.short_term, - score.medium_term, - score.long_term, - score.composite, - score.signal, - json.dumps({ - "short_details": score.detail.short_details, - "medium_detail": score.detail.medium_detail, - "long_detail": score.detail.long_detail, - }, ensure_ascii=False, default=str), - ), - ) + with conn.cursor() as cur: + cur.execute( + """ + INSERT INTO scores + (ts_code, trade_date, close, oi, oi_chg, + short_term, medium_term, long_term, composite, signal, detail_json) + VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) + ON CONFLICT (ts_code, trade_date) DO UPDATE SET + close = EXCLUDED.close, + oi = EXCLUDED.oi, + oi_chg = EXCLUDED.oi_chg, + short_term = EXCLUDED.short_term, + medium_term = EXCLUDED.medium_term, + long_term = EXCLUDED.long_term, + composite = EXCLUDED.composite, + signal = EXCLUDED.signal, + detail_json = EXCLUDED.detail_json, + created_at = CURRENT_TIMESTAMP + """, + ( + score.ts_code, + score.trade_date, + score.close, + score.oi, + score.oi_chg, + score.short_term, + score.medium_term, + score.long_term, + score.composite, + score.signal, + json.dumps({ + "short_details": score.detail.short_details, + "medium_detail": score.detail.medium_detail, + "long_detail": score.detail.long_detail, + }, ensure_ascii=False, default=str), + ), + ) conn.commit() finally: conn.close() -def get_latest_score(ts_code: str, db_path: str = DEFAULT_DB_PATH) -> Optional[dict]: +def get_latest_score(ts_code: str, db_url: str = DEFAULT_DB_URL) -> Optional[dict]: """查询最新打分记录。""" - conn = _get_conn(db_path) + conn = _get_conn(db_url) try: - row = conn.execute( - "SELECT * FROM scores WHERE ts_code = ? ORDER BY trade_date DESC LIMIT 1", - (ts_code,), - ).fetchone() - return dict(row) if row else None + with conn.cursor(cursor_factory=RealDictCursor) as cur: + cur.execute( + "SELECT * FROM scores WHERE ts_code = %s ORDER BY trade_date DESC LIMIT 1", + (ts_code,), + ) + row = cur.fetchone() + return dict(row) if row else None finally: conn.close() diff --git a/web/backend/Dockerfile b/web/backend/Dockerfile index 0866937..87b588d 100644 --- a/web/backend/Dockerfile +++ b/web/backend/Dockerfile @@ -46,7 +46,6 @@ COPY --from=api --chown=app:app /out/web /app/web ENV TZ=Asia/Shanghai \ LISTEN_ADDR=:8080 \ - FUTURES_DB_PATH=/app/data/futures.db \ AUTH_DB_PATH=/app/auth/auth.db EXPOSE 8080 diff --git a/web/backend/go.mod b/web/backend/go.mod index edd0dc6..2f0eadb 100644 --- a/web/backend/go.mod +++ b/web/backend/go.mod @@ -5,6 +5,7 @@ go 1.25.8 require ( github.com/go-chi/chi/v5 v5.1.0 github.com/golang-jwt/jwt/v5 v5.2.1 + github.com/lib/pq v1.10.9 golang.org/x/crypto v0.27.0 modernc.org/sqlite v1.32.0 ) diff --git a/web/backend/internal/config/config.go b/web/backend/internal/config/config.go index 0693f85..82e514d 100644 --- a/web/backend/internal/config/config.go +++ b/web/backend/internal/config/config.go @@ -7,21 +7,24 @@ import ( ) type Config struct { - ListenAddr string - FuturesDBPath string - AuthDBPath string - JWTSecret []byte - AdminUser string - AdminPass string + ListenAddr string + DatabaseURL string + AuthDBPath string + JWTSecret []byte + AdminUser string + AdminPass string } func Load() (*Config, error) { cfg := &Config{ - ListenAddr: getenv("LISTEN_ADDR", ":8080"), - FuturesDBPath: getenv("FUTURES_DB_PATH", "/app/data/futures.db"), - AuthDBPath: getenv("AUTH_DB_PATH", "/app/auth/auth.db"), - AdminUser: strings.TrimSpace(os.Getenv("ADMIN_USER")), - AdminPass: os.Getenv("ADMIN_PASS"), + ListenAddr: getenv("LISTEN_ADDR", ":8080"), + DatabaseURL: os.Getenv("DATABASE_URL"), + AuthDBPath: getenv("AUTH_DB_PATH", "/app/auth/auth.db"), + AdminUser: strings.TrimSpace(os.Getenv("ADMIN_USER")), + AdminPass: os.Getenv("ADMIN_PASS"), + } + if cfg.DatabaseURL == "" { + return nil, fmt.Errorf("DATABASE_URL 环境变量未设置") } secret := strings.TrimSpace(os.Getenv("JWT_SECRET")) if len(secret) < 16 { diff --git a/web/backend/internal/store/futures.go b/web/backend/internal/store/futures.go index 5e0d976..3059da2 100644 --- a/web/backend/internal/store/futures.go +++ b/web/backend/internal/store/futures.go @@ -6,21 +6,22 @@ import ( "errors" "fmt" "strings" + + _ "github.com/lib/pq" ) var ErrMissingTsCode = errors.New("ts_code 必填") type FuturesStore struct{ db *sql.DB } -func OpenFutures(path string) (*FuturesStore, error) { - dsn := fmt.Sprintf("file:%s?mode=ro&_pragma=query_only(true)", path) - db, err := sql.Open("sqlite", dsn) +func OpenFutures(databaseURL string) (*FuturesStore, error) { + db, err := sql.Open("postgres", databaseURL) if err != nil { - return nil, fmt.Errorf("open futures.db: %w", err) + return nil, fmt.Errorf("open futures db: %w", err) } - db.SetMaxOpenConns(4) + db.SetMaxOpenConns(8) if err := db.Ping(); err != nil { - return nil, fmt.Errorf("ping futures.db: %w", err) + return nil, fmt.Errorf("ping futures db: %w", err) } return &FuturesStore{db: db}, nil } diff --git a/web/backend/main.go b/web/backend/main.go index 88b6dc2..9b506ae 100644 --- a/web/backend/main.go +++ b/web/backend/main.go @@ -24,7 +24,7 @@ func main() { log.Fatalf("config: %v", err) } - futures, err := store.OpenFutures(cfg.FuturesDBPath) + futures, err := store.OpenFutures(cfg.DatabaseURL) if err != nil { log.Fatalf("open futures: %v", err) }