grid_trader.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383
  1. from __future__ import annotations
  2. import time
  3. from src.trader_mcp.strategy_sdk import Strategy
  4. class Strategy(Strategy):
  5. LABEL = "Grid Trader"
  6. TICK_MINUTES = 0.2
  7. # NOTE:
  8. # This strategy is currently using a protective workaround for stale order state,
  9. # because exec-mcp can temporarily report order records that do not reflect the
  10. # clean post-reset strategy state. The grid prefers its own fresh persisted state
  11. # first, so the real exchange behavior stays testable while exec-mcp is improved.
  12. # Expect the reconciliation behavior to change again once exec-mcp is fixed.
  13. CONFIG_SCHEMA = {
  14. "grid_levels": {"type": "int", "default": 6, "min": 1, "max": 20},
  15. "grid_step_pct": {"type": "float", "default": 0.012, "min": 0.001, "max": 0.1},
  16. "volatility_timeframe": {"type": "string", "default": "1h"},
  17. "volatility_multiplier": {"type": "float", "default": 0.5, "min": 0.0, "max": 10.0},
  18. "grid_step_min_pct": {"type": "float", "default": 0.005, "min": 0.0001, "max": 0.5},
  19. "grid_step_max_pct": {"type": "float", "default": 0.03, "min": 0.0001, "max": 1.0},
  20. "order_size": {"type": "float", "default": 0.0, "min": 0.0},
  21. "inventory_cap_pct": {"type": "float", "default": 0.7, "min": 0.0, "max": 1.0},
  22. "recenter_pct": {"type": "float", "default": 0.05, "min": 0.0, "max": 0.5},
  23. "fee_rate": {"type": "float", "default": 0.0025, "min": 0.0, "max": 0.05},
  24. "trade_sides": {"type": "string", "default": "both"},
  25. "max_notional_per_order": {"type": "float", "default": 0.0, "min": 0.0},
  26. "order_call_delay_ms": {"type": "int", "default": 250, "min": 0, "max": 10000},
  27. "debug_orders": {"type": "bool", "default": True},
  28. "use_all_available": {"type": "bool", "default": True},
  29. }
  30. STATE_SCHEMA = {
  31. "center_price": {"type": "float", "default": 0.0},
  32. "last_price": {"type": "float", "default": 0.0},
  33. "seeded": {"type": "bool", "default": False},
  34. "last_action": {"type": "string", "default": "idle"},
  35. "last_error": {"type": "string", "default": ""},
  36. "orders": {"type": "list", "default": []},
  37. "order_ids": {"type": "list", "default": []},
  38. "debug_log": {"type": "list", "default": []},
  39. "base_available": {"type": "float", "default": 0.0},
  40. "counter_available": {"type": "float", "default": 0.0},
  41. }
  42. def init(self):
  43. return {
  44. "center_price": 0.0,
  45. "last_price": 0.0,
  46. "seeded": False,
  47. "last_action": "idle",
  48. "last_error": "",
  49. "orders": [],
  50. "order_ids": [],
  51. "debug_log": ["init cancel all orders"],
  52. "base_available": 0.0,
  53. "counter_available": 0.0,
  54. }
  55. def _log(self, message: str) -> None:
  56. state = getattr(self, "state", {}) or {}
  57. log = list(state.get("debug_log") or [])
  58. log.append(message)
  59. state["debug_log"] = log[-12:]
  60. self.state = state
  61. def _base_symbol(self) -> str:
  62. return (self.context.base_currency or self.context.market_symbol or "XRP").split("/")[0].upper()
  63. def _market_symbol(self) -> str:
  64. return self.context.market_symbol or f"{self._base_symbol().lower()}usd"
  65. def _mode(self) -> str:
  66. return getattr(self.context, "mode", "active") or "active"
  67. def _price(self) -> float:
  68. payload = self.context.get_price(self._base_symbol())
  69. return float(payload.get("price") or 0.0)
  70. def _regime_snapshot(self) -> dict:
  71. timeframes = ["1d", "4h", "1h", "15m"]
  72. snapshot = {}
  73. for tf in timeframes:
  74. try:
  75. snapshot[tf] = self.context.get_regime(self._base_symbol(), tf)
  76. except Exception as exc:
  77. snapshot[tf] = {"error": str(exc)}
  78. return snapshot
  79. def _grid_step_pct(self) -> float:
  80. base_step = float(self.config.get("grid_step_pct", 0.012) or 0.012)
  81. tf = str(self.config.get("volatility_timeframe", "1h") or "1h")
  82. multiplier = float(self.config.get("volatility_multiplier", 0.5) or 0.0)
  83. min_step = float(self.config.get("grid_step_min_pct", 0.005) or 0.0)
  84. max_step = float(self.config.get("grid_step_max_pct", 0.03) or 1.0)
  85. try:
  86. regime = self.context.get_regime(self._base_symbol(), tf)
  87. short_regime = self.context.get_regime(self._base_symbol(), "15m")
  88. atr_pct = float((regime or {}).get("volatility", {}).get("atr_percent") or 0.0)
  89. short_atr_pct = float((short_regime or {}).get("volatility", {}).get("atr_percent") or 0.0)
  90. atr_pct = max(atr_pct, short_atr_pct)
  91. self.state["regimes"] = self._regime_snapshot()
  92. except Exception as exc:
  93. self._log(f"regime fetch failed: {exc}")
  94. atr_pct = 0.0
  95. adaptive = (atr_pct / 100.0) * multiplier if atr_pct > 0 else base_step
  96. step = adaptive if atr_pct > 0 else base_step
  97. step = max(step, min_step)
  98. step = min(step, max_step)
  99. self.state["grid_step_pct"] = step
  100. self.state["atr_percent"] = atr_pct
  101. return step
  102. def _available_balance(self, asset_code: str) -> float:
  103. try:
  104. info = self.context.get_account_info()
  105. except Exception as exc:
  106. self._log(f"account info failed: {exc}")
  107. return 0.0
  108. balances = info.get("balances") if isinstance(info, dict) else []
  109. if not isinstance(balances, list):
  110. return 0.0
  111. wanted = str(asset_code or "").upper()
  112. for balance in balances:
  113. if not isinstance(balance, dict):
  114. continue
  115. if str(balance.get("asset_code") or "").upper() != wanted:
  116. continue
  117. try:
  118. return float(balance.get("available") if balance.get("available") is not None else balance.get("total") or 0.0)
  119. except Exception:
  120. return 0.0
  121. return 0.0
  122. def _supported_levels(self, side: str, price: float, min_notional: float) -> int:
  123. if min_notional <= 0 or price <= 0:
  124. return 0
  125. safety = 0.995
  126. fee_rate = float(self.config.get("fee_rate", 0.0025) or 0.0)
  127. if side == "buy":
  128. quote = self.context.counter_currency or "USD"
  129. quote_available = self._available_balance(quote)
  130. self.state["counter_available"] = quote_available
  131. usable_notional = quote_available * safety
  132. return max(int(usable_notional / min_notional), 0)
  133. base = self._base_symbol()
  134. base_available = self._available_balance(base)
  135. self.state["base_available"] = base_available
  136. usable_notional = base_available * safety * price / (1 + fee_rate)
  137. return max(int(usable_notional / min_notional), 0)
  138. def _side_allowed(self, side: str) -> bool:
  139. selected = str(self.config.get("trade_sides", "both") or "both").strip().lower()
  140. if selected == "both":
  141. return True
  142. return selected == side
  143. def _suggest_amount(self, side: str, price: float, levels: int, min_notional: float) -> float:
  144. if levels <= 0 or price <= 0:
  145. return 0.0
  146. safety = 0.995
  147. fee_rate = float(self.config.get("fee_rate", 0.0025) or 0.0)
  148. max_notional = float(self.config.get("max_notional_per_order", 0.0) or 0.0)
  149. manual = float(self.config.get("order_size", 0.0) or 0.0)
  150. if side == "buy":
  151. quote = self.context.counter_currency or "USD"
  152. quote_available = self._available_balance(quote)
  153. self.state["counter_available"] = quote_available
  154. spendable_quote = quote_available * safety
  155. amount = spendable_quote / (max(levels, 1) * price * (1 + fee_rate))
  156. else:
  157. base = self._base_symbol()
  158. base_available = self._available_balance(base)
  159. self.state["base_available"] = base_available
  160. spendable_base = (base_available * safety) / (1 + fee_rate)
  161. amount = spendable_base / max(levels, 1)
  162. min_size = (min_notional / price) if price > 0 else 0.0
  163. amount = max(amount, min_size * 1.05)
  164. if max_notional > 0 and price > 0:
  165. amount = min(amount, max_notional / (price * (1 + fee_rate)))
  166. if manual > 0:
  167. amount = min(amount, manual)
  168. return max(amount, 0.0)
  169. def _place_grid(self, center: float) -> None:
  170. mode = self._mode()
  171. levels = int(self.config.get("grid_levels", 6) or 6)
  172. step = self._grid_step_pct()
  173. min_notional = float(self.context.minimum_order_value or 0.0)
  174. market = self._market_symbol()
  175. orders = []
  176. order_ids = []
  177. def _capture_order_id(result):
  178. if isinstance(result, dict):
  179. return result.get("bitstamp_order_id") or result.get("order_id") or result.get("id") or result.get("client_order_id")
  180. return None
  181. buy_levels = min(levels, self._supported_levels("buy", center, min_notional)) if (mode == "active" and self._side_allowed("buy")) else (levels if self._side_allowed("buy") else 0)
  182. sell_levels = min(levels, self._supported_levels("sell", center, min_notional)) if (mode == "active" and self._side_allowed("sell")) else (levels if self._side_allowed("sell") else 0)
  183. buy_amount = self._suggest_amount("buy", center, max(buy_levels, 1), min_notional)
  184. sell_amount = self._suggest_amount("sell", center, max(sell_levels, 1), min_notional)
  185. for i in range(1, levels + 1):
  186. buy_price = round(center * (1 - (step * i)), 8)
  187. sell_price = round(center * (1 + (step * i)), 8)
  188. if mode != "active":
  189. orders.append({"side": "buy", "price": buy_price, "amount": buy_amount, "result": {"simulated": True}})
  190. orders.append({"side": "sell", "price": sell_price, "amount": sell_amount, "result": {"simulated": True}})
  191. self._log(f"plan level {i}: buy {buy_price} amount {buy_amount:.6g} / sell {sell_price} amount {sell_amount:.6g}")
  192. continue
  193. if i > buy_levels and i > sell_levels:
  194. self._log(f"skip level {i}: no capacity on either side")
  195. continue
  196. min_size_buy = (min_notional / buy_price) if buy_price > 0 else 0.0
  197. min_size_sell = (min_notional / sell_price) if sell_price > 0 else 0.0
  198. try:
  199. if i <= buy_levels and buy_amount >= min_size_buy:
  200. buy = self.context.place_order(side="buy", order_type="limit", amount=buy_amount, price=buy_price, market=market)
  201. orders.append({"side": "buy", "price": buy_price, "amount": buy_amount, "result": buy})
  202. buy_id = _capture_order_id(buy)
  203. if buy_id is not None:
  204. order_ids.append(str(buy_id))
  205. if i <= sell_levels and sell_amount >= min_size_sell:
  206. sell = self.context.place_order(side="sell", order_type="limit", amount=sell_amount, price=sell_price, market=market)
  207. orders.append({"side": "sell", "price": sell_price, "amount": sell_amount, "result": sell})
  208. sell_id = _capture_order_id(sell)
  209. if sell_id is not None:
  210. order_ids.append(str(sell_id))
  211. self._log(f"seed level {i}: buy {buy_price} amount {buy_amount:.6g} / sell {sell_price} amount {sell_amount:.6g}")
  212. except Exception as exc: # best effort for first draft
  213. self.state["last_error"] = str(exc)
  214. self._log(f"seed level {i} failed: {exc}")
  215. continue
  216. delay = max(int(self.config.get("order_call_delay_ms", 250) or 0), 0) / 1000.0
  217. if delay > 0:
  218. time.sleep(delay)
  219. self.state["orders"] = orders
  220. self.state["order_ids"] = order_ids
  221. self.state["last_action"] = "seeded grid"
  222. def _cancel_orders(self, order_ids) -> None:
  223. for order_id in order_ids or []:
  224. self._log(f"dropping stale order {order_id} from state")
  225. def on_tick(self, tick):
  226. price = self._price()
  227. self.state["last_price"] = price
  228. self.state["last_error"] = ""
  229. try:
  230. open_orders = self.context.get_open_orders()
  231. live_ids = []
  232. if isinstance(open_orders, list):
  233. for order in open_orders:
  234. if isinstance(order, dict):
  235. live_ids.append(str(order.get("bitstamp_order_id") or order.get("order_id") or order.get("id") or order.get("client_order_id") or ""))
  236. live_ids = [oid for oid in live_ids if oid]
  237. open_order_count = len(live_ids)
  238. expected_ids = [str(oid) for oid in (self.state.get("order_ids") or []) if oid]
  239. stale_ids = [oid for oid in live_ids if oid not in expected_ids]
  240. missing_ids = [oid for oid in expected_ids if oid not in live_ids]
  241. except Exception as exc:
  242. open_order_count = -1
  243. live_ids = []
  244. expected_ids = []
  245. stale_ids = []
  246. missing_ids = []
  247. self.state["last_error"] = str(exc)
  248. self._log(f"open orders check failed: {exc}")
  249. # Workaround: after a reset, trust the fresh strategy state first.
  250. # This prevents stale exec-mcp records from blocking the next clean test.
  251. if not (self.state.get("order_ids") or []):
  252. live_ids = []
  253. open_order_count = 0
  254. expected_ids = []
  255. stale_ids = []
  256. missing_ids = []
  257. self.state["open_order_count"] = open_order_count
  258. mode = self._mode()
  259. if mode != "active":
  260. if not self.state.get("seeded") or not self.state.get("center_price"):
  261. self.state["center_price"] = price
  262. self._place_grid(price)
  263. self.state["seeded"] = True
  264. self._log(f"planned grid at {price}")
  265. return {"action": "plan", "price": price}
  266. center = float(self.state.get("center_price") or price)
  267. recenter_pct = float(self.config.get("recenter_pct", 0.05) or 0.05)
  268. deviation = abs(price - center) / center if center else 0.0
  269. if deviation >= recenter_pct:
  270. self.state["center_price"] = price
  271. self._place_grid(price)
  272. self._log(f"planned recenter to {price}")
  273. return {"action": "plan", "price": price, "deviation": deviation}
  274. self.state["last_action"] = "observe monitor"
  275. self._log(f"observe at {price} dev {deviation:.4f}")
  276. return {"action": "observe", "price": price, "deviation": deviation}
  277. if stale_ids:
  278. self._log(f"stale live orders: {stale_ids}")
  279. self._cancel_orders(stale_ids)
  280. live_ids = [oid for oid in live_ids if oid not in stale_ids]
  281. if missing_ids:
  282. self._log(f"missing tracked orders: {missing_ids}")
  283. self.state["order_ids"] = live_ids
  284. if not self.state.get("seeded") or not self.state.get("center_price"):
  285. self.state["center_price"] = price
  286. self._place_grid(price)
  287. self.state["seeded"] = True
  288. mode = self._mode()
  289. self._log(f"{'seeded' if mode == 'active' else 'planned'} grid at {price}")
  290. return {"action": "seed" if mode == "active" else "plan", "price": price}
  291. if open_order_count == 0 or (expected_ids and not set(expected_ids).intersection(set(live_ids))):
  292. self._log("no open orders, reseeding grid")
  293. self.state["center_price"] = price
  294. self._place_grid(price)
  295. mode = self._mode()
  296. self.state["last_action"] = "reseeded" if mode == "active" else f"{mode} monitor"
  297. return {"action": "reseed" if mode == "active" else "plan", "price": price}
  298. center = float(self.state.get("center_price") or price)
  299. recenter_pct = float(self.config.get("recenter_pct", 0.05) or 0.05)
  300. deviation = abs(price - center) / center if center else 0.0
  301. if deviation >= recenter_pct:
  302. try:
  303. self.context.cancel_all_orders()
  304. except Exception as exc:
  305. self.state["last_error"] = str(exc)
  306. self.state["center_price"] = price
  307. self._place_grid(price)
  308. mode = self._mode()
  309. self.state["last_action"] = "recentered" if mode == "active" else f"{mode} monitor"
  310. self._log(f"recentered grid to {price}")
  311. return {"action": "recenter" if mode == "active" else "plan", "price": price, "deviation": deviation}
  312. mode = self._mode()
  313. self.state["last_action"] = "hold" if mode == "active" else f"{mode} monitor"
  314. self._log(f"hold at {price} dev {deviation:.4f}")
  315. return {"action": "hold" if mode == "active" else "plan", "price": price, "deviation": deviation}
  316. def render(self):
  317. return {
  318. "widgets": [
  319. {"type": "metric", "label": "market", "value": self._market_symbol()},
  320. {"type": "metric", "label": "center", "value": round(float(self.state.get("center_price") or 0.0), 6)},
  321. {"type": "metric", "label": "last price", "value": round(float(self.state.get("last_price") or 0.0), 6)},
  322. {"type": "metric", "label": "state", "value": self.state.get("last_action", "idle")},
  323. {"type": "metric", "label": "orders", "value": len(self.state.get("orders") or [])},
  324. {"type": "metric", "label": "open orders", "value": self.state.get("open_order_count", 0)},
  325. {"type": "metric", "label": "ATR %", "value": round(float(self.state.get("atr_percent") or 0.0), 4)},
  326. {"type": "metric", "label": "grid step %", "value": round(float(self.state.get("grid_step_pct") or 0.0) * 100.0, 4)},
  327. {"type": "metric", "label": "1d", "value": ((self.state.get('regimes') or {}).get('1d') or {}).get('trend', {}).get('state', 'n/a')},
  328. {"type": "metric", "label": "4h", "value": ((self.state.get('regimes') or {}).get('4h') or {}).get('trend', {}).get('state', 'n/a')},
  329. {"type": "metric", "label": "1h", "value": ((self.state.get('regimes') or {}).get('1h') or {}).get('trend', {}).get('state', 'n/a')},
  330. {"type": "metric", "label": "15m", "value": ((self.state.get('regimes') or {}).get('15m') or {}).get('trend', {}).get('state', 'n/a')},
  331. {"type": "metric", "label": f"{self._base_symbol()} avail", "value": round(float(self.state.get("base_available") or 0.0), 8)},
  332. {"type": "metric", "label": f"{self.context.counter_currency or 'USD'} avail", "value": round(float(self.state.get("counter_available") or 0.0), 8)},
  333. {"type": "text", "label": "error", "value": self.state.get("last_error", "") or "none"},
  334. {"type": "log", "label": "debug log", "lines": self.state.get("debug_log") or []},
  335. ]
  336. }