import json import os import sqlite3 from typing import Optional import pandas as pd from .models import ScoreResult DEFAULT_DB_PATH = os.environ.get("DB_PATH", "/app/data/futures.db") def _get_conn(db_path: str = DEFAULT_DB_PATH) -> sqlite3.Connection: conn = sqlite3.connect(db_path) conn.row_factory = sqlite3.Row return conn def init_db(db_path: str = DEFAULT_DB_PATH): """初始化数据库,创建 candles 和 scores 表。""" os.makedirs(os.path.dirname(db_path), exist_ok=True) conn = _get_conn(db_path) try: conn.execute("PRAGMA journal_mode=WAL") conn.execute(""" CREATE TABLE IF NOT EXISTS candles ( ts_code TEXT NOT NULL, trade_date TEXT NOT NULL, open REAL, high REAL, low REAL, close REAL, vol REAL, amount REAL, oi REAL, oi_chg REAL, pre_close REAL, PRIMARY KEY (ts_code, trade_date) ) """) conn.execute(""" CREATE TABLE IF NOT EXISTS scores ( id INTEGER PRIMARY KEY AUTOINCREMENT, ts_code TEXT NOT NULL, trade_date TEXT NOT NULL, close REAL, oi REAL, oi_chg REAL, short_term REAL, medium_term REAL, long_term REAL, composite REAL, signal TEXT, detail_json TEXT, created_at TEXT DEFAULT (datetime('now', 'localtime')), UNIQUE (ts_code, trade_date) ) """) conn.commit() finally: conn.close() def save_candles(df: pd.DataFrame, db_path: str = DEFAULT_DB_PATH): """批量写入/更新日线数据。""" if df.empty: return conn = _get_conn(db_path) try: df = df.copy() df = df.where(pd.notna(df), None) records = df.to_dict(orient="records") conn.executemany( """ INSERT OR REPLACE INTO candles (ts_code, trade_date, open, high, low, close, vol, amount, oi, oi_chg, pre_close) VALUES (:ts_code, :trade_date, :open, :high, :low, :close, :vol, :amount, :oi, :oi_chg, :pre_close) """, records, ) conn.commit() finally: conn.close() def save_score(score: ScoreResult, db_path: str = DEFAULT_DB_PATH): """写入打分结果。""" conn = _get_conn(db_path) try: conn.execute( """ INSERT OR REPLACE INTO scores (ts_code, trade_date, close, oi, oi_chg, short_term, medium_term, long_term, composite, signal, detail_json) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( score.ts_code, score.trade_date, score.close, score.oi, score.oi_chg, score.short_term, score.medium_term, score.long_term, score.composite, score.signal, json.dumps({ "short_details": score.detail.short_details, "medium_detail": score.detail.medium_detail, "long_detail": score.detail.long_detail, }, ensure_ascii=False, default=str), ), ) conn.commit() finally: conn.close() def get_latest_score(ts_code: str, db_path: str = DEFAULT_DB_PATH) -> Optional[dict]: """查询最新打分记录。""" conn = _get_conn(db_path) try: row = conn.execute( "SELECT * FROM scores WHERE ts_code = ? ORDER BY trade_date DESC LIMIT 1", (ts_code,), ).fetchone() return dict(row) if row else None finally: conn.close()