โ ะะฐะทะฐะด"""SQLite database schema and helpers"""
import sqlite3
import os
import threading
from contextlib import contextmanager
from datetime import datetime, timezone
from loguru import logger
DB_PATH = os.path.join(os.path.dirname(__file__), "data", "weather-bot.db")
_local = threading.local()
def get_conn():
"""Get thread-local SQLite connection (reused within same thread)"""
conn = getattr(_local, 'conn', None)
if conn is not None:
try:
conn.execute("SELECT 1") # check if still alive
return conn
except sqlite3.ProgrammingError:
_local.conn = None
conn = sqlite3.connect(DB_PATH, timeout=10)
conn.row_factory = sqlite3.Row
conn.execute("PRAGMA foreign_keys=ON")
_local.conn = conn
return conn
@contextmanager
def _get_conn_safe():
"""Context manager for safe connection usage with auto-commit/rollback"""
conn = get_conn()
try:
yield conn
conn.commit()
except Exception:
conn.rollback()
raise
def init_db():
"""Create tables if not exist, set WAL mode (persistent, once)"""
conn = get_conn()
conn.execute("PRAGMA journal_mode=WAL")
conn.executescript("""
CREATE TABLE IF NOT EXISTS markets (
id TEXT PRIMARY KEY,
question TEXT NOT NULL,
city TEXT,
date TEXT,
metric TEXT,
threshold REAL,
threshold_unit TEXT DEFAULT 'F',
operator TEXT DEFAULT 'gte',
yes_token_id TEXT,
no_token_id TEXT,
yes_price REAL,
no_price REAL,
volume REAL DEFAULT 0,
end_date TEXT,
resolution_source TEXT,
resolved INTEGER DEFAULT 0,
outcome TEXT,
threshold_low REAL,
threshold_high REAL,
first_seen TEXT DEFAULT (datetime('now')),
updated_at TEXT DEFAULT (datetime('now')),
status TEXT DEFAULT 'active'
);
CREATE TABLE IF NOT EXISTS forecasts (
id INTEGER PRIMARY KEY AUTOINCREMENT,
market_id TEXT NOT NULL,
forecast_value REAL,
model_probability REAL,
sigma REAL,
days_ahead INTEGER,
source TEXT DEFAULT 'open-meteo',
fetched_at TEXT DEFAULT (datetime('now')),
FOREIGN KEY (market_id) REFERENCES markets(id)
);
CREATE TABLE IF NOT EXISTS trades (
id INTEGER PRIMARY KEY AUTOINCREMENT,
market_id TEXT NOT NULL,
order_id TEXT,
side TEXT NOT NULL,
token_id TEXT NOT NULL,
price REAL NOT NULL,
size REAL NOT NULL,
edge REAL,
edge_tier TEXT,
model_prob REAL,
market_prob REAL,
status TEXT DEFAULT 'pending',
filled_at TEXT,
resolved_at TEXT,
outcome TEXT,
pnl REAL DEFAULT 0,
dry_run INTEGER DEFAULT 1,
created_at TEXT DEFAULT (datetime('now')),
FOREIGN KEY (market_id) REFERENCES markets(id)
);
CREATE TABLE IF NOT EXISTS daily_stats (
date TEXT PRIMARY KEY,
total_trades INTEGER DEFAULT 0,
wins INTEGER DEFAULT 0,
losses INTEGER DEFAULT 0,
pnl REAL DEFAULT 0,
bankroll REAL DEFAULT 0,
avg_edge REAL DEFAULT 0,
best_edge REAL DEFAULT 0,
markets_scanned INTEGER DEFAULT 0,
opportunities_found INTEGER DEFAULT 0
);
CREATE TABLE IF NOT EXISTS bot_state (
key TEXT PRIMARY KEY,
value TEXT,
updated_at TEXT DEFAULT (datetime('now'))
);
CREATE TABLE IF NOT EXISTS forecast_cache (
cache_key TEXT PRIMARY KEY,
city TEXT NOT NULL,
date TEXT NOT NULL,
metric TEXT NOT NULL,
forecast_value REAL NOT NULL,
sigma REAL NOT NULL,
days_ahead INTEGER,
source TEXT,
cached_at TEXT DEFAULT (datetime('now'))
);
CREATE TABLE IF NOT EXISTS calibration (
id INTEGER PRIMARY KEY AUTOINCREMENT,
city TEXT NOT NULL,
date TEXT NOT NULL,
metric TEXT NOT NULL,
forecast_value REAL NOT NULL,
actual_value REAL NOT NULL,
model_prob REAL,
outcome_binary INTEGER,
error_f REAL,
brier_component REAL,
days_ahead INTEGER,
source TEXT,
created_at TEXT DEFAULT (datetime('now'))
);
CREATE INDEX IF NOT EXISTS idx_calibration_city ON calibration(city);
""")
conn.commit()
conn.close()
# Run migrations for existing databases
_run_migrations()
logger.info(f"Database initialized: {DB_PATH}")
def _run_migrations():
"""Add missing columns to existing tables (safe to run multiple times)"""
conn = get_conn()
# Get current columns in markets table
cursor = conn.execute("PRAGMA table_info(markets)")
existing_cols = {row["name"] for row in cursor.fetchall()}
migrations_markets = [
("markets", "threshold_low", "REAL"),
("markets", "threshold_high", "REAL"),
]
for table, col, col_type in migrations_markets:
if col not in existing_cols:
conn.execute(f"ALTER TABLE {table} ADD COLUMN {col} {col_type}")
logger.info(f"Migration: added column {table}.{col} ({col_type})")
# Forecast cache migrations (ensemble support)
cursor2 = conn.execute("PRAGMA table_info(forecast_cache)")
fc_cols = {row["name"] for row in cursor2.fetchall()}
fc_migrations = [
("forecast_cache", "member_values", "TEXT"), # JSON array of floats
("forecast_cache", "n_members", "INTEGER"),
("forecast_cache", "ensemble_std", "REAL"),
]
for table, col, col_type in fc_migrations:
if col not in fc_cols:
conn.execute(f"ALTER TABLE {table} ADD COLUMN {col} {col_type}")
logger.info(f"Migration: added column {table}.{col} ({col_type})")
conn.commit()
conn.close()
# === Helper functions ===
def save_market(market: dict):
with _get_conn_safe() as conn:
conn.execute("""
INSERT INTO markets
(id, question, city, date, metric, threshold, threshold_unit,
operator, yes_token_id, no_token_id, yes_price, no_price,
volume, end_date, resolution_source, threshold_low, threshold_high, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, datetime('now'))
ON CONFLICT(id) DO UPDATE SET
question = excluded.question,
yes_price = excluded.yes_price,
no_price = excluded.no_price,
volume = excluded.volume,
threshold_low = excluded.threshold_low,
threshold_high = excluded.threshold_high,
updated_at = datetime('now')
""", (
market["id"], market["question"], market.get("city"),
market.get("date"), market.get("metric"), market.get("threshold"),
market.get("threshold_unit", "F"), market.get("operator", "gte"),
market.get("yes_token_id"), market.get("no_token_id"),
market.get("yes_price"), market.get("no_price"),
market.get("volume", 0), market.get("end_date"),
market.get("resolution_source"),
market.get("threshold_low"), market.get("threshold_high"),
))
def save_forecast(forecast: dict):
with _get_conn_safe() as conn:
conn.execute("""
INSERT INTO forecasts
(market_id, forecast_value, model_probability, sigma, days_ahead, source)
VALUES (?, ?, ?, ?, ?, ?)
""", (
forecast["market_id"], forecast["forecast_value"],
forecast["model_probability"], forecast["sigma"],
forecast["days_ahead"], forecast.get("source", "open-meteo")
))
def save_trade(trade: dict) -> int:
with _get_conn_safe() as conn:
cursor = conn.execute("""
INSERT INTO trades
(market_id, order_id, side, token_id, price, size,
edge, edge_tier, model_prob, market_prob, status, dry_run)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""", (
trade["market_id"], trade.get("order_id"), trade["side"],
trade["token_id"], trade["price"], trade["size"],
trade["edge"], trade.get("edge_tier"), trade["model_prob"],
trade["market_prob"], trade.get("status", "pending"),
trade.get("dry_run", 1)
))
return cursor.lastrowid
def get_active_trades():
conn = get_conn()
rows = conn.execute(
"SELECT t.*, m.question, m.city FROM trades t LEFT JOIN markets m ON t.market_id = m.id WHERE t.status IN ('pending', 'filled', 'unverified', 'simulated') AND t.outcome IS NULL ORDER BY t.created_at DESC"
).fetchall()
return [dict(r) for r in rows]
def get_recent_trades(limit=50):
conn = get_conn()
rows = conn.execute(
"SELECT t.*, m.question, m.city FROM trades t LEFT JOIN markets m ON t.market_id = m.id ORDER BY t.created_at DESC LIMIT ?",
(limit,)
).fetchall()
return [dict(r) for r in rows]
def get_today_stats():
today = datetime.now(timezone.utc).strftime("%Y-%m-%d")
conn = get_conn()
row = conn.execute(
"SELECT * FROM daily_stats WHERE date = ?", (today,)
).fetchone()
if not row:
trades = conn.execute(
"SELECT * FROM trades WHERE date(created_at) = ?", (today,)
).fetchall()
total = len(trades)
wins = sum(1 for t in trades if t["outcome"] == "win")
losses = sum(1 for t in trades if t["outcome"] == "loss")
pnl = sum(t["pnl"] for t in trades if t["pnl"])
return {
"date": today, "total_trades": total, "wins": wins,
"losses": losses, "pnl": round(pnl, 2), "win_rate": round(wins/(wins+losses)*100, 1) if (wins+losses) else 0
}
result = dict(row)
total = result.get("total_trades", 0)
resolved = result.get("wins", 0) + result.get("losses", 0)
result["win_rate"] = round(result["wins"]/resolved*100, 1) if resolved else 0
return result
def get_bot_state(key, default=None):
conn = get_conn()
row = conn.execute("SELECT value FROM bot_state WHERE key = ?", (key,)).fetchone()
return row["value"] if row else default
def set_bot_state(key, value):
with _get_conn_safe() as conn:
conn.execute(
"INSERT OR REPLACE INTO bot_state (key, value, updated_at) VALUES (?, ?, datetime('now'))",
(key, str(value))
)
def get_cached_forecast(cache_key: str, ttl_seconds: int = 7200) -> dict | None:
"""Get forecast from SQLite cache if not expired (with ensemble data)"""
import json as _json
conn = get_conn()
row = conn.execute(
"SELECT * FROM forecast_cache WHERE cache_key = ? AND cached_at > datetime('now', ?)",
(cache_key, f"-{ttl_seconds} seconds")
).fetchone()
if not row:
return None
result = dict(row)
# Deserialize member_values from JSON string
mv = result.get("member_values")
if mv and isinstance(mv, str):
try:
result["member_values"] = _json.loads(mv)
except Exception:
result["member_values"] = None
elif not mv:
result["member_values"] = None
return result
def save_cached_forecast(cache_key: str, data: dict):
"""Save forecast to SQLite cache (with optional ensemble data)"""
import json as _json
member_values_json = None
if data.get("member_values"):
member_values_json = _json.dumps(data["member_values"])
with _get_conn_safe() as conn:
conn.execute("""
INSERT OR REPLACE INTO forecast_cache
(cache_key, city, date, metric, forecast_value, sigma, days_ahead, source,
member_values, n_members, ensemble_std, cached_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, datetime('now'))
""", (
cache_key, data["city"], data["date"], data["metric"],
data["forecast_value"], data["sigma"], data["days_ahead"], data["source"],
member_values_json, data.get("n_members", 0), data.get("ensemble_std"),
))
def cleanup_forecast_cache():
"""Remove expired cache entries (older than 24h)"""
with _get_conn_safe() as conn:
conn.execute("DELETE FROM forecast_cache WHERE cached_at < datetime('now', '-24 hours')")
# ============================================================
# Calibration (self-learning)
# ============================================================
def save_calibration(data: dict):
"""Save forecast vs actual observation for calibration"""
with _get_conn_safe() as conn:
# Avoid duplicates: same city+date+metric
existing = conn.execute(
"SELECT id FROM calibration WHERE city = ? AND date = ? AND metric = ?",
(data["city"], data["date"], data["metric"])
).fetchone()
if existing:
return # Already recorded
fv = data.get("forecast_value")
av = data.get("actual_value")
if fv is None or av is None:
return # Can't compute calibration without both values
error_f = abs(fv - av)
model_p = data.get("model_prob") or 0.5
outcome_b = data.get("outcome_binary")
# Brier score only meaningful for binary outcomes (gte/lte), not between/eq
brier = round((model_p - outcome_b) ** 2, 4) if outcome_b is not None else None
conn.execute("""
INSERT INTO calibration
(city, date, metric, forecast_value, actual_value, model_prob,
outcome_binary, error_f, brier_component, days_ahead, source)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""", (
data["city"], data["date"], data["metric"],
fv, av,
data.get("model_prob"), outcome_b,
round(error_f, 2), brier,
data.get("days_ahead"), data.get("source"),
))
def get_city_calibration(city: str, min_samples: int = 5) -> dict | None:
"""
Get calibration stats for a city.
Returns: {brier_score, mae, sample_count, confidence_mult}
"""
conn = get_conn()
rows = conn.execute(
"SELECT error_f, brier_component FROM calibration WHERE city = ?",
(city,)
).fetchall()
if len(rows) < min_samples:
return None # Not enough data to calibrate
errors = [r["error_f"] for r in rows]
briers = [r["brier_component"] for r in rows if r["brier_component"] is not None]
mae = sum(errors) / len(errors)
brier_score = sum(briers) / len(briers) if briers else 0.5
# Confidence multiplier:
# Brier 0.0 = perfect โ mult 1.2 (bonus)
# Brier 0.25 = baseline (random coin) โ mult 1.0
# Brier 0.5+ = terrible โ mult 0.3 (heavy penalty)
if brier_score < 0.15:
confidence_mult = 1.2 # Good calibration โ slight bonus
elif brier_score < 0.25:
confidence_mult = 1.0 # Average โ no change
elif brier_score < 0.35:
confidence_mult = 0.6 # Below average โ reduce bets
else:
confidence_mult = 0.3 # Bad โ heavily reduce
return {
"city": city,
"brier_score": round(brier_score, 4),
"mae": round(mae, 2),
"sample_count": len(rows),
"confidence_mult": confidence_mult,
}
def get_all_city_calibrations(min_samples: int = 5) -> list[dict]:
"""Get calibration stats for all cities with enough data"""
conn = get_conn()
cities = conn.execute(
"SELECT DISTINCT city FROM calibration GROUP BY city HAVING COUNT(*) >= ?",
(min_samples,)
).fetchall()
results = []
for row in cities:
cal = get_city_calibration(row["city"], min_samples)
if cal:
results.append(cal)
return sorted(results, key=lambda x: x["brier_score"])