Jelajahi Sumber

Add stop loss rebalancer strategy

Lukas Goldschmidt 1 bulan lalu
induk
melakukan
feb4ec9bde
3 mengubah file dengan 255 tambahan dan 4 penghapusan
  1. 1 1
      src/trader_mcp/dashboard.py
  2. 228 0
      strategies/stop_loss_trader.py
  3. 26 3
      tests/test_strategies.py

+ 1 - 1
src/trader_mcp/dashboard.py

@@ -353,7 +353,7 @@ def dashboard_strategies_add(
     name: str = Form(...),
     strategy_type: str = Form(...),
     account_id: str = Form(...),
-    market_symbol: str = Form(...),
+    market_symbol: str = Form(""),
 ):
     strategy_id = str(uuid4())
     default_config = get_strategy_default_config(strategy_type.strip())

+ 228 - 0
strategies/stop_loss_trader.py

@@ -0,0 +1,228 @@
+from __future__ import annotations
+
+from datetime import datetime, timezone
+
+from src.trader_mcp.strategy_sdk import Strategy
+
+
+class Strategy(Strategy):
+    LABEL = "Stop Loss Rebalancer"
+    TICK_MINUTES = 0.2
+    CONFIG_SCHEMA = {
+        "regime_timeframes": {"type": "list", "default": ["1d", "4h", "1h", "15m"]},
+        "trend_enter_threshold": {"type": "float", "default": 0.7, "min": 0.0, "max": 1.0},
+        "trend_exit_threshold": {"type": "float", "default": 0.45, "min": 0.0, "max": 1.0},
+        "trail_distance_pct": {"type": "float", "default": 0.03, "min": 0.0, "max": 1.0},
+        "rebalance_target_ratio": {"type": "float", "default": 0.5, "min": 0.0, "max": 1.0},
+        "rebalance_step_ratio": {"type": "float", "default": 0.15, "min": 0.0, "max": 1.0},
+        "min_order_size": {"type": "float", "default": 0.0, "min": 0.0},
+        "max_order_size": {"type": "float", "default": 0.0, "min": 0.0},
+        "order_spacing_ticks": {"type": "int", "default": 1, "min": 0, "max": 1000},
+        "cooldown_ticks": {"type": "int", "default": 2, "min": 0, "max": 1000},
+        "balance_tolerance": {"type": "float", "default": 0.05, "min": 0.0, "max": 1.0},
+        "fee_rate": {"type": "float", "default": 0.0025, "min": 0.0, "max": 0.05},
+        "debug_orders": {"type": "bool", "default": True},
+    }
+    STATE_SCHEMA = {
+        "last_price": {"type": "float", "default": 0.0},
+        "last_action": {"type": "string", "default": "idle"},
+        "last_error": {"type": "string", "default": ""},
+        "debug_log": {"type": "list", "default": []},
+        "regimes": {"type": "dict", "default": {}},
+        "regimes_updated_at": {"type": "string", "default": ""},
+        "base_available": {"type": "float", "default": 0.0},
+        "counter_available": {"type": "float", "default": 0.0},
+        "trailing_anchor": {"type": "float", "default": 0.0},
+        "cooldown_remaining": {"type": "int", "default": 0},
+    }
+
+    def init(self):
+        return {
+            "last_price": 0.0,
+            "last_action": "idle",
+            "last_error": "",
+            "debug_log": ["init stop loss rebalancer"],
+            "regimes": {},
+            "regimes_updated_at": "",
+            "base_available": 0.0,
+            "counter_available": 0.0,
+            "trailing_anchor": 0.0,
+            "cooldown_remaining": 0,
+        }
+
+    def _log(self, message: str) -> None:
+        state = getattr(self, "state", {}) or {}
+        log = list(state.get("debug_log") or [])
+        log.append(message)
+        state["debug_log"] = log[-12:]
+        self.state = state
+
+    def _base_symbol(self) -> str:
+        return (self.context.base_currency or self.context.market_symbol or "XRP").split("/")[0].upper()
+
+    def _market_symbol(self) -> str:
+        return self.context.market_symbol or f"{self._base_symbol().lower()}usd"
+
+    def _price(self) -> float:
+        payload = self.context.get_price(self._base_symbol())
+        return float(payload.get("price") or 0.0)
+
+    def _refresh_regimes(self) -> None:
+        regimes: dict[str, dict] = {}
+        for tf in self.config.get("regime_timeframes") or ["1d", "4h", "1h", "15m"]:
+            try:
+                regimes[str(tf)] = self.context.get_regime(self._base_symbol(), str(tf))
+            except Exception as exc:
+                regimes[str(tf)] = {"error": str(exc)}
+        self.state["regimes"] = regimes
+        self.state["regimes_updated_at"] = datetime.now(timezone.utc).isoformat()
+
+    def _refresh_balance_snapshot(self) -> None:
+        try:
+            info = self.context.get_account_info()
+        except Exception as exc:
+            self._log(f"balance refresh failed: {exc}")
+            return
+        balances = info.get("balances") if isinstance(info, dict) else []
+        if not isinstance(balances, list):
+            return
+        base = self._base_symbol()
+        quote = str(self.context.counter_currency or "USD").upper()
+        for balance in balances:
+            if not isinstance(balance, dict):
+                continue
+            asset = str(balance.get("asset_code") or "").upper()
+            try:
+                available = float(balance.get("available") if balance.get("available") is not None else balance.get("total") or 0.0)
+            except Exception:
+                continue
+            if asset == base:
+                self.state["base_available"] = available
+            if asset == quote:
+                self.state["counter_available"] = available
+
+    def _regime_strength(self) -> float:
+        regimes = self.state.get("regimes") or {}
+        strengths = []
+        for tf in self.config.get("regime_timeframes") or []:
+            regime = regimes.get(str(tf)) or {}
+            trend = regime.get("trend") or {}
+            strengths.append(float(trend.get("strength") or 0.0))
+        return max(strengths) if strengths else 0.0
+
+    def _is_trending(self) -> bool:
+        strength = self._regime_strength()
+        return strength >= float(self.config.get("trend_enter_threshold", 0.7) or 0.7)
+
+    def _account_value_ratio(self, price: float) -> float:
+        base_value = float(self.state.get("base_available") or 0.0) * price
+        counter_value = float(self.state.get("counter_available") or 0.0)
+        total = base_value + counter_value
+        if total <= 0:
+            return 0.5
+        return base_value / total
+
+    def _desired_side(self, price: float) -> str:
+        # If base dominates, sell some into strength, otherwise buy some back.
+        ratio = self._account_value_ratio(price)
+        target = float(self.config.get("rebalance_target_ratio", 0.5) or 0.5)
+        return "sell" if ratio > target else "buy"
+
+    def _suggest_amount(self, side: str, price: float) -> float:
+        fee_rate = float(self.config.get("fee_rate", 0.0025) or 0.0)
+        step_ratio = float(self.config.get("rebalance_step_ratio", 0.15) or 0.0)
+        target = float(self.config.get("rebalance_target_ratio", 0.5) or 0.5)
+        min_order = float(self.config.get("min_order_size", 0.0) or 0.0)
+        max_order = float(self.config.get("max_order_size", 0.0) or 0.0)
+        balance_tolerance = float(self.config.get("balance_tolerance", 0.05) or 0.0)
+        base_value = float(self.state.get("base_available") or 0.0) * price
+        counter_value = float(self.state.get("counter_available") or 0.0)
+        total = base_value + counter_value
+        if total <= 0 or price <= 0:
+            return 0.0
+        current = base_value / total
+        drift = abs(current - target)
+        if drift <= balance_tolerance:
+            return 0.0
+
+        notional = total * min(drift, step_ratio)
+        if side == "sell":
+            amount = notional / (price * (1 + fee_rate))
+            amount = min(amount, float(self.state.get("base_available") or 0.0))
+        else:
+            amount = notional / (price * (1 + fee_rate))
+            amount = min(amount, float(self.state.get("counter_available") or 0.0) / price if price > 0 else 0.0)
+
+        if min_order > 0:
+            amount = max(amount, min_order)
+        if max_order > 0:
+            amount = min(amount, max_order)
+        return max(amount, 0.0)
+
+    def on_tick(self, tick):
+        self.state["last_error"] = ""
+        self._refresh_balance_snapshot()
+        self._refresh_regimes()
+        price = self._price()
+        self.state["last_price"] = price
+
+        if int(self.state.get("cooldown_remaining") or 0) > 0:
+            self.state["cooldown_remaining"] = int(self.state.get("cooldown_remaining") or 0) - 1
+            self.state["last_action"] = "cooldown"
+            return {"action": "cooldown", "price": price}
+
+        if not self._is_trending():
+            self.state["last_action"] = "standby"
+            return {"action": "standby", "price": price}
+
+        side = self._desired_side(price)
+        amount = self._suggest_amount(side, price)
+        trail_distance = float(self.config.get("trail_distance_pct", 0.03) or 0.03)
+
+        if amount <= 0:
+            self.state["last_action"] = "hold"
+            return {"action": "hold", "price": price}
+
+        try:
+            market = self._market_symbol()
+            if side == "sell":
+                self.state["trailing_anchor"] = max(float(self.state.get("trailing_anchor") or 0.0), price)
+                order_price = round(price * (1 + trail_distance), 8)
+            else:
+                self.state["trailing_anchor"] = min(float(self.state.get("trailing_anchor") or price), price) if self.state.get("trailing_anchor") else price
+                order_price = round(price * (1 - trail_distance), 8)
+
+            if self.config.get("debug_orders", True):
+                self._log(f"{side} rebalance amount={amount:.6g} price={order_price} ratio={self._account_value_ratio(price):.4f}")
+
+            result = self.context.place_order(
+                side=side,
+                order_type="limit",
+                amount=amount,
+                price=order_price,
+                market=market,
+            )
+            self.state["cooldown_remaining"] = int(self.config.get("cooldown_ticks", 2) or 2)
+            self.state["last_action"] = f"{side}_rebalance"
+            return {"action": side, "price": order_price, "amount": amount, "result": result}
+        except Exception as exc:
+            self.state["last_error"] = str(exc)
+            self._log(f"rebalance failed: {exc}")
+            self.state["last_action"] = "error"
+            return {"action": "error", "price": price, "error": str(exc)}
+
+    def render(self):
+        return {
+            "widgets": [
+                {"type": "metric", "label": "market", "value": self._market_symbol()},
+                {"type": "metric", "label": "price", "value": round(float(self.state.get("last_price") or 0.0), 6)},
+                {"type": "metric", "label": "state", "value": self.state.get("last_action", "idle")},
+                {"type": "metric", "label": "base avail", "value": round(float(self.state.get("base_available") or 0.0), 8)},
+                {"type": "metric", "label": "counter avail", "value": round(float(self.state.get("counter_available") or 0.0), 8)},
+                {"type": "metric", "label": "ratio", "value": round(self._account_value_ratio(float(self.state.get("last_price") or 0.0) or 1.0), 4)},
+                {"type": "metric", "label": "trailing anchor", "value": round(float(self.state.get("trailing_anchor") or 0.0), 6)},
+                {"type": "metric", "label": "cooldown", "value": int(self.state.get("cooldown_remaining") or 0)},
+                {"type": "text", "label": "error", "value": self.state.get("last_error", "") or "none"},
+                {"type": "log", "label": "debug log", "lines": self.state.get("debug_log") or []},
+            ]
+        }

