306 lines
9.3 KiB
Go
306 lines
9.3 KiB
Go
package store
|
||
|
||
import (
|
||
"database/sql"
|
||
"encoding/json"
|
||
"errors"
|
||
"fmt"
|
||
"strings"
|
||
|
||
_ "github.com/lib/pq"
|
||
)
|
||
|
||
var ErrMissingTsCode = errors.New("ts_code 必填")
|
||
|
||
type FuturesStore struct{ db *sql.DB }
|
||
|
||
func OpenFutures(databaseURL string) (*FuturesStore, error) {
|
||
db, err := sql.Open("postgres", databaseURL)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("open futures db: %w", err)
|
||
}
|
||
db.SetMaxOpenConns(8)
|
||
if err := db.Ping(); err != nil {
|
||
return nil, fmt.Errorf("ping futures db: %w", err)
|
||
}
|
||
return &FuturesStore{db: db}, nil
|
||
}
|
||
|
||
func (s *FuturesStore) Close() error { return s.db.Close() }
|
||
|
||
type Score struct {
|
||
ID string `json:"id"`
|
||
TsCode string `json:"ts_code"`
|
||
TradeDate string `json:"trade_date"`
|
||
Close float64 `json:"close"`
|
||
OI float64 `json:"oi"`
|
||
OIChg float64 `json:"oi_chg"`
|
||
ShortTerm float64 `json:"short_term"`
|
||
MediumTerm float64 `json:"medium_term"`
|
||
LongTerm float64 `json:"long_term"`
|
||
Composite float64 `json:"composite"`
|
||
Signal string `json:"signal"`
|
||
Detail json.RawMessage `json:"detail,omitempty"`
|
||
CreatedAt string `json:"created_at"`
|
||
}
|
||
|
||
type ScoreFilter struct {
|
||
TsCode string
|
||
Start string
|
||
End string
|
||
Signal string
|
||
Limit int
|
||
}
|
||
|
||
func (s *FuturesStore) ListScores(f ScoreFilter) ([]Score, error) {
|
||
q := `SELECT id, ts_code, trade_date, close, oi, oi_chg, short_term, medium_term, long_term,
|
||
composite, signal, created_at FROM scores WHERE 1=1`
|
||
args := []any{}
|
||
n := 0
|
||
next := func() string { n++; return fmt.Sprintf("$%d", n) }
|
||
if f.TsCode != "" {
|
||
q += " AND ts_code = " + next()
|
||
args = append(args, f.TsCode)
|
||
}
|
||
if f.Start != "" {
|
||
q += " AND trade_date >= " + next()
|
||
args = append(args, f.Start)
|
||
}
|
||
if f.End != "" {
|
||
q += " AND trade_date <= " + next()
|
||
args = append(args, f.End)
|
||
}
|
||
if f.Signal != "" {
|
||
q += " AND signal LIKE " + next()
|
||
args = append(args, "%"+f.Signal+"%")
|
||
}
|
||
q += " ORDER BY trade_date DESC, id DESC"
|
||
if f.Limit <= 0 || f.Limit > 1000 {
|
||
f.Limit = 200
|
||
}
|
||
q += " LIMIT " + next()
|
||
args = append(args, f.Limit)
|
||
|
||
rows, err := s.db.Query(q, args...)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
defer rows.Close()
|
||
out := []Score{}
|
||
for rows.Next() {
|
||
var x Score
|
||
if err := rows.Scan(&x.ID, &x.TsCode, &x.TradeDate, &x.Close, &x.OI, &x.OIChg,
|
||
&x.ShortTerm, &x.MediumTerm, &x.LongTerm, &x.Composite, &x.Signal, &x.CreatedAt); err != nil {
|
||
return nil, err
|
||
}
|
||
out = append(out, x)
|
||
}
|
||
return out, rows.Err()
|
||
}
|
||
|
||
func (s *FuturesStore) GetScore(id string) (*Score, error) {
|
||
row := s.db.QueryRow(`SELECT id, ts_code, trade_date, close, oi, oi_chg, short_term, medium_term,
|
||
long_term, composite, signal, detail_json, created_at FROM scores WHERE id = $1`, id)
|
||
var x Score
|
||
var detail sql.NullString
|
||
if err := row.Scan(&x.ID, &x.TsCode, &x.TradeDate, &x.Close, &x.OI, &x.OIChg,
|
||
&x.ShortTerm, &x.MediumTerm, &x.LongTerm, &x.Composite, &x.Signal, &detail, &x.CreatedAt); err != nil {
|
||
return nil, err
|
||
}
|
||
if detail.Valid && strings.TrimSpace(detail.String) != "" {
|
||
x.Detail = json.RawMessage(detail.String)
|
||
}
|
||
return &x, nil
|
||
}
|
||
|
||
func (s *FuturesStore) ListContracts() ([]string, error) {
|
||
rows, err := s.db.Query(`SELECT DISTINCT ts_code FROM scores ORDER BY ts_code ASC`)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
defer rows.Close()
|
||
out := []string{}
|
||
for rows.Next() {
|
||
var c string
|
||
if err := rows.Scan(&c); err != nil {
|
||
return nil, err
|
||
}
|
||
out = append(out, c)
|
||
}
|
||
return out, rows.Err()
|
||
}
|
||
|
||
type Candle struct {
|
||
TsCode string `json:"ts_code"`
|
||
TradeDate string `json:"trade_date"`
|
||
Open float64 `json:"open"`
|
||
High float64 `json:"high"`
|
||
Low float64 `json:"low"`
|
||
Close float64 `json:"close"`
|
||
Vol float64 `json:"vol"`
|
||
Amount float64 `json:"amount"`
|
||
OI float64 `json:"oi"`
|
||
OIChg float64 `json:"oi_chg"`
|
||
PreClose float64 `json:"pre_close"`
|
||
}
|
||
|
||
// ── Daily Direction ────────────────────────────────────────────────
|
||
|
||
// DailyDirection 日内方向分析结果。
|
||
type DailyDirection struct {
|
||
ID string `json:"id"`
|
||
Symbol string `json:"symbol"`
|
||
TradeDate string `json:"trade_date"`
|
||
TargetDate string `json:"target_date"`
|
||
Direction string `json:"direction"`
|
||
Confidence float64 `json:"confidence"`
|
||
Support string `json:"support"` // JSONB → string
|
||
Resistance string `json:"resistance"` // JSONB → string
|
||
Reasoning string `json:"reasoning"`
|
||
RiskNote string `json:"risk_note"`
|
||
PromptSnapshot string `json:"prompt_snapshot,omitempty"`
|
||
CreatedAt string `json:"created_at"`
|
||
}
|
||
|
||
// EnsureDailyDirectionTable 建 daily_direction 表(幂等)。
|
||
func (s *FuturesStore) EnsureDailyDirectionTable() error {
|
||
_, err := s.db.Exec(`
|
||
CREATE TABLE IF NOT EXISTS daily_direction (
|
||
id UUID DEFAULT uuidv7() PRIMARY KEY,
|
||
symbol TEXT NOT NULL,
|
||
trade_date TEXT NOT NULL,
|
||
target_date TEXT NOT NULL,
|
||
direction TEXT NOT NULL,
|
||
confidence REAL,
|
||
support JSONB,
|
||
resistance JSONB,
|
||
reasoning TEXT,
|
||
risk_note TEXT,
|
||
prompt_snapshot TEXT,
|
||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||
UNIQUE (symbol, trade_date)
|
||
)
|
||
`)
|
||
return err
|
||
}
|
||
|
||
// SaveDailyDirection 写入(upsert)一条方向分析。
|
||
func (s *FuturesStore) SaveDailyDirection(dd *DailyDirection) error {
|
||
_, err := s.db.Exec(`
|
||
INSERT INTO daily_direction
|
||
(symbol, trade_date, target_date, direction, confidence, support, resistance, reasoning, risk_note, prompt_snapshot)
|
||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
|
||
ON CONFLICT (symbol, trade_date) DO UPDATE SET
|
||
target_date = EXCLUDED.target_date,
|
||
direction = EXCLUDED.direction,
|
||
confidence = EXCLUDED.confidence,
|
||
support = EXCLUDED.support,
|
||
resistance = EXCLUDED.resistance,
|
||
reasoning = EXCLUDED.reasoning,
|
||
risk_note = EXCLUDED.risk_note,
|
||
prompt_snapshot = EXCLUDED.prompt_snapshot,
|
||
created_at = CURRENT_TIMESTAMP
|
||
`, dd.Symbol, dd.TradeDate, dd.TargetDate, dd.Direction, dd.Confidence,
|
||
dd.Support, dd.Resistance, dd.Reasoning, dd.RiskNote, dd.PromptSnapshot)
|
||
return err
|
||
}
|
||
|
||
// ListDailyDirections 查询方向分析列表。
|
||
func (s *FuturesStore) ListDailyDirections(symbol, start, end string, limit int) ([]DailyDirection, error) {
|
||
if limit <= 0 || limit > 500 {
|
||
limit = 50
|
||
}
|
||
q := `SELECT id, symbol, trade_date, target_date, direction, confidence,
|
||
COALESCE(support::text, '[]'), COALESCE(resistance::text, '[]'),
|
||
reasoning, risk_note, COALESCE(prompt_snapshot, ''),
|
||
COALESCE(created_at::text, '')
|
||
FROM daily_direction WHERE 1=1`
|
||
args := []any{}
|
||
n := 0
|
||
next := func() string { n++; return fmt.Sprintf("$%d", n) }
|
||
if symbol != "" {
|
||
q += " AND symbol = " + next()
|
||
args = append(args, symbol)
|
||
}
|
||
if start != "" {
|
||
q += " AND trade_date >= " + next()
|
||
args = append(args, start)
|
||
}
|
||
if end != "" {
|
||
q += " AND trade_date <= " + next()
|
||
args = append(args, end)
|
||
}
|
||
q += " ORDER BY trade_date DESC, symbol ASC LIMIT " + next()
|
||
args = append(args, limit)
|
||
|
||
rows, err := s.db.Query(q, args...)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
defer rows.Close()
|
||
out := []DailyDirection{}
|
||
for rows.Next() {
|
||
var dd DailyDirection
|
||
if err := rows.Scan(&dd.ID, &dd.Symbol, &dd.TradeDate, &dd.TargetDate,
|
||
&dd.Direction, &dd.Confidence, &dd.Support, &dd.Resistance,
|
||
&dd.Reasoning, &dd.RiskNote, &dd.PromptSnapshot, &dd.CreatedAt); err != nil {
|
||
return nil, err
|
||
}
|
||
out = append(out, dd)
|
||
}
|
||
return out, rows.Err()
|
||
}
|
||
|
||
// GetActiveTsCode 通过 scores 表查找某品种在指定日期的活跃合约代码。
|
||
func (s *FuturesStore) GetActiveTsCode(symbol, tradeDate string) (string, error) {
|
||
var tsCode string
|
||
err := s.db.QueryRow(
|
||
`SELECT ts_code FROM scores WHERE trade_date = $1 AND ts_code LIKE $2 || '%' ORDER BY ts_code DESC LIMIT 1`,
|
||
tradeDate, symbol,
|
||
).Scan(&tsCode)
|
||
if err != nil {
|
||
return "", fmt.Errorf("no active contract for %s on %s: %w", symbol, tradeDate, err)
|
||
}
|
||
return tsCode, nil
|
||
}
|
||
|
||
func (s *FuturesStore) ListCandles(tsCode, start, end string) ([]Candle, error) {
|
||
if tsCode == "" {
|
||
return nil, ErrMissingTsCode
|
||
}
|
||
q := `SELECT ts_code, trade_date,
|
||
COALESCE(NULLIF(open, 'NaN'::real), 0), COALESCE(NULLIF(high, 'NaN'::real), 0),
|
||
COALESCE(NULLIF(low, 'NaN'::real), 0), COALESCE(NULLIF(close, 'NaN'::real), 0),
|
||
COALESCE(NULLIF(vol, 'NaN'::real), 0), COALESCE(NULLIF(amount, 'NaN'::real), 0),
|
||
COALESCE(NULLIF(oi, 'NaN'::real), 0), COALESCE(NULLIF(oi_chg, 'NaN'::real), 0),
|
||
COALESCE(NULLIF(pre_close, 'NaN'::real), 0)
|
||
FROM candles WHERE ts_code = $1`
|
||
args := []any{tsCode}
|
||
n := 1
|
||
next := func() string { n++; return fmt.Sprintf("$%d", n) }
|
||
if start != "" {
|
||
q += " AND trade_date >= " + next()
|
||
args = append(args, start)
|
||
}
|
||
if end != "" {
|
||
q += " AND trade_date <= " + next()
|
||
args = append(args, end)
|
||
}
|
||
q += " ORDER BY trade_date ASC LIMIT 1000"
|
||
rows, err := s.db.Query(q, args...)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
defer rows.Close()
|
||
out := []Candle{}
|
||
for rows.Next() {
|
||
var c Candle
|
||
if err := rows.Scan(&c.TsCode, &c.TradeDate, &c.Open, &c.High, &c.Low, &c.Close,
|
||
&c.Vol, &c.Amount, &c.OI, &c.OIChg, &c.PreClose); err != nil {
|
||
return nil, err
|
||
}
|
||
out = append(out, c)
|
||
}
|
||
return out, rows.Err()
|
||
}
|