harness.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334
  1. from __future__ import annotations
  2. from dataclasses import dataclass, asdict
  3. from datetime import datetime, timedelta, timezone
  4. from pathlib import Path
  5. from typing import Any
  6. import json
  7. import sys
  8. from .candles import Candle, load_candles_csv, resample_candles, slice_through, timeframe_seconds
  9. from .indicators import atr, bollinger, ema, macd_histogram, rsi, sma, vwap
  10. def _bootstrap_hermes_imports() -> None:
  11. import sys
  12. root = Path(__file__).resolve().parents[3]
  13. src = root / "src"
  14. if str(src) not in sys.path:
  15. sys.path.insert(0, str(src))
  16. _bootstrap_hermes_imports()
  17. from hermes_mcp.decision_engine import make_decision # noqa: E402
  18. from hermes_mcp.narrative_engine import build_narrative # noqa: E402
  19. from hermes_mcp.state_engine import synthesize_state # noqa: E402
  20. DEFAULT_TIMEFRAMES = ("1m", "5m", "15m", "1h", "4h", "1d")
  21. @dataclass(frozen=True)
  22. class ReplayConfig:
  23. market_symbol: str = "XRPUSD"
  24. base_currency: str = "XRP"
  25. quote_currency: str = "USD"
  26. account_id: str = "sim-account"
  27. fee_rate: float = 0.004
  28. horizon_bars: int = 30
  29. timeframes: tuple[str, ...] = DEFAULT_TIMEFRAMES
  30. base_balance: float = 500.0
  31. quote_balance: float = 500.0
  32. inventory_state: str = "balanced"
  33. rebalance_needed: bool = False
  34. @dataclass(frozen=True)
  35. class ReplayRow:
  36. timestamp: str
  37. close: float
  38. decision_mode: str
  39. decision_action: str
  40. target_strategy: str | None
  41. confidence: float
  42. future_return_pct: float | None
  43. fee_adjusted_future_return_pct: float | None
  44. score: float
  45. reason_summary: str
  46. payload: dict[str, Any]
  47. def _group_by_timeframe(candles: list[Candle], timeframe: str) -> list[Candle]:
  48. if timeframe == "1m":
  49. return candles
  50. return resample_candles(candles, timeframe)
  51. def _window_regime(candles: list[Candle], timeframe: str) -> dict[str, Any] | None:
  52. if not candles:
  53. return None
  54. closes = [c.close for c in candles]
  55. highs = [c.high for c in candles]
  56. lows = [c.low for c in candles]
  57. volumes = [c.volume for c in candles]
  58. ema_fast = ema(closes, 8)
  59. ema_slow = ema(closes, 21)
  60. sma_long = sma(closes, 50)
  61. price = closes[-1]
  62. rsi_value = rsi(closes, 14)
  63. macd_value = macd_histogram(closes)
  64. atr_value = atr(highs, lows, closes, 14)
  65. middle, upper, lower = bollinger(closes, 20, 2.0)
  66. vwap_value = vwap(highs, lows, closes, volumes)
  67. atr_percent = None
  68. if atr_value is not None and price:
  69. atr_percent = (atr_value / price) * 100.0
  70. reversal_direction = "none"
  71. reversal_score = 0.0
  72. if rsi_value is not None:
  73. if rsi_value >= 70:
  74. reversal_direction = "down"
  75. reversal_score = min((rsi_value - 70.0) * 2.0, 100.0)
  76. elif rsi_value <= 30:
  77. reversal_direction = "up"
  78. reversal_score = min((30.0 - rsi_value) * 2.0, 100.0)
  79. if ema_fast is not None and ema_slow is not None:
  80. if ema_fast > ema_slow:
  81. trend_strength = min(((ema_fast - ema_slow) / price) * 100.0 if price else 0.0, 5.0)
  82. trend_direction = "bullish"
  83. else:
  84. trend_strength = min(((ema_slow - ema_fast) / price) * 100.0 if price else 0.0, 5.0)
  85. trend_direction = "bearish"
  86. else:
  87. trend_strength = 0.0
  88. trend_direction = "mixed"
  89. if upper is not None and lower is not None and price:
  90. band_span = max(upper - lower, 1e-9)
  91. band_pos = (price - lower) / band_span
  92. if band_pos >= 0.85:
  93. price_location = "near_upper_band"
  94. elif band_pos <= 0.15:
  95. price_location = "near_lower_band"
  96. elif band_pos >= 0.6:
  97. price_location = "upper_half"
  98. elif band_pos <= 0.4:
  99. price_location = "lower_half"
  100. else:
  101. price_location = "centered"
  102. else:
  103. price_location = "unknown"
  104. regime = {
  105. "timeframe": timeframe,
  106. "price": round(price, 8),
  107. "trend": {
  108. "ema_fast": round(ema_fast, 8) if ema_fast is not None else None,
  109. "ema_slow": round(ema_slow, 8) if ema_slow is not None else None,
  110. "sma_long": round(sma_long, 8) if sma_long is not None else None,
  111. },
  112. "momentum": {
  113. "rsi": round(rsi_value, 4) if rsi_value is not None else None,
  114. "macd_histogram": round(macd_value, 8) if macd_value is not None else None,
  115. },
  116. "volatility": {
  117. "atr": round(atr_value, 8) if atr_value is not None else None,
  118. "atr_percent": round(atr_percent, 8) if atr_percent is not None else None,
  119. },
  120. "bands": {
  121. "bollinger": {
  122. "middle": round(middle, 8) if middle is not None else None,
  123. "upper": round(upper, 8) if upper is not None else None,
  124. "lower": round(lower, 8) if lower is not None else None,
  125. }
  126. },
  127. "vwap": round(vwap_value, 8) if vwap_value is not None else None,
  128. "reversal": {
  129. "direction": reversal_direction,
  130. "score": round(reversal_score, 4),
  131. },
  132. "meta": {
  133. "trend_direction": trend_direction,
  134. "trend_strength": round(trend_strength, 6),
  135. "price_location": price_location,
  136. "candle_count": len(candles),
  137. },
  138. }
  139. return regime
  140. def build_regimes(candles: list[Candle], timeframes: tuple[str, ...] = DEFAULT_TIMEFRAMES) -> list[dict[str, Any]]:
  141. regimes: list[dict[str, Any]] = []
  142. for timeframe in timeframes:
  143. if timeframe == "1m":
  144. tf_candles = candles
  145. else:
  146. tf_candles = _group_by_timeframe(candles, timeframe)
  147. if tf_candles:
  148. regimes.append(_window_regime(tf_candles, timeframe))
  149. return [r for r in regimes if r is not None]
  150. def _wallet_state(config: ReplayConfig, close: float) -> dict[str, Any]:
  151. base_value = config.base_balance * close
  152. quote_value = config.quote_balance
  153. total_value = base_value + quote_value
  154. base_ratio = base_value / total_value if total_value else 0.5
  155. quote_ratio = quote_value / total_value if total_value else 0.5
  156. imbalance = abs(base_ratio - 0.5)
  157. return {
  158. "inventory_state": config.inventory_state,
  159. "rebalance_needed": config.rebalance_needed,
  160. "grid_ready": config.inventory_state == "balanced",
  161. "base_ratio": round(base_ratio, 4),
  162. "quote_ratio": round(quote_ratio, 4),
  163. "imbalance_score": round(imbalance, 4),
  164. }
  165. def _strategies(config: ReplayConfig) -> list[dict[str, Any]]:
  166. return [
  167. {
  168. "id": "grid-1",
  169. "strategy_type": "grid_trader",
  170. "mode": "active",
  171. "account_id": config.account_id,
  172. "market_symbol": config.market_symbol,
  173. "state": {},
  174. "config": {},
  175. },
  176. {
  177. "id": "trend-1",
  178. "strategy_type": "trend_follower",
  179. "mode": "off",
  180. "account_id": config.account_id,
  181. "market_symbol": config.market_symbol,
  182. "state": {},
  183. "config": {"trade_side": "both"},
  184. },
  185. {
  186. "id": "protect-1",
  187. "strategy_type": "exposure_protector",
  188. "mode": "off",
  189. "account_id": config.account_id,
  190. "market_symbol": config.market_symbol,
  191. "state": {},
  192. "config": {},
  193. },
  194. ]
  195. def _future_return(candles: list[Candle], index: int, horizon_bars: int) -> float | None:
  196. future_index = index + horizon_bars
  197. if future_index >= len(candles):
  198. return None
  199. start = candles[index].close
  200. end = candles[future_index].close
  201. if start == 0:
  202. return None
  203. return ((end - start) / start) * 100.0
  204. def _fee_adjusted_return(future_return_pct: float | None, fee_rate: float) -> float | None:
  205. if future_return_pct is None:
  206. return None
  207. return future_return_pct - (fee_rate * 100.0 * 2.0)
  208. def _direction_from_decision(decision_action: str, narrative: dict[str, Any]) -> str | None:
  209. if "trend" in decision_action:
  210. breakout = narrative.get("grid_breakout_pressure") if isinstance(narrative.get("grid_breakout_pressure"), dict) else {}
  211. meso_bias = str(breakout.get("meso_bias") or "")
  212. if meso_bias in {"bullish", "bearish"}:
  213. return meso_bias
  214. stance = str(narrative.get("stance") or "")
  215. if "bullish" in stance:
  216. return "bullish"
  217. if "bearish" in stance:
  218. return "bearish"
  219. return None
  220. def _score_row(decision_action: str, future_return_pct: float | None, fee_adjusted_return_pct: float | None) -> float:
  221. if future_return_pct is None or fee_adjusted_return_pct is None:
  222. return 0.0
  223. if decision_action == "keep_grid":
  224. return 1.0 if abs(future_return_pct) < 0.25 else -abs(fee_adjusted_return_pct) / 10.0
  225. if "trend" in decision_action:
  226. return fee_adjusted_return_pct / 5.0
  227. if "protect" in decision_action or "rebalance" in decision_action:
  228. return max(0.0, 0.5 - abs(future_return_pct) / 10.0)
  229. return future_return_pct / 10.0
  230. def run_replay(*, candles: list[Candle], config: ReplayConfig, lookback_bars: int = 2000, progress_every: int = 2000) -> list[ReplayRow]:
  231. if len(candles) < 50:
  232. return []
  233. rows: list[ReplayRow] = []
  234. start_index = max(50, lookback_bars)
  235. total = max(0, (len(candles) - config.horizon_bars) - start_index)
  236. for i, index in enumerate(range(start_index, len(candles) - config.horizon_bars)):
  237. if progress_every and (i % progress_every == 0):
  238. print(f"replay {i}/{total}", file=sys.stderr)
  239. window = candles[max(0, index - lookback_bars + 1) : index + 1]
  240. current = candles[index]
  241. regimes = build_regimes(window, config.timeframes)
  242. concern = {
  243. "id": "sim-concern",
  244. "account_id": config.account_id,
  245. "market_symbol": config.market_symbol,
  246. "base_currency": config.base_currency,
  247. "quote_currency": config.quote_currency,
  248. }
  249. account_info = {
  250. "balances": [
  251. {"asset_code": config.base_currency, "available": config.base_balance},
  252. {"asset_code": config.quote_currency, "available": config.quote_balance},
  253. ]
  254. }
  255. state_payload = synthesize_state(concern=concern, regimes=regimes, account_info=account_info)
  256. narrative = build_narrative(concern=concern, state_payload=state_payload.payload)
  257. wallet_state = _wallet_state(config, current.close)
  258. decision = make_decision(
  259. concern=concern,
  260. narrative_payload=narrative.payload,
  261. wallet_state=wallet_state,
  262. strategies=_strategies(config),
  263. history_window={
  264. "window_seconds": timeframe_seconds("1m") * config.horizon_bars,
  265. "recent_states": [],
  266. },
  267. )
  268. future_return_pct = _future_return(candles, index, config.horizon_bars)
  269. fee_adjusted = _fee_adjusted_return(future_return_pct, config.fee_rate)
  270. score = _score_row(decision.action, future_return_pct, fee_adjusted)
  271. rows.append(
  272. ReplayRow(
  273. timestamp=current.timestamp.isoformat(),
  274. close=current.close,
  275. decision_mode=decision.mode,
  276. decision_action=decision.action,
  277. target_strategy=decision.target_strategy,
  278. confidence=decision.confidence,
  279. future_return_pct=future_return_pct,
  280. fee_adjusted_future_return_pct=fee_adjusted,
  281. score=score,
  282. reason_summary=decision.reason_summary,
  283. payload={
  284. "decision": decision.payload,
  285. "state": state_payload.payload,
  286. "narrative": narrative.payload,
  287. },
  288. )
  289. )
  290. return rows
  291. def rows_to_jsonl(rows: list[ReplayRow]) -> str:
  292. return "\n".join(json.dumps(asdict(row), ensure_ascii=False) for row in rows)