import json import os from typing import Optional import pandas as pd import psycopg from psycopg.rows import dict_row from .models import ScoreResult DEFAULT_DB_URL = os.environ.get("DATABASE_URL", "postgresql://trade:trade@postgres:5432/futures") def _get_conn(db_url: str = DEFAULT_DB_URL): return psycopg.connect(db_url) def init_db(db_url: str = DEFAULT_DB_URL): """初始化数据库,创建 candles 和 scores 表。""" conn = _get_conn(db_url) try: with conn.cursor() as cur: cur.execute(""" CREATE TABLE IF NOT EXISTS candles ( ts_code TEXT NOT NULL, trade_date TEXT NOT NULL, open REAL, high REAL, low REAL, close REAL, vol REAL, amount REAL, oi REAL, oi_chg REAL, pre_close REAL, PRIMARY KEY (ts_code, trade_date) ) """) cur.execute(""" CREATE TABLE IF NOT EXISTS scores ( id UUID DEFAULT uuidv7() PRIMARY KEY, ts_code TEXT NOT NULL, trade_date TEXT NOT NULL, close REAL, oi REAL, oi_chg REAL, short_term REAL, medium_term REAL, long_term REAL, composite REAL, signal TEXT, detail_json TEXT, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, UNIQUE (ts_code, trade_date) ) """) conn.commit() finally: conn.close() def save_candles(df: pd.DataFrame, db_url: str = DEFAULT_DB_URL): """批量写入/更新日线数据。""" if df.empty: return conn = _get_conn(db_url) try: df = df.copy() df = df.where(pd.notna(df), None) records = df.to_dict(orient="records") with conn.cursor() as cur: cur.executemany( """ INSERT INTO candles (ts_code, trade_date, open, high, low, close, vol, amount, oi, oi_chg, pre_close) VALUES (%(ts_code)s, %(trade_date)s, %(open)s, %(high)s, %(low)s, %(close)s, %(vol)s, %(amount)s, %(oi)s, %(oi_chg)s, %(pre_close)s) ON CONFLICT (ts_code, trade_date) DO UPDATE SET open = EXCLUDED.open, high = EXCLUDED.high, low = EXCLUDED.low, close = EXCLUDED.close, vol = EXCLUDED.vol, amount = EXCLUDED.amount, oi = EXCLUDED.oi, oi_chg = EXCLUDED.oi_chg, pre_close = EXCLUDED.pre_close """, records, ) conn.commit() finally: conn.close() def save_score(score: ScoreResult, db_url: str = DEFAULT_DB_URL): """写入打分结果。""" conn = _get_conn(db_url) try: with conn.cursor() as cur: cur.execute( """ INSERT INTO scores (ts_code, trade_date, close, oi, oi_chg, short_term, medium_term, long_term, composite, signal, detail_json) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) ON CONFLICT (ts_code, trade_date) DO UPDATE SET close = EXCLUDED.close, oi = EXCLUDED.oi, oi_chg = EXCLUDED.oi_chg, short_term = EXCLUDED.short_term, medium_term = EXCLUDED.medium_term, long_term = EXCLUDED.long_term, composite = EXCLUDED.composite, signal = EXCLUDED.signal, detail_json = EXCLUDED.detail_json, created_at = CURRENT_TIMESTAMP """, ( score.ts_code, score.trade_date, score.close, score.oi, score.oi_chg, score.short_term, score.medium_term, score.long_term, score.composite, score.signal, json.dumps({ "short_details": score.detail.short_details, "medium_detail": score.detail.medium_detail, "long_detail": score.detail.long_detail, }, ensure_ascii=False, default=str), ), ) conn.commit() finally: conn.close() def get_latest_score(ts_code: str, db_url: str = DEFAULT_DB_URL) -> Optional[dict]: """查询最新打分记录。""" conn = _get_conn(db_url) try: with conn.cursor(row_factory=dict_row) as cur: cur.execute( "SELECT * FROM scores WHERE ts_code = %s ORDER BY trade_date DESC LIMIT 1", (ts_code,), ) row = cur.fetchone() return dict(row) if row else None finally: conn.close()