← Back
"""
Position — dataclass and tracker for managing open positions.
"""

import logging
from dataclasses import dataclass, field
from datetime import datetime, timezone, timedelta

logger = logging.getLogger(__name__)

VANCOUVER_TZ = timezone(timedelta(hours=-7))


@dataclass
class Position:
    """Generic trading position — used by all strategies."""
    symbol: str
    side: str  # "BUY" or "SELL"
    entry_price: float
    quantity: float
    sl_price: float
    tp_prices: list[float]  # Can be single or multiple TP levels
    tp_quantities: list[float]  # Qty per TP level
    opened_at: datetime
    trade_id: str
    strategy: str  # Strategy name (e.g. "FUND", "SCALP")

    # TP tracking
    tp_hit: list[bool] = field(default_factory=list)
    remaining_qty: float = 0.0

    # Exchange order IDs
    sl_order_id: str | None = None
    tp_order_ids: list[str | None] = field(default_factory=list)
    use_exchange_orders: bool = False
    moved_to_be: bool = False

    # Optional signal metadata
    signal_data: dict = field(default_factory=dict)

    def __post_init__(self):
        if not self.tp_hit:
            self.tp_hit = [False] * len(self.tp_prices)
        if not self.tp_order_ids:
            self.tp_order_ids = [None] * len(self.tp_prices)
        if self.remaining_qty == 0:
            self.remaining_qty = self.quantity

    @property
    def age_minutes(self) -> float:
        now = datetime.now(VANCOUVER_TZ)
        return (now - self.opened_at).total_seconds() / 60

    @property
    def is_long(self) -> bool:
        return self.side.upper() == "BUY"

    def unrealized_pnl_pct(self, current_price: float) -> float:
        """Calculate unrealized PnL percentage."""
        if self.is_long:
            return ((current_price - self.entry_price) / self.entry_price) * 100
        else:
            return ((self.entry_price - current_price) / self.entry_price) * 100


class PositionTracker:
    """Manages dict of active positions across strategies."""

    def __init__(self):
        self.positions: dict[str, Position] = {}  # symbol → Position

    def add(self, pos: Position):
        self.positions[pos.symbol] = pos

    def remove(self, symbol: str):
        self.positions.pop(symbol, None)

    def get(self, symbol: str) -> Position | None:
        return self.positions.get(symbol)

    def has(self, symbol: str) -> bool:
        return symbol in self.positions

    def count(self, strategy: str = "") -> int:
        if not strategy:
            return len(self.positions)
        return sum(1 for p in self.positions.values() if p.strategy == strategy)

    def symbols(self) -> set[str]:
        return set(self.positions.keys())

    def all(self) -> list[Position]:
        return list(self.positions.values())

    def by_strategy(self, strategy: str) -> list[Position]:
        return [p for p in self.positions.values() if p.strategy == strategy]