"""
Position — dataclass and tracker for managing open positions.
"""
import logging
from dataclasses import dataclass, field
from datetime import datetime, timezone, timedelta
logger = logging.getLogger(__name__)
VANCOUVER_TZ = timezone(timedelta(hours=-7))
@dataclass
class Position:
"""Generic trading position — used by all strategies."""
symbol: str
side: str # "BUY" or "SELL"
entry_price: float
quantity: float
sl_price: float
tp_prices: list[float] # Can be single or multiple TP levels
tp_quantities: list[float] # Qty per TP level
opened_at: datetime
trade_id: str
strategy: str # Strategy name (e.g. "FUND", "SCALP")
# TP tracking
tp_hit: list[bool] = field(default_factory=list)
remaining_qty: float = 0.0
# Exchange order IDs
sl_order_id: str | None = None
tp_order_ids: list[str | None] = field(default_factory=list)
use_exchange_orders: bool = False
moved_to_be: bool = False
# Optional signal metadata
signal_data: dict = field(default_factory=dict)
def __post_init__(self):
if not self.tp_hit:
self.tp_hit = [False] * len(self.tp_prices)
if not self.tp_order_ids:
self.tp_order_ids = [None] * len(self.tp_prices)
if self.remaining_qty == 0:
self.remaining_qty = self.quantity
@property
def age_minutes(self) -> float:
now = datetime.now(VANCOUVER_TZ)
return (now - self.opened_at).total_seconds() / 60
@property
def is_long(self) -> bool:
return self.side.upper() == "BUY"
def unrealized_pnl_pct(self, current_price: float) -> float:
"""Calculate unrealized PnL percentage."""
if self.is_long:
return ((current_price - self.entry_price) / self.entry_price) * 100
else:
return ((self.entry_price - current_price) / self.entry_price) * 100
class PositionTracker:
"""Manages dict of active positions across strategies."""
def __init__(self):
self.positions: dict[str, Position] = {} # symbol → Position
def add(self, pos: Position):
self.positions[pos.symbol] = pos
def remove(self, symbol: str):
self.positions.pop(symbol, None)
def get(self, symbol: str) -> Position | None:
return self.positions.get(symbol)
def has(self, symbol: str) -> bool:
return symbol in self.positions
def count(self, strategy: str = "") -> int:
if not strategy:
return len(self.positions)
return sum(1 for p in self.positions.values() if p.strategy == strategy)
def symbols(self) -> set[str]:
return set(self.positions.keys())
def all(self) -> list[Position]:
return list(self.positions.values())
def by_strategy(self, strategy: str) -> list[Position]:
return [p for p in self.positions.values() if p.strategy == strategy]