Compare commits
4 Commits
7aa74dc9bc
...
9d2997a3cb
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9d2997a3cb | ||
|
|
cdf793608d | ||
|
|
c54ba5a470 | ||
|
|
01edda923a |
@@ -1,6 +1,6 @@
|
||||
from typing import Optional
|
||||
|
||||
from datetime import date
|
||||
from datetime import date, datetime, timedelta
|
||||
|
||||
from fastapi import FastAPI, HTTPException, Query
|
||||
from pydantic import BaseModel
|
||||
@@ -16,6 +16,12 @@ class RunRequest(BaseModel):
|
||||
trade_date: Optional[str] = None
|
||||
|
||||
|
||||
class RunRangeRequest(BaseModel):
|
||||
symbol: str = "FG"
|
||||
start_date: str
|
||||
end_date: str
|
||||
|
||||
|
||||
class RunResponse(BaseModel):
|
||||
ts_code: str
|
||||
trade_date: str
|
||||
@@ -41,7 +47,8 @@ def health():
|
||||
|
||||
@app.post("/api/v1/run", response_model=RunResponse)
|
||||
def run_pipeline(req: RunRequest):
|
||||
ts_code = req.ts_code or contracts.active_contract(req.symbol)
|
||||
ref_date = datetime.strptime(req.trade_date, "%Y%m%d").date() if req.trade_date else None
|
||||
ts_code = req.ts_code or contracts.active_contract(req.symbol, ref_date)
|
||||
if not req.ts_code:
|
||||
print(f"[AUTO] {req.symbol} 当月主力 -> {ts_code}")
|
||||
|
||||
@@ -88,6 +95,42 @@ def run_batch():
|
||||
return {"results": results, "errors": errors}
|
||||
|
||||
|
||||
@app.post("/api/v1/run/range")
|
||||
def run_range(req: RunRangeRequest):
|
||||
"""对指定日期区间内的每一天分别打分。"""
|
||||
start_dt = datetime.strptime(req.start_date, "%Y%m%d").date()
|
||||
ts_code = contracts.active_contract(req.symbol, start_dt)
|
||||
|
||||
# 为确保区间开始日有足够前置数据,拉取时 start_date 前推 60 天
|
||||
fetch_start = (start_dt - timedelta(days=60)).strftime("%Y%m%d")
|
||||
|
||||
df = fetcher.fetch_contract(ts_code, start_date=fetch_start, end_date=req.end_date)
|
||||
storage.save_candles(df)
|
||||
|
||||
results, warnings = scorer.score_range(df, req.start_date, req.end_date)
|
||||
|
||||
for r in results:
|
||||
storage.save_score(r)
|
||||
|
||||
return {
|
||||
"ts_code": ts_code,
|
||||
"start_date": req.start_date,
|
||||
"end_date": req.end_date,
|
||||
"scored": len(results),
|
||||
"skipped": len(warnings),
|
||||
"warnings": warnings,
|
||||
"results": [
|
||||
{
|
||||
"trade_date": r.trade_date,
|
||||
"close": r.close,
|
||||
"composite": r.composite,
|
||||
"signal": r.signal,
|
||||
}
|
||||
for r in results
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@app.get("/api/v1/scores")
|
||||
def list_scores(
|
||||
ts_code: Optional[str] = Query(None),
|
||||
|
||||
@@ -11,10 +11,19 @@ def _init_api():
|
||||
return ts.pro_api()
|
||||
|
||||
|
||||
def fetch_contract(ts_code: str, limit: int = 100) -> pd.DataFrame:
|
||||
def fetch_contract(
|
||||
ts_code: str,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
) -> pd.DataFrame:
|
||||
"""拉取指定期货合约的日线数据,返回按 trade_date 升序排列的 DataFrame。"""
|
||||
pro = _init_api()
|
||||
df = pro.fut_daily(ts_code=ts_code)
|
||||
kwargs: dict = {"ts_code": ts_code}
|
||||
if start_date:
|
||||
kwargs["start_date"] = start_date
|
||||
if end_date:
|
||||
kwargs["end_date"] = end_date
|
||||
df = pro.fut_daily(**kwargs)
|
||||
|
||||
if df.empty:
|
||||
raise RuntimeError(f"未返回 {ts_code} 的任何数据,可能合约代码错误或 token 积分不足")
|
||||
|
||||
@@ -81,6 +81,52 @@ def run(ts_code: str, trade_date: Optional[str] = None) -> int:
|
||||
return 0
|
||||
|
||||
|
||||
def run_range(ts_code: str, start_date: str, end_date: str) -> int:
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
storage.init_db()
|
||||
|
||||
print(f"[1/4] 拉取 {ts_code} 数据 ({start_date} ~ {end_date})...")
|
||||
start_dt = datetime.strptime(start_date, "%Y%m%d")
|
||||
fetch_start = (start_dt - timedelta(days=60)).strftime("%Y%m%d")
|
||||
df = fetcher.fetch_contract(ts_code, start_date=fetch_start, end_date=end_date)
|
||||
print(f" 返回 {len(df)} 行")
|
||||
|
||||
print(f"[2/4] 写入/更新 PostgreSQL...")
|
||||
storage.save_candles(df)
|
||||
|
||||
print(f"[3/4] 批量计算打分...")
|
||||
results, warnings = scorer.score_range(df, start_date, end_date)
|
||||
|
||||
print(f"[4/4] 保存打分结果...")
|
||||
for r in results:
|
||||
storage.save_score(r)
|
||||
|
||||
print("\n" + "=" * 65)
|
||||
print(f"合约: {ts_code}")
|
||||
print(f"区间: {start_date} ~ {end_date}")
|
||||
print(f"成功打分: {len(results)} 条")
|
||||
if warnings:
|
||||
print(f"跳过: {len(warnings)} 条")
|
||||
for w in warnings[:5]:
|
||||
print(f" - {w}")
|
||||
if len(warnings) > 5:
|
||||
print(f" ... 还有 {len(warnings) - 5} 条")
|
||||
print("=" * 65)
|
||||
|
||||
quadrant_names = {
|
||||
"accumulation": "增仓上涨", "distribution": "增仓下跌",
|
||||
"covering": "减仓上涨", "liquidation": "减仓下跌", "flat": "持仓持平",
|
||||
}
|
||||
print(f"\n{'日期':<12} {'收盘':>10} {'综合':>8} {'信号':<20}")
|
||||
print("-" * 55)
|
||||
for r in results:
|
||||
print(f"{r.trade_date:<12} {r.close:>10.2f} {r.composite:>8.1f} {r.signal:<20}")
|
||||
|
||||
print(f"\n[OK] {len(results)} 条打分已持久化到 PostgreSQL")
|
||||
return 0
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = argparse.ArgumentParser(description="期货合约三层打分模型")
|
||||
parser.add_argument(
|
||||
@@ -97,11 +143,26 @@ def main() -> int:
|
||||
"--date",
|
||||
help="指定打分日期,格式 YYYYMMDD,不传则对最新日期打分",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--start-date",
|
||||
dest="start_date",
|
||||
help="区间打分开始日期,格式 YYYYMMDD(与 --end-date 同时使用时生效)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--end-date",
|
||||
dest="end_date",
|
||||
help="区间打分结束日期,格式 YYYYMMDD(与 --start-date 同时使用时生效)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
ts_code = args.ts_code or contracts.active_contract(args.symbol)
|
||||
if not args.ts_code:
|
||||
print(f"[AUTO] {args.symbol} 当月主力 -> {ts_code}")
|
||||
|
||||
if args.start_date and args.end_date:
|
||||
print(f"[RANGE] 区间打分: {args.start_date} ~ {args.end_date}")
|
||||
return run_range(ts_code, args.start_date, args.end_date)
|
||||
|
||||
if args.date:
|
||||
print(f"[DATE] 指定打分日期: {args.date}")
|
||||
return run(ts_code, args.date)
|
||||
|
||||
@@ -229,3 +229,28 @@ def score_daily(df: pd.DataFrame, trade_date: Optional[str] = None) -> ScoreResu
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def score_range(
|
||||
df: pd.DataFrame, start_date: str, end_date: str
|
||||
) -> tuple[list[ScoreResult], list[str]]:
|
||||
"""对日期区间内的每一天分别打分,返回 (结果列表, 警告列表)。"""
|
||||
if len(df) < 31:
|
||||
raise ValueError(f"数据量不足(仅 {len(df)} 行),需要至少 31 行")
|
||||
|
||||
results: list[ScoreResult] = []
|
||||
warnings: list[str] = []
|
||||
|
||||
target_dates = df[
|
||||
(df["trade_date"].astype(str) >= str(start_date))
|
||||
& (df["trade_date"].astype(str) <= str(end_date))
|
||||
]["trade_date"].astype(str).tolist()
|
||||
|
||||
for td in target_dates:
|
||||
try:
|
||||
result = score_daily(df, td)
|
||||
results.append(result)
|
||||
except ValueError as e:
|
||||
warnings.append(str(e))
|
||||
|
||||
return results, warnings
|
||||
|
||||
@@ -56,6 +56,38 @@ func (d *Deps) RunBatch(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = io.Copy(w, resp.Body)
|
||||
}
|
||||
|
||||
type runRangeRequest struct {
|
||||
Symbol string `json:"symbol,omitempty"`
|
||||
StartDate string `json:"start_date,omitempty"`
|
||||
EndDate string `json:"end_date,omitempty"`
|
||||
}
|
||||
|
||||
func (d *Deps) RunRange(w http.ResponseWriter, r *http.Request) {
|
||||
var req runRangeRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeErr(w, http.StatusBadRequest, "invalid json")
|
||||
return
|
||||
}
|
||||
|
||||
body, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
writeErr(w, http.StatusInternalServerError, "encode request failed")
|
||||
return
|
||||
}
|
||||
|
||||
client := &http.Client{Timeout: 180 * time.Second}
|
||||
resp, err := client.Post(d.TushareURL+"/api/v1/run/range", "application/json", bytes.NewReader(body))
|
||||
if err != nil {
|
||||
writeErr(w, http.StatusBadGateway, fmt.Sprintf("tushare service unavailable: %v", err))
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(resp.StatusCode)
|
||||
_, _ = io.Copy(w, resp.Body)
|
||||
}
|
||||
|
||||
func (d *Deps) GetActiveContract(w http.ResponseWriter, r *http.Request) {
|
||||
symbol := r.URL.Query().Get("symbol")
|
||||
if symbol == "" {
|
||||
|
||||
@@ -32,6 +32,7 @@ func New(d *handlers.Deps, dist fs.FS) http.Handler {
|
||||
r.Get("/candles", d.ListCandles)
|
||||
r.Post("/run", d.RunPipeline)
|
||||
r.Post("/run/batch", d.RunBatch)
|
||||
r.Post("/run/range", d.RunRange)
|
||||
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Use(mw.RequireAdmin)
|
||||
|
||||
@@ -30,6 +30,33 @@ export function runPipeline(req: RunRequest) {
|
||||
return client.post<RunResponse>('/run', req).then((r) => r.data)
|
||||
}
|
||||
|
||||
export interface RunRangeRequest {
|
||||
symbol: string
|
||||
start_date: string
|
||||
end_date: string
|
||||
}
|
||||
|
||||
export interface RunRangeResult {
|
||||
trade_date: string
|
||||
close: number
|
||||
composite: number
|
||||
signal: string
|
||||
}
|
||||
|
||||
export interface RunRangeResponse {
|
||||
ts_code: string
|
||||
start_date: string
|
||||
end_date: string
|
||||
scored: number
|
||||
skipped: number
|
||||
warnings: string[]
|
||||
results: RunRangeResult[]
|
||||
}
|
||||
|
||||
export function runRange(req: RunRangeRequest) {
|
||||
return client.post<RunRangeResponse>('/run/range', req, { timeout: 180_000 }).then((r) => r.data)
|
||||
}
|
||||
|
||||
export function runBatch() {
|
||||
return client.post('/run/batch', null, { timeout: 180_000 }).then((r) => r.data)
|
||||
}
|
||||
|
||||
@@ -3,9 +3,11 @@ import { nextTick, onMounted, reactive, ref, watch } from 'vue'
|
||||
import { ElMessage } from 'element-plus'
|
||||
import {
|
||||
runPipeline,
|
||||
runRange,
|
||||
getActiveContract,
|
||||
type ActiveContract,
|
||||
type RunResponse,
|
||||
type RunRangeResponse,
|
||||
} from '@/api/run'
|
||||
import { parseTsCode } from '@/utils/contract'
|
||||
import { useMobile } from '@/composables/useMobile'
|
||||
@@ -14,6 +16,8 @@ const { isMobile } = useMobile()
|
||||
|
||||
const SYMBOLS = ['FG', 'SA', 'RB', 'MA', 'CF', 'M']
|
||||
|
||||
const mode = ref<'single' | 'range'>('single')
|
||||
|
||||
const form = reactive<{
|
||||
symbol: string
|
||||
trade_date: string
|
||||
@@ -22,20 +26,32 @@ const form = reactive<{
|
||||
trade_date: '',
|
||||
})
|
||||
|
||||
const range = reactive<{
|
||||
dates: [string, string] | []
|
||||
}>({
|
||||
dates: [],
|
||||
})
|
||||
|
||||
const active = ref<ActiveContract | null>(null)
|
||||
const activeLoading = ref(false)
|
||||
const loading = ref(false)
|
||||
const result = ref<RunResponse | null>(null)
|
||||
const rangeResult = ref<RunRangeResponse | null>(null)
|
||||
const resultRef = ref<HTMLElement | null>(null)
|
||||
|
||||
async function loadActive() {
|
||||
activeLoading.value = true
|
||||
try {
|
||||
active.value = await getActiveContract(form.symbol)
|
||||
// 切换品种后,如果原日期落在新合约的可选范围之外,清空它
|
||||
if (form.trade_date && !isDateAllowed(toDate(form.trade_date))) {
|
||||
form.trade_date = ''
|
||||
}
|
||||
if (Array.isArray(range.dates) && range.dates.length === 2) {
|
||||
const [s, e] = range.dates
|
||||
if (!isDateAllowed(toDate(s)) || !isDateAllowed(toDate(e))) {
|
||||
range.dates = []
|
||||
}
|
||||
}
|
||||
} catch (err: any) {
|
||||
active.value = null
|
||||
ElMessage.error(err?.response?.data?.error || '加载主力合约失败')
|
||||
@@ -45,17 +61,15 @@ async function loadActive() {
|
||||
}
|
||||
|
||||
function toDate(s: string) {
|
||||
// s 形如 'YYYY-MM-DD'
|
||||
const [y, m, d] = s.split('-').map(Number)
|
||||
return new Date(y, m - 1, d)
|
||||
}
|
||||
|
||||
function isDateAllowed(d: Date): boolean {
|
||||
if (!active.value) return true
|
||||
const min = toDate(active.value.min_date).getTime()
|
||||
const max = toDate(active.value.max_date).getTime()
|
||||
const t = d.getTime()
|
||||
return t >= min && t <= max
|
||||
return t <= max
|
||||
}
|
||||
|
||||
function disabledDate(d: Date) {
|
||||
@@ -69,12 +83,29 @@ async function submit() {
|
||||
}
|
||||
loading.value = true
|
||||
result.value = null
|
||||
rangeResult.value = null
|
||||
try {
|
||||
const req: { symbol: string; trade_date?: string } = { symbol: form.symbol }
|
||||
if (form.trade_date) req.trade_date = form.trade_date.replace(/-/g, '')
|
||||
const resp = await runPipeline(req)
|
||||
result.value = resp
|
||||
ElMessage.success('打分完成')
|
||||
if (mode.value === 'single') {
|
||||
const req: { symbol: string; trade_date?: string } = { symbol: form.symbol }
|
||||
if (form.trade_date) req.trade_date = form.trade_date.replace(/-/g, '')
|
||||
const resp = await runPipeline(req)
|
||||
result.value = resp
|
||||
ElMessage.success('打分完成')
|
||||
} else {
|
||||
if (!Array.isArray(range.dates) || range.dates.length !== 2) {
|
||||
ElMessage.warning('请选择日期区间')
|
||||
loading.value = false
|
||||
return
|
||||
}
|
||||
const [start, end] = range.dates
|
||||
const resp = await runRange({
|
||||
symbol: form.symbol,
|
||||
start_date: start.replace(/-/g, ''),
|
||||
end_date: end.replace(/-/g, ''),
|
||||
})
|
||||
rangeResult.value = resp
|
||||
ElMessage.success(`区间打分完成,成功 ${resp.scored} 条`)
|
||||
}
|
||||
await nextTick()
|
||||
resultRef.value?.scrollIntoView({ behavior: 'smooth', block: 'start' })
|
||||
} catch (err: any) {
|
||||
@@ -104,21 +135,18 @@ onMounted(loadActive)
|
||||
<span>手动打分</span>
|
||||
</template>
|
||||
<el-form :model="form" label-width="100px" style="max-width: 480px">
|
||||
<el-form-item label="模式">
|
||||
<el-radio-group v-model="mode">
|
||||
<el-radio-button label="single">单日打分</el-radio-button>
|
||||
<el-radio-button label="range">区间打分</el-radio-button>
|
||||
</el-radio-group>
|
||||
</el-form-item>
|
||||
<el-form-item label="品种">
|
||||
<el-select v-model="form.symbol" :loading="activeLoading" style="width: 100%">
|
||||
<el-option v-for="s in SYMBOLS" :key="s" :label="s" :value="s" />
|
||||
</el-select>
|
||||
</el-form-item>
|
||||
<el-form-item label="主力合约">
|
||||
<span v-if="active">
|
||||
{{ parseTsCode(active.ts_code).contract }}
|
||||
<el-text type="info" size="small" style="margin-left: 8px">
|
||||
({{ active.ts_code }})
|
||||
</el-text>
|
||||
</span>
|
||||
<el-text v-else type="info">加载中…</el-text>
|
||||
</el-form-item>
|
||||
<el-form-item label="打分日期">
|
||||
<el-form-item v-if="mode === 'single'" label="打分日期">
|
||||
<el-date-picker
|
||||
v-model="form.trade_date"
|
||||
type="date"
|
||||
@@ -129,9 +157,23 @@ onMounted(loadActive)
|
||||
style="width: 100%"
|
||||
/>
|
||||
</el-form-item>
|
||||
<el-form-item v-else label="日期区间">
|
||||
<el-date-picker
|
||||
v-model="range.dates"
|
||||
type="daterange"
|
||||
:placeholder="active ? `${active.min_date} ~ ${active.max_date}` : '加载中…'"
|
||||
value-format="YYYY-MM-DD"
|
||||
:disabled-date="disabledDate"
|
||||
:disabled="!active"
|
||||
range-separator="→"
|
||||
start-placeholder="开始"
|
||||
end-placeholder="结束"
|
||||
style="width: 100%"
|
||||
/>
|
||||
</el-form-item>
|
||||
<el-form-item>
|
||||
<el-button type="primary" :loading="loading" :disabled="!active" @click="submit">
|
||||
执行打分
|
||||
{{ mode === 'single' ? '执行打分' : '批量打分' }}
|
||||
</el-button>
|
||||
</el-form-item>
|
||||
</el-form>
|
||||
@@ -160,6 +202,41 @@ onMounted(loadActive)
|
||||
</el-descriptions>
|
||||
</el-card>
|
||||
</div>
|
||||
|
||||
<div v-if="rangeResult" ref="resultRef">
|
||||
<el-card shadow="never" class="result-card">
|
||||
<template #header>
|
||||
<span>区间打分结果</span>
|
||||
</template>
|
||||
<el-descriptions :column="isMobile ? 1 : 2" border>
|
||||
<el-descriptions-item label="合约">{{ parseTsCode(rangeResult.ts_code).symbol }}</el-descriptions-item>
|
||||
<el-descriptions-item label="区间">{{ rangeResult.start_date }} ~ {{ rangeResult.end_date }}</el-descriptions-item>
|
||||
<el-descriptions-item label="成功">{{ rangeResult.scored }} 条</el-descriptions-item>
|
||||
<el-descriptions-item label="跳过">{{ rangeResult.skipped }} 条</el-descriptions-item>
|
||||
</el-descriptions>
|
||||
<el-alert
|
||||
v-if="rangeResult.warnings.length > 0"
|
||||
:title="`警告 (${rangeResult.warnings.length} 条)`"
|
||||
type="warning"
|
||||
:closable="false"
|
||||
style="margin-top: 12px"
|
||||
>
|
||||
<div style="max-height: 120px; overflow-y: auto">
|
||||
<div v-for="(w, i) in rangeResult.warnings" :key="i" style="font-size: 12px">{{ w }}</div>
|
||||
</div>
|
||||
</el-alert>
|
||||
<el-table :data="rangeResult.results" stripe style="margin-top: 16px" max-height="400">
|
||||
<el-table-column prop="trade_date" label="日期" width="110" />
|
||||
<el-table-column prop="close" label="收盘" width="90" />
|
||||
<el-table-column prop="composite" label="综合" width="80" />
|
||||
<el-table-column label="信号">
|
||||
<template #default="{ row }">
|
||||
<el-tag :type="signalTagType(row.signal)" size="small">{{ row.signal }}</el-tag>
|
||||
</template>
|
||||
</el-table-column>
|
||||
</el-table>
|
||||
</el-card>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
@@ -169,4 +246,7 @@ onMounted(loadActive)
|
||||
flex-direction: column;
|
||||
gap: 16px;
|
||||
}
|
||||
.result-card {
|
||||
margin-top: 8px;
|
||||
}
|
||||
</style>
|
||||
|
||||
Reference in New Issue
Block a user