+ 26 - 3
tests/test_strategies.py

@@ -5,7 +5,7 @@ from tempfile import TemporaryDirectory
 
 from fastapi.testclient import TestClient
 
-from src.trader_mcp import strategy_store
+from src.trader_mcp import strategy_registry, strategy_store
 from src.trader_mcp.server import app
 from src.trader_mcp.strategy_context import StrategyContext
 
@@ -60,7 +60,7 @@ def test_strategies_endpoints_roundtrip():
 def test_strategy_context_binds_identity(monkeypatch):
     calls = {}
 
-    def fake_place_order(**arguments):
+    def fake_place_order(arguments):
         calls["place_order"] = arguments
         return {"ok": True}
 
@@ -76,7 +76,7 @@ def test_strategy_context_binds_identity(monkeypatch):
     monkeypatch.setattr("src.trader_mcp.strategy_context.list_open_orders", fake_open_orders)
     monkeypatch.setattr("src.trader_mcp.strategy_context.cancel_all_orders", fake_cancel_all)
 
-    ctx = StrategyContext(id="inst-1", account_id="acct-1", client_id="client-1")
+    ctx = StrategyContext(id="inst-1", account_id="acct-1", client_id="client-1", mode="active")
 
     ctx.place_order(side="sell", market="xrpusd", order_type="limit", amount="10", price="2")
     ctx.get_open_orders()
@@ -86,3 +86,26 @@ def test_strategy_context_binds_identity(monkeypatch):
     assert calls["place_order"]["client_id"] == "client-1"
     assert calls["open_orders"] == {"account_id": "acct-1", "client_id": "client-1"}
     assert calls["cancel_all"] == {"account_id": "acct-1", "client_id": "client-1"}
+
+
+def test_stop_loss_strategy_loads_with_aligned_regime_config(tmp_path):
+    original_db = strategy_store.DB_PATH
+    original_dir = strategy_registry.STRATEGIES_DIR
+    try:
+        strategy_store.DB_PATH = tmp_path / "trader_mcp.sqlite3"
+        strategy_registry.STRATEGIES_DIR = tmp_path / "strategies"
+        strategy_registry.STRATEGIES_DIR.mkdir()
+        (strategy_registry.STRATEGIES_DIR / "grid_trader.py").write_text((Path(__file__).resolve().parents[1] / "strategies" / "grid_trader.py").read_text())
+        (strategy_registry.STRATEGIES_DIR / "stop_loss_trader.py").write_text((Path(__file__).resolve().parents[1] / "strategies" / "stop_loss_trader.py").read_text())
+
+        grid_defaults = strategy_registry.get_strategy_default_config("grid_trader")
+        stop_defaults = strategy_registry.get_strategy_default_config("stop_loss_trader")
+
+        assert grid_defaults["trade_sides"] == "both"
+        assert grid_defaults["trend_guard_reversal_max"] == 0.25
+        assert stop_defaults["regime_timeframes"] == ["1d", "4h", "1h", "15m"]
+        assert stop_defaults["trend_enter_threshold"] == 0.7
+        assert stop_defaults["trend_exit_threshold"] == 0.45
+    finally:
+        strategy_store.DB_PATH = original_db
+        strategy_registry.STRATEGIES_DIR = original_dir