Browse Source

Add order recovery and file logging

Lukas Goldschmidt 1 month ago
parent
commit
57ae87e344

+ 7 - 0
src/trader_mcp/exec_client.py

@@ -38,6 +38,13 @@ def list_open_orders(account_id: str, client_id: str | None = None) -> Any:
     return _mcp.call_tool("get_open_orders", args)
 
 
+def query_order(account_id: str, order_id: str, client_id: str | None = None) -> Any:
+    args: dict[str, Any] = {"account_id": account_id, "order_id": order_id}
+    if client_id is not None:
+        args["client_id"] = client_id
+    return _mcp.call_tool("query_order", args)
+
+
 def get_account_info(account_id: str) -> Any:
     return _mcp.call_tool("get_account_info", {"account_id": account_id})
 

+ 30 - 0
src/trader_mcp/logging_utils.py

@@ -0,0 +1,30 @@
+from __future__ import annotations
+
+import logging
+from functools import lru_cache
+from pathlib import Path
+
+_LOG_PATH = Path(__file__).resolve().parents[2] / "logs" / "trader_mcp.log"
+
+
+@lru_cache(maxsize=1)
+def get_trader_logger() -> logging.Logger:
+    logger = logging.getLogger("trader_mcp")
+    if logger.handlers:
+        return logger
+
+    _LOG_PATH.parent.mkdir(parents=True, exist_ok=True)
+    logger.setLevel(logging.INFO)
+    logger.propagate = False
+
+    formatter = logging.Formatter("%(asctime)s %(levelname)s %(message)s")
+
+    file_handler = logging.FileHandler(_LOG_PATH)
+    file_handler.setFormatter(formatter)
+    logger.addHandler(file_handler)
+
+    return logger
+
+
+def log_event(source: str, message: str) -> None:
+    get_trader_logger().info("%s %s", source, message)

+ 4 - 1
src/trader_mcp/strategy_context.py

@@ -3,7 +3,7 @@ from __future__ import annotations
 from dataclasses import dataclass, field
 from typing import Any
 
-from .exec_client import list_open_orders, cancel_all_orders, cancel_order, place_order, get_account_info
+from .exec_client import list_open_orders, query_order, cancel_all_orders, cancel_order, place_order, get_account_info
 from .news_client import call_news_tool
 from .crypto_client import call_crypto_tool
 
@@ -30,6 +30,9 @@ class StrategyContext:
             return payload["orders"]
         return payload
 
+    def query_order(self, order_id: str) -> Any:
+        return query_order(self.account_id, order_id, self.client_id)
+
     def cancel_all_orders(self) -> Any:
         return cancel_all_orders(self.account_id, self.client_id)
 

+ 116 - 22
strategies/grid_trader.py

@@ -4,6 +4,7 @@ import time
 from datetime import datetime, timezone
 
 from src.trader_mcp.strategy_sdk import Strategy
+from src.trader_mcp.logging_utils import log_event
 
 
 class Strategy(Strategy):
@@ -43,6 +44,8 @@ class Strategy(Strategy):
         "regimes_updated_at": {"type": "string", "default": ""},
         "account_snapshot_updated_at": {"type": "string", "default": ""},
         "grid_refresh_pending_until": {"type": "string", "default": ""},
+        "mismatch_ticks": {"type": "int", "default": 0},
+        "recovery_cooldown_until": {"type": "string", "default": ""},
     }
 
     def init(self):
@@ -61,6 +64,8 @@ class Strategy(Strategy):
             "regimes_updated_at": "",
             "account_snapshot_updated_at": "",
             "grid_refresh_pending_until": "",
+            "mismatch_ticks": 0,
+            "recovery_cooldown_until": "",
         }
 
     def _log(self, message: str) -> None:
@@ -69,6 +74,7 @@ class Strategy(Strategy):
         log.append(message)
         state["debug_log"] = log[-12:]
         self.state = state
+        log_event("grid", message)
 
     def _set_grid_refresh_pause(self, seconds: float = 30.0) -> None:
         self.state["grid_refresh_pending_until"] = (datetime.now(timezone.utc).timestamp() + max(seconds, 0.0))
@@ -80,6 +86,40 @@ class Strategy(Strategy):
             until = 0.0
         return until > datetime.now(timezone.utc).timestamp()
 
