新增自定义时间段批量打分功能:支持设置日期区间,对区间内每天自动拉取数据并打分

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
fish
2026-05-07 21:37:20 +08:00
parent 7aa74dc9bc
commit 01edda923a
8 changed files with 293 additions and 12 deletions

View File

@@ -1,6 +1,6 @@
from typing import Optional
from datetime import date
from datetime import date, datetime, timedelta
from fastapi import FastAPI, HTTPException, Query
from pydantic import BaseModel
@@ -16,6 +16,12 @@ class RunRequest(BaseModel):
trade_date: Optional[str] = None
class RunRangeRequest(BaseModel):
symbol: str = "FG"
start_date: str
end_date: str
class RunResponse(BaseModel):
ts_code: str
trade_date: str
@@ -88,6 +94,42 @@ def run_batch():
return {"results": results, "errors": errors}
@app.post("/api/v1/run/range")
def run_range(req: RunRangeRequest):
"""对指定日期区间内的每一天分别打分。"""
ts_code = contracts.active_contract(req.symbol)
# 为确保区间开始日有足够前置数据,拉取时 start_date 前推 60 天
start_dt = datetime.strptime(req.start_date, "%Y%m%d")
fetch_start = (start_dt - timedelta(days=60)).strftime("%Y%m%d")
df = fetcher.fetch_contract(ts_code, start_date=fetch_start, end_date=req.end_date)
storage.save_candles(df)
results, warnings = scorer.score_range(df, req.start_date, req.end_date)
for r in results:
storage.save_score(r)
return {
"ts_code": ts_code,
"start_date": req.start_date,
"end_date": req.end_date,
"scored": len(results),
"skipped": len(warnings),
"warnings": warnings,
"results": [
{
"trade_date": r.trade_date,
"close": r.close,
"composite": r.composite,
"signal": r.signal,
}
for r in results
],
}
@app.get("/api/v1/scores")
def list_scores(
ts_code: Optional[str] = Query(None),

View File

@@ -11,10 +11,19 @@ def _init_api():
return ts.pro_api()
def fetch_contract(ts_code: str, limit: int = 100) -> pd.DataFrame:
def fetch_contract(
ts_code: str,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
) -> pd.DataFrame:
"""拉取指定期货合约的日线数据,返回按 trade_date 升序排列的 DataFrame。"""
pro = _init_api()
df = pro.fut_daily(ts_code=ts_code)
kwargs: dict = {"ts_code": ts_code}
if start_date:
kwargs["start_date"] = start_date
if end_date:
kwargs["end_date"] = end_date
df = pro.fut_daily(**kwargs)
if df.empty:
raise RuntimeError(f"未返回 {ts_code} 的任何数据,可能合约代码错误或 token 积分不足")

View File

@@ -81,6 +81,52 @@ def run(ts_code: str, trade_date: Optional[str] = None) -> int:
return 0
def run_range(ts_code: str, start_date: str, end_date: str) -> int:
from datetime import datetime, timedelta
storage.init_db()
print(f"[1/4] 拉取 {ts_code} 数据 ({start_date} ~ {end_date})...")
start_dt = datetime.strptime(start_date, "%Y%m%d")
fetch_start = (start_dt - timedelta(days=60)).strftime("%Y%m%d")
df = fetcher.fetch_contract(ts_code, start_date=fetch_start, end_date=end_date)
print(f" 返回 {len(df)}")
print(f"[2/4] 写入/更新 PostgreSQL...")
storage.save_candles(df)
print(f"[3/4] 批量计算打分...")
results, warnings = scorer.score_range(df, start_date, end_date)
print(f"[4/4] 保存打分结果...")
for r in results:
storage.save_score(r)
print("\n" + "=" * 65)
print(f"合约: {ts_code}")
print(f"区间: {start_date} ~ {end_date}")
print(f"成功打分: {len(results)}")
if warnings:
print(f"跳过: {len(warnings)}")
for w in warnings[:5]:
print(f" - {w}")
if len(warnings) > 5:
print(f" ... 还有 {len(warnings) - 5}")
print("=" * 65)
quadrant_names = {
"accumulation": "增仓上涨", "distribution": "增仓下跌",
"covering": "减仓上涨", "liquidation": "减仓下跌", "flat": "持仓持平",
}
print(f"\n{'日期':<12} {'收盘':>10} {'综合':>8} {'信号':<20}")
print("-" * 55)
for r in results:
print(f"{r.trade_date:<12} {r.close:>10.2f} {r.composite:>8.1f} {r.signal:<20}")
print(f"\n[OK] {len(results)} 条打分已持久化到 PostgreSQL")
return 0
def main() -> int:
parser = argparse.ArgumentParser(description="期货合约三层打分模型")
parser.add_argument(
@@ -97,11 +143,26 @@ def main() -> int:
"--date",
help="指定打分日期,格式 YYYYMMDD,不传则对最新日期打分",
)
parser.add_argument(
"--start-date",
dest="start_date",
help="区间打分开始日期,格式 YYYYMMDD(与 --end-date 同时使用时生效)",
)
parser.add_argument(
"--end-date",
dest="end_date",
help="区间打分结束日期,格式 YYYYMMDD(与 --start-date 同时使用时生效)",
)
args = parser.parse_args()
ts_code = args.ts_code or contracts.active_contract(args.symbol)
if not args.ts_code:
print(f"[AUTO] {args.symbol} 当月主力 -> {ts_code}")
if args.start_date and args.end_date:
print(f"[RANGE] 区间打分: {args.start_date} ~ {args.end_date}")
return run_range(ts_code, args.start_date, args.end_date)
if args.date:
print(f"[DATE] 指定打分日期: {args.date}")
return run(ts_code, args.date)

View File

@@ -229,3 +229,28 @@ def score_daily(df: pd.DataFrame, trade_date: Optional[str] = None) -> ScoreResu
},
),
)
def score_range(
df: pd.DataFrame, start_date: str, end_date: str
) -> tuple[list[ScoreResult], list[str]]:
"""对日期区间内的每一天分别打分,返回 (结果列表, 警告列表)。"""
if len(df) < 31:
raise ValueError(f"数据量不足(仅 {len(df)} 行),需要至少 31 行")
results: list[ScoreResult] = []
warnings: list[str] = []
target_dates = df[
(df["trade_date"].astype(str) >= str(start_date))
& (df["trade_date"].astype(str) <= str(end_date))
]["trade_date"].astype(str).tolist()
for td in target_dates:
try:
result = score_daily(df, td)
results.append(result)
except ValueError as e:
warnings.append(str(e))
return results, warnings

View File

@@ -56,6 +56,32 @@ func (d *Deps) RunBatch(w http.ResponseWriter, r *http.Request) {
_, _ = io.Copy(w, resp.Body)
}
func (d *Deps) RunRange(w http.ResponseWriter, r *http.Request) {
var req runRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeErr(w, http.StatusBadRequest, "invalid json")
return
}
body, err := json.Marshal(req)
if err != nil {
writeErr(w, http.StatusInternalServerError, "encode request failed")
return
}
client := &http.Client{Timeout: 180 * time.Second}
resp, err := client.Post(d.TushareURL+"/api/v1/run/range", "application/json", bytes.NewReader(body))
if err != nil {
writeErr(w, http.StatusBadGateway, fmt.Sprintf("tushare service unavailable: %v", err))
return
}
defer resp.Body.Close()
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(resp.StatusCode)
_, _ = io.Copy(w, resp.Body)
}
func (d *Deps) GetActiveContract(w http.ResponseWriter, r *http.Request) {
symbol := r.URL.Query().Get("symbol")
if symbol == "" {

View File

@@ -32,6 +32,7 @@ func New(d *handlers.Deps, dist fs.FS) http.Handler {
r.Get("/candles", d.ListCandles)
r.Post("/run", d.RunPipeline)
r.Post("/run/batch", d.RunBatch)
r.Post("/run/range", d.RunRange)
r.Group(func(r chi.Router) {
r.Use(mw.RequireAdmin)

View File

@@ -30,6 +30,33 @@ export function runPipeline(req: RunRequest) {
return client.post<RunResponse>('/run', req).then((r) => r.data)
}
export interface RunRangeRequest {
symbol: string
start_date: string
end_date: string
}
export interface RunRangeResult {
trade_date: string
close: number
composite: number
signal: string
}
export interface RunRangeResponse {
ts_code: string
start_date: string
end_date: string
scored: number
skipped: number
warnings: string[]
results: RunRangeResult[]
}
export function runRange(req: RunRangeRequest) {
return client.post<RunRangeResponse>('/run/range', req, { timeout: 180_000 }).then((r) => r.data)
}
export function runBatch() {
return client.post('/run/batch', null, { timeout: 180_000 }).then((r) => r.data)
}

View File

@@ -3,9 +3,11 @@ import { nextTick, onMounted, reactive, ref, watch } from 'vue'
import { ElMessage } from 'element-plus'
import {
runPipeline,
runRange,
getActiveContract,
type ActiveContract,
type RunResponse,
type RunRangeResponse,
} from '@/api/run'
import { parseTsCode } from '@/utils/contract'
import { useMobile } from '@/composables/useMobile'
@@ -14,6 +16,8 @@ const { isMobile } = useMobile()
const SYMBOLS = ['FG', 'SA', 'RB', 'MA', 'CF', 'M']
const mode = ref<'single' | 'range'>('single')
const form = reactive<{
symbol: string
trade_date: string
@@ -22,20 +26,32 @@ const form = reactive<{
trade_date: '',
})
const range = reactive<{
dates: [string, string] | []
}>({
dates: [],
})
const active = ref<ActiveContract | null>(null)
const activeLoading = ref(false)
const loading = ref(false)
const result = ref<RunResponse | null>(null)
const rangeResult = ref<RunRangeResponse | null>(null)
const resultRef = ref<HTMLElement | null>(null)
async function loadActive() {
activeLoading.value = true
try {
active.value = await getActiveContract(form.symbol)
// 切换品种后,如果原日期落在新合约的可选范围之外,清空它
if (form.trade_date && !isDateAllowed(toDate(form.trade_date))) {
form.trade_date = ''
}
if (Array.isArray(range.dates) && range.dates.length === 2) {
const [s, e] = range.dates
if (!isDateAllowed(toDate(s)) || !isDateAllowed(toDate(e))) {
range.dates = []
}
}
} catch (err: any) {
active.value = null
ElMessage.error(err?.response?.data?.error || '加载主力合约失败')
@@ -45,7 +61,6 @@ async function loadActive() {
}
function toDate(s: string) {
// s 形如 'YYYY-MM-DD'
const [y, m, d] = s.split('-').map(Number)
return new Date(y, m - 1, d)
}
@@ -69,12 +84,29 @@ async function submit() {
}
loading.value = true
result.value = null
rangeResult.value = null
try {
if (mode.value === 'single') {
const req: { symbol: string; trade_date?: string } = { symbol: form.symbol }
if (form.trade_date) req.trade_date = form.trade_date.replace(/-/g, '')
const resp = await runPipeline(req)
result.value = resp
ElMessage.success('打分完成')
} else {
if (!Array.isArray(range.dates) || range.dates.length !== 2) {
ElMessage.warning('请选择日期区间')
loading.value = false
return
}
const [start, end] = range.dates
const resp = await runRange({
symbol: form.symbol,
start_date: start.replace(/-/g, ''),
end_date: end.replace(/-/g, ''),
})
rangeResult.value = resp
ElMessage.success(`区间打分完成,成功 ${resp.scored}`)
}
await nextTick()
resultRef.value?.scrollIntoView({ behavior: 'smooth', block: 'start' })
} catch (err: any) {
@@ -104,6 +136,12 @@ onMounted(loadActive)
<span>手动打分</span>
</template>
<el-form :model="form" label-width="100px" style="max-width: 480px">
<el-form-item label="模式">
<el-radio-group v-model="mode">
<el-radio-button label="single">单日打分</el-radio-button>
<el-radio-button label="range">区间打分</el-radio-button>
</el-radio-group>
</el-form-item>
<el-form-item label="品种">
<el-select v-model="form.symbol" :loading="activeLoading" style="width: 100%">
<el-option v-for="s in SYMBOLS" :key="s" :label="s" :value="s" />
@@ -118,7 +156,7 @@ onMounted(loadActive)
</span>
<el-text v-else type="info">加载中</el-text>
</el-form-item>
<el-form-item label="打分日期">
<el-form-item v-if="mode === 'single'" label="打分日期">
<el-date-picker
v-model="form.trade_date"
type="date"
@@ -129,9 +167,23 @@ onMounted(loadActive)
style="width: 100%"
/>
</el-form-item>
<el-form-item v-else label="日期区间">
<el-date-picker
v-model="range.dates"
type="daterange"
:placeholder="active ? `${active.min_date} ~ ${active.max_date}` : '加载中…'"
value-format="YYYY-MM-DD"
:disabled-date="disabledDate"
:disabled="!active"
range-separator=""
start-placeholder="开始"
end-placeholder="结束"
style="width: 100%"
/>
</el-form-item>
<el-form-item>
<el-button type="primary" :loading="loading" :disabled="!active" @click="submit">
执行打分
{{ mode === 'single' ? '执行打分' : '批量打分' }}
</el-button>
</el-form-item>
</el-form>
@@ -160,6 +212,41 @@ onMounted(loadActive)
</el-descriptions>
</el-card>
</div>
<div v-if="rangeResult" ref="resultRef">
<el-card shadow="never" class="result-card">
<template #header>
<span>区间打分结果</span>
</template>
<el-descriptions :column="isMobile ? 1 : 2" border>
<el-descriptions-item label="合约">{{ parseTsCode(rangeResult.ts_code).symbol }}</el-descriptions-item>
<el-descriptions-item label="区间">{{ rangeResult.start_date }} ~ {{ rangeResult.end_date }}</el-descriptions-item>
<el-descriptions-item label="成功">{{ rangeResult.scored }} </el-descriptions-item>
<el-descriptions-item label="跳过">{{ rangeResult.skipped }} </el-descriptions-item>
</el-descriptions>
<el-alert
v-if="rangeResult.warnings.length > 0"
:title="`警告 (${rangeResult.warnings.length} 条)`"
type="warning"
:closable="false"
style="margin-top: 12px"
>
<div style="max-height: 120px; overflow-y: auto">
<div v-for="(w, i) in rangeResult.warnings" :key="i" style="font-size: 12px">{{ w }}</div>
</div>
</el-alert>
<el-table :data="rangeResult.results" stripe style="margin-top: 16px" max-height="400">
<el-table-column prop="trade_date" label="日期" width="110" />
<el-table-column prop="close" label="收盘" width="90" />
<el-table-column prop="composite" label="综合" width="80" />
<el-table-column label="信号">
<template #default="{ row }">
<el-tag :type="signalTagType(row.signal)" size="small">{{ row.signal }}</el-tag>
</template>
</el-table-column>
</el-table>
</el-card>
</div>
</div>
</template>
@@ -169,4 +256,7 @@ onMounted(loadActive)
flex-direction: column;
gap: 16px;
}
.result-card {
margin-top: 8px;
}
</style>