โ† ะะฐะทะฐะด
"""SQLite database schema and helpers""" import sqlite3 import os import threading from contextlib import contextmanager from datetime import datetime, timezone from loguru import logger DB_PATH = os.path.join(os.path.dirname(__file__), "data", "weather-bot.db") _local = threading.local() def get_conn(): """Get thread-local SQLite connection (reused within same thread)""" conn = getattr(_local, 'conn', None) if conn is not None: try: conn.execute("SELECT 1") # check if still alive return conn except sqlite3.ProgrammingError: _local.conn = None conn = sqlite3.connect(DB_PATH, timeout=10) conn.row_factory = sqlite3.Row conn.execute("PRAGMA foreign_keys=ON") _local.conn = conn return conn @contextmanager def _get_conn_safe(): """Context manager for safe connection usage with auto-commit/rollback""" conn = get_conn() try: yield conn conn.commit() except Exception: conn.rollback() raise def init_db(): """Create tables if not exist, set WAL mode (persistent, once)""" conn = get_conn() conn.execute("PRAGMA journal_mode=WAL") conn.executescript(""" CREATE TABLE IF NOT EXISTS markets ( id TEXT PRIMARY KEY, question TEXT NOT NULL, city TEXT, date TEXT, metric TEXT, threshold REAL, threshold_unit TEXT DEFAULT 'F', operator TEXT DEFAULT 'gte', yes_token_id TEXT, no_token_id TEXT, yes_price REAL, no_price REAL, volume REAL DEFAULT 0, end_date TEXT, resolution_source TEXT, resolved INTEGER DEFAULT 0, outcome TEXT, threshold_low REAL, threshold_high REAL, first_seen TEXT DEFAULT (datetime('now')), updated_at TEXT DEFAULT (datetime('now')), status TEXT DEFAULT 'active' ); CREATE TABLE IF NOT EXISTS forecasts ( id INTEGER PRIMARY KEY AUTOINCREMENT, market_id TEXT NOT NULL, forecast_value REAL, model_probability REAL, sigma REAL, days_ahead INTEGER, source TEXT DEFAULT 'open-meteo', fetched_at TEXT DEFAULT (datetime('now')), FOREIGN KEY (market_id) REFERENCES markets(id) ); CREATE TABLE IF NOT EXISTS trades ( id INTEGER PRIMARY KEY AUTOINCREMENT, market_id TEXT NOT NULL, order_id TEXT, side TEXT NOT NULL, token_id TEXT NOT NULL, price REAL NOT NULL, size REAL NOT NULL, edge REAL, edge_tier TEXT, model_prob REAL, market_prob REAL, status TEXT DEFAULT 'pending', filled_at TEXT, resolved_at TEXT, outcome TEXT, pnl REAL DEFAULT 0, dry_run INTEGER DEFAULT 1, created_at TEXT DEFAULT (datetime('now')), FOREIGN KEY (market_id) REFERENCES markets(id) ); CREATE TABLE IF NOT EXISTS daily_stats ( date TEXT PRIMARY KEY, total_trades INTEGER DEFAULT 0, wins INTEGER DEFAULT 0, losses INTEGER DEFAULT 0, pnl REAL DEFAULT 0, bankroll REAL DEFAULT 0, avg_edge REAL DEFAULT 0, best_edge REAL DEFAULT 0, markets_scanned INTEGER DEFAULT 0, opportunities_found INTEGER DEFAULT 0 ); CREATE TABLE IF NOT EXISTS bot_state ( key TEXT PRIMARY KEY, value TEXT, updated_at TEXT DEFAULT (datetime('now')) ); CREATE TABLE IF NOT EXISTS forecast_cache ( cache_key TEXT PRIMARY KEY, city TEXT NOT NULL, date TEXT NOT NULL, metric TEXT NOT NULL, forecast_value REAL NOT NULL, sigma REAL NOT NULL, days_ahead INTEGER, source TEXT, cached_at TEXT DEFAULT (datetime('now')) ); CREATE TABLE IF NOT EXISTS calibration ( id INTEGER PRIMARY KEY AUTOINCREMENT, city TEXT NOT NULL, date TEXT NOT NULL, metric TEXT NOT NULL, forecast_value REAL NOT NULL, actual_value REAL NOT NULL, model_prob REAL, outcome_binary INTEGER, error_f REAL, brier_component REAL, days_ahead INTEGER, source TEXT, created_at TEXT DEFAULT (datetime('now')) ); CREATE INDEX IF NOT EXISTS idx_calibration_city ON calibration(city); """) conn.commit() conn.close() # Run migrations for existing databases _run_migrations() logger.info(f"Database initialized: {DB_PATH}") def _run_migrations(): """Add missing columns to existing tables (safe to run multiple times)""" conn = get_conn() # Get current columns in markets table cursor = conn.execute("PRAGMA table_info(markets)") existing_cols = {row["name"] for row in cursor.fetchall()} migrations_markets = [ ("markets", "threshold_low", "REAL"), ("markets", "threshold_high", "REAL"), ] for table, col, col_type in migrations_markets: if col not in existing_cols: conn.execute(f"ALTER TABLE {table} ADD COLUMN {col} {col_type}") logger.info(f"Migration: added column {table}.{col} ({col_type})") # Forecast cache migrations (ensemble support) cursor2 = conn.execute("PRAGMA table_info(forecast_cache)") fc_cols = {row["name"] for row in cursor2.fetchall()} fc_migrations = [ ("forecast_cache", "member_values", "TEXT"), # JSON array of floats ("forecast_cache", "n_members", "INTEGER"), ("forecast_cache", "ensemble_std", "REAL"), ] for table, col, col_type in fc_migrations: if col not in fc_cols: conn.execute(f"ALTER TABLE {table} ADD COLUMN {col} {col_type}") logger.info(f"Migration: added column {table}.{col} ({col_type})") conn.commit() conn.close() # === Helper functions === def save_market(market: dict): with _get_conn_safe() as conn: conn.execute(""" INSERT INTO markets (id, question, city, date, metric, threshold, threshold_unit, operator, yes_token_id, no_token_id, yes_price, no_price, volume, end_date, resolution_source, threshold_low, threshold_high, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, datetime('now')) ON CONFLICT(id) DO UPDATE SET question = excluded.question, yes_price = excluded.yes_price, no_price = excluded.no_price, volume = excluded.volume, threshold_low = excluded.threshold_low, threshold_high = excluded.threshold_high, updated_at = datetime('now') """, ( market["id"], market["question"], market.get("city"), market.get("date"), market.get("metric"), market.get("threshold"), market.get("threshold_unit", "F"), market.get("operator", "gte"), market.get("yes_token_id"), market.get("no_token_id"), market.get("yes_price"), market.get("no_price"), market.get("volume", 0), market.get("end_date"), market.get("resolution_source"), market.get("threshold_low"), market.get("threshold_high"), )) def save_forecast(forecast: dict): with _get_conn_safe() as conn: conn.execute(""" INSERT INTO forecasts (market_id, forecast_value, model_probability, sigma, days_ahead, source) VALUES (?, ?, ?, ?, ?, ?) """, ( forecast["market_id"], forecast["forecast_value"], forecast["model_probability"], forecast["sigma"], forecast["days_ahead"], forecast.get("source", "open-meteo") )) def save_trade(trade: dict) -> int: with _get_conn_safe() as conn: cursor = conn.execute(""" INSERT INTO trades (market_id, order_id, side, token_id, price, size, edge, edge_tier, model_prob, market_prob, status, dry_run) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( trade["market_id"], trade.get("order_id"), trade["side"], trade["token_id"], trade["price"], trade["size"], trade["edge"], trade.get("edge_tier"), trade["model_prob"], trade["market_prob"], trade.get("status", "pending"), trade.get("dry_run", 1) )) return cursor.lastrowid def get_active_trades(): conn = get_conn() rows = conn.execute( "SELECT t.*, m.question, m.city FROM trades t LEFT JOIN markets m ON t.market_id = m.id WHERE t.status IN ('pending', 'filled', 'unverified', 'simulated') AND t.outcome IS NULL ORDER BY t.created_at DESC" ).fetchall() return [dict(r) for r in rows] def get_recent_trades(limit=50): conn = get_conn() rows = conn.execute( "SELECT t.*, m.question, m.city FROM trades t LEFT JOIN markets m ON t.market_id = m.id ORDER BY t.created_at DESC LIMIT ?", (limit,) ).fetchall() return [dict(r) for r in rows] def get_today_stats(): today = datetime.now(timezone.utc).strftime("%Y-%m-%d") conn = get_conn() row = conn.execute( "SELECT * FROM daily_stats WHERE date = ?", (today,) ).fetchone() if not row: trades = conn.execute( "SELECT * FROM trades WHERE date(created_at) = ?", (today,) ).fetchall() total = len(trades) wins = sum(1 for t in trades if t["outcome"] == "win") losses = sum(1 for t in trades if t["outcome"] == "loss") pnl = sum(t["pnl"] for t in trades if t["pnl"]) return { "date": today, "total_trades": total, "wins": wins, "losses": losses, "pnl": round(pnl, 2), "win_rate": round(wins/(wins+losses)*100, 1) if (wins+losses) else 0 } result = dict(row) total = result.get("total_trades", 0) resolved = result.get("wins", 0) + result.get("losses", 0) result["win_rate"] = round(result["wins"]/resolved*100, 1) if resolved else 0 return result def get_bot_state(key, default=None): conn = get_conn() row = conn.execute("SELECT value FROM bot_state WHERE key = ?", (key,)).fetchone() return row["value"] if row else default def set_bot_state(key, value): with _get_conn_safe() as conn: conn.execute( "INSERT OR REPLACE INTO bot_state (key, value, updated_at) VALUES (?, ?, datetime('now'))", (key, str(value)) ) def get_cached_forecast(cache_key: str, ttl_seconds: int = 7200) -> dict | None: """Get forecast from SQLite cache if not expired (with ensemble data)""" import json as _json conn = get_conn() row = conn.execute( "SELECT * FROM forecast_cache WHERE cache_key = ? AND cached_at > datetime('now', ?)", (cache_key, f"-{ttl_seconds} seconds") ).fetchone() if not row: return None result = dict(row) # Deserialize member_values from JSON string mv = result.get("member_values") if mv and isinstance(mv, str): try: result["member_values"] = _json.loads(mv) except Exception: result["member_values"] = None elif not mv: result["member_values"] = None return result def save_cached_forecast(cache_key: str, data: dict): """Save forecast to SQLite cache (with optional ensemble data)""" import json as _json member_values_json = None if data.get("member_values"): member_values_json = _json.dumps(data["member_values"]) with _get_conn_safe() as conn: conn.execute(""" INSERT OR REPLACE INTO forecast_cache (cache_key, city, date, metric, forecast_value, sigma, days_ahead, source, member_values, n_members, ensemble_std, cached_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, datetime('now')) """, ( cache_key, data["city"], data["date"], data["metric"], data["forecast_value"], data["sigma"], data["days_ahead"], data["source"], member_values_json, data.get("n_members", 0), data.get("ensemble_std"), )) def cleanup_forecast_cache(): """Remove expired cache entries (older than 24h)""" with _get_conn_safe() as conn: conn.execute("DELETE FROM forecast_cache WHERE cached_at < datetime('now', '-24 hours')") # ============================================================ # Calibration (self-learning) # ============================================================ def save_calibration(data: dict): """Save forecast vs actual observation for calibration""" with _get_conn_safe() as conn: # Avoid duplicates: same city+date+metric existing = conn.execute( "SELECT id FROM calibration WHERE city = ? AND date = ? AND metric = ?", (data["city"], data["date"], data["metric"]) ).fetchone() if existing: return # Already recorded fv = data.get("forecast_value") av = data.get("actual_value") if fv is None or av is None: return # Can't compute calibration without both values error_f = abs(fv - av) model_p = data.get("model_prob") or 0.5 outcome_b = data.get("outcome_binary") # Brier score only meaningful for binary outcomes (gte/lte), not between/eq brier = round((model_p - outcome_b) ** 2, 4) if outcome_b is not None else None conn.execute(""" INSERT INTO calibration (city, date, metric, forecast_value, actual_value, model_prob, outcome_binary, error_f, brier_component, days_ahead, source) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( data["city"], data["date"], data["metric"], fv, av, data.get("model_prob"), outcome_b, round(error_f, 2), brier, data.get("days_ahead"), data.get("source"), )) def get_city_calibration(city: str, min_samples: int = 5) -> dict | None: """ Get calibration stats for a city. Returns: {brier_score, mae, sample_count, confidence_mult} """ conn = get_conn() rows = conn.execute( "SELECT error_f, brier_component FROM calibration WHERE city = ?", (city,) ).fetchall() if len(rows) < min_samples: return None # Not enough data to calibrate errors = [r["error_f"] for r in rows] briers = [r["brier_component"] for r in rows if r["brier_component"] is not None] mae = sum(errors) / len(errors) brier_score = sum(briers) / len(briers) if briers else 0.5 # Confidence multiplier: # Brier 0.0 = perfect โ†’ mult 1.2 (bonus) # Brier 0.25 = baseline (random coin) โ†’ mult 1.0 # Brier 0.5+ = terrible โ†’ mult 0.3 (heavy penalty) if brier_score < 0.15: confidence_mult = 1.2 # Good calibration โ€” slight bonus elif brier_score < 0.25: confidence_mult = 1.0 # Average โ€” no change elif brier_score < 0.35: confidence_mult = 0.6 # Below average โ€” reduce bets else: confidence_mult = 0.3 # Bad โ€” heavily reduce return { "city": city, "brier_score": round(brier_score, 4), "mae": round(mae, 2), "sample_count": len(rows), "confidence_mult": confidence_mult, } def get_all_city_calibrations(min_samples: int = 5) -> list[dict]: """Get calibration stats for all cities with enough data""" conn = get_conn() cities = conn.execute( "SELECT DISTINCT city FROM calibration GROUP BY city HAVING COUNT(*) >= ?", (min_samples,) ).fetchall() results = [] for row in cities: cal = get_city_calibration(row["city"], min_samples) if cal: results.append(cal) return sorted(results, key=lambda x: x["brier_score"])