+    def _recovery_paused(self) -> bool:
+        try:
+            until = float(self.state.get("recovery_cooldown_until") or 0.0)
+        except Exception:
+            until = 0.0
+        return until > datetime.now(timezone.utc).timestamp()
+
+    def _trip_recovery_pause(self, seconds: float = 30.0) -> None:
+        self.state["recovery_cooldown_until"] = (datetime.now(timezone.utc).timestamp() + max(seconds, 0.0))
+
+    def _recover_grid(self, price: float) -> None:
+        self._log(f"recovery mode: cancel all and rebuild from {price}")
+        try:
+            self.context.cancel_all_orders()
+        except Exception as exc:
+            self.state["last_error"] = str(exc)
+            self._log(f"recovery cancel-all failed: {exc}")
+        self.state["orders"] = []
+        self.state["order_ids"] = []
+        self.state["open_order_count"] = 0
+        self.state["center_price"] = price
+        self.state["seeded"] = True
+        self._place_grid(price)
+        self._sync_open_orders_state()
+        self.state["mismatch_ticks"] = 0
+        self._trip_recovery_pause()
+
+    def _order_count_mismatch(self, tracked_ids: list[str], live_orders: list[dict]) -> bool:
+        live_ids = [str(order.get("bitstamp_order_id") or order.get("order_id") or order.get("id") or order.get("client_order_id") or "") for order in live_orders if isinstance(order, dict)]
+        live_ids = [oid for oid in live_ids if oid]
+        if len(live_ids) != len([oid for oid in tracked_ids if oid]):
+            return True
+        return False
+
     def _base_symbol(self) -> str:
         return (self.context.base_currency or self.context.market_symbol or "XRP").split("/")[0].upper()
 
@@ -530,6 +570,20 @@ class Strategy(Strategy):
                 self._log(f"replacement order failed for {side}→{opposite} at {price}: {exc}")
         return placed
 
+    def _recenter_and_rebuild_from_fill(self, fill_price: float) -> None:
+        """Treat a fill as the new market anchor and rebuild the grid from there."""
+        if fill_price <= 0:
+            return
+        try:
+            self.context.cancel_all_orders()
+        except Exception as exc:
+            self.state["last_error"] = str(exc)
+            self._log(f"fill rebuild cancel-all failed: {exc}")
+        self.state["center_price"] = fill_price
+        self.state["seeded"] = True
+        self._place_grid(fill_price)
+        self._set_grid_refresh_pause()
+
     def _sync_open_orders_state(self) -> list[dict]:
         try:
             open_orders = self.context.get_open_orders()
@@ -580,11 +634,40 @@ class Strategy(Strategy):
             and str(order.get("bitstamp_order_id") or order.get("order_id") or order.get("id") or order.get("client_order_id") or "") in (previous_ids - current_ids)
         ]
         if vanished_orders and not self._grid_refresh_paused():
-            replaced_ids = self._place_replacement_orders(vanished_orders, price)
-            if replaced_ids:
-                live_orders = self._sync_open_orders_state()
-                live_ids = list(self.state.get("order_ids") or [])
-                open_order_count = len(live_ids)
+            for order in vanished_orders:
+                order_id = str(order.get("bitstamp_order_id") or order.get("order_id") or order.get("id") or order.get("client_order_id") or "")
+                if not order_id:
+                    continue
+                try:
+                    payload = self.context.query_order(order_id)
+                except Exception as exc:
+                    self._log(f"order status query failed for {order_id}: {exc}")
+                    continue
+
+                raw = payload.get("raw") if isinstance(payload, dict) else {}
+                if not isinstance(raw, dict):
+                    raw = {}
+                status = str(payload.get("status") or raw.get("status") or order.get("status") or "").strip().lower()
+                if status in {"finished", "filled", "closed"}:
+                    fill_price = 0.0
+                    for candidate in (raw.get("price"), order.get("price"), price):
+                        try:
+                            fill_price = float(candidate or 0.0)
+                        except Exception:
+                            fill_price = 0.0
+                        if fill_price > 0:
+                            break
+                    if fill_price > 0:
+                        self._log(f"filled order {order_id} detected via exec status={status}, recentering at {fill_price}")
+                        self._recenter_and_rebuild_from_fill(fill_price)
+                        live_orders = self._sync_open_orders_state()
+                        live_ids = list(self.state.get("order_ids") or [])
+                        open_order_count = len(live_ids)
+                        return live_orders, live_ids, open_order_count
+
+                if status in {"cancelled", "expired", "missing"}:
+                    self._log(f"vanished order {order_id} resolved as {status}")
+                    continue
 
         surplus_cancelled = self._cancel_surplus_side_orders(live_orders, int(self.config.get("grid_levels", 6) or 6))
         duplicate_cancelled = self._cancel_duplicate_level_orders(live_orders)
@@ -602,6 +685,7 @@ class Strategy(Strategy):
 
     def on_tick(self, tick):
         previous_orders = list(self.state.get("orders") or [])
+        tracked_ids_before_sync = list(self.state.get("order_ids") or [])
         self._refresh_balance_snapshot()
         price = self._price()
         self.state["last_price"] = price
@@ -672,6 +756,33 @@ class Strategy(Strategy):
             self._log(f"missing tracked orders: {missing_ids}")
             self.state["order_ids"] = live_ids
 
+        if self._order_count_mismatch(tracked_ids_before_sync, live_orders):
+            self.state["mismatch_ticks"] = int(self.state.get("mismatch_ticks") or 0) + 1
+            self._log(f"order count mismatch detected: tracked={len(tracked_ids_before_sync)} live={len(live_orders)} ticks={self.state['mismatch_ticks']}")
+            if self.state["mismatch_ticks"] >= 2 and not self._recovery_paused() and self._mode() == "active":
+                self._recover_grid(price)
+                return {"action": "recovery", "price": price}
+        else:
+            self.state["mismatch_ticks"] = 0
+
+        center = float(self.state.get("center_price") or price)
+        recenter_pct = float(self.config.get("recenter_pct", 0.05) or 0.05)
+        deviation = abs(price - center) / center if center else 0.0
+        if mode == "active" and deviation >= recenter_pct and not self._grid_refresh_paused():
+            self._log(f"recenter needed at price={price} center={center} dev={deviation:.4f}")
+            try:
+                self.context.cancel_all_orders()
+            except Exception as exc:
+                self.state["last_error"] = str(exc)
+                self._log(f"recenter cancel-all failed: {exc}")
+            self.state["center_price"] = price
+            self._place_grid(price)
+            live_orders = self._sync_open_orders_state()
+            live_ids = list(self.state.get("order_ids") or [])
+            open_order_count = len(live_ids)
+            self.state["last_action"] = "recentered"
+            return {"action": "recenter", "price": price, "deviation": deviation}
+
         live_orders, live_ids, open_order_count = self._reconcile_after_sync(previous_orders, live_orders, desired_sides, price)
 
         if desired_sides != {"buy", "sell"}:
@@ -721,23 +832,6 @@ class Strategy(Strategy):
             self.state["last_action"] = "reseeded" if mode == "active" else f"{mode} monitor"
             return {"action": "reseed" if mode == "active" else "plan", "price": price}
 
-        center = float(self.state.get("center_price") or price)
-        recenter_pct = float(self.config.get("recenter_pct", 0.05) or 0.05)
-        deviation = abs(price - center) / center if center else 0.0
-
-        if deviation >= recenter_pct and not self._grid_refresh_paused():
-            try:
-                self.context.cancel_all_orders()
-            except Exception as exc:
-                self.state["last_error"] = str(exc)
-            self.state["center_price"] = price
-            self._place_grid(price)
-            live_orders = self._sync_open_orders_state()
-            mode = self._mode()
-            self.state["last_action"] = "recentered" if mode == "active" else f"{mode} monitor"
-            self._log(f"recentered grid to {price}")
-            return {"action": "recenter" if mode == "active" else "plan", "price": price, "deviation": deviation}
-
         mode = self._mode()
         self.state["last_action"] = "hold" if mode == "active" else f"{mode} monitor"
         self._log(f"hold at {price} dev {deviation:.4f}")

+ 2 - 0
strategies/stop_loss_trader.py

@@ -3,6 +3,7 @@ from __future__ import annotations
 from datetime import datetime, timezone
 
 from src.trader_mcp.strategy_sdk import Strategy
+from src.trader_mcp.logging_utils import log_event
 
 
 class Strategy(Strategy):
@@ -56,6 +57,7 @@ class Strategy(Strategy):
         log.append(message)
         state["debug_log"] = log[-12:]
         self.state = state
+        log_event("stoploss", message)
 
     def _base_symbol(self) -> str:
         return (self.context.base_currency or self.context.market_symbol or "XRP").split("/")[0].upper()