poller.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. from __future__ import annotations
  2. import time
  3. from dataclasses import dataclass
  4. from typing import Any, Optional
  5. import logging
  6. from .config import DB_PATH, METALS_CANDLE_RETENTION_DAYS, METALS_PAIRS, POLL_INTERVAL_SECONDS
  7. from .storage import init_db, prune_candles_older_than, upsert_candle, latest_candles
  8. from .swissquote import SwissquoteClient
  9. TIMEFRAME_SECONDS = {
  10. "5m": 300,
  11. "15m": 900,
  12. "1h": 3600,
  13. "4h": 14400,
  14. "1d": 86400,
  15. }
  16. logger = logging.getLogger(__name__)
  17. @dataclass
  18. class CandleState:
  19. symbol: str
  20. timeframe: str
  21. start_ts: int
  22. open: float
  23. high: float
  24. low: float
  25. close: float
  26. def update(self, price: float) -> None:
  27. self.high = max(self.high, price)
  28. self.low = min(self.low, price)
  29. self.close = price
  30. def to_row(self) -> dict[str, Any]:
  31. secs = TIMEFRAME_SECONDS.get(self.timeframe, 300)
  32. return {
  33. "symbol": self.symbol,
  34. "timeframe": self.timeframe,
  35. "open": self.open,
  36. "high": self.high,
  37. "low": self.low,
  38. "close": self.close,
  39. "start_ts": self.start_ts,
  40. "end_ts": self.start_ts + secs * 1000,
  41. }
  42. def _bucket_start(ts_ms: int, timeframe: str) -> int:
  43. secs = TIMEFRAME_SECONDS.get(timeframe, 300)
  44. return (ts_ms // (secs * 1000)) * (secs * 1000)
  45. def _derive_higher_timeframe_candles(
  46. symbol: str,
  47. higher_timeframe: str,
  48. lower_timeframe: str,
  49. db_path: str,
  50. ) -> Optional[CandleState]:
  51. lower_secs = TIMEFRAME_SECONDS.get(lower_timeframe, 300)
  52. higher_secs = TIMEFRAME_SECONDS.get(higher_timeframe, 3600)
  53. needed = higher_secs // lower_secs
  54. if needed <= 1:
  55. return None
  56. candles = latest_candles(db_path, symbol, lower_timeframe, limit=needed)
  57. if len(candles) < needed:
  58. return None
  59. first = candles[0]
  60. last = candles[-1]
  61. # Determine which higher timeframe bucket the FIRST candle belongs to
  62. # This is the bucket we're trying to fill
  63. higher_start = _bucket_start(first["start_ts"], higher_timeframe)
  64. higher_end = higher_start + higher_secs * 1000
  65. # Check if the span of lower candles fits within this higher bucket
  66. # The last lower candle must end at or before the higher bucket ends
  67. last_end = last["end_ts"]
  68. if last_end > higher_end:
  69. # The lower candles span across higher boundaries
  70. # Can't derive a complete higher candle yet
  71. return None
  72. # Check if we have enough lower candles that start within this higher bucket
  73. # (i.e., the first lower candle starts at or after the higher bucket start)
  74. if first["start_ts"] < higher_start:
  75. # First candle starts before the higher bucket - can't derive
  76. return None
  77. opens = [c["open"] for c in candles]
  78. highs = [c["high"] for c in candles]
  79. lows = [c["low"] for c in candles]
  80. closes = [c["close"] for c in candles]
  81. return CandleState(
  82. symbol=symbol,
  83. timeframe=higher_timeframe,
  84. start_ts=higher_start,
  85. open=opens[0],
  86. high=max(highs),
  87. low=min(lows),
  88. close=closes[-1],
  89. )
  90. def _finalize_candle(state: CandleState, db_path: str) -> None:
  91. upsert_candle(db_path, state.to_row())
  92. class CandlePoller:
  93. def __init__(self) -> None:
  94. self.client = SwissquoteClient()
  95. self.states_5m: dict[str, CandleState] = {}
  96. self._last_prune_ts = 0.0
  97. init_db(DB_PATH)
  98. def step(self) -> None:
  99. now_ms = int(time.time() * 1000)
  100. for symbol in METALS_PAIRS:
  101. quote = self.client.fetch_quote(symbol)
  102. if not quote:
  103. continue
  104. start_ts = _bucket_start(quote.timestamp, "5m")
  105. state = self.states_5m.get(symbol)
  106. if state is None or state.start_ts != start_ts:
  107. if state is not None:
  108. _finalize_candle(state, DB_PATH)
  109. self.states_5m[symbol] = CandleState(
  110. symbol=symbol,
  111. timeframe="5m",
  112. start_ts=start_ts,
  113. open=quote.mid,
  114. high=quote.mid,
  115. low=quote.mid,
  116. close=quote.mid,
  117. )
  118. else:
  119. state.update(quote.mid)
  120. for state in self.states_5m.values():
  121. upsert_candle(DB_PATH, state.to_row())
  122. higher_timeframes = ["15m", "1h", "4h", "1d"]
  123. for symbol in METALS_PAIRS:
  124. for higher_tf in higher_timeframes:
  125. derived = _derive_higher_timeframe_candles(
  126. symbol, higher_tf, "5m", DB_PATH
  127. )
  128. if derived is not None:
  129. upsert_candle(DB_PATH, derived.to_row())
  130. if now_ms - self._last_prune_ts >= 3600 * 1000:
  131. prune_candles_older_than(DB_PATH, METALS_CANDLE_RETENTION_DAYS)
  132. self._last_prune_ts = float(now_ms)
  133. def flush(self) -> None:
  134. for state in self.states_5m.values():
  135. upsert_candle(DB_PATH, state.to_row())
  136. def run_forever(self) -> None:
  137. init_db(DB_PATH)
  138. while True:
  139. try:
  140. self.step()
  141. except Exception as exc:
  142. logger.exception("metals poller cycle failed: %s", exc)
  143. time.sleep(POLL_INTERVAL_SECONDS)