← Back
"""SQLite database schema and helpers"""
import sqlite3
import os
import threading
from contextlib import contextmanager
from datetime import datetime, timezone
from loguru import logger

DB_PATH = os.path.join(os.path.dirname(__file__), "data", "weather-bot.db")

_local = threading.local()


def get_conn():
    """Get thread-local SQLite connection (reused within same thread)"""
    conn = getattr(_local, 'conn', None)
    if conn is not None:
        try:
            conn.execute("SELECT 1")  # check if still alive
            return conn
        except sqlite3.ProgrammingError:
            _local.conn = None

    conn = sqlite3.connect(DB_PATH, timeout=10)
    conn.row_factory = sqlite3.Row
    conn.execute("PRAGMA foreign_keys=ON")
    _local.conn = conn
    return conn


@contextmanager
def _get_conn_safe():
    """Context manager for safe connection usage with auto-commit/rollback"""
    conn = get_conn()
    try:
        yield conn
        conn.commit()
    except Exception:
        conn.rollback()
        raise


def init_db():
    """Create tables if not exist, set WAL mode (persistent, once)"""
    conn = get_conn()
    conn.execute("PRAGMA journal_mode=WAL")
    conn.executescript("""
        CREATE TABLE IF NOT EXISTS markets (
            id TEXT PRIMARY KEY,
            question TEXT NOT NULL,
            city TEXT,
            date TEXT,
            metric TEXT,
            threshold REAL,
            threshold_unit TEXT DEFAULT 'F',
            operator TEXT DEFAULT 'gte',
            yes_token_id TEXT,
            no_token_id TEXT,
            yes_price REAL,
            no_price REAL,
            volume REAL DEFAULT 0,
            end_date TEXT,
            resolution_source TEXT,
            resolved INTEGER DEFAULT 0,
            outcome TEXT,
            threshold_low REAL,
            threshold_high REAL,
            first_seen TEXT DEFAULT (datetime('now')),
            updated_at TEXT DEFAULT (datetime('now')),
            status TEXT DEFAULT 'active'
        );

        CREATE TABLE IF NOT EXISTS forecasts (
            id INTEGER PRIMARY KEY AUTOINCREMENT,
            market_id TEXT NOT NULL,
            forecast_value REAL,
            model_probability REAL,
            sigma REAL,
            days_ahead INTEGER,
            source TEXT DEFAULT 'open-meteo',
            fetched_at TEXT DEFAULT (datetime('now')),
            FOREIGN KEY (market_id) REFERENCES markets(id)
        );

        CREATE TABLE IF NOT EXISTS trades (
            id INTEGER PRIMARY KEY AUTOINCREMENT,
            market_id TEXT NOT NULL,
            order_id TEXT,
            side TEXT NOT NULL,
            token_id TEXT NOT NULL,
            price REAL NOT NULL,
            size REAL NOT NULL,
            edge REAL,
            edge_tier TEXT,
            model_prob REAL,
            market_prob REAL,
            status TEXT DEFAULT 'pending',
            filled_at TEXT,
            resolved_at TEXT,
            outcome TEXT,
            pnl REAL DEFAULT 0,
            dry_run INTEGER DEFAULT 1,
            created_at TEXT DEFAULT (datetime('now')),
            FOREIGN KEY (market_id) REFERENCES markets(id)
        );

        CREATE TABLE IF NOT EXISTS daily_stats (
            date TEXT PRIMARY KEY,
            total_trades INTEGER DEFAULT 0,
            wins INTEGER DEFAULT 0,
            losses INTEGER DEFAULT 0,
            pnl REAL DEFAULT 0,
            bankroll REAL DEFAULT 0,
            avg_edge REAL DEFAULT 0,
            best_edge REAL DEFAULT 0,
            markets_scanned INTEGER DEFAULT 0,
            opportunities_found INTEGER DEFAULT 0
        );

        CREATE TABLE IF NOT EXISTS bot_state (
            key TEXT PRIMARY KEY,
            value TEXT,
            updated_at TEXT DEFAULT (datetime('now'))
        );

        CREATE TABLE IF NOT EXISTS forecast_cache (
            cache_key TEXT PRIMARY KEY,
            city TEXT NOT NULL,
            date TEXT NOT NULL,
            metric TEXT NOT NULL,
            forecast_value REAL NOT NULL,
            sigma REAL NOT NULL,
            days_ahead INTEGER,
            source TEXT,
            cached_at TEXT DEFAULT (datetime('now'))
        );

        CREATE TABLE IF NOT EXISTS calibration (
            id INTEGER PRIMARY KEY AUTOINCREMENT,
            city TEXT NOT NULL,
            date TEXT NOT NULL,
            metric TEXT NOT NULL,
            forecast_value REAL NOT NULL,
            actual_value REAL NOT NULL,
            model_prob REAL,
            outcome_binary INTEGER,
            error_f REAL,
            brier_component REAL,
            days_ahead INTEGER,
            source TEXT,
            created_at TEXT DEFAULT (datetime('now'))
        );

        CREATE INDEX IF NOT EXISTS idx_calibration_city ON calibration(city);
    """)
    conn.commit()

    # Defense-in-depth against duplicate live positions (stale-cancel + fill race,
    # concurrent scans). Partial UNIQUE index blocks a second open live order on the
    # same market+side at the DB layer. Wrapped: if pre-existing dupes block creation,
    # log loudly but don't crash startup (the in-code idempotency guards still apply).
    try:
        conn.execute("""
            CREATE UNIQUE INDEX IF NOT EXISTS idx_trades_unique_open_live
            ON trades(market_id, side)
            WHERE status IN ('pending', 'filled') AND dry_run = 0
        """)
        conn.commit()
    except Exception as e:
        logger.error(f"Could not create unique-open-live index (likely existing duplicates): {e}")

    conn.close()
    _local.conn = None  # Reset thread-local ref so get_conn() creates fresh connection

    # Run migrations for existing databases
    _run_migrations()

    logger.info(f"Database initialized: {DB_PATH}")


def _run_migrations():
    """Add missing columns to existing tables (safe to run multiple times)"""
    conn = get_conn()
    # Get current columns in markets table
    cursor = conn.execute("PRAGMA table_info(markets)")
    existing_cols = {row["name"] for row in cursor.fetchall()}

    migrations_markets = [
        ("markets", "threshold_low", "REAL"),
        ("markets", "threshold_high", "REAL"),
    ]

    for table, col, col_type in migrations_markets:
        if col not in existing_cols:
            conn.execute(f"ALTER TABLE {table} ADD COLUMN {col} {col_type}")
            logger.info(f"Migration: added column {table}.{col} ({col_type})")

    # Forecast cache migrations (ensemble support)
    cursor2 = conn.execute("PRAGMA table_info(forecast_cache)")
    fc_cols = {row["name"] for row in cursor2.fetchall()}

    fc_migrations = [
        ("forecast_cache", "member_values", "TEXT"),      # JSON array of floats
        ("forecast_cache", "n_members", "INTEGER"),
        ("forecast_cache", "ensemble_std", "REAL"),
    ]

    for table, col, col_type in fc_migrations:
        if col not in fc_cols:
            conn.execute(f"ALTER TABLE {table} ADD COLUMN {col} {col_type}")
            logger.info(f"Migration: added column {table}.{col} ({col_type})")

    # Trades migrations (redeem tracking)
    cursor3 = conn.execute("PRAGMA table_info(trades)")
    trades_cols = {row["name"] for row in cursor3.fetchall()}
    if "redeemed" not in trades_cols:
        conn.execute("ALTER TABLE trades ADD COLUMN redeemed INTEGER DEFAULT 0")
        logger.info("Migration: added column trades.redeemed (INTEGER)")

    conn.commit()
    conn.close()
    _local.conn = None  # Reset thread-local ref so get_conn() creates fresh connection


# === Helper functions ===

def save_market(market: dict):
    with _get_conn_safe() as conn:
        conn.execute("""
            INSERT INTO markets
            (id, question, city, date, metric, threshold, threshold_unit,
             operator, yes_token_id, no_token_id, yes_price, no_price,
             volume, end_date, resolution_source, threshold_low, threshold_high, updated_at)
            VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, datetime('now'))
            ON CONFLICT(id) DO UPDATE SET
                question = excluded.question,
                yes_price = excluded.yes_price,
                no_price = excluded.no_price,
                volume = excluded.volume,
                threshold_low = excluded.threshold_low,
                threshold_high = excluded.threshold_high,
                updated_at = datetime('now')
        """, (
            market["id"], market["question"], market.get("city"),
            market.get("date"), market.get("metric"), market.get("threshold"),
            market.get("threshold_unit", "F"), market.get("operator", "gte"),
            market.get("yes_token_id"), market.get("no_token_id"),
            market.get("yes_price"), market.get("no_price"),
            market.get("volume", 0), market.get("end_date"),
            market.get("resolution_source"),
            market.get("threshold_low"), market.get("threshold_high"),
        ))


def save_forecast(forecast: dict):
    with _get_conn_safe() as conn:
        conn.execute("""
            INSERT INTO forecasts
            (market_id, forecast_value, model_probability, sigma, days_ahead, source)
            VALUES (?, ?, ?, ?, ?, ?)
        """, (
            forecast["market_id"], forecast["forecast_value"],
            forecast["model_probability"], forecast["sigma"],
            forecast["days_ahead"], forecast.get("source", "open-meteo")
        ))


def save_trade(trade: dict) -> int:
    try:
        with _get_conn_safe() as conn:
            cursor = conn.execute("""
                INSERT INTO trades
                (market_id, order_id, side, token_id, price, size,
                 edge, edge_tier, model_prob, market_prob, status, dry_run)
                VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
            """, (
                trade["market_id"], trade.get("order_id"), trade["side"],
                trade["token_id"], trade["price"], trade["size"],
                trade["edge"], trade.get("edge_tier"), trade["model_prob"],
                trade["market_prob"], trade.get("status", "pending"),
                trade.get("dry_run", 1)
            ))
            return cursor.lastrowid
    except sqlite3.IntegrityError as e:
        # Blocked by idx_trades_unique_open_live: an open live position already
        # exists for this market+side. Final guard against duplicate orders.
        logger.warning(f"Duplicate live trade blocked by DB constraint ({trade.get('market_id', '?')[:8]}/{trade.get('side')}): {e}")
        return -1


def has_cancelled_for_market(market_id: str, side: str, live_only=False) -> bool:
    """Check if there's a cancelled trade for this market+side that hasn't resolved yet.
    Prevents duplicate orders caused by stale-cancel + CLOB-fill race condition:
    resolver cancels stale pending order in DB, but CLOB already filled it on-chain.
    """
    conn = get_conn()
    dry_filter = " AND dry_run = 0" if live_only else ""
    row = conn.execute(
        f"SELECT COUNT(*) FROM trades WHERE market_id = ? AND side = ? AND status = 'cancelled'{dry_filter}",
        (market_id, side)
    ).fetchone()
    return row[0] > 0


def get_active_trades(live_only=False):
    conn = get_conn()
    dry_filter = " AND t.dry_run = 0" if live_only else ""
    rows = conn.execute(
        f"SELECT t.*, m.question, m.city FROM trades t LEFT JOIN markets m ON t.market_id = m.id WHERE t.status IN ('pending', 'filled', 'simulated') AND t.outcome IS NULL{dry_filter} ORDER BY t.created_at DESC"
    ).fetchall()
    return [dict(r) for r in rows]


def get_recent_trades(limit=50, live_only=False):
    from config import STATS_START_DATE
    conn = get_conn()
    dry_filter = " AND t.dry_run = 0" if live_only else ""
    rows = conn.execute(
        f"SELECT t.*, m.question, m.city FROM trades t LEFT JOIN markets m ON t.market_id = m.id WHERE t.created_at >= ? AND t.status NOT IN ('cancelled', 'unverified'){dry_filter} ORDER BY t.created_at DESC LIMIT ?",
        (STATS_START_DATE, limit)
    ).fetchall()
    return [dict(r) for r in rows]


def get_today_stats(live_only=False):
    from config import STATS_START_DATE
    from zoneinfo import ZoneInfo
    from datetime import timedelta
    # created_at is stored as naive UTC ('YYYY-MM-DD HH:MM:SS'). Bucket the day in
    # Vancouver time (DST-safe) so an evening loss streak is not split across the
    # UTC-midnight boundary, which would weaken the daily-loss kill switch.
    tz = ZoneInfo("America/Vancouver")
    now_van = datetime.now(tz)
    day_start = now_van.replace(hour=0, minute=0, second=0, microsecond=0)
    today = now_van.strftime("%Y-%m-%d")
    start_utc = day_start.astimezone(timezone.utc).strftime("%Y-%m-%d %H:%M:%S")
    end_utc = (day_start + timedelta(days=1)).astimezone(timezone.utc).strftime("%Y-%m-%d %H:%M:%S")
    conn = get_conn()
    dry_filter = " AND dry_run = 0" if live_only else ""
    # Always compute from trades (daily_stats table may include dry-run)
    trades = conn.execute(
        f"SELECT * FROM trades WHERE created_at >= ? AND created_at < ? AND created_at >= ? AND status NOT IN ('cancelled', 'unverified', 'failed'){dry_filter}",
        (start_utc, end_utc, STATS_START_DATE)
    ).fetchall()
    total = len(trades)
    wins = sum(1 for t in trades if t["outcome"] == "win")
    losses = sum(1 for t in trades if t["outcome"] == "loss")
    pnl = sum((t["pnl"] or 0) for t in trades)
    return {
        "date": today, "total_trades": total, "wins": wins,
        "losses": losses, "pnl": round(pnl, 2), "win_rate": round(wins/(wins+losses)*100, 1) if (wins+losses) else 0
    }


def get_bot_state(key, default=None):
    conn = get_conn()
    row = conn.execute("SELECT value FROM bot_state WHERE key = ?", (key,)).fetchone()
    return row["value"] if row else default


def set_bot_state(key, value):
    with _get_conn_safe() as conn:
        conn.execute(
            "INSERT OR REPLACE INTO bot_state (key, value, updated_at) VALUES (?, ?, datetime('now'))",
            (key, str(value))
        )


def get_cached_forecast(cache_key: str, ttl_seconds: int = 7200) -> dict | None:
    """Get forecast from SQLite cache if not expired (with ensemble data)"""
    import json as _json

    conn = get_conn()
    row = conn.execute(
        "SELECT * FROM forecast_cache WHERE cache_key = ? AND cached_at > datetime('now', ?)",
        (cache_key, f"-{ttl_seconds} seconds")
    ).fetchone()
    if not row:
        return None

    result = dict(row)

    # Deserialize member_values from JSON string
    mv = result.get("member_values")
    if mv and isinstance(mv, str):
        try:
            result["member_values"] = _json.loads(mv)
        except Exception:
            result["member_values"] = None
    elif not mv:
        result["member_values"] = None

    return result


def save_cached_forecast(cache_key: str, data: dict):
    """Save forecast to SQLite cache (with optional ensemble data)"""
    import json as _json

    member_values_json = None
    if data.get("member_values"):
        member_values_json = _json.dumps(data["member_values"])

    with _get_conn_safe() as conn:
        conn.execute("""
            INSERT OR REPLACE INTO forecast_cache
            (cache_key, city, date, metric, forecast_value, sigma, days_ahead, source,
             member_values, n_members, ensemble_std, cached_at)
            VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, datetime('now'))
        """, (
            cache_key, data["city"], data["date"], data["metric"],
            data["forecast_value"], data["sigma"], data["days_ahead"], data["source"],
            member_values_json, data.get("n_members", 0), data.get("ensemble_std"),
        ))


def cleanup_forecast_cache():
    """Remove expired cache entries (older than 24h)"""
    with _get_conn_safe() as conn:
        conn.execute("DELETE FROM forecast_cache WHERE cached_at < datetime('now', '-24 hours')")


def cleanup_old_forecasts(days: int = 7):
    """Remove forecasts older than N days to prevent unbounded DB growth."""
    with _get_conn_safe() as conn:
        cur = conn.execute(
            "DELETE FROM forecasts WHERE fetched_at < datetime('now', ?)",
            (f'-{days} days',)
        )
        if cur.rowcount > 0:
            logger.info(f"Cleaned up {cur.rowcount} old forecast rows (>{days} days)")


# ============================================================
# Calibration (self-learning)
# ============================================================

def save_calibration(data: dict):
    """Save forecast vs actual observation for calibration"""
    with _get_conn_safe() as conn:
        # Avoid duplicates: same city+date+metric
        existing = conn.execute(
            "SELECT id FROM calibration WHERE city = ? AND date = ? AND metric = ?",
            (data["city"], data["date"], data["metric"])
        ).fetchone()
        if existing:
            return  # Already recorded

        fv = data.get("forecast_value")
        av = data.get("actual_value")
        if fv is None or av is None:
            return  # Can't compute calibration without both values
        error_f = abs(fv - av)

        model_p = data.get("model_prob") or 0.5
        outcome_b = data.get("outcome_binary")
        # Brier score only meaningful for binary outcomes (gte/lte), not between/eq
        brier = round((model_p - outcome_b) ** 2, 4) if outcome_b is not None else None

        conn.execute("""
            INSERT INTO calibration
            (city, date, metric, forecast_value, actual_value, model_prob,
             outcome_binary, error_f, brier_component, days_ahead, source)
            VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
        """, (
            data["city"], data["date"], data["metric"],
            fv, av,
            data.get("model_prob"), outcome_b,
            round(error_f, 2), brier,
            data.get("days_ahead"), data.get("source"),
        ))


def get_city_calibration(city: str, min_samples: int = 5) -> dict | None:
    """
    Get calibration stats for a city.
    Returns: {brier_score, mae, sample_count, confidence_mult}
    """
    conn = get_conn()
    rows = conn.execute(
        "SELECT error_f, brier_component FROM calibration WHERE city = ?",
        (city,)
    ).fetchall()

    if len(rows) < min_samples:
        return None  # Not enough data to calibrate

    errors = [r["error_f"] for r in rows]
    briers = [r["brier_component"] for r in rows if r["brier_component"] is not None]

    mae = sum(errors) / len(errors)
    brier_score = sum(briers) / len(briers) if briers else 0.5

    # Confidence multiplier:
    # Brier 0.0 = perfect → mult 1.2 (bonus)
    # Brier 0.25 = baseline (random coin) → mult 1.0
    # Brier 0.5+ = terrible → mult 0.3 (heavy penalty)
    if brier_score < 0.15:
        confidence_mult = 1.2  # Good calibration — slight bonus
    elif brier_score < 0.25:
        confidence_mult = 1.0  # Average — no change
    elif brier_score < 0.35:
        confidence_mult = 0.6  # Below average — reduce bets
    else:
        confidence_mult = 0.3  # Bad — heavily reduce

    return {
        "city": city,
        "brier_score": round(brier_score, 4),
        "mae": round(mae, 2),
        "sample_count": len(rows),
        "confidence_mult": confidence_mult,
    }


def get_all_city_calibrations(min_samples: int = 5) -> list[dict]:
    """Get calibration stats for all cities with enough data"""
    conn = get_conn()
    cities = conn.execute(
        "SELECT DISTINCT city FROM calibration GROUP BY city HAVING COUNT(*) >= ?",
        (min_samples,)
    ).fetchall()

    results = []
    for row in cities:
        cal = get_city_calibration(row["city"], min_samples)
        if cal:
            results.append(cal)

    return sorted(results, key=lambda x: x["brier_score"])

📜 Git History

ddaa0a2fix(audit): chunk 2 - kill switch, auth, config safety5 weeks ago
16f2ea8fix(audit): chunk 1 - critical money/correctness bugs5 weeks ago
8fca132chore: initial commit — version control setup5 weeks ago
Show last diff
Loading...