← Back
"""
Risk Management — position sizing, exposure limits, daily loss tracking.
"""

import logging
from datetime import datetime

from src.config import (
    MAX_OPEN_POSITIONS,
    MAX_DAILY_LOSS_USDT,
    MAX_TOTAL_EXPOSURE_PCT,
    TRADE_SIZE_USDT,
)
from src.exchange.client import BybitFuturesClient
from src.core.position import PositionTracker
from src.core.trade_log import load_trade_log, now_van

logger = logging.getLogger(__name__)


class RiskManager:
    """Checks risk limits before allowing new trades."""

    def __init__(self, client: BybitFuturesClient, tracker: PositionTracker):
        self.client = client
        self.tracker = tracker

    def can_open_position(self, symbol: str, strategy: str = "") -> tuple[bool, str]:
        """
        Check if a new position can be opened.
        Returns (allowed, reason).
        """
        # 1. Already in this symbol?
        if self.tracker.has(symbol):
            return False, f"Already in {symbol}"

        # 2. Position exists on exchange but not tracked?
        existing = self.client.get_position(symbol)
        if existing:
            return False, f"Untracked position exists on Bybit for {symbol}"

        # 3. Max positions
        count = self.tracker.count(strategy) if strategy else self.tracker.count()
        if count >= MAX_OPEN_POSITIONS:
            return False, f"Max positions ({MAX_OPEN_POSITIONS}) reached"

        # 4. Balance check
        balance = self.client.get_account_balance()
        if balance < TRADE_SIZE_USDT:
            return False, f"Low balance: ${balance:.2f} < ${TRADE_SIZE_USDT}"

        # 5. Daily loss limit
        daily_loss = self._get_daily_loss()
        if daily_loss >= MAX_DAILY_LOSS_USDT:
            return False, f"Daily loss limit hit: ${daily_loss:.2f} >= ${MAX_DAILY_LOSS_USDT}"

        # 6. Total exposure check
        equity = self.client.get_total_equity()
        if equity > 0:
            total_exposure = sum(
                p.quantity * p.entry_price for p in self.tracker.all()
            )
            # Add the new position notional
            proposed_exposure = total_exposure + TRADE_SIZE_USDT * 3  # rough estimate
            exposure_pct = (proposed_exposure / equity) * 100
            if exposure_pct > MAX_TOTAL_EXPOSURE_PCT:
                return False, f"Exposure {exposure_pct:.0f}% > max {MAX_TOTAL_EXPOSURE_PCT}%"

        return True, "OK"

    def _get_daily_loss(self) -> float:
        """Sum of negative PnL events today."""
        log = load_trade_log()
        today = now_van().replace(hour=0, minute=0, second=0, microsecond=0)
        daily_loss = 0.0

        for e in log:
            try:
                ts = datetime.fromisoformat(e["timestamp"])
                if ts >= today:
                    pnl = e.get("pnl_usdt", 0) or e.get("realized_pnl_usdt", 0)
                    if pnl < 0:
                        daily_loss += abs(pnl)
            except Exception:
                continue

        return daily_loss

    def calculate_position_size(self, balance: float, risk_pct: float,
                                entry: float, stop_loss: float) -> float:
        """
        Risk-based position sizing.
        size = (balance * risk_pct) / |entry - stop_loss|
        """
        distance = abs(entry - stop_loss)
        if distance <= 0:
            return 0
        risk_amount = balance * (risk_pct / 100)
        return risk_amount / distance