← Back
"""
Zatochki (Knife Catcher) — Trade Manager
==========================================
Signal → immediate entry (no watchlist — knife catching is time-sensitive).
TP1 partial (50% at 0.7%) + trailing (0.5% callback) on rest.
Dynamic SL from spike wick, capped at 1.2%.

TMM description: vol_spike, vwap_dist, rsi, sl_pct, natr_5m, vol_24h, oi_change
"""

import json
import time
import logging
import tempfile
import os
from datetime import datetime, timezone

from src.zatochki_config import (
    TRADE_SIZE_USD, LEVERAGE, MAX_POSITIONS,
    TP1_PCT, TP1_CLOSE_RATIO, TRAIL_CALLBACK_PCT,
    SL_CAP_PCT, MAX_BARS_IN_TRADE, TAKER_FEE,
    Z_POSITIONS_FILE, Z_TRADE_LOG_FILE,
    TMM_TAG, TELEGRAM_PREFIX, STRATEGY_NAME,
)

logger = logging.getLogger("zatochki.manager")


class ZatochkiPosition:
    """One open Zatochki position."""

    def __init__(self, symbol, side, entry_price, qty, sl_price,
                 symbol_info, signal_data, opened_at=None):
        self.symbol = symbol
        self.side = side  # "LONG" or "SHORT"
        self.entry_price = entry_price
        self.qty = qty
        self.sl_price = sl_price
        self.symbol_info = symbol_info
        self.signal_data = signal_data  # full signal dict for TMM logging
        self.opened_at = opened_at or datetime.now(timezone.utc).isoformat()
        self.sl_order_placed = False
        self.tp_order_placed = False
        # TP1 partial + trailing
        self.original_qty = qty
        self.tp1_hit = False
        self.trail_high = 0.0
        self.entry_bar = int(time.time())  # for timeout check

    def to_dict(self):
        return {
            "symbol": self.symbol,
            "side": self.side,
            "entry_price": self.entry_price,
            "qty": self.qty,
            "original_qty": self.original_qty,
            "sl_price": self.sl_price,
            "signal_data": self.signal_data,
            "opened_at": self.opened_at,
            "sl_order_placed": self.sl_order_placed,
            "tp_order_placed": self.tp_order_placed,
            "tp1_hit": self.tp1_hit,
            "trail_high": self.trail_high,
            "entry_bar": self.entry_bar,
        }

    @classmethod
    def from_dict(cls, d, exchange=None):
        pos = cls(
            symbol=d["symbol"], side=d["side"],
            entry_price=d["entry_price"], qty=d["qty"],
            sl_price=d["sl_price"], symbol_info=None,
            signal_data=d.get("signal_data", {}),
            opened_at=d.get("opened_at"),
        )
        pos.original_qty = d.get("original_qty", d["qty"])
        pos.sl_order_placed = d.get("sl_order_placed", False)
        pos.tp_order_placed = d.get("tp_order_placed", False)
        pos.tp1_hit = d.get("tp1_hit", False)
        pos.trail_high = d.get("trail_high", 0)
        pos.entry_bar = d.get("entry_bar", int(time.time()))

        if exchange:
            try:
                pos.symbol_info = exchange.get_symbol_info(pos.symbol)
            except Exception:
                pass

        return pos


class ZatochkiManager:
    """Trade manager for Zatochki strategy."""

    def __init__(self, exchange, screener, notifier=None, tmm=None):
        self.exchange = exchange
        self.screener = screener
        self.notifier = notifier
        self.tmm = tmm
        self.positions = {}  # symbol → ZatochkiPosition
        self._load_positions()

        # Wire up screener callback
        self.screener.get_open_positions = lambda: self.positions

    # ============================================================
    # PROCESS SIGNALS → OPEN POSITIONS
    # ============================================================

    def process_signals(self, signals):
        """Process signals from screener — immediate entry."""
        if not signals:
            return

        if len(self.positions) >= MAX_POSITIONS:
            return

        for sig in signals:
            symbol = sig["symbol"]

            if symbol in self.positions:
                continue
            if len(self.positions) >= MAX_POSITIONS:
                break

            try:
                self._open_position(sig)
                time.sleep(0.2)
            except Exception as e:
                logger.error(f"Open failed {symbol}: {e}")

    def _open_position(self, signal):
        """Open position based on signal."""
        symbol = signal["symbol"]
        side = signal["direction"]

        # Safety check: no existing position on exchange
        exchange_positions = self.exchange.get_positions()
        for ep in exchange_positions:
            if ep["symbol"] == symbol and float(ep.get("positionAmt", 0)) != 0:
                logger.warning(f"Already have position on {symbol}, skipping")
                return

        sym_info = self._get_symbol_info(symbol)
        if not sym_info:
            return

        actual_leverage = self.exchange.set_leverage(symbol) or LEVERAGE
        self.exchange.set_margin_type(symbol)

        # Calculate qty
        mark_price = self.exchange.get_mark_price(symbol)
        target_notional = TRADE_SIZE_USD * LEVERAGE
        qty = target_notional / mark_price
        qty = self.exchange.round_qty(sym_info, qty)

        if sym_info.get("min_qty") and qty < sym_info["min_qty"]:
            logger.warning(f"Qty {qty} below min for {symbol}")
            return

        # Market order
        order_side = "BUY" if side == "LONG" else "SELL"
        order, fill_price = self.exchange.open_market(symbol, order_side, qty)

        if fill_price == 0:
            logger.error(f"Fill price 0 for {symbol}!")
            try:
                close_side = "SELL" if order_side == "BUY" else "BUY"
                self.exchange.close_position(symbol, close_side, qty)
            except Exception:
                pass
            return

        # SL from signal (dynamic + capped)
        sl_price = signal["sl_price"]

        # Recalculate SL based on fill price (not signal entry_price)
        if side == "LONG":
            sl_pct = (fill_price - sl_price) / fill_price
            if sl_pct > SL_CAP_PCT:
                sl_price = fill_price * (1 - SL_CAP_PCT)
            tp1_price = fill_price * (1 + TP1_PCT)
        else:
            sl_pct = (sl_price - fill_price) / fill_price
            if sl_pct > SL_CAP_PCT:
                sl_price = fill_price * (1 + SL_CAP_PCT)
            tp1_price = fill_price * (1 - TP1_PCT)

        pos = ZatochkiPosition(
            symbol=symbol, side=side,
            entry_price=fill_price, qty=qty,
            sl_price=sl_price, symbol_info=sym_info,
            signal_data=signal,
        )

        # TP1 qty = 50%
        tp1_qty = self.exchange.round_qty(sym_info, qty * TP1_CLOSE_RATIO)

        # Place SL (full) + TP1 (partial)
        self._place_sl_tp(pos, tp1_price, tp1_qty)

        self.positions[symbol] = pos
        self._save_positions()

        # Telegram notification
        sl_pct_actual = abs(fill_price - sl_price) / fill_price * 100
        msg = (
            f"{TELEGRAM_PREFIX} *{side} {symbol}*\n"
            f"Entry: {fill_price}\n"
            f"SL: {sl_price:.6f} (\\-{sl_pct_actual:.1f}%)\n"
            f"TP1: {tp1_price:.6f} (\\+{TP1_PCT*100:.1f}%) \\→ 50% close\n"
            f"Trail: {TRAIL_CALLBACK_PCT*100}% callback\n"
            f"Vol spike: {signal['vol_spike_ratio']}x | "
            f"VWAP: {signal['vwap_dist_pct']}%\n"
            f"RSI: {signal['rsi']} | "
            f"OI: {signal.get('oi_change_pct', 'N/A')}%\n"
            f"Qty: {qty} (${target_notional:.0f}, {actual_leverage}x)"
        )
        logger.info(f"OPEN {side} {symbol} vol={signal['vol_spike_ratio']}x rsi={signal['rsi']}")
        self._notify(msg)

        # TMM: tag trade
        if self.tmm:
            try:
                self._tmm_tag_trade(symbol, side, signal)
            except Exception as te:
                logger.warning(f"TMM tag error {symbol}: {te}")

    def _place_sl_tp(self, pos, tp1_price, tp1_qty):
        """Place SL and TP1 orders on exchange."""
        sym_info = pos.symbol_info or self._get_symbol_info(pos.symbol)
        close_side = "SELL" if pos.side == "LONG" else "BUY"

        # SL on full qty
        try:
            self.exchange.place_sl(pos.symbol, close_side, pos.qty, pos.sl_price, sym_info)
            pos.sl_order_placed = True
        except Exception as e:
            logger.error(f"SL failed {pos.symbol}: {e}")

        # TP1 on partial qty
        if not pos.tp1_hit:
            try:
                self.exchange.place_tp(pos.symbol, close_side, tp1_qty, tp1_price, sym_info)
                pos.tp_order_placed = True
            except Exception as e:
                logger.error(f"TP1 failed {pos.symbol}: {e}")

    # ============================================================
    # MONITOR POSITIONS
    # ============================================================

    def check_positions(self):
        """
        Called every 5 sec:
        1. Position closed? → log
        2. TP1 partial fill → SL→BE, start trailing
        3. Trailing: track peak, callback → close rest
        4. Timeout
        """
        if not self.positions:
            return

        exchange_positions = self.exchange.get_positions()
        exchange_map = {p["symbol"]: p for p in exchange_positions}

        to_remove = []

        for symbol, pos in list(self.positions.items()):
            try:
                ep = exchange_map.get(symbol)
                current_qty = abs(float(ep.get("positionAmt", 0))) if ep else 0

                # === 1. Position fully closed ===
                if current_qty == 0:
                    result = self._determine_close_result(pos)
                    self._log_trade(pos, result)
                    self.screener.set_cooldown(symbol)
                    to_remove.append(symbol)
                    continue

                # === 2. TP1 detection (qty dropped ~50%) ===
                if not pos.tp1_hit and current_qty < pos.original_qty * 0.8:
                    pos.tp1_hit = True
                    pos.qty = current_qty

                    # Move SL to BE
                    pos.sl_price = pos.entry_price
                    mark_price = self.exchange.get_mark_price(symbol)
                    pos.trail_high = mark_price

                    # Cancel old orders, place new SL at BE
                    self.exchange.cancel_all_orders(symbol)
                    sym_info = pos.symbol_info or self._get_symbol_info(symbol)
                    close_side = "SELL" if pos.side == "LONG" else "BUY"

                    try:
                        self.exchange.place_sl(symbol, close_side, pos.qty, pos.sl_price, sym_info)
                        pos.sl_order_placed = True
                    except Exception as e:
                        logger.error(f"BE SL failed {symbol}: {e}")

                    self._save_positions()
                    logger.info(f"TP1 hit {symbol} — SL→BE, trailing started")
                    self._notify(
                        f"\U0001f3af *TP1 {pos.side} {symbol}*\n"
                        f"50% closed at +{TP1_PCT*100:.1f}%\n"
                        f"SL \\→ BE | Trailing {TRAIL_CALLBACK_PCT*100}%"
                    )
                    continue

                # === 3. Trailing stop (after TP1) ===
                if pos.tp1_hit:
                    mark_price = self.exchange.get_mark_price(symbol)

                    if pos.side == "LONG":
                        if mark_price > pos.trail_high:
                            pos.trail_high = mark_price
                        trail_sl = pos.trail_high * (1 - TRAIL_CALLBACK_PCT)
                        if mark_price <= trail_sl:
                            self._close_trailing(pos, mark_price, "TRAIL")
                            to_remove.append(symbol)
                            continue
                    else:
                        if mark_price < pos.trail_high or pos.trail_high == 0:
                            pos.trail_high = mark_price
                        trail_sl = pos.trail_high * (1 + TRAIL_CALLBACK_PCT)
                        if mark_price >= trail_sl:
                            self._close_trailing(pos, mark_price, "TRAIL")
                            to_remove.append(symbol)
                            continue

                # === 4. Timeout (MAX_BARS_IN_TRADE minutes) ===
                elapsed_sec = time.time() - pos.entry_bar
                if elapsed_sec > MAX_BARS_IN_TRADE * 60:
                    mark_price = self.exchange.get_mark_price(symbol)
                    if pos.tp1_hit:
                        self._close_trailing(pos, mark_price, "TIMEOUT")
                    else:
                        self._close_position_market(pos, "TIMEOUT", mark_price)
                    to_remove.append(symbol)
                    continue

            except Exception as e:
                logger.error(f"Check error {symbol}: {e}")

        for symbol in to_remove:
            self.positions.pop(symbol, None)

        if to_remove:
            self._save_positions()

    # ============================================================
    # CLOSE HELPERS
    # ============================================================

    def _close_trailing(self, pos, mark_price, reason="TRAIL"):
        """Close remaining position after TP1."""
        try:
            close_side = "SELL" if pos.side == "LONG" else "BUY"
            self.exchange.cancel_all_orders(pos.symbol)
            fill_price = self.exchange.close_position(pos.symbol, close_side, pos.qty)

            if pos.side == "LONG":
                trail_pnl_pct = (fill_price - pos.entry_price) / pos.entry_price * 100
            else:
                trail_pnl_pct = (pos.entry_price - fill_price) / pos.entry_price * 100

            remaining_ratio = 1 - TP1_CLOSE_RATIO
            trail_pnl_usd = TRADE_SIZE_USD * LEVERAGE * remaining_ratio * (trail_pnl_pct / 100)
            tp1_pnl_usd = TRADE_SIZE_USD * LEVERAGE * TP1_CLOSE_RATIO * TP1_PCT
            total_pnl_usd = tp1_pnl_usd + trail_pnl_usd
            total_pnl_pct = TP1_PCT * 100 * TP1_CLOSE_RATIO + trail_pnl_pct * remaining_ratio

            trade = self._build_trade_dict(pos, reason, total_pnl_pct, total_pnl_usd, fill_price)
            trade["tp1_hit"] = True
            self._append_trade_log(trade)

            emoji = "\U0001f3c3" if total_pnl_usd > 0 else "\u26a0\ufe0f"
            msg = (
                f"{emoji} *{reason} {pos.side} {pos.symbol}*\n"
                f"TP1: +{TP1_PCT*100:.1f}% (${tp1_pnl_usd:+.2f})\n"
                f"Rest: {trail_pnl_pct:+.1f}% (${trail_pnl_usd:+.2f})\n"
                f"*Total: ${total_pnl_usd:+.2f}*\n"
                f"Peak: {pos.trail_high:.6f} \\→ Exit: {fill_price:.6f}"
            )
            logger.info(f"{reason} {pos.side} {pos.symbol} total=${total_pnl_usd:+.2f}")
            self._notify(msg)
        except Exception as e:
            logger.error(f"Trail close failed {pos.symbol}: {e}")
            self._log_trade(pos, f"{reason}_ERROR")

    def _close_position_market(self, pos, reason, mark_price):
        """Close full position via market (timeout, etc.)."""
        try:
            close_side = "SELL" if pos.side == "LONG" else "BUY"
            self.exchange.cancel_all_orders(pos.symbol)
            fill_price = self.exchange.close_position(pos.symbol, close_side, pos.qty)

            if pos.side == "LONG":
                pnl_pct = (fill_price - pos.entry_price) / pos.entry_price * 100
            else:
                pnl_pct = (pos.entry_price - fill_price) / pos.entry_price * 100

            pnl_usd = TRADE_SIZE_USD * LEVERAGE * (pnl_pct / 100)

            trade = self._build_trade_dict(pos, reason, pnl_pct, pnl_usd, fill_price)
            self._append_trade_log(trade)

            emoji = "\u23f1\ufe0f"
            msg = (
                f"{emoji} *{reason} {pos.side} {pos.symbol}*\n"
                f"Entry: {pos.entry_price} \\→ {fill_price}\n"
                f"PnL: {pnl_pct:+.1f}% (${pnl_usd:+.2f})"
            )
            logger.info(f"{reason} {pos.side} {pos.symbol} PnL={pnl_pct:+.1f}%")
            self._notify(msg)
        except Exception as e:
            logger.error(f"Market close failed {pos.symbol}: {e}")
            self._log_trade(pos, reason)

    def _determine_close_result(self, pos):
        """Determine if closed by SL or TP."""
        try:
            open_orders = self.exchange.get_open_orders(pos.symbol)
            has_tp = any(o.get("type") == "LIMIT" for o in open_orders)
            if has_tp:
                return "SL"

            try:
                mark = self.exchange.get_mark_price(pos.symbol)
                dist_tp = abs(mark - pos.entry_price * (1 + TP1_PCT if pos.side == "LONG" else 1 - TP1_PCT))
                dist_sl = abs(mark - pos.sl_price)
                return "TP" if dist_tp < dist_sl else "SL"
            except Exception:
                return "TP"
        except Exception:
            return "UNKNOWN"

    # ============================================================
    # TRADE LOG
    # ============================================================

    def _log_trade(self, pos, result):
        """Log trade with estimated PnL."""
        sig = pos.signal_data or {}

        if pos.tp1_hit:
            tp1_pnl_usd = TRADE_SIZE_USD * LEVERAGE * TP1_CLOSE_RATIO * TP1_PCT
            remaining_ratio = 1 - TP1_CLOSE_RATIO
            if result == "SL":
                rest_pnl_usd = -(TRADE_SIZE_USD * LEVERAGE * remaining_ratio * SL_CAP_PCT)
                total_pnl_usd = tp1_pnl_usd + rest_pnl_usd
                total_pnl_pct = TP1_PCT * 100 * TP1_CLOSE_RATIO - SL_CAP_PCT * 100 * remaining_ratio
                result = "TP1+SL"
            else:
                total_pnl_usd = tp1_pnl_usd
                total_pnl_pct = TP1_PCT * 100 * TP1_CLOSE_RATIO
                result = "TP1+BE"
        elif result == "TP":
            # Fast TP1 detected (exchange closed TP1+SL before we saw it)
            total_pnl_pct = TP1_PCT * 100 * TP1_CLOSE_RATIO
            total_pnl_usd = TRADE_SIZE_USD * LEVERAGE * TP1_CLOSE_RATIO * TP1_PCT
            result = "TP1+BE (fast)"
            logger.info(f"Fast TP1 detected {pos.symbol}")
        elif result == "SL":
            sl_pct = abs(pos.entry_price - pos.sl_price) / pos.entry_price
            total_pnl_pct = -sl_pct * 100
            total_pnl_usd = -(TRADE_SIZE_USD * LEVERAGE * sl_pct)
        else:
            total_pnl_pct = 0
            total_pnl_usd = 0

        trade = self._build_trade_dict(pos, result, total_pnl_pct, total_pnl_usd)
        self._append_trade_log(trade)

        emoji = "\u2705" if "TP" in result else "\u274c" if result == "SL" else "\u26a0\ufe0f"
        msg = (
            f"{emoji} *{result} {pos.side} {pos.symbol}*\n"
            f"Entry: {pos.entry_price}\n"
            f"PnL: {total_pnl_pct:+.1f}% (${total_pnl_usd:+.2f})\n"
            f"Vol: {sig.get('vol_spike_ratio', '?')}x | RSI: {sig.get('rsi', '?')}"
        )
        logger.info(f"{result} {pos.side} {pos.symbol} PnL={total_pnl_pct:+.1f}%")
        self._notify(msg)

    def _build_trade_dict(self, pos, result, pnl_pct, pnl_usd, close_price=None):
        """Build trade dict for log + TMM."""
        sig = pos.signal_data or {}
        return {
            "symbol": pos.symbol,
            "side": pos.side,
            "strategy": STRATEGY_NAME,
            "entry_price": pos.entry_price,
            "close_price": close_price,
            "sl_price": pos.sl_price,
            "qty": pos.original_qty,
            "result": result,
            "pnl_pct": round(pnl_pct, 2),
            "pnl_usd": round(pnl_usd, 2),
            "tp1_hit": pos.tp1_hit,
            "trail_high": pos.trail_high,
            # Signal params for TMM description
            "vol_spike_ratio": sig.get("vol_spike_ratio"),
            "vwap_dist_pct": sig.get("vwap_dist_pct"),
            "rsi": sig.get("rsi"),
            "sl_pct": sig.get("sl_pct"),
            "natr_5m": sig.get("natr_5m"),
            "volume_24h": sig.get("volume_24h"),
            "oi_change_pct": sig.get("oi_change_pct"),
            "opened_at": pos.opened_at,
            "closed_at": datetime.now(timezone.utc).isoformat(),
        }

    def _append_trade_log(self, trade):
        """Atomic append to trade log."""
        try:
            with open(Z_TRADE_LOG_FILE, "r") as f:
                log = json.load(f)
        except Exception:
            log = []

        log.append(trade)
        try:
            dir_name = os.path.dirname(Z_TRADE_LOG_FILE)
            fd, tmp_path = tempfile.mkstemp(dir=dir_name, suffix=".tmp")
            with os.fdopen(fd, "w") as f:
                json.dump(log, f, indent=2)
            os.replace(tmp_path, Z_TRADE_LOG_FILE)
        except Exception as e:
            logger.error(f"Error saving trade log: {e}")

    # ============================================================
    # TMM INTEGRATION
    # ============================================================

    def _tmm_tag_trade(self, symbol, side, signal):
        """Tag trade in TMM with Zatochki tag + full signal params in description."""
        if not self.tmm or not self.tmm.enabled:
            return

        time.sleep(5)  # Wait for TMM to import from Bybit

        order_side = "BUY" if side == "LONG" else "SELL"
        trade_id = self.tmm.find_recent_trade(symbol, order_side)

        if not trade_id:
            logger.warning(f"TMM: trade not found {symbol} {side}, will retry")
            self.tmm._pending_tags.append({
                "symbol": symbol, "side": order_side,
                "score": 0, "z_score": 0,
                "reasons": [TMM_TAG],
                "attempts": 1, "next_retry": time.time() + 15,
                "_zatochki_signal": signal,
            })
            return

        self._apply_tmm_tags(trade_id, signal)

    def _apply_tmm_tags(self, trade_id, signal):
        """Apply Zatochki tag + detailed description to TMM trade."""
        self.tmm.tag_trade(trade_id, TMM_TAG)

        # Description with ALL params for future optimization
        vol_24h_m = signal.get("volume_24h", 0) / 1e6 if signal.get("volume_24h") else 0
        desc = (
            f"Zatochki Bot\n"
            f"vol_spike: {signal.get('vol_spike_ratio', '?')}x | "
            f"vwap_dist: {signal.get('vwap_dist_pct', '?')}% | "
            f"rsi: {signal.get('rsi', '?')}\n"
            f"sl_pct: {signal.get('sl_pct', '?')}% | "
            f"natr_5m: {signal.get('natr_5m', '?')}% | "
            f"vol_24h: ${vol_24h_m:.0f}M\n"
            f"oi_change: {signal.get('oi_change_pct', 'N/A')}%"
        )
        self.tmm.update_description(trade_id, desc)

    # ============================================================
    # POSITIONS PERSISTENCE
    # ============================================================

    def _save_positions(self):
        """Save positions to file."""
        data = {s: p.to_dict() for s, p in self.positions.items()}
        try:
            dir_name = os.path.dirname(Z_POSITIONS_FILE)
            os.makedirs(dir_name, exist_ok=True)
            fd, tmp_path = tempfile.mkstemp(dir=dir_name, suffix=".tmp")
            with os.fdopen(fd, "w") as f:
                json.dump(data, f, indent=2)
            os.replace(tmp_path, Z_POSITIONS_FILE)
        except Exception as e:
            logger.error(f"Save positions error: {e}")

    def _load_positions(self):
        """Load positions from file."""
        try:
            with open(Z_POSITIONS_FILE, "r") as f:
                data = json.load(f)
            for symbol, d in data.items():
                self.positions[symbol] = ZatochkiPosition.from_dict(d, self.exchange)
            if self.positions:
                logger.info(f"Loaded {len(self.positions)} Zatochki positions")
        except (FileNotFoundError, json.JSONDecodeError):
            self.positions = {}

    # ============================================================
    # RECOVERY (after restart)
    # ============================================================

    def recovery(self):
        """Check saved positions vs exchange after restart."""
        if not self.positions:
            return

        exchange_positions = self.exchange.get_positions()
        exchange_map = {p["symbol"]: p for p in exchange_positions}

        to_remove = []
        for symbol, pos in self.positions.items():
            ep = exchange_map.get(symbol)
            current_qty = abs(float(ep.get("positionAmt", 0))) if ep else 0

            if current_qty == 0:
                logger.info(f"Recovery: {symbol} closed while offline")
                self._log_trade(pos, "OFFLINE_CLOSE")
                to_remove.append(symbol)
                continue

            # Detect TP1 (qty dropped)
            if current_qty < pos.original_qty * 0.8 and not pos.tp1_hit:
                pos.tp1_hit = True
                pos.qty = current_qty
                pos.sl_price = pos.entry_price
                mark_price = self.exchange.get_mark_price(symbol)
                pos.trail_high = mark_price
                logger.info(f"Recovery: {symbol} TP1 detected, SL→BE")

            # Re-place orders
            sym_info = pos.symbol_info or self._get_symbol_info(symbol)
            pos.symbol_info = sym_info
            self.exchange.cancel_all_orders(symbol)

            close_side = "SELL" if pos.side == "LONG" else "BUY"
            try:
                self.exchange.place_sl(symbol, close_side, pos.qty, pos.sl_price, sym_info)
                pos.sl_order_placed = True
            except Exception as e:
                logger.error(f"Recovery SL failed {symbol}: {e}")

            if not pos.tp1_hit:
                tp1_price = pos.entry_price * (1 + TP1_PCT) if pos.side == "LONG" else pos.entry_price * (1 - TP1_PCT)
                tp1_qty = self.exchange.round_qty(sym_info, pos.qty * TP1_CLOSE_RATIO)
                try:
                    self.exchange.place_tp(symbol, close_side, tp1_qty, tp1_price, sym_info)
                    pos.tp_order_placed = True
                except Exception as e:
                    logger.error(f"Recovery TP failed {symbol}: {e}")

        for symbol in to_remove:
            self.positions.pop(symbol, None)

        self._save_positions()

    # ============================================================
    # HELPERS
    # ============================================================

    def _get_symbol_info(self, symbol):
        """Get symbol trading info (tick size, qty step, etc.)."""
        try:
            return self.exchange.get_symbol_info(symbol)
        except Exception as e:
            logger.error(f"Symbol info error {symbol}: {e}")
            return None

    def _notify(self, msg):
        """Send Telegram notification."""
        if self.notifier:
            self.notifier(msg)

    # ============================================================
    # STATUS (for Telegram commands)
    # ============================================================

    def get_positions_info(self):
        """Get formatted positions info for /zstatus command."""
        if not self.positions:
            return "No open Zatochki positions"

        lines = [f"{TELEGRAM_PREFIX} *Zatochki Positions ({len(self.positions)})*\n"]

        for symbol, pos in self.positions.items():
            try:
                mark = self.exchange.get_mark_price(symbol)
                if pos.side == "LONG":
                    pnl_pct = (mark - pos.entry_price) / pos.entry_price * 100
                else:
                    pnl_pct = (pos.entry_price - mark) / pos.entry_price * 100

                tp1_flag = "\U0001f3af" if pos.tp1_hit else ""
                elapsed = (time.time() - pos.entry_bar) / 60

                sig = pos.signal_data or {}
                lines.append(
                    f"{'\\-' if pos.side == 'SHORT' else '\\+'} "
                    f"`{symbol}` {pos.side} {tp1_flag}\n"
                    f"  PnL: {pnl_pct:+.1f}% | {elapsed:.0f}min\n"
                    f"  vol: {sig.get('vol_spike_ratio', '?')}x rsi: {sig.get('rsi', '?')}"
                )
            except Exception:
                lines.append(f"`{symbol}` {pos.side} (error getting price)")

        return "\n".join(lines)

    def get_stats(self):
        """Get trade stats for /zstats command."""
        try:
            with open(Z_TRADE_LOG_FILE, "r") as f:
                trades = json.load(f)
        except Exception:
            return "No trade history"

        if not trades:
            return "No trades yet"

        n = len(trades)
        wins = [t for t in trades if t.get("pnl_usd", 0) > 0]
        losses = [t for t in trades if t.get("pnl_usd", 0) <= 0]
        wr = len(wins) / n * 100
        total_pnl = sum(t.get("pnl_usd", 0) for t in trades)

        return (
            f"{TELEGRAM_PREFIX} *Zatochki Stats*\n"
            f"Trades: {n} | WR: {wr:.0f}%\n"
            f"Wins: {len(wins)} | Losses: {len(losses)}\n"
            f"PnL: ${total_pnl:+.2f}"
        )

📜 Git History

c6f6bd5chore: initial commit — version control setup5 weeks ago
Show last diff
Loading...