strategy_context.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325
  1. from __future__ import annotations
  2. from dataclasses import dataclass, field
  3. from typing import Any
  4. from .exec_client import list_open_orders, query_order, cancel_all_orders, cancel_order, place_order, get_account_info, get_account_fees
  5. from .news_client import call_news_tool
  6. from .crypto_client import call_crypto_tool
  7. def _order_ref(payload: Any) -> str | None:
  8. if not isinstance(payload, dict):
  9. return None
  10. for key in ("order_id", "bitstamp_order_id", "id", "client_order_id"):
  11. value = payload.get(key)
  12. if value is None:
  13. continue
  14. text = str(value).strip()
  15. if text:
  16. return text
  17. return None
  18. def _cancelled_order_ids(cancel_result: Any) -> list[str]:
  19. if not isinstance(cancel_result, dict):
  20. return []
  21. rows = cancel_result.get("cancelled")
  22. if not isinstance(rows, list):
  23. return []
  24. cancelled: list[str] = []
  25. for row in rows:
  26. if not isinstance(row, dict) or not bool(row.get("ok")):
  27. continue
  28. order_id = _order_ref(row)
  29. if order_id is not None:
  30. cancelled.append(order_id)
  31. return cancelled
  32. def _has_inconclusive_cancel_rows(cancel_result: Any) -> bool:
  33. if not isinstance(cancel_result, dict):
  34. return False
  35. rows = cancel_result.get("cancelled")
  36. if not isinstance(rows, list):
  37. return False
  38. return any(isinstance(row, dict) and not bool(row.get("ok")) for row in rows)
  39. @dataclass(frozen=True)
  40. class StrategyContext:
  41. id: str
  42. account_id: str
  43. client_id: str | None = field(default=None, repr=False)
  44. mode: str = "off"
  45. market_symbol: str | None = None
  46. base_currency: str | None = None
  47. counter_currency: str | None = None
  48. minimum_order_value: float | None = None
  49. def __getattr__(self, name: str):
  50. if name == "mode":
  51. return "active"
  52. raise AttributeError(name)
  53. def get_open_orders(self) -> Any:
  54. payload = list_open_orders(self.account_id, self.client_id)
  55. if isinstance(payload, dict) and isinstance(payload.get("orders"), list):
  56. return payload["orders"]
  57. return payload
  58. def query_order(self, order_id: str) -> Any:
  59. return query_order(self.account_id, order_id)
  60. def cancel_all_orders(self) -> Any:
  61. return cancel_all_orders(self.account_id, self.client_id)
  62. def cancel_all_orders_confirmed(self) -> dict[str, Any]:
  63. cancel_result = None
  64. cancel_error = None
  65. verification_error = None
  66. remaining_orders = None
  67. try:
  68. cancel_result = self.cancel_all_orders()
  69. except Exception as exc:
  70. cancel_error = str(exc)
  71. try:
  72. payload = self.get_open_orders()
  73. if isinstance(payload, list):
  74. remaining_orders = payload
  75. else:
  76. verification_error = f"unexpected open orders payload type: {type(payload).__name__}"
  77. except Exception as exc:
  78. verification_error = str(exc)
  79. cancelled_order_ids = _cancelled_order_ids(cancel_result)
  80. inconclusive_cancel = cancel_error is not None or _has_inconclusive_cancel_rows(cancel_result)
  81. conclusive = (
  82. not inconclusive_cancel
  83. and remaining_orders is not None
  84. and len(remaining_orders) == 0
  85. )
  86. cleanup_status = "cleanup_confirmed" if conclusive else ("cleanup_failed" if cancel_error else "cleanup_partial")
  87. error = None
  88. if not conclusive:
  89. if cancel_error:
  90. error = cancel_error
  91. elif _has_inconclusive_cancel_rows(cancel_result):
  92. error = "cancel-all reported uncancelled orders"
  93. elif verification_error:
  94. error = verification_error
  95. elif remaining_orders is not None:
  96. error = f"{len(remaining_orders)} open orders remain after cancel-all"
  97. else:
  98. error = "open order verification unavailable after cancel-all"
  99. return {
  100. "ok": cancel_error is None and verification_error is None and conclusive,
  101. "conclusive": conclusive,
  102. "cleanup_status": cleanup_status,
  103. "cancelled_order_ids": cancelled_order_ids,
  104. "remaining_orders": remaining_orders,
  105. "cancel_result": cancel_result,
  106. "cancel_error": cancel_error,
  107. "verification_error": verification_error,
  108. "error": error,
  109. }
  110. def cancel_order(self, order_id: str) -> Any:
  111. return cancel_order(self.account_id, order_id)
  112. def place_order(self, **kwargs: Any) -> Any:
  113. mode = getattr(self, "mode", "active")
  114. if mode != "active":
  115. raise RuntimeError(f"place_order not allowed in {mode} mode")
  116. kwargs.setdefault("account_id", self.account_id)
  117. kwargs.setdefault("client_id", self.client_id)
  118. return place_order(kwargs)
  119. def get_price(self, symbol: str) -> Any:
  120. return call_crypto_tool("get_price", {"symbol": symbol})
  121. def get_regime(self, symbol: str, timeframe: str = "1h") -> Any:
  122. return call_crypto_tool("get_regime", {"symbol": symbol, "timeframe": timeframe})
  123. def get_account_info(self) -> Any:
  124. return get_account_info(self.account_id)
  125. def get_balance_snapshot(self) -> dict[str, Any]:
  126. info = self.get_account_info()
  127. balances = info.get("balances") if isinstance(info, dict) else []
  128. if not isinstance(balances, list):
  129. balances = []
  130. return {
  131. "account_id": self.account_id,
  132. "market_symbol": self.market_symbol,
  133. "base_currency": self.base_currency,
  134. "counter_currency": self.counter_currency,
  135. "balances": balances,
  136. }
  137. def get_open_order_snapshot(self) -> dict[str, Any]:
  138. orders = self.get_open_orders()
  139. return {
  140. "account_id": self.account_id,
  141. "market_symbol": self.market_symbol,
  142. "open_orders": orders if isinstance(orders, list) else [],
  143. }
  144. def get_strategy_snapshot(self) -> dict[str, Any]:
  145. return {
  146. "identity": {
  147. "strategy_id": self.id,
  148. "account_id": self.account_id,
  149. "market": self.market_symbol,
  150. "base_currency": self.base_currency,
  151. "quote_currency": self.counter_currency,
  152. },
  153. "control": {
  154. "mode": getattr(self, "mode", "off"),
  155. },
  156. "position": self.get_balance_snapshot(),
  157. "orders": self.get_open_order_snapshot(),
  158. }
  159. def _available_balance(self, asset_code: str) -> float:
  160. try:
  161. info = self.get_account_info()
  162. except Exception:
  163. return 0.0
  164. balances = info.get("balances") if isinstance(info, dict) else []
  165. if not isinstance(balances, list):
  166. return 0.0
  167. wanted = str(asset_code or "").upper()
  168. for balance in balances:
  169. if not isinstance(balance, dict):
  170. continue
  171. if str(balance.get("asset_code") or "").upper() != wanted:
  172. continue
  173. try:
  174. return float(balance.get("available") if balance.get("available") is not None else balance.get("total") or 0.0)
  175. except Exception:
  176. return 0.0
  177. return 0.0
  178. def get_account_fees(self, market_symbol: str | None = None) -> Any:
  179. return get_account_fees(self.account_id, market_symbol)
  180. def get_fee_rates(self, market_symbol: str | None = None) -> dict[str, float]:
  181. payload = get_account_fees(self.account_id, market_symbol)
  182. if not isinstance(payload, dict):
  183. return {"maker": 0.0, "taker": 0.0}
  184. fees = payload.get("fees") if isinstance(payload.get("fees"), dict) else {}
  185. def _normalize_fee(value: object) -> float:
  186. try:
  187. rate = float(value or 0.0)
  188. except Exception:
  189. return 0.0
  190. if rate > 0.1:
  191. rate /= 100.0
  192. return rate
  193. try:
  194. maker = _normalize_fee(fees.get("maker"))
  195. except Exception:
  196. maker = 0.0
  197. try:
  198. taker = _normalize_fee(fees.get("taker"))
  199. except Exception:
  200. taker = 0.0
  201. return {"maker": maker, "taker": taker}
  202. def suggest_order_amount(
  203. self,
  204. *,
  205. side: str,
  206. price: float,
  207. levels: int,
  208. min_notional: float,
  209. fee_rate: float,
  210. quote_notional: float = 0.0,
  211. max_notional_per_order: float = 0.0,
  212. dust_collect: bool = False,
  213. order_size: float = 0.0,
  214. safety: float = 0.995,
  215. available_balances: dict[str, float] | None = None,
  216. ) -> float:
  217. """Return a conservative per-order amount for this venue/account.
  218. The returned amount is exchange-aware but strategy-agnostic, so other
  219. strategies can reuse the same sizing rules. `quote_notional` is the
  220. canonical quote-currency cap when a strategy wants quote-standard sizing.
  221. """
  222. if levels <= 0 or price <= 0:
  223. return 0.0
  224. side = str(side or "").strip().lower()
  225. fee_rate = max(float(fee_rate or 0.0), 0.0)
  226. quote_notional = float(quote_notional or 0.0)
  227. max_notional_per_order = float(max_notional_per_order or 0.0)
  228. order_size = float(order_size or 0.0)
  229. balance_overrides = {
  230. str(asset_code or "").upper(): max(float(amount or 0.0), 0.0)
  231. for asset_code, amount in dict(available_balances or {}).items()
  232. }
  233. min_amount = (min_notional / price) if min_notional > 0 else 0.0
  234. if side == "buy":
  235. quote = self.counter_currency or "USD"
  236. quote_available = balance_overrides.get(str(quote).upper())
  237. if quote_available is None:
  238. quote_available = self._available_balance(quote) if hasattr(self, "_available_balance") else 0.0
  239. spendable_quote = quote_available * safety
  240. quote_cap = spendable_quote if max_notional_per_order <= 0 else min(spendable_quote, max_notional_per_order)
  241. if quote_notional > 0:
  242. quote_cap = min(quote_cap, quote_notional)
  243. if dust_collect and max_notional_per_order > 0:
  244. leftover_quote = max(spendable_quote - max_notional_per_order, 0.0)
  245. if 0.0 < leftover_quote < min_notional:
  246. quote_cap = spendable_quote
  247. if quote_cap <= 0:
  248. return 0.0
  249. # `max_notional_per_order` is already a per-order cap in quote
  250. # currency, so do not dilute it by the number of grid levels.
  251. per_order_quote = quote_cap
  252. min_quote_needed = min_notional * (1 + fee_rate)
  253. if per_order_quote < min_quote_needed:
  254. return 0.0
  255. amount = per_order_quote / (price * (1 + fee_rate))
  256. else:
  257. base = self.base_currency or (self.market_symbol or "XRP")
  258. base_available = balance_overrides.get(str(base).upper())
  259. if base_available is None:
  260. base_available = self._available_balance(base) if hasattr(self, "_available_balance") else 0.0
  261. spendable_base = base_available * safety
  262. if quote_notional > 0 and price > 0:
  263. spendable_base = min(spendable_base, quote_notional / price)
  264. if max_notional_per_order > 0 and price > 0:
  265. base_cap = max_notional_per_order / price
  266. if dust_collect:
  267. leftover_base = max(spendable_base - base_cap, 0.0)
  268. if 0.0 < leftover_base * price < min_notional:
  269. spendable_base = spendable_base
  270. else:
  271. spendable_base = min(spendable_base, base_cap)
  272. else:
  273. spendable_base = min(spendable_base, base_cap)
  274. if spendable_base <= 0:
  275. return 0.0
  276. # Same rule as above, the cap is per order, not per grid level.
  277. amount = spendable_base
  278. if amount < min_amount:
  279. return 0.0
  280. if order_size > 0:
  281. if order_size < min_amount:
  282. return 0.0
  283. amount = min(amount, order_size)
  284. return max(amount, 0.0)
  285. def get_news(self, **kwargs: Any) -> Any:
  286. return call_news_tool("search", kwargs)