test_strategies.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. from __future__ import annotations
  2. from pathlib import Path
  3. from tempfile import TemporaryDirectory
  4. from fastapi.testclient import TestClient
  5. from src.trader_mcp import strategy_registry, strategy_store
  6. from src.trader_mcp.server import app
  7. from src.trader_mcp.strategy_context import StrategyContext
  8. from strategies.grid_trader import Strategy as GridStrategy
  9. STRATEGY_CODE = '''
  10. from src.trader_mcp.strategy_sdk import Strategy
  11. class Strategy(Strategy):
  12. def init(self):
  13. return {"started": True, "config_copy": dict(self.config)}
  14. '''
  15. def test_strategies_endpoints_roundtrip():
  16. with TemporaryDirectory() as tmpdir:
  17. strategy_store.DB_PATH = Path(tmpdir) / "trader_mcp.sqlite3"
  18. from src.trader_mcp import strategy_registry
  19. strategy_registry.STRATEGIES_DIR = Path(tmpdir) / "strategies"
  20. strategy_registry.STRATEGIES_DIR.mkdir()
  21. (strategy_registry.STRATEGIES_DIR / "demo.py").write_text(STRATEGY_CODE)
  22. client = TestClient(app)
  23. r = client.get("/strategies")
  24. assert r.status_code == 200
  25. body = r.json()
  26. assert "available" in body
  27. assert "configured" in body
  28. r = client.post(
  29. "/strategies",
  30. json={
  31. "id": "demo-1",
  32. "strategy_type": "demo",
  33. "account_id": "acct-1",
  34. "client_id": "strategy:test",
  35. "mode": "observe",
  36. "config": {"risk": 0.01},
  37. },
  38. )
  39. assert r.status_code == 200
  40. assert r.json()["id"] == "demo-1"
  41. r = client.get("/strategies")
  42. assert any(item["id"] == "demo-1" for item in r.json()["configured"])
  43. r = client.delete("/strategies/demo-1")
  44. assert r.status_code == 200
  45. assert r.json()["ok"] is True
  46. def test_strategy_context_binds_identity(monkeypatch):
  47. calls = {}
  48. def fake_place_order(arguments):
  49. calls["place_order"] = arguments
  50. return {"ok": True}
  51. def fake_open_orders(account_id, client_id=None):
  52. calls["open_orders"] = {"account_id": account_id, "client_id": client_id}
  53. return {"ok": True}
  54. def fake_cancel_all(account_id, client_id=None):
  55. calls["cancel_all"] = {"account_id": account_id, "client_id": client_id}
  56. return {"ok": True}
  57. monkeypatch.setattr("src.trader_mcp.strategy_context.place_order", fake_place_order)
  58. monkeypatch.setattr("src.trader_mcp.strategy_context.list_open_orders", fake_open_orders)
  59. monkeypatch.setattr("src.trader_mcp.strategy_context.cancel_all_orders", fake_cancel_all)
  60. ctx = StrategyContext(id="inst-1", account_id="acct-1", client_id="client-1", mode="active")
  61. ctx.place_order(side="sell", market="xrpusd", order_type="limit", amount="10", price="2")
  62. ctx.get_open_orders()
  63. ctx.cancel_all_orders()
  64. assert calls["place_order"]["account_id"] == "acct-1"
  65. assert calls["place_order"]["client_id"] == "client-1"
  66. assert calls["open_orders"] == {"account_id": "acct-1", "client_id": "client-1"}
  67. assert calls["cancel_all"] == {"account_id": "acct-1", "client_id": "client-1"}
  68. def test_stop_loss_strategy_loads_with_aligned_regime_config(tmp_path):
  69. original_db = strategy_store.DB_PATH
  70. original_dir = strategy_registry.STRATEGIES_DIR
  71. try:
  72. strategy_store.DB_PATH = tmp_path / "trader_mcp.sqlite3"
  73. strategy_registry.STRATEGIES_DIR = tmp_path / "strategies"
  74. strategy_registry.STRATEGIES_DIR.mkdir()
  75. (strategy_registry.STRATEGIES_DIR / "grid_trader.py").write_text((Path(__file__).resolve().parents[1] / "strategies" / "grid_trader.py").read_text())
  76. (strategy_registry.STRATEGIES_DIR / "stop_loss_trader.py").write_text((Path(__file__).resolve().parents[1] / "strategies" / "stop_loss_trader.py").read_text())
  77. grid_defaults = strategy_registry.get_strategy_default_config("grid_trader")
  78. stop_defaults = strategy_registry.get_strategy_default_config("stop_loss_trader")
  79. assert grid_defaults["trade_sides"] == "both"
  80. assert grid_defaults["trend_guard_reversal_max"] == 0.25
  81. assert stop_defaults["regime_timeframes"] == ["1d", "4h", "1h", "15m"]
  82. assert stop_defaults["trend_enter_threshold"] == 0.7
  83. assert stop_defaults["trend_exit_threshold"] == 0.45
  84. finally:
  85. strategy_store.DB_PATH = original_db
  86. strategy_registry.STRATEGIES_DIR = original_dir
  87. def test_grid_top_up_uses_missing_levels_budget():
  88. class FakeContext:
  89. base_currency = "XRP"
  90. counter_currency = "USD"
  91. market_symbol = "xrpusd"
  92. minimum_order_value = 10.0
  93. mode = "active"
  94. def __init__(self):
  95. self.placed_orders = []
  96. def get_fee_rates(self, market):
  97. return {"maker": 0.0, "taker": 0.004}
  98. def get_account_info(self):
  99. return {
  100. "balances": [
  101. {"asset_code": "USD", "available": 13.55},
  102. {"asset_code": "XRP", "available": 22.0103},
  103. ]
  104. }
  105. def suggest_order_amount(
  106. self,
  107. *,
  108. side,
  109. price,
  110. levels,
  111. min_notional,
  112. fee_rate,
  113. max_notional_per_order=0.0,
  114. dust_collect=False,
  115. inventory_cap_pct=0.0,
  116. order_size=0.0,
  117. safety=0.995,
  118. ):
  119. if side == "buy":
  120. quote_available = 13.55
  121. spendable_quote = quote_available * safety
  122. quote_cap = min(spendable_quote, max_notional_per_order) if max_notional_per_order > 0 else spendable_quote
  123. if quote_cap < min_notional * (1 + fee_rate):
  124. return 0.0
  125. return quote_cap / (price * (1 + fee_rate))
  126. return 0.0
  127. def place_order(self, **kwargs):
  128. self.placed_orders.append(kwargs)
  129. return {"status": "ok", "id": f"oid-{len(self.placed_orders)}"}
  130. ctx = FakeContext()
  131. strategy = GridStrategy(
  132. ctx,
  133. {
  134. "grid_levels": 2,
  135. "grid_step_pct": 0.0062,
  136. "grid_step_min_pct": 0.0033,
  137. "grid_step_max_pct": 0.012,
  138. "max_notional_per_order": 12,
  139. "order_call_delay_ms": 0,
  140. "trade_sides": "both",
  141. "debug_orders": True,
  142. "dust_collect": True,
  143. "enable_trend_guard": False,
  144. "fee_rate": 0.004,
  145. },
  146. )
  147. strategy.state["center_price"] = 1.3285
  148. strategy.state["orders"] = [
  149. {"side": "buy", "price": 1.3243993, "amount": 7.63, "id": "existing-buy"},
  150. {"side": "sell", "price": 1.3326007, "amount": 9.0, "id": "sell-1"},
  151. {"side": "sell", "price": 1.3367011, "amount": 9.0, "id": "sell-2"},
  152. ]
  153. strategy.state["order_ids"] = ["existing-buy", "sell-1", "sell-2"]
  154. strategy._top_up_missing_levels(strategy.state["center_price"], strategy.state["orders"])
  155. assert len(ctx.placed_orders) == 1
  156. assert ctx.placed_orders[0]["side"] == "buy"
  157. assert float(ctx.placed_orders[0]["amount"]) > 7.57