← Back
ā˜†
"""
Base Strategy — abstract interface that all strategies must implement.

Each strategy:
  1. Has a name and event prefix (for trade log)
  2. Scans for opportunities (scan loop)
  3. Opens positions on signals
  4. Monitors positions (TP/SL/time stop)
  5. Reports its config

To add a new strategy:
  1. Create src/strategies/my_strategy.py
  2. Subclass BaseStrategy
  3. Register in main.py
"""

import asyncio
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Callable, Awaitable

from src.exchange.client import BybitFuturesClient
from src.exchange.order_manager import OrderManager
from src.core.position import Position, PositionTracker
from src.core.risk import RiskManager
from src.core.trade_log import log_event

logger = logging.getLogger(__name__)


@dataclass
class Signal:
    """Trade signal from a strategy scanner."""
    symbol: str
    side: str  # "BUY" or "SELL"
    reason: str  # Human-readable reason
    metadata: dict | None = None  # Strategy-specific data


class BaseStrategy(ABC):
    """Abstract base for all trading strategies."""

    # Subclasses must set these
    name: str = "BASE"             # e.g. "FUNDING", "OI_DIVERGENCE"
    event_prefix: str = "BASE_"   # e.g. "FUND_", "OI_"
    scan_interval: int = 60       # seconds between scans
    check_interval: int = 3       # seconds between position checks

    def __init__(
        self,
        client: BybitFuturesClient,
        order_mgr: OrderManager,
        tracker: PositionTracker,
        risk_mgr: RiskManager,
        notify_fn: Callable[[str], Awaitable[None]],
    ):
        self.client = client
        self.order_mgr = order_mgr
        self.tracker = tracker
        self.risk = risk_mgr
        self.notify = notify_fn

    @abstractmethod
    async def scan(self) -> list[Signal]:
        """
        Scan market for trade opportunities.
        Returns list of Signals to potentially act on.
        """
        ...

    @abstractmethod
    async def on_signal(self, signal: Signal) -> bool:
        """
        Execute entry on a signal. Should:
        1. Check risk (self.risk.can_open_position)
        2. Open position via self.client
        3. Place TP/SL via self.order_mgr
        4. Add to self.tracker
        5. Log entry via log_event
        6. Notify via self.notify

        Returns True if position opened.
        """
        ...

    @abstractmethod
    async def check_position(self, pos: Position):
        """
        Monitor a single position. Should:
        1. Check TP/SL fills (exchange orders or polling)
        2. Handle breakeven moves
        3. Handle time stops
        4. Handle external closes
        """
        ...

    @abstractmethod
    def get_config(self) -> dict:
        """Return strategy-specific configuration dict."""
        ...

    # ── Default loops (can be overridden) ────────────────────

    async def scan_loop(self):
        """Periodic scan loop."""
        logger.info(f"{self.name} scanner started (interval={self.scan_interval}s)")
        await asyncio.sleep(5)  # Let bot initialize first

        while True:
            try:
                signals = await self.scan()
                for signal in signals:
                    await self.on_signal(signal)
            except Exception as e:
                logger.error(f"{self.name} scan error: {e}", exc_info=True)

            await asyncio.sleep(self.scan_interval)

    async def monitor_loop(self):
        """Position monitoring loop."""
        logger.info(f"{self.name} monitor started (check every {self.check_interval}s)")

        while True:
            try:
                positions = self.tracker.by_strategy(self.name)
                for pos in positions:
                    await self.check_position(pos)
            except Exception as e:
                logger.error(f"{self.name} monitor error: {e}", exc_info=True)

            await asyncio.sleep(self.check_interval)

    def format_positions(self) -> str:
        """Format strategy positions for Telegram display."""
        positions = self.tracker.by_strategy(self.name)
        if not positions:
            return f"No {self.name} positions"

        lines = [f"šŸ“Š {self.name} positions:\n━━━━━━━━━━━━━━━━━━━━"]
        for pos in positions:
            price = self.client.get_mark_price(pos.symbol) or pos.entry_price
            pnl = pos.unrealized_pnl_pct(price)
            d = "L" if pos.is_long else "S"
            emoji = "🟢" if pnl >= 0 else "šŸ”“"
            lines.append(
                f"{emoji} {d} {pos.symbol} | {pnl:+.2f}% | ${price:.4f} | {pos.age_minutes:.0f}min"
            )
        lines.append("━━━━━━━━━━━━━━━━━━━━")
        return "\n".join(lines)