diff --git a/tushare/src/api.py b/tushare/src/api.py index 4fb72f6..3cd5347 100644 --- a/tushare/src/api.py +++ b/tushare/src/api.py @@ -1,6 +1,6 @@ from typing import Optional -from datetime import date +from datetime import date, datetime, timedelta from fastapi import FastAPI, HTTPException, Query from pydantic import BaseModel @@ -16,6 +16,12 @@ class RunRequest(BaseModel): trade_date: Optional[str] = None +class RunRangeRequest(BaseModel): + symbol: str = "FG" + start_date: str + end_date: str + + class RunResponse(BaseModel): ts_code: str trade_date: str @@ -88,6 +94,42 @@ def run_batch(): return {"results": results, "errors": errors} +@app.post("/api/v1/run/range") +def run_range(req: RunRangeRequest): + """对指定日期区间内的每一天分别打分。""" + ts_code = contracts.active_contract(req.symbol) + + # 为确保区间开始日有足够前置数据,拉取时 start_date 前推 60 天 + start_dt = datetime.strptime(req.start_date, "%Y%m%d") + fetch_start = (start_dt - timedelta(days=60)).strftime("%Y%m%d") + + df = fetcher.fetch_contract(ts_code, start_date=fetch_start, end_date=req.end_date) + storage.save_candles(df) + + results, warnings = scorer.score_range(df, req.start_date, req.end_date) + + for r in results: + storage.save_score(r) + + return { + "ts_code": ts_code, + "start_date": req.start_date, + "end_date": req.end_date, + "scored": len(results), + "skipped": len(warnings), + "warnings": warnings, + "results": [ + { + "trade_date": r.trade_date, + "close": r.close, + "composite": r.composite, + "signal": r.signal, + } + for r in results + ], + } + + @app.get("/api/v1/scores") def list_scores( ts_code: Optional[str] = Query(None), diff --git a/tushare/src/fetcher.py b/tushare/src/fetcher.py index 9122bd0..b226e99 100644 --- a/tushare/src/fetcher.py +++ b/tushare/src/fetcher.py @@ -11,10 +11,19 @@ def _init_api(): return ts.pro_api() -def fetch_contract(ts_code: str, limit: int = 100) -> pd.DataFrame: +def fetch_contract( + ts_code: str, + start_date: Optional[str] = None, + end_date: Optional[str] = None, +) -> pd.DataFrame: """拉取指定期货合约的日线数据,返回按 trade_date 升序排列的 DataFrame。""" pro = _init_api() - df = pro.fut_daily(ts_code=ts_code) + kwargs: dict = {"ts_code": ts_code} + if start_date: + kwargs["start_date"] = start_date + if end_date: + kwargs["end_date"] = end_date + df = pro.fut_daily(**kwargs) if df.empty: raise RuntimeError(f"未返回 {ts_code} 的任何数据,可能合约代码错误或 token 积分不足") diff --git a/tushare/src/main.py b/tushare/src/main.py index b340c2f..8333169 100644 --- a/tushare/src/main.py +++ b/tushare/src/main.py @@ -81,6 +81,52 @@ def run(ts_code: str, trade_date: Optional[str] = None) -> int: return 0 +def run_range(ts_code: str, start_date: str, end_date: str) -> int: + from datetime import datetime, timedelta + + storage.init_db() + + print(f"[1/4] 拉取 {ts_code} 数据 ({start_date} ~ {end_date})...") + start_dt = datetime.strptime(start_date, "%Y%m%d") + fetch_start = (start_dt - timedelta(days=60)).strftime("%Y%m%d") + df = fetcher.fetch_contract(ts_code, start_date=fetch_start, end_date=end_date) + print(f" 返回 {len(df)} 行") + + print(f"[2/4] 写入/更新 PostgreSQL...") + storage.save_candles(df) + + print(f"[3/4] 批量计算打分...") + results, warnings = scorer.score_range(df, start_date, end_date) + + print(f"[4/4] 保存打分结果...") + for r in results: + storage.save_score(r) + + print("\n" + "=" * 65) + print(f"合约: {ts_code}") + print(f"区间: {start_date} ~ {end_date}") + print(f"成功打分: {len(results)} 条") + if warnings: + print(f"跳过: {len(warnings)} 条") + for w in warnings[:5]: + print(f" - {w}") + if len(warnings) > 5: + print(f" ... 还有 {len(warnings) - 5} 条") + print("=" * 65) + + quadrant_names = { + "accumulation": "增仓上涨", "distribution": "增仓下跌", + "covering": "减仓上涨", "liquidation": "减仓下跌", "flat": "持仓持平", + } + print(f"\n{'日期':<12} {'收盘':>10} {'综合':>8} {'信号':<20}") + print("-" * 55) + for r in results: + print(f"{r.trade_date:<12} {r.close:>10.2f} {r.composite:>8.1f} {r.signal:<20}") + + print(f"\n[OK] {len(results)} 条打分已持久化到 PostgreSQL") + return 0 + + def main() -> int: parser = argparse.ArgumentParser(description="期货合约三层打分模型") parser.add_argument( @@ -97,11 +143,26 @@ def main() -> int: "--date", help="指定打分日期,格式 YYYYMMDD,不传则对最新日期打分", ) + parser.add_argument( + "--start-date", + dest="start_date", + help="区间打分开始日期,格式 YYYYMMDD(与 --end-date 同时使用时生效)", + ) + parser.add_argument( + "--end-date", + dest="end_date", + help="区间打分结束日期,格式 YYYYMMDD(与 --start-date 同时使用时生效)", + ) args = parser.parse_args() ts_code = args.ts_code or contracts.active_contract(args.symbol) if not args.ts_code: print(f"[AUTO] {args.symbol} 当月主力 -> {ts_code}") + + if args.start_date and args.end_date: + print(f"[RANGE] 区间打分: {args.start_date} ~ {args.end_date}") + return run_range(ts_code, args.start_date, args.end_date) + if args.date: print(f"[DATE] 指定打分日期: {args.date}") return run(ts_code, args.date) diff --git a/tushare/src/scorer.py b/tushare/src/scorer.py index ad9a891..b43ada5 100644 --- a/tushare/src/scorer.py +++ b/tushare/src/scorer.py @@ -229,3 +229,28 @@ def score_daily(df: pd.DataFrame, trade_date: Optional[str] = None) -> ScoreResu }, ), ) + + +def score_range( + df: pd.DataFrame, start_date: str, end_date: str +) -> tuple[list[ScoreResult], list[str]]: + """对日期区间内的每一天分别打分,返回 (结果列表, 警告列表)。""" + if len(df) < 31: + raise ValueError(f"数据量不足(仅 {len(df)} 行),需要至少 31 行") + + results: list[ScoreResult] = [] + warnings: list[str] = [] + + target_dates = df[ + (df["trade_date"].astype(str) >= str(start_date)) + & (df["trade_date"].astype(str) <= str(end_date)) + ]["trade_date"].astype(str).tolist() + + for td in target_dates: + try: + result = score_daily(df, td) + results.append(result) + except ValueError as e: + warnings.append(str(e)) + + return results, warnings diff --git a/web/backend/internal/handlers/run.go b/web/backend/internal/handlers/run.go index 72c3685..3baa569 100644 --- a/web/backend/internal/handlers/run.go +++ b/web/backend/internal/handlers/run.go @@ -56,6 +56,32 @@ func (d *Deps) RunBatch(w http.ResponseWriter, r *http.Request) { _, _ = io.Copy(w, resp.Body) } +func (d *Deps) RunRange(w http.ResponseWriter, r *http.Request) { + var req runRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeErr(w, http.StatusBadRequest, "invalid json") + return + } + + body, err := json.Marshal(req) + if err != nil { + writeErr(w, http.StatusInternalServerError, "encode request failed") + return + } + + client := &http.Client{Timeout: 180 * time.Second} + resp, err := client.Post(d.TushareURL+"/api/v1/run/range", "application/json", bytes.NewReader(body)) + if err != nil { + writeErr(w, http.StatusBadGateway, fmt.Sprintf("tushare service unavailable: %v", err)) + return + } + defer resp.Body.Close() + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(resp.StatusCode) + _, _ = io.Copy(w, resp.Body) +} + func (d *Deps) GetActiveContract(w http.ResponseWriter, r *http.Request) { symbol := r.URL.Query().Get("symbol") if symbol == "" { diff --git a/web/backend/internal/router/router.go b/web/backend/internal/router/router.go index 3dcccc6..7dd191c 100644 --- a/web/backend/internal/router/router.go +++ b/web/backend/internal/router/router.go @@ -32,6 +32,7 @@ func New(d *handlers.Deps, dist fs.FS) http.Handler { r.Get("/candles", d.ListCandles) r.Post("/run", d.RunPipeline) r.Post("/run/batch", d.RunBatch) + r.Post("/run/range", d.RunRange) r.Group(func(r chi.Router) { r.Use(mw.RequireAdmin) diff --git a/web/frontend/src/api/run.ts b/web/frontend/src/api/run.ts index 23797e1..95521f6 100644 --- a/web/frontend/src/api/run.ts +++ b/web/frontend/src/api/run.ts @@ -30,6 +30,33 @@ export function runPipeline(req: RunRequest) { return client.post('/run', req).then((r) => r.data) } +export interface RunRangeRequest { + symbol: string + start_date: string + end_date: string +} + +export interface RunRangeResult { + trade_date: string + close: number + composite: number + signal: string +} + +export interface RunRangeResponse { + ts_code: string + start_date: string + end_date: string + scored: number + skipped: number + warnings: string[] + results: RunRangeResult[] +} + +export function runRange(req: RunRangeRequest) { + return client.post('/run/range', req, { timeout: 180_000 }).then((r) => r.data) +} + export function runBatch() { return client.post('/run/batch', null, { timeout: 180_000 }).then((r) => r.data) } diff --git a/web/frontend/src/views/RunView.vue b/web/frontend/src/views/RunView.vue index 09dc96b..2bee0fc 100644 --- a/web/frontend/src/views/RunView.vue +++ b/web/frontend/src/views/RunView.vue @@ -3,9 +3,11 @@ import { nextTick, onMounted, reactive, ref, watch } from 'vue' import { ElMessage } from 'element-plus' import { runPipeline, + runRange, getActiveContract, type ActiveContract, type RunResponse, + type RunRangeResponse, } from '@/api/run' import { parseTsCode } from '@/utils/contract' import { useMobile } from '@/composables/useMobile' @@ -14,6 +16,8 @@ const { isMobile } = useMobile() const SYMBOLS = ['FG', 'SA', 'RB', 'MA', 'CF', 'M'] +const mode = ref<'single' | 'range'>('single') + const form = reactive<{ symbol: string trade_date: string @@ -22,20 +26,32 @@ const form = reactive<{ trade_date: '', }) +const range = reactive<{ + dates: [string, string] | [] +}>({ + dates: [], +}) + const active = ref(null) const activeLoading = ref(false) const loading = ref(false) const result = ref(null) +const rangeResult = ref(null) const resultRef = ref(null) async function loadActive() { activeLoading.value = true try { active.value = await getActiveContract(form.symbol) - // 切换品种后,如果原日期落在新合约的可选范围之外,清空它 if (form.trade_date && !isDateAllowed(toDate(form.trade_date))) { form.trade_date = '' } + if (Array.isArray(range.dates) && range.dates.length === 2) { + const [s, e] = range.dates + if (!isDateAllowed(toDate(s)) || !isDateAllowed(toDate(e))) { + range.dates = [] + } + } } catch (err: any) { active.value = null ElMessage.error(err?.response?.data?.error || '加载主力合约失败') @@ -45,7 +61,6 @@ async function loadActive() { } function toDate(s: string) { - // s 形如 'YYYY-MM-DD' const [y, m, d] = s.split('-').map(Number) return new Date(y, m - 1, d) } @@ -69,12 +84,29 @@ async function submit() { } loading.value = true result.value = null + rangeResult.value = null try { - const req: { symbol: string; trade_date?: string } = { symbol: form.symbol } - if (form.trade_date) req.trade_date = form.trade_date.replace(/-/g, '') - const resp = await runPipeline(req) - result.value = resp - ElMessage.success('打分完成') + if (mode.value === 'single') { + const req: { symbol: string; trade_date?: string } = { symbol: form.symbol } + if (form.trade_date) req.trade_date = form.trade_date.replace(/-/g, '') + const resp = await runPipeline(req) + result.value = resp + ElMessage.success('打分完成') + } else { + if (!Array.isArray(range.dates) || range.dates.length !== 2) { + ElMessage.warning('请选择日期区间') + loading.value = false + return + } + const [start, end] = range.dates + const resp = await runRange({ + symbol: form.symbol, + start_date: start.replace(/-/g, ''), + end_date: end.replace(/-/g, ''), + }) + rangeResult.value = resp + ElMessage.success(`区间打分完成,成功 ${resp.scored} 条`) + } await nextTick() resultRef.value?.scrollIntoView({ behavior: 'smooth', block: 'start' }) } catch (err: any) { @@ -104,6 +136,12 @@ onMounted(loadActive) 手动打分 + + + 单日打分 + 区间打分 + + @@ -118,7 +156,7 @@ onMounted(loadActive) 加载中… - + + + + - 执行打分 + {{ mode === 'single' ? '执行打分' : '批量打分' }} @@ -160,6 +212,41 @@ onMounted(loadActive) + +
+ + + + {{ parseTsCode(rangeResult.ts_code).symbol }} + {{ rangeResult.start_date }} ~ {{ rangeResult.end_date }} + {{ rangeResult.scored }} 条 + {{ rangeResult.skipped }} 条 + + +
+
{{ w }}
+
+
+ + + + + + + + +
+
@@ -169,4 +256,7 @@ onMounted(loadActive) flex-direction: column; gap: 16px; } +.result-card { + margin-top: 8px; +}