trader_client.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. from __future__ import annotations
  2. from typing import Any
  3. from datetime import timedelta
  4. import json
  5. from mcp import ClientSession
  6. from mcp.client.sse import sse_client
  7. def _normalize_url(base_url: str) -> str:
  8. url = (base_url or "").strip()
  9. if not url:
  10. return url
  11. if not url.endswith("/mcp/sse"):
  12. url = url.rstrip("/") + "/mcp/sse"
  13. return url
  14. async def _call_tool(base_url: str, tool: str, arguments: dict[str, Any]) -> dict[str, Any]:
  15. url = _normalize_url(base_url)
  16. if not url:
  17. return {}
  18. async with sse_client(url, timeout=8.0, sse_read_timeout=8.0) as streams:
  19. async with ClientSession(*streams, read_timeout_seconds=timedelta(seconds=8)) as session:
  20. await session.initialize()
  21. result = await session.call_tool(tool, arguments)
  22. content = getattr(result, "content", None) or []
  23. if not content:
  24. return {}
  25. first = content[0]
  26. text = getattr(first, "text", None)
  27. if text is None and isinstance(first, dict):
  28. text = first.get("text")
  29. if text is None:
  30. return {}
  31. try:
  32. payload = json.loads(text)
  33. return payload if isinstance(payload, dict) else {}
  34. except Exception:
  35. return {"raw": text}
  36. async def list_strategies(base_url: str) -> list[dict[str, Any]]:
  37. payload = await _call_tool(base_url, "list_strategies", {})
  38. strategies = payload.get("strategies", payload.get("configured", [])) or []
  39. return [s for s in strategies if isinstance(s, dict)]
  40. async def get_strategy(base_url: str, instance_id: str, *, include_state: bool = True, include_report: bool = True) -> dict[str, Any]:
  41. payload = await _call_tool(
  42. base_url,
  43. "get_strategy",
  44. {
  45. "instance_id": instance_id,
  46. "include_state": include_state,
  47. "include_report": include_report,
  48. },
  49. )
  50. return payload if isinstance(payload, dict) else {}
  51. async def list_accounts(base_url: str) -> list[dict[str, Any]]:
  52. payload = await _call_tool(base_url, "list_accounts", {})
  53. accounts = payload.get("accounts", []) or []
  54. return [a for a in accounts if isinstance(a, dict)]
  55. async def cancel_all_orders(base_url: str, account_id: str, client_id: str | None = None) -> dict[str, Any]:
  56. payload = await _call_tool(base_url, "cancel_all_orders", {"account_id": account_id, "client_id": client_id})
  57. return payload if isinstance(payload, dict) else {}
  58. async def control_strategy(base_url: str, instance_id: str, action: str) -> dict[str, Any]:
  59. payload = await _call_tool(base_url, "control_strategy", {"instance_id": instance_id, "action": action})
  60. return payload if isinstance(payload, dict) else {}
  61. async def apply_control_decision(base_url: str, payload: dict[str, Any]) -> dict[str, Any]:
  62. response = await _call_tool(base_url, "apply_control_decision", {"payload": payload})
  63. return response if isinstance(response, dict) else {}