迁移 PostgreSQL 并新增 Python API 服务
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -1,12 +1,29 @@
|
|||||||
services:
|
services:
|
||||||
|
postgres:
|
||||||
|
image: postgres:18.3-alpine3.23
|
||||||
|
environment:
|
||||||
|
POSTGRES_USER: trade
|
||||||
|
POSTGRES_PASSWORD: trade
|
||||||
|
POSTGRES_DB: futures
|
||||||
|
volumes:
|
||||||
|
- pgdata:/var/lib/postgresql/data
|
||||||
|
healthcheck:
|
||||||
|
test: ["CMD-SHELL", "pg_isready -U trade -d futures"]
|
||||||
|
interval: 5s
|
||||||
|
timeout: 5s
|
||||||
|
retries: 5
|
||||||
|
|
||||||
tushare:
|
tushare:
|
||||||
build: ./tushare
|
build: ./tushare
|
||||||
env_file: ./tushare/.env
|
env_file: ./tushare/.env
|
||||||
environment:
|
environment:
|
||||||
- DB_PATH=/app/data/futures.db
|
- DATABASE_URL=postgresql://trade:trade@postgres:5432/futures
|
||||||
volumes:
|
depends_on:
|
||||||
- ./data:/app/data
|
postgres:
|
||||||
command: ["python", "-m", "src.main"]
|
condition: service_healthy
|
||||||
|
ports:
|
||||||
|
- "8000:8000"
|
||||||
|
command: ["uvicorn", "src.api:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||||
|
|
||||||
web:
|
web:
|
||||||
build:
|
build:
|
||||||
@@ -15,14 +32,15 @@ services:
|
|||||||
env_file: ./web/backend/.env
|
env_file: ./web/backend/.env
|
||||||
environment:
|
environment:
|
||||||
- LISTEN_ADDR=:8080
|
- LISTEN_ADDR=:8080
|
||||||
- FUTURES_DB_PATH=/app/data/futures.db
|
- DATABASE_URL=postgres://trade:trade@postgres:5432/futures?sslmode=disable
|
||||||
- AUTH_DB_PATH=/app/auth/auth.db
|
- AUTH_DB_PATH=/app/auth/auth.db
|
||||||
volumes:
|
depends_on:
|
||||||
# futures.db 由 tushare 写入,web 端通过 DSN mode=ro&query_only 只读访问;
|
- postgres
|
||||||
# 不在容器层加 :ro,因为 WAL 模式下读访问也需要写 -shm 同步文件
|
|
||||||
- ./data:/app/data
|
|
||||||
# auth.db 由 web 自己写,落在 ./data/auth.db (已被 .gitignore)
|
|
||||||
- ./data:/app/auth
|
|
||||||
ports:
|
ports:
|
||||||
- "8080:8080"
|
- "8080:8080"
|
||||||
|
volumes:
|
||||||
|
- ./data:/app/auth
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
|
|
||||||
|
volumes:
|
||||||
|
pgdata:
|
||||||
|
|||||||
@@ -5,12 +5,12 @@ ENV PYTHONDONTWRITEBYTECODE=1 \
|
|||||||
PIP_NO_CACHE_DIR=1 \
|
PIP_NO_CACHE_DIR=1 \
|
||||||
PIP_DISABLE_PIP_VERSION_CHECK=1 \
|
PIP_DISABLE_PIP_VERSION_CHECK=1 \
|
||||||
TZ=Asia/Shanghai \
|
TZ=Asia/Shanghai \
|
||||||
DB_PATH=/app/data/futures.db
|
DATABASE_URL=postgresql://trade:trade@postgres:5432/futures
|
||||||
|
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
|
|
||||||
# 运行时依赖 + 时区
|
# 运行时依赖 + 时区 + libpq(psycopg2)
|
||||||
RUN apk add --no-cache tzdata \
|
RUN apk add --no-cache tzdata libpq \
|
||||||
&& cp /usr/share/zoneinfo/Asia/Shanghai /etc/localtime \
|
&& cp /usr/share/zoneinfo/Asia/Shanghai /etc/localtime \
|
||||||
&& echo "Asia/Shanghai" > /etc/timezone
|
&& echo "Asia/Shanghai" > /etc/timezone
|
||||||
|
|
||||||
@@ -25,4 +25,4 @@ RUN adduser -D -u 1000 app \
|
|||||||
COPY --chown=app:app src ./src
|
COPY --chown=app:app src ./src
|
||||||
USER app
|
USER app
|
||||||
|
|
||||||
CMD ["python", "-m", "src.main"]
|
CMD ["uvicorn", "src.api:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||||
|
|||||||
@@ -1,3 +1,6 @@
|
|||||||
tushare>=1.4.0
|
tushare>=1.4.0
|
||||||
pandas>=2.2.0
|
pandas>=2.2.0
|
||||||
requests>=2.31.0
|
requests>=2.31.0
|
||||||
|
fastapi>=0.115.0
|
||||||
|
uvicorn[standard]>=0.34.0
|
||||||
|
psycopg2-binary>=2.9.10
|
||||||
|
|||||||
162
tushare/src/api.py
Normal file
162
tushare/src/api.py
Normal file
@@ -0,0 +1,162 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from fastapi import FastAPI, HTTPException, Query
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from . import contracts, fetcher, notifier, scorer, storage
|
||||||
|
|
||||||
|
app = FastAPI(title="期货数据采集与打分服务")
|
||||||
|
|
||||||
|
|
||||||
|
class RunRequest(BaseModel):
|
||||||
|
ts_code: Optional[str] = None
|
||||||
|
symbol: str = "FG"
|
||||||
|
|
||||||
|
|
||||||
|
class RunResponse(BaseModel):
|
||||||
|
ts_code: str
|
||||||
|
trade_date: str
|
||||||
|
close: float
|
||||||
|
oi: float
|
||||||
|
oi_chg: float
|
||||||
|
short_term: float
|
||||||
|
medium_term: float
|
||||||
|
long_term: float
|
||||||
|
composite: float
|
||||||
|
signal: str
|
||||||
|
|
||||||
|
|
||||||
|
@app.on_event("startup")
|
||||||
|
def startup():
|
||||||
|
storage.init_db()
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/health")
|
||||||
|
def health():
|
||||||
|
return {"status": "ok"}
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/api/v1/run", response_model=RunResponse)
|
||||||
|
def run_pipeline(req: RunRequest):
|
||||||
|
ts_code = req.ts_code or contracts.active_contract(req.symbol)
|
||||||
|
if not req.ts_code:
|
||||||
|
print(f"[AUTO] {req.symbol} 当月主力 -> {ts_code}")
|
||||||
|
|
||||||
|
df = fetcher.fetch_contract(ts_code)
|
||||||
|
storage.save_candles(df)
|
||||||
|
result = scorer.score_daily(df)
|
||||||
|
storage.save_score(result)
|
||||||
|
|
||||||
|
push_title = f"{result.ts_code.split('.')[0]} {result.trade_date}"
|
||||||
|
push_body = (
|
||||||
|
f"综合 {result.composite:.1f}\n"
|
||||||
|
f"短期 {result.short_term:.1f} | 中期 {result.medium_term:.1f} | 长期 {result.long_term:.1f}\n"
|
||||||
|
f"{result.signal}"
|
||||||
|
)
|
||||||
|
notifier.push_bark(push_title, push_body)
|
||||||
|
|
||||||
|
return RunResponse(
|
||||||
|
ts_code=result.ts_code,
|
||||||
|
trade_date=result.trade_date,
|
||||||
|
close=result.close,
|
||||||
|
oi=result.oi,
|
||||||
|
oi_chg=result.oi_chg,
|
||||||
|
short_term=result.short_term,
|
||||||
|
medium_term=result.medium_term,
|
||||||
|
long_term=result.long_term,
|
||||||
|
composite=result.composite,
|
||||||
|
signal=result.signal,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/api/v1/scores")
|
||||||
|
def list_scores(
|
||||||
|
ts_code: Optional[str] = Query(None),
|
||||||
|
start: Optional[str] = Query(None),
|
||||||
|
end: Optional[str] = Query(None),
|
||||||
|
limit: int = Query(200, ge=1, le=500),
|
||||||
|
):
|
||||||
|
conn = storage._get_conn()
|
||||||
|
try:
|
||||||
|
with conn.cursor() as cur:
|
||||||
|
q = """SELECT id, ts_code, trade_date, close, oi, oi_chg, short_term, medium_term,
|
||||||
|
long_term, composite, signal, created_at FROM scores WHERE 1=1"""
|
||||||
|
args = []
|
||||||
|
if ts_code:
|
||||||
|
q += " AND ts_code = %s"
|
||||||
|
args.append(ts_code)
|
||||||
|
if start:
|
||||||
|
q += " AND trade_date >= %s"
|
||||||
|
args.append(start)
|
||||||
|
if end:
|
||||||
|
q += " AND trade_date <= %s"
|
||||||
|
args.append(end)
|
||||||
|
q += " ORDER BY trade_date DESC, id DESC LIMIT %s"
|
||||||
|
args.append(limit)
|
||||||
|
cur.execute(q, args)
|
||||||
|
cols = [d[0] for d in cur.description]
|
||||||
|
rows = [dict(zip(cols, row)) for row in cur.fetchall()]
|
||||||
|
return rows
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/api/v1/scores/{score_id}")
|
||||||
|
def get_score(score_id: int):
|
||||||
|
conn = storage._get_conn()
|
||||||
|
try:
|
||||||
|
with conn.cursor() as cur:
|
||||||
|
cur.execute(
|
||||||
|
"""SELECT id, ts_code, trade_date, close, oi, oi_chg, short_term, medium_term,
|
||||||
|
long_term, composite, signal, detail_json, created_at
|
||||||
|
FROM scores WHERE id = %s""",
|
||||||
|
(score_id,),
|
||||||
|
)
|
||||||
|
row = cur.fetchone()
|
||||||
|
if not row:
|
||||||
|
raise HTTPException(status_code=404, detail="not found")
|
||||||
|
cols = [d[0] for d in cur.description]
|
||||||
|
return dict(zip(cols, row))
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/api/v1/contracts")
|
||||||
|
def list_contracts():
|
||||||
|
conn = storage._get_conn()
|
||||||
|
try:
|
||||||
|
with conn.cursor() as cur:
|
||||||
|
cur.execute("SELECT DISTINCT ts_code FROM scores ORDER BY ts_code ASC")
|
||||||
|
return [r[0] for r in cur.fetchall()]
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/api/v1/candles")
|
||||||
|
def list_candles(
|
||||||
|
ts_code: str = Query(...),
|
||||||
|
start: Optional[str] = Query(None),
|
||||||
|
end: Optional[str] = Query(None),
|
||||||
|
):
|
||||||
|
conn = storage._get_conn()
|
||||||
|
try:
|
||||||
|
with conn.cursor() as cur:
|
||||||
|
q = """SELECT ts_code, trade_date,
|
||||||
|
COALESCE(open, 0), COALESCE(high, 0), COALESCE(low, 0), COALESCE(close, 0),
|
||||||
|
COALESCE(vol, 0), COALESCE(amount, 0),
|
||||||
|
COALESCE(oi, 0), COALESCE(oi_chg, 0), COALESCE(pre_close, 0)
|
||||||
|
FROM candles WHERE ts_code = %s"""
|
||||||
|
args = [ts_code]
|
||||||
|
if start:
|
||||||
|
q += " AND trade_date >= %s"
|
||||||
|
args.append(start)
|
||||||
|
if end:
|
||||||
|
q += " AND trade_date <= %s"
|
||||||
|
args.append(end)
|
||||||
|
q += " ORDER BY trade_date ASC LIMIT 1000"
|
||||||
|
cur.execute(q, args)
|
||||||
|
cols = ["ts_code", "trade_date", "open", "high", "low", "close",
|
||||||
|
"vol", "amount", "oi", "oi_chg", "pre_close"]
|
||||||
|
return [dict(zip(cols, row)) for row in cur.fetchall()]
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
@@ -1,28 +1,26 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import sqlite3
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
import psycopg2
|
||||||
|
from psycopg2.extras import RealDictCursor
|
||||||
|
|
||||||
from .models import ScoreResult
|
from .models import ScoreResult
|
||||||
|
|
||||||
DEFAULT_DB_PATH = os.environ.get("DB_PATH", "/app/data/futures.db")
|
DEFAULT_DB_URL = os.environ.get("DATABASE_URL", "postgresql://trade:trade@postgres:5432/futures")
|
||||||
|
|
||||||
|
|
||||||
def _get_conn(db_path: str = DEFAULT_DB_PATH) -> sqlite3.Connection:
|
def _get_conn(db_url: str = DEFAULT_DB_URL):
|
||||||
conn = sqlite3.connect(db_path)
|
return psycopg2.connect(db_url)
|
||||||
conn.row_factory = sqlite3.Row
|
|
||||||
return conn
|
|
||||||
|
|
||||||
|
|
||||||
def init_db(db_path: str = DEFAULT_DB_PATH):
|
def init_db(db_url: str = DEFAULT_DB_URL):
|
||||||
"""初始化数据库,创建 candles 和 scores 表。"""
|
"""初始化数据库,创建 candles 和 scores 表。"""
|
||||||
os.makedirs(os.path.dirname(db_path), exist_ok=True)
|
conn = _get_conn(db_url)
|
||||||
conn = _get_conn(db_path)
|
|
||||||
try:
|
try:
|
||||||
conn.execute("PRAGMA journal_mode=WAL")
|
with conn.cursor() as cur:
|
||||||
conn.execute("""
|
cur.execute("""
|
||||||
CREATE TABLE IF NOT EXISTS candles (
|
CREATE TABLE IF NOT EXISTS candles (
|
||||||
ts_code TEXT NOT NULL,
|
ts_code TEXT NOT NULL,
|
||||||
trade_date TEXT NOT NULL,
|
trade_date TEXT NOT NULL,
|
||||||
@@ -38,9 +36,9 @@ def init_db(db_path: str = DEFAULT_DB_PATH):
|
|||||||
PRIMARY KEY (ts_code, trade_date)
|
PRIMARY KEY (ts_code, trade_date)
|
||||||
)
|
)
|
||||||
""")
|
""")
|
||||||
conn.execute("""
|
cur.execute("""
|
||||||
CREATE TABLE IF NOT EXISTS scores (
|
CREATE TABLE IF NOT EXISTS scores (
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
id INTEGER GENERATED ALWAYS AS IDENTITY PRIMARY KEY,
|
||||||
ts_code TEXT NOT NULL,
|
ts_code TEXT NOT NULL,
|
||||||
trade_date TEXT NOT NULL,
|
trade_date TEXT NOT NULL,
|
||||||
close REAL,
|
close REAL,
|
||||||
@@ -52,7 +50,7 @@ def init_db(db_path: str = DEFAULT_DB_PATH):
|
|||||||
composite REAL,
|
composite REAL,
|
||||||
signal TEXT,
|
signal TEXT,
|
||||||
detail_json TEXT,
|
detail_json TEXT,
|
||||||
created_at TEXT DEFAULT (datetime('now', 'localtime')),
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||||
UNIQUE (ts_code, trade_date)
|
UNIQUE (ts_code, trade_date)
|
||||||
)
|
)
|
||||||
""")
|
""")
|
||||||
@@ -61,21 +59,32 @@ def init_db(db_path: str = DEFAULT_DB_PATH):
|
|||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
|
|
||||||
def save_candles(df: pd.DataFrame, db_path: str = DEFAULT_DB_PATH):
|
def save_candles(df: pd.DataFrame, db_url: str = DEFAULT_DB_URL):
|
||||||
"""批量写入/更新日线数据。"""
|
"""批量写入/更新日线数据。"""
|
||||||
if df.empty:
|
if df.empty:
|
||||||
return
|
return
|
||||||
conn = _get_conn(db_path)
|
conn = _get_conn(db_url)
|
||||||
try:
|
try:
|
||||||
df = df.copy()
|
df = df.copy()
|
||||||
df = df.where(pd.notna(df), None)
|
df = df.where(pd.notna(df), None)
|
||||||
records = df.to_dict(orient="records")
|
records = df.to_dict(orient="records")
|
||||||
conn.executemany(
|
with conn.cursor() as cur:
|
||||||
|
cur.executemany(
|
||||||
"""
|
"""
|
||||||
INSERT OR REPLACE INTO candles
|
INSERT INTO candles
|
||||||
(ts_code, trade_date, open, high, low, close, vol, amount, oi, oi_chg, pre_close)
|
(ts_code, trade_date, open, high, low, close, vol, amount, oi, oi_chg, pre_close)
|
||||||
VALUES (:ts_code, :trade_date, :open, :high, :low, :close,
|
VALUES (%(ts_code)s, %(trade_date)s, %(open)s, %(high)s, %(low)s, %(close)s,
|
||||||
:vol, :amount, :oi, :oi_chg, :pre_close)
|
%(vol)s, %(amount)s, %(oi)s, %(oi_chg)s, %(pre_close)s)
|
||||||
|
ON CONFLICT (ts_code, trade_date) DO UPDATE SET
|
||||||
|
open = EXCLUDED.open,
|
||||||
|
high = EXCLUDED.high,
|
||||||
|
low = EXCLUDED.low,
|
||||||
|
close = EXCLUDED.close,
|
||||||
|
vol = EXCLUDED.vol,
|
||||||
|
amount = EXCLUDED.amount,
|
||||||
|
oi = EXCLUDED.oi,
|
||||||
|
oi_chg = EXCLUDED.oi_chg,
|
||||||
|
pre_close = EXCLUDED.pre_close
|
||||||
""",
|
""",
|
||||||
records,
|
records,
|
||||||
)
|
)
|
||||||
@@ -84,16 +93,28 @@ def save_candles(df: pd.DataFrame, db_path: str = DEFAULT_DB_PATH):
|
|||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
|
|
||||||
def save_score(score: ScoreResult, db_path: str = DEFAULT_DB_PATH):
|
def save_score(score: ScoreResult, db_url: str = DEFAULT_DB_URL):
|
||||||
"""写入打分结果。"""
|
"""写入打分结果。"""
|
||||||
conn = _get_conn(db_path)
|
conn = _get_conn(db_url)
|
||||||
try:
|
try:
|
||||||
conn.execute(
|
with conn.cursor() as cur:
|
||||||
|
cur.execute(
|
||||||
"""
|
"""
|
||||||
INSERT OR REPLACE INTO scores
|
INSERT INTO scores
|
||||||
(ts_code, trade_date, close, oi, oi_chg,
|
(ts_code, trade_date, close, oi, oi_chg,
|
||||||
short_term, medium_term, long_term, composite, signal, detail_json)
|
short_term, medium_term, long_term, composite, signal, detail_json)
|
||||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
|
||||||
|
ON CONFLICT (ts_code, trade_date) DO UPDATE SET
|
||||||
|
close = EXCLUDED.close,
|
||||||
|
oi = EXCLUDED.oi,
|
||||||
|
oi_chg = EXCLUDED.oi_chg,
|
||||||
|
short_term = EXCLUDED.short_term,
|
||||||
|
medium_term = EXCLUDED.medium_term,
|
||||||
|
long_term = EXCLUDED.long_term,
|
||||||
|
composite = EXCLUDED.composite,
|
||||||
|
signal = EXCLUDED.signal,
|
||||||
|
detail_json = EXCLUDED.detail_json,
|
||||||
|
created_at = CURRENT_TIMESTAMP
|
||||||
""",
|
""",
|
||||||
(
|
(
|
||||||
score.ts_code,
|
score.ts_code,
|
||||||
@@ -118,14 +139,16 @@ def save_score(score: ScoreResult, db_path: str = DEFAULT_DB_PATH):
|
|||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
|
|
||||||
def get_latest_score(ts_code: str, db_path: str = DEFAULT_DB_PATH) -> Optional[dict]:
|
def get_latest_score(ts_code: str, db_url: str = DEFAULT_DB_URL) -> Optional[dict]:
|
||||||
"""查询最新打分记录。"""
|
"""查询最新打分记录。"""
|
||||||
conn = _get_conn(db_path)
|
conn = _get_conn(db_url)
|
||||||
try:
|
try:
|
||||||
row = conn.execute(
|
with conn.cursor(cursor_factory=RealDictCursor) as cur:
|
||||||
"SELECT * FROM scores WHERE ts_code = ? ORDER BY trade_date DESC LIMIT 1",
|
cur.execute(
|
||||||
|
"SELECT * FROM scores WHERE ts_code = %s ORDER BY trade_date DESC LIMIT 1",
|
||||||
(ts_code,),
|
(ts_code,),
|
||||||
).fetchone()
|
)
|
||||||
|
row = cur.fetchone()
|
||||||
return dict(row) if row else None
|
return dict(row) if row else None
|
||||||
finally:
|
finally:
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|||||||
@@ -46,7 +46,6 @@ COPY --from=api --chown=app:app /out/web /app/web
|
|||||||
|
|
||||||
ENV TZ=Asia/Shanghai \
|
ENV TZ=Asia/Shanghai \
|
||||||
LISTEN_ADDR=:8080 \
|
LISTEN_ADDR=:8080 \
|
||||||
FUTURES_DB_PATH=/app/data/futures.db \
|
|
||||||
AUTH_DB_PATH=/app/auth/auth.db
|
AUTH_DB_PATH=/app/auth/auth.db
|
||||||
|
|
||||||
EXPOSE 8080
|
EXPOSE 8080
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ go 1.25.8
|
|||||||
require (
|
require (
|
||||||
github.com/go-chi/chi/v5 v5.1.0
|
github.com/go-chi/chi/v5 v5.1.0
|
||||||
github.com/golang-jwt/jwt/v5 v5.2.1
|
github.com/golang-jwt/jwt/v5 v5.2.1
|
||||||
|
github.com/lib/pq v1.10.9
|
||||||
golang.org/x/crypto v0.27.0
|
golang.org/x/crypto v0.27.0
|
||||||
modernc.org/sqlite v1.32.0
|
modernc.org/sqlite v1.32.0
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import (
|
|||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
ListenAddr string
|
ListenAddr string
|
||||||
FuturesDBPath string
|
DatabaseURL string
|
||||||
AuthDBPath string
|
AuthDBPath string
|
||||||
JWTSecret []byte
|
JWTSecret []byte
|
||||||
AdminUser string
|
AdminUser string
|
||||||
@@ -18,11 +18,14 @@ type Config struct {
|
|||||||
func Load() (*Config, error) {
|
func Load() (*Config, error) {
|
||||||
cfg := &Config{
|
cfg := &Config{
|
||||||
ListenAddr: getenv("LISTEN_ADDR", ":8080"),
|
ListenAddr: getenv("LISTEN_ADDR", ":8080"),
|
||||||
FuturesDBPath: getenv("FUTURES_DB_PATH", "/app/data/futures.db"),
|
DatabaseURL: os.Getenv("DATABASE_URL"),
|
||||||
AuthDBPath: getenv("AUTH_DB_PATH", "/app/auth/auth.db"),
|
AuthDBPath: getenv("AUTH_DB_PATH", "/app/auth/auth.db"),
|
||||||
AdminUser: strings.TrimSpace(os.Getenv("ADMIN_USER")),
|
AdminUser: strings.TrimSpace(os.Getenv("ADMIN_USER")),
|
||||||
AdminPass: os.Getenv("ADMIN_PASS"),
|
AdminPass: os.Getenv("ADMIN_PASS"),
|
||||||
}
|
}
|
||||||
|
if cfg.DatabaseURL == "" {
|
||||||
|
return nil, fmt.Errorf("DATABASE_URL 环境变量未设置")
|
||||||
|
}
|
||||||
secret := strings.TrimSpace(os.Getenv("JWT_SECRET"))
|
secret := strings.TrimSpace(os.Getenv("JWT_SECRET"))
|
||||||
if len(secret) < 16 {
|
if len(secret) < 16 {
|
||||||
return nil, fmt.Errorf("JWT_SECRET 必须至少 16 个字符 (建议 openssl rand -hex 32)")
|
return nil, fmt.Errorf("JWT_SECRET 必须至少 16 个字符 (建议 openssl rand -hex 32)")
|
||||||
|
|||||||
@@ -6,21 +6,22 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
_ "github.com/lib/pq"
|
||||||
)
|
)
|
||||||
|
|
||||||
var ErrMissingTsCode = errors.New("ts_code 必填")
|
var ErrMissingTsCode = errors.New("ts_code 必填")
|
||||||
|
|
||||||
type FuturesStore struct{ db *sql.DB }
|
type FuturesStore struct{ db *sql.DB }
|
||||||
|
|
||||||
func OpenFutures(path string) (*FuturesStore, error) {
|
func OpenFutures(databaseURL string) (*FuturesStore, error) {
|
||||||
dsn := fmt.Sprintf("file:%s?mode=ro&_pragma=query_only(true)", path)
|
db, err := sql.Open("postgres", databaseURL)
|
||||||
db, err := sql.Open("sqlite", dsn)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("open futures.db: %w", err)
|
return nil, fmt.Errorf("open futures db: %w", err)
|
||||||
}
|
}
|
||||||
db.SetMaxOpenConns(4)
|
db.SetMaxOpenConns(8)
|
||||||
if err := db.Ping(); err != nil {
|
if err := db.Ping(); err != nil {
|
||||||
return nil, fmt.Errorf("ping futures.db: %w", err)
|
return nil, fmt.Errorf("ping futures db: %w", err)
|
||||||
}
|
}
|
||||||
return &FuturesStore{db: db}, nil
|
return &FuturesStore{db: db}, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ func main() {
|
|||||||
log.Fatalf("config: %v", err)
|
log.Fatalf("config: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
futures, err := store.OpenFutures(cfg.FuturesDBPath)
|
futures, err := store.OpenFutures(cfg.DatabaseURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("open futures: %v", err)
|
log.Fatalf("open futures: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user