新增自定义时间段批量打分功能:支持设置日期区间,对区间内每天自动拉取数据并打分
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
from typing import Optional
|
||||
|
||||
from datetime import date
|
||||
from datetime import date, datetime, timedelta
|
||||
|
||||
from fastapi import FastAPI, HTTPException, Query
|
||||
from pydantic import BaseModel
|
||||
@@ -16,6 +16,12 @@ class RunRequest(BaseModel):
|
||||
trade_date: Optional[str] = None
|
||||
|
||||
|
||||
class RunRangeRequest(BaseModel):
|
||||
symbol: str = "FG"
|
||||
start_date: str
|
||||
end_date: str
|
||||
|
||||
|
||||
class RunResponse(BaseModel):
|
||||
ts_code: str
|
||||
trade_date: str
|
||||
@@ -88,6 +94,42 @@ def run_batch():
|
||||
return {"results": results, "errors": errors}
|
||||
|
||||
|
||||
@app.post("/api/v1/run/range")
|
||||
def run_range(req: RunRangeRequest):
|
||||
"""对指定日期区间内的每一天分别打分。"""
|
||||
ts_code = contracts.active_contract(req.symbol)
|
||||
|
||||
# 为确保区间开始日有足够前置数据,拉取时 start_date 前推 60 天
|
||||
start_dt = datetime.strptime(req.start_date, "%Y%m%d")
|
||||
fetch_start = (start_dt - timedelta(days=60)).strftime("%Y%m%d")
|
||||
|
||||
df = fetcher.fetch_contract(ts_code, start_date=fetch_start, end_date=req.end_date)
|
||||
storage.save_candles(df)
|
||||
|
||||
results, warnings = scorer.score_range(df, req.start_date, req.end_date)
|
||||
|
||||
for r in results:
|
||||
storage.save_score(r)
|
||||
|
||||
return {
|
||||
"ts_code": ts_code,
|
||||
"start_date": req.start_date,
|
||||
"end_date": req.end_date,
|
||||
"scored": len(results),
|
||||
"skipped": len(warnings),
|
||||
"warnings": warnings,
|
||||
"results": [
|
||||
{
|
||||
"trade_date": r.trade_date,
|
||||
"close": r.close,
|
||||
"composite": r.composite,
|
||||
"signal": r.signal,
|
||||
}
|
||||
for r in results
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@app.get("/api/v1/scores")
|
||||
def list_scores(
|
||||
ts_code: Optional[str] = Query(None),
|
||||
|
||||
@@ -11,10 +11,19 @@ def _init_api():
|
||||
return ts.pro_api()
|
||||
|
||||
|
||||
def fetch_contract(ts_code: str, limit: int = 100) -> pd.DataFrame:
|
||||
def fetch_contract(
|
||||
ts_code: str,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
) -> pd.DataFrame:
|
||||
"""拉取指定期货合约的日线数据,返回按 trade_date 升序排列的 DataFrame。"""
|
||||
pro = _init_api()
|
||||
df = pro.fut_daily(ts_code=ts_code)
|
||||
kwargs: dict = {"ts_code": ts_code}
|
||||
if start_date:
|
||||
kwargs["start_date"] = start_date
|
||||
if end_date:
|
||||
kwargs["end_date"] = end_date
|
||||
df = pro.fut_daily(**kwargs)
|
||||
|
||||
if df.empty:
|
||||
raise RuntimeError(f"未返回 {ts_code} 的任何数据,可能合约代码错误或 token 积分不足")
|
||||
|
||||
@@ -81,6 +81,52 @@ def run(ts_code: str, trade_date: Optional[str] = None) -> int:
|
||||
return 0
|
||||
|
||||
|
||||
def run_range(ts_code: str, start_date: str, end_date: str) -> int:
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
storage.init_db()
|
||||
|
||||
print(f"[1/4] 拉取 {ts_code} 数据 ({start_date} ~ {end_date})...")
|
||||
start_dt = datetime.strptime(start_date, "%Y%m%d")
|
||||
fetch_start = (start_dt - timedelta(days=60)).strftime("%Y%m%d")
|
||||
df = fetcher.fetch_contract(ts_code, start_date=fetch_start, end_date=end_date)
|
||||
print(f" 返回 {len(df)} 行")
|
||||
|
||||
print(f"[2/4] 写入/更新 PostgreSQL...")
|
||||
storage.save_candles(df)
|
||||
|
||||
print(f"[3/4] 批量计算打分...")
|
||||
results, warnings = scorer.score_range(df, start_date, end_date)
|
||||
|
||||
print(f"[4/4] 保存打分结果...")
|
||||
for r in results:
|
||||
storage.save_score(r)
|
||||
|
||||
print("\n" + "=" * 65)
|
||||
print(f"合约: {ts_code}")
|
||||
print(f"区间: {start_date} ~ {end_date}")
|
||||
print(f"成功打分: {len(results)} 条")
|
||||
if warnings:
|
||||
print(f"跳过: {len(warnings)} 条")
|
||||
for w in warnings[:5]:
|
||||
print(f" - {w}")
|
||||
if len(warnings) > 5:
|
||||
print(f" ... 还有 {len(warnings) - 5} 条")
|
||||
print("=" * 65)
|
||||
|
||||
quadrant_names = {
|
||||
"accumulation": "增仓上涨", "distribution": "增仓下跌",
|
||||
"covering": "减仓上涨", "liquidation": "减仓下跌", "flat": "持仓持平",
|
||||
}
|
||||
print(f"\n{'日期':<12} {'收盘':>10} {'综合':>8} {'信号':<20}")
|
||||
print("-" * 55)
|
||||
for r in results:
|
||||
print(f"{r.trade_date:<12} {r.close:>10.2f} {r.composite:>8.1f} {r.signal:<20}")
|
||||
|
||||
print(f"\n[OK] {len(results)} 条打分已持久化到 PostgreSQL")
|
||||
return 0
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = argparse.ArgumentParser(description="期货合约三层打分模型")
|
||||
parser.add_argument(
|
||||
@@ -97,11 +143,26 @@ def main() -> int:
|
||||
"--date",
|
||||
help="指定打分日期,格式 YYYYMMDD,不传则对最新日期打分",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--start-date",
|
||||
dest="start_date",
|
||||
help="区间打分开始日期,格式 YYYYMMDD(与 --end-date 同时使用时生效)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--end-date",
|
||||
dest="end_date",
|
||||
help="区间打分结束日期,格式 YYYYMMDD(与 --start-date 同时使用时生效)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
ts_code = args.ts_code or contracts.active_contract(args.symbol)
|
||||
if not args.ts_code:
|
||||
print(f"[AUTO] {args.symbol} 当月主力 -> {ts_code}")
|
||||
|
||||
if args.start_date and args.end_date:
|
||||
print(f"[RANGE] 区间打分: {args.start_date} ~ {args.end_date}")
|
||||
return run_range(ts_code, args.start_date, args.end_date)
|
||||
|
||||
if args.date:
|
||||
print(f"[DATE] 指定打分日期: {args.date}")
|
||||
return run(ts_code, args.date)
|
||||
|
||||
@@ -229,3 +229,28 @@ def score_daily(df: pd.DataFrame, trade_date: Optional[str] = None) -> ScoreResu
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def score_range(
|
||||
df: pd.DataFrame, start_date: str, end_date: str
|
||||
) -> tuple[list[ScoreResult], list[str]]:
|
||||
"""对日期区间内的每一天分别打分,返回 (结果列表, 警告列表)。"""
|
||||
if len(df) < 31:
|
||||
raise ValueError(f"数据量不足(仅 {len(df)} 行),需要至少 31 行")
|
||||
|
||||
results: list[ScoreResult] = []
|
||||
warnings: list[str] = []
|
||||
|
||||
target_dates = df[
|
||||
(df["trade_date"].astype(str) >= str(start_date))
|
||||
& (df["trade_date"].astype(str) <= str(end_date))
|
||||
]["trade_date"].astype(str).tolist()
|
||||
|
||||
for td in target_dates:
|
||||
try:
|
||||
result = score_daily(df, td)
|
||||
results.append(result)
|
||||
except ValueError as e:
|
||||
warnings.append(str(e))
|
||||
|
||||
return results, warnings
|
||||
|
||||
Reference in New Issue
Block a user