Compare commits

...

4 Commits

8 changed files with 302 additions and 24 deletions

View File

@@ -1,6 +1,6 @@
from typing import Optional from typing import Optional
from datetime import date from datetime import date, datetime, timedelta
from fastapi import FastAPI, HTTPException, Query from fastapi import FastAPI, HTTPException, Query
from pydantic import BaseModel from pydantic import BaseModel
@@ -16,6 +16,12 @@ class RunRequest(BaseModel):
trade_date: Optional[str] = None trade_date: Optional[str] = None
class RunRangeRequest(BaseModel):
symbol: str = "FG"
start_date: str
end_date: str
class RunResponse(BaseModel): class RunResponse(BaseModel):
ts_code: str ts_code: str
trade_date: str trade_date: str
@@ -41,7 +47,8 @@ def health():
@app.post("/api/v1/run", response_model=RunResponse) @app.post("/api/v1/run", response_model=RunResponse)
def run_pipeline(req: RunRequest): def run_pipeline(req: RunRequest):
ts_code = req.ts_code or contracts.active_contract(req.symbol) ref_date = datetime.strptime(req.trade_date, "%Y%m%d").date() if req.trade_date else None
ts_code = req.ts_code or contracts.active_contract(req.symbol, ref_date)
if not req.ts_code: if not req.ts_code:
print(f"[AUTO] {req.symbol} 当月主力 -> {ts_code}") print(f"[AUTO] {req.symbol} 当月主力 -> {ts_code}")
@@ -88,6 +95,42 @@ def run_batch():
return {"results": results, "errors": errors} return {"results": results, "errors": errors}
@app.post("/api/v1/run/range")
def run_range(req: RunRangeRequest):
"""对指定日期区间内的每一天分别打分。"""
start_dt = datetime.strptime(req.start_date, "%Y%m%d").date()
ts_code = contracts.active_contract(req.symbol, start_dt)
# 为确保区间开始日有足够前置数据,拉取时 start_date 前推 60 天
fetch_start = (start_dt - timedelta(days=60)).strftime("%Y%m%d")
df = fetcher.fetch_contract(ts_code, start_date=fetch_start, end_date=req.end_date)
storage.save_candles(df)
results, warnings = scorer.score_range(df, req.start_date, req.end_date)
for r in results:
storage.save_score(r)
return {
"ts_code": ts_code,
"start_date": req.start_date,
"end_date": req.end_date,
"scored": len(results),
"skipped": len(warnings),
"warnings": warnings,
"results": [
{
"trade_date": r.trade_date,
"close": r.close,
"composite": r.composite,
"signal": r.signal,
}
for r in results
],
}
@app.get("/api/v1/scores") @app.get("/api/v1/scores")
def list_scores( def list_scores(
ts_code: Optional[str] = Query(None), ts_code: Optional[str] = Query(None),

View File

@@ -11,10 +11,19 @@ def _init_api():
return ts.pro_api() return ts.pro_api()
def fetch_contract(ts_code: str, limit: int = 100) -> pd.DataFrame: def fetch_contract(
ts_code: str,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
) -> pd.DataFrame:
"""拉取指定期货合约的日线数据,返回按 trade_date 升序排列的 DataFrame。""" """拉取指定期货合约的日线数据,返回按 trade_date 升序排列的 DataFrame。"""
pro = _init_api() pro = _init_api()
df = pro.fut_daily(ts_code=ts_code) kwargs: dict = {"ts_code": ts_code}
if start_date:
kwargs["start_date"] = start_date
if end_date:
kwargs["end_date"] = end_date
df = pro.fut_daily(**kwargs)
if df.empty: if df.empty:
raise RuntimeError(f"未返回 {ts_code} 的任何数据,可能合约代码错误或 token 积分不足") raise RuntimeError(f"未返回 {ts_code} 的任何数据,可能合约代码错误或 token 积分不足")

View File

@@ -81,6 +81,52 @@ def run(ts_code: str, trade_date: Optional[str] = None) -> int:
return 0 return 0
def run_range(ts_code: str, start_date: str, end_date: str) -> int:
from datetime import datetime, timedelta
storage.init_db()
print(f"[1/4] 拉取 {ts_code} 数据 ({start_date} ~ {end_date})...")
start_dt = datetime.strptime(start_date, "%Y%m%d")
fetch_start = (start_dt - timedelta(days=60)).strftime("%Y%m%d")
df = fetcher.fetch_contract(ts_code, start_date=fetch_start, end_date=end_date)
print(f" 返回 {len(df)}")
print(f"[2/4] 写入/更新 PostgreSQL...")
storage.save_candles(df)
print(f"[3/4] 批量计算打分...")
results, warnings = scorer.score_range(df, start_date, end_date)
print(f"[4/4] 保存打分结果...")
for r in results:
storage.save_score(r)
print("\n" + "=" * 65)
print(f"合约: {ts_code}")
print(f"区间: {start_date} ~ {end_date}")
print(f"成功打分: {len(results)}")
if warnings:
print(f"跳过: {len(warnings)}")
for w in warnings[:5]:
print(f" - {w}")
if len(warnings) > 5:
print(f" ... 还有 {len(warnings) - 5}")
print("=" * 65)
quadrant_names = {
"accumulation": "增仓上涨", "distribution": "增仓下跌",
"covering": "减仓上涨", "liquidation": "减仓下跌", "flat": "持仓持平",
}
print(f"\n{'日期':<12} {'收盘':>10} {'综合':>8} {'信号':<20}")
print("-" * 55)
for r in results:
print(f"{r.trade_date:<12} {r.close:>10.2f} {r.composite:>8.1f} {r.signal:<20}")
print(f"\n[OK] {len(results)} 条打分已持久化到 PostgreSQL")
return 0
def main() -> int: def main() -> int:
parser = argparse.ArgumentParser(description="期货合约三层打分模型") parser = argparse.ArgumentParser(description="期货合约三层打分模型")
parser.add_argument( parser.add_argument(
@@ -97,11 +143,26 @@ def main() -> int:
"--date", "--date",
help="指定打分日期,格式 YYYYMMDD,不传则对最新日期打分", help="指定打分日期,格式 YYYYMMDD,不传则对最新日期打分",
) )
parser.add_argument(
"--start-date",
dest="start_date",
help="区间打分开始日期,格式 YYYYMMDD(与 --end-date 同时使用时生效)",
)
parser.add_argument(
"--end-date",
dest="end_date",
help="区间打分结束日期,格式 YYYYMMDD(与 --start-date 同时使用时生效)",
)
args = parser.parse_args() args = parser.parse_args()
ts_code = args.ts_code or contracts.active_contract(args.symbol) ts_code = args.ts_code or contracts.active_contract(args.symbol)
if not args.ts_code: if not args.ts_code:
print(f"[AUTO] {args.symbol} 当月主力 -> {ts_code}") print(f"[AUTO] {args.symbol} 当月主力 -> {ts_code}")
if args.start_date and args.end_date:
print(f"[RANGE] 区间打分: {args.start_date} ~ {args.end_date}")
return run_range(ts_code, args.start_date, args.end_date)
if args.date: if args.date:
print(f"[DATE] 指定打分日期: {args.date}") print(f"[DATE] 指定打分日期: {args.date}")
return run(ts_code, args.date) return run(ts_code, args.date)

View File

@@ -229,3 +229,28 @@ def score_daily(df: pd.DataFrame, trade_date: Optional[str] = None) -> ScoreResu
}, },
), ),
) )
def score_range(
df: pd.DataFrame, start_date: str, end_date: str
) -> tuple[list[ScoreResult], list[str]]:
"""对日期区间内的每一天分别打分,返回 (结果列表, 警告列表)。"""
if len(df) < 31:
raise ValueError(f"数据量不足(仅 {len(df)} 行),需要至少 31 行")
results: list[ScoreResult] = []
warnings: list[str] = []
target_dates = df[
(df["trade_date"].astype(str) >= str(start_date))
& (df["trade_date"].astype(str) <= str(end_date))
]["trade_date"].astype(str).tolist()
for td in target_dates:
try:
result = score_daily(df, td)
results.append(result)
except ValueError as e:
warnings.append(str(e))
return results, warnings

View File

@@ -56,6 +56,38 @@ func (d *Deps) RunBatch(w http.ResponseWriter, r *http.Request) {
_, _ = io.Copy(w, resp.Body) _, _ = io.Copy(w, resp.Body)
} }
type runRangeRequest struct {
Symbol string `json:"symbol,omitempty"`
StartDate string `json:"start_date,omitempty"`
EndDate string `json:"end_date,omitempty"`
}
func (d *Deps) RunRange(w http.ResponseWriter, r *http.Request) {
var req runRangeRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeErr(w, http.StatusBadRequest, "invalid json")
return
}
body, err := json.Marshal(req)
if err != nil {
writeErr(w, http.StatusInternalServerError, "encode request failed")
return
}
client := &http.Client{Timeout: 180 * time.Second}
resp, err := client.Post(d.TushareURL+"/api/v1/run/range", "application/json", bytes.NewReader(body))
if err != nil {
writeErr(w, http.StatusBadGateway, fmt.Sprintf("tushare service unavailable: %v", err))
return
}
defer resp.Body.Close()
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(resp.StatusCode)
_, _ = io.Copy(w, resp.Body)
}
func (d *Deps) GetActiveContract(w http.ResponseWriter, r *http.Request) { func (d *Deps) GetActiveContract(w http.ResponseWriter, r *http.Request) {
symbol := r.URL.Query().Get("symbol") symbol := r.URL.Query().Get("symbol")
if symbol == "" { if symbol == "" {

View File

@@ -32,6 +32,7 @@ func New(d *handlers.Deps, dist fs.FS) http.Handler {
r.Get("/candles", d.ListCandles) r.Get("/candles", d.ListCandles)
r.Post("/run", d.RunPipeline) r.Post("/run", d.RunPipeline)
r.Post("/run/batch", d.RunBatch) r.Post("/run/batch", d.RunBatch)
r.Post("/run/range", d.RunRange)
r.Group(func(r chi.Router) { r.Group(func(r chi.Router) {
r.Use(mw.RequireAdmin) r.Use(mw.RequireAdmin)

View File

@@ -30,6 +30,33 @@ export function runPipeline(req: RunRequest) {
return client.post<RunResponse>('/run', req).then((r) => r.data) return client.post<RunResponse>('/run', req).then((r) => r.data)
} }
export interface RunRangeRequest {
symbol: string
start_date: string
end_date: string
}
export interface RunRangeResult {
trade_date: string
close: number
composite: number
signal: string
}
export interface RunRangeResponse {
ts_code: string
start_date: string
end_date: string
scored: number
skipped: number
warnings: string[]
results: RunRangeResult[]
}
export function runRange(req: RunRangeRequest) {
return client.post<RunRangeResponse>('/run/range', req, { timeout: 180_000 }).then((r) => r.data)
}
export function runBatch() { export function runBatch() {
return client.post('/run/batch', null, { timeout: 180_000 }).then((r) => r.data) return client.post('/run/batch', null, { timeout: 180_000 }).then((r) => r.data)
} }

View File

@@ -3,9 +3,11 @@ import { nextTick, onMounted, reactive, ref, watch } from 'vue'
import { ElMessage } from 'element-plus' import { ElMessage } from 'element-plus'
import { import {
runPipeline, runPipeline,
runRange,
getActiveContract, getActiveContract,
type ActiveContract, type ActiveContract,
type RunResponse, type RunResponse,
type RunRangeResponse,
} from '@/api/run' } from '@/api/run'
import { parseTsCode } from '@/utils/contract' import { parseTsCode } from '@/utils/contract'
import { useMobile } from '@/composables/useMobile' import { useMobile } from '@/composables/useMobile'
@@ -14,6 +16,8 @@ const { isMobile } = useMobile()
const SYMBOLS = ['FG', 'SA', 'RB', 'MA', 'CF', 'M'] const SYMBOLS = ['FG', 'SA', 'RB', 'MA', 'CF', 'M']
const mode = ref<'single' | 'range'>('single')
const form = reactive<{ const form = reactive<{
symbol: string symbol: string
trade_date: string trade_date: string
@@ -22,20 +26,32 @@ const form = reactive<{
trade_date: '', trade_date: '',
}) })
const range = reactive<{
dates: [string, string] | []
}>({
dates: [],
})
const active = ref<ActiveContract | null>(null) const active = ref<ActiveContract | null>(null)
const activeLoading = ref(false) const activeLoading = ref(false)
const loading = ref(false) const loading = ref(false)
const result = ref<RunResponse | null>(null) const result = ref<RunResponse | null>(null)
const rangeResult = ref<RunRangeResponse | null>(null)
const resultRef = ref<HTMLElement | null>(null) const resultRef = ref<HTMLElement | null>(null)
async function loadActive() { async function loadActive() {
activeLoading.value = true activeLoading.value = true
try { try {
active.value = await getActiveContract(form.symbol) active.value = await getActiveContract(form.symbol)
// 切换品种后,如果原日期落在新合约的可选范围之外,清空它
if (form.trade_date && !isDateAllowed(toDate(form.trade_date))) { if (form.trade_date && !isDateAllowed(toDate(form.trade_date))) {
form.trade_date = '' form.trade_date = ''
} }
if (Array.isArray(range.dates) && range.dates.length === 2) {
const [s, e] = range.dates
if (!isDateAllowed(toDate(s)) || !isDateAllowed(toDate(e))) {
range.dates = []
}
}
} catch (err: any) { } catch (err: any) {
active.value = null active.value = null
ElMessage.error(err?.response?.data?.error || '加载主力合约失败') ElMessage.error(err?.response?.data?.error || '加载主力合约失败')
@@ -45,17 +61,15 @@ async function loadActive() {
} }
function toDate(s: string) { function toDate(s: string) {
// s 形如 'YYYY-MM-DD'
const [y, m, d] = s.split('-').map(Number) const [y, m, d] = s.split('-').map(Number)
return new Date(y, m - 1, d) return new Date(y, m - 1, d)
} }
function isDateAllowed(d: Date): boolean { function isDateAllowed(d: Date): boolean {
if (!active.value) return true if (!active.value) return true
const min = toDate(active.value.min_date).getTime()
const max = toDate(active.value.max_date).getTime() const max = toDate(active.value.max_date).getTime()
const t = d.getTime() const t = d.getTime()
return t >= min && t <= max return t <= max
} }
function disabledDate(d: Date) { function disabledDate(d: Date) {
@@ -69,12 +83,29 @@ async function submit() {
} }
loading.value = true loading.value = true
result.value = null result.value = null
rangeResult.value = null
try { try {
const req: { symbol: string; trade_date?: string } = { symbol: form.symbol } if (mode.value === 'single') {
if (form.trade_date) req.trade_date = form.trade_date.replace(/-/g, '') const req: { symbol: string; trade_date?: string } = { symbol: form.symbol }
const resp = await runPipeline(req) if (form.trade_date) req.trade_date = form.trade_date.replace(/-/g, '')
result.value = resp const resp = await runPipeline(req)
ElMessage.success('打分完成') result.value = resp
ElMessage.success('打分完成')
} else {
if (!Array.isArray(range.dates) || range.dates.length !== 2) {
ElMessage.warning('请选择日期区间')
loading.value = false
return
}
const [start, end] = range.dates
const resp = await runRange({
symbol: form.symbol,
start_date: start.replace(/-/g, ''),
end_date: end.replace(/-/g, ''),
})
rangeResult.value = resp
ElMessage.success(`区间打分完成,成功 ${resp.scored}`)
}
await nextTick() await nextTick()
resultRef.value?.scrollIntoView({ behavior: 'smooth', block: 'start' }) resultRef.value?.scrollIntoView({ behavior: 'smooth', block: 'start' })
} catch (err: any) { } catch (err: any) {
@@ -104,21 +135,18 @@ onMounted(loadActive)
<span>手动打分</span> <span>手动打分</span>
</template> </template>
<el-form :model="form" label-width="100px" style="max-width: 480px"> <el-form :model="form" label-width="100px" style="max-width: 480px">
<el-form-item label="模式">
<el-radio-group v-model="mode">
<el-radio-button label="single">单日打分</el-radio-button>
<el-radio-button label="range">区间打分</el-radio-button>
</el-radio-group>
</el-form-item>
<el-form-item label="品种"> <el-form-item label="品种">
<el-select v-model="form.symbol" :loading="activeLoading" style="width: 100%"> <el-select v-model="form.symbol" :loading="activeLoading" style="width: 100%">
<el-option v-for="s in SYMBOLS" :key="s" :label="s" :value="s" /> <el-option v-for="s in SYMBOLS" :key="s" :label="s" :value="s" />
</el-select> </el-select>
</el-form-item> </el-form-item>
<el-form-item label="主力合约"> <el-form-item v-if="mode === 'single'" label="打分日期">
<span v-if="active">
{{ parseTsCode(active.ts_code).contract }}
<el-text type="info" size="small" style="margin-left: 8px">
({{ active.ts_code }})
</el-text>
</span>
<el-text v-else type="info">加载中</el-text>
</el-form-item>
<el-form-item label="打分日期">
<el-date-picker <el-date-picker
v-model="form.trade_date" v-model="form.trade_date"
type="date" type="date"
@@ -129,9 +157,23 @@ onMounted(loadActive)
style="width: 100%" style="width: 100%"
/> />
</el-form-item> </el-form-item>
<el-form-item v-else label="日期区间">
<el-date-picker
v-model="range.dates"
type="daterange"
:placeholder="active ? `${active.min_date} ~ ${active.max_date}` : '加载中…'"
value-format="YYYY-MM-DD"
:disabled-date="disabledDate"
:disabled="!active"
range-separator=""
start-placeholder="开始"
end-placeholder="结束"
style="width: 100%"
/>
</el-form-item>
<el-form-item> <el-form-item>
<el-button type="primary" :loading="loading" :disabled="!active" @click="submit"> <el-button type="primary" :loading="loading" :disabled="!active" @click="submit">
执行打分 {{ mode === 'single' ? '执行打分' : '批量打分' }}
</el-button> </el-button>
</el-form-item> </el-form-item>
</el-form> </el-form>
@@ -160,6 +202,41 @@ onMounted(loadActive)
</el-descriptions> </el-descriptions>
</el-card> </el-card>
</div> </div>
<div v-if="rangeResult" ref="resultRef">
<el-card shadow="never" class="result-card">
<template #header>
<span>区间打分结果</span>
</template>
<el-descriptions :column="isMobile ? 1 : 2" border>
<el-descriptions-item label="合约">{{ parseTsCode(rangeResult.ts_code).symbol }}</el-descriptions-item>
<el-descriptions-item label="区间">{{ rangeResult.start_date }} ~ {{ rangeResult.end_date }}</el-descriptions-item>
<el-descriptions-item label="成功">{{ rangeResult.scored }} </el-descriptions-item>
<el-descriptions-item label="跳过">{{ rangeResult.skipped }} </el-descriptions-item>
</el-descriptions>
<el-alert
v-if="rangeResult.warnings.length > 0"
:title="`警告 (${rangeResult.warnings.length} 条)`"
type="warning"
:closable="false"
style="margin-top: 12px"
>
<div style="max-height: 120px; overflow-y: auto">
<div v-for="(w, i) in rangeResult.warnings" :key="i" style="font-size: 12px">{{ w }}</div>
</div>
</el-alert>
<el-table :data="rangeResult.results" stripe style="margin-top: 16px" max-height="400">
<el-table-column prop="trade_date" label="日期" width="110" />
<el-table-column prop="close" label="收盘" width="90" />
<el-table-column prop="composite" label="综合" width="80" />
<el-table-column label="信号">
<template #default="{ row }">
<el-tag :type="signalTagType(row.signal)" size="small">{{ row.signal }}</el-tag>
</template>
</el-table-column>
</el-table>
</el-card>
</div>
</div> </div>
</template> </template>
@@ -169,4 +246,7 @@ onMounted(loadActive)
flex-direction: column; flex-direction: column;
gap: 16px; gap: 16px;
} }
.result-card {
margin-top: 8px;
}
</style> </style>