163 lines
4.8 KiB
Python
163 lines
4.8 KiB
Python
from typing import Optional
|
|
|
|
from fastapi import FastAPI, HTTPException, Query
|
|
from pydantic import BaseModel
|
|
|
|
from . import contracts, fetcher, notifier, scorer, storage
|
|
|
|
app = FastAPI(title="期货数据采集与打分服务")
|
|
|
|
|
|
class RunRequest(BaseModel):
|
|
ts_code: Optional[str] = None
|
|
symbol: str = "FG"
|
|
|
|
|
|
class RunResponse(BaseModel):
|
|
ts_code: str
|
|
trade_date: str
|
|
close: float
|
|
oi: float
|
|
oi_chg: float
|
|
short_term: float
|
|
medium_term: float
|
|
long_term: float
|
|
composite: float
|
|
signal: str
|
|
|
|
|
|
@app.on_event("startup")
|
|
def startup():
|
|
storage.init_db()
|
|
|
|
|
|
@app.get("/health")
|
|
def health():
|
|
return {"status": "ok"}
|
|
|
|
|
|
@app.post("/api/v1/run", response_model=RunResponse)
|
|
def run_pipeline(req: RunRequest):
|
|
ts_code = req.ts_code or contracts.active_contract(req.symbol)
|
|
if not req.ts_code:
|
|
print(f"[AUTO] {req.symbol} 当月主力 -> {ts_code}")
|
|
|
|
df = fetcher.fetch_contract(ts_code)
|
|
storage.save_candles(df)
|
|
result = scorer.score_daily(df)
|
|
storage.save_score(result)
|
|
|
|
push_title = f"{result.ts_code.split('.')[0]} {result.trade_date}"
|
|
push_body = (
|
|
f"综合 {result.composite:.1f}\n"
|
|
f"短期 {result.short_term:.1f} | 中期 {result.medium_term:.1f} | 长期 {result.long_term:.1f}\n"
|
|
f"{result.signal}"
|
|
)
|
|
notifier.push_bark(push_title, push_body)
|
|
|
|
return RunResponse(
|
|
ts_code=result.ts_code,
|
|
trade_date=result.trade_date,
|
|
close=result.close,
|
|
oi=result.oi,
|
|
oi_chg=result.oi_chg,
|
|
short_term=result.short_term,
|
|
medium_term=result.medium_term,
|
|
long_term=result.long_term,
|
|
composite=result.composite,
|
|
signal=result.signal,
|
|
)
|
|
|
|
|
|
@app.get("/api/v1/scores")
|
|
def list_scores(
|
|
ts_code: Optional[str] = Query(None),
|
|
start: Optional[str] = Query(None),
|
|
end: Optional[str] = Query(None),
|
|
limit: int = Query(200, ge=1, le=500),
|
|
):
|
|
conn = storage._get_conn()
|
|
try:
|
|
with conn.cursor() as cur:
|
|
q = """SELECT id, ts_code, trade_date, close, oi, oi_chg, short_term, medium_term,
|
|
long_term, composite, signal, created_at FROM scores WHERE 1=1"""
|
|
args = []
|
|
if ts_code:
|
|
q += " AND ts_code = %s"
|
|
args.append(ts_code)
|
|
if start:
|
|
q += " AND trade_date >= %s"
|
|
args.append(start)
|
|
if end:
|
|
q += " AND trade_date <= %s"
|
|
args.append(end)
|
|
q += " ORDER BY trade_date DESC, id DESC LIMIT %s"
|
|
args.append(limit)
|
|
cur.execute(q, args)
|
|
cols = [d[0] for d in cur.description]
|
|
rows = [dict(zip(cols, row)) for row in cur.fetchall()]
|
|
return rows
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
@app.get("/api/v1/scores/{score_id}")
|
|
def get_score(score_id: str):
|
|
conn = storage._get_conn()
|
|
try:
|
|
with conn.cursor() as cur:
|
|
cur.execute(
|
|
"""SELECT id, ts_code, trade_date, close, oi, oi_chg, short_term, medium_term,
|
|
long_term, composite, signal, detail_json, created_at
|
|
FROM scores WHERE id = %s""",
|
|
(score_id,),
|
|
)
|
|
row = cur.fetchone()
|
|
if not row:
|
|
raise HTTPException(status_code=404, detail="not found")
|
|
cols = [d[0] for d in cur.description]
|
|
return dict(zip(cols, row))
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
@app.get("/api/v1/contracts")
|
|
def list_contracts():
|
|
conn = storage._get_conn()
|
|
try:
|
|
with conn.cursor() as cur:
|
|
cur.execute("SELECT DISTINCT ts_code FROM scores ORDER BY ts_code ASC")
|
|
return [r[0] for r in cur.fetchall()]
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
@app.get("/api/v1/candles")
|
|
def list_candles(
|
|
ts_code: str = Query(...),
|
|
start: Optional[str] = Query(None),
|
|
end: Optional[str] = Query(None),
|
|
):
|
|
conn = storage._get_conn()
|
|
try:
|
|
with conn.cursor() as cur:
|
|
q = """SELECT ts_code, trade_date,
|
|
COALESCE(open, 0), COALESCE(high, 0), COALESCE(low, 0), COALESCE(close, 0),
|
|
COALESCE(vol, 0), COALESCE(amount, 0),
|
|
COALESCE(oi, 0), COALESCE(oi_chg, 0), COALESCE(pre_close, 0)
|
|
FROM candles WHERE ts_code = %s"""
|
|
args = [ts_code]
|
|
if start:
|
|
q += " AND trade_date >= %s"
|
|
args.append(start)
|
|
if end:
|
|
q += " AND trade_date <= %s"
|
|
args.append(end)
|
|
q += " ORDER BY trade_date ASC LIMIT 1000"
|
|
cur.execute(q, args)
|
|
cols = ["ts_code", "trade_date", "open", "high", "low", "close",
|
|
"vol", "amount", "oi", "oi_chg", "pre_close"]
|
|
return [dict(zip(cols, row)) for row in cur.fetchall()]
|
|
finally:
|
|
conn.close()
|