test_strategies.py 3.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. from __future__ import annotations
  2. from pathlib import Path
  3. from tempfile import TemporaryDirectory
  4. from fastapi.testclient import TestClient
  5. from src.trader_mcp import strategy_store
  6. from src.trader_mcp.server import app
  7. from src.trader_mcp.strategy_context import StrategyContext
  8. STRATEGY_CODE = '''
  9. from src.trader_mcp.strategy_sdk import Strategy
  10. class Strategy(Strategy):
  11. def init(self):
  12. return {"started": True, "config_copy": dict(self.config)}
  13. '''
  14. def test_strategies_endpoints_roundtrip():
  15. with TemporaryDirectory() as tmpdir:
  16. strategy_store.DB_PATH = Path(tmpdir) / "trader_mcp.sqlite3"
  17. from src.trader_mcp import strategy_registry
  18. strategy_registry.STRATEGIES_DIR = Path(tmpdir) / "strategies"
  19. strategy_registry.STRATEGIES_DIR.mkdir()
  20. (strategy_registry.STRATEGIES_DIR / "demo.py").write_text(STRATEGY_CODE)
  21. client = TestClient(app)
  22. r = client.get("/strategies")
  23. assert r.status_code == 200
  24. body = r.json()
  25. assert "available" in body
  26. assert "configured" in body
  27. r = client.post(
  28. "/strategies",
  29. json={
  30. "id": "demo-1",
  31. "strategy_type": "demo",
  32. "account_id": "acct-1",
  33. "client_id": "strategy:test",
  34. "mode": "observe",
  35. "config": {"risk": 0.01},
  36. },
  37. )
  38. assert r.status_code == 200
  39. assert r.json()["id"] == "demo-1"
  40. r = client.get("/strategies")
  41. assert any(item["id"] == "demo-1" for item in r.json()["configured"])
  42. r = client.delete("/strategies/demo-1")
  43. assert r.status_code == 200
  44. assert r.json()["ok"] is True
  45. def test_strategy_context_binds_identity(monkeypatch):
  46. calls = {}
  47. def fake_place_order(**arguments):
  48. calls["place_order"] = arguments
  49. return {"ok": True}
  50. def fake_open_orders(account_id, client_id=None):
  51. calls["open_orders"] = {"account_id": account_id, "client_id": client_id}
  52. return {"ok": True}
  53. def fake_cancel_all(account_id, client_id=None):
  54. calls["cancel_all"] = {"account_id": account_id, "client_id": client_id}
  55. return {"ok": True}
  56. monkeypatch.setattr("src.trader_mcp.strategy_context.place_order", fake_place_order)
  57. monkeypatch.setattr("src.trader_mcp.strategy_context.list_open_orders", fake_open_orders)
  58. monkeypatch.setattr("src.trader_mcp.strategy_context.cancel_all_orders", fake_cancel_all)
  59. ctx = StrategyContext(id="inst-1", account_id="acct-1", client_id="client-1")
  60. ctx.place_order(side="sell", market="xrpusd", order_type="limit", amount="10", price="2")
  61. ctx.get_open_orders()
  62. ctx.cancel_all_orders()
  63. assert calls["place_order"]["account_id"] == "acct-1"
  64. assert calls["place_order"]["client_id"] == "client-1"
  65. assert calls["open_orders"] == {"account_id": "acct-1", "client_id": "client-1"}
  66. assert calls["cancel_all"] == {"account_id": "acct-1", "client_id": "client-1"}