| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176 |
- from __future__ import annotations
- import time
- from dataclasses import dataclass
- from typing import Any, Optional
- import logging
- from .config import DB_PATH, METALS_CANDLE_RETENTION_DAYS, METALS_PAIRS, POLL_INTERVAL_SECONDS
- from .storage import init_db, prune_candles_older_than, upsert_candle, latest_candles
- from .swissquote import SwissquoteClient
- TIMEFRAME_SECONDS = {
- "5m": 300,
- "15m": 900,
- "1h": 3600,
- "4h": 14400,
- "1d": 86400,
- }
- logger = logging.getLogger(__name__)
- @dataclass
- class CandleState:
- symbol: str
- timeframe: str
- start_ts: int
- open: float
- high: float
- low: float
- close: float
- def update(self, price: float) -> None:
- self.high = max(self.high, price)
- self.low = min(self.low, price)
- self.close = price
- def to_row(self) -> dict[str, Any]:
- secs = TIMEFRAME_SECONDS.get(self.timeframe, 300)
- return {
- "symbol": self.symbol,
- "timeframe": self.timeframe,
- "open": self.open,
- "high": self.high,
- "low": self.low,
- "close": self.close,
- "start_ts": self.start_ts,
- "end_ts": self.start_ts + secs * 1000,
- }
- def _bucket_start(ts_ms: int, timeframe: str) -> int:
- secs = TIMEFRAME_SECONDS.get(timeframe, 300)
- return (ts_ms // (secs * 1000)) * (secs * 1000)
- def _derive_higher_timeframe_candles(
- symbol: str,
- higher_timeframe: str,
- lower_timeframe: str,
- db_path: str,
- ) -> Optional[CandleState]:
- lower_secs = TIMEFRAME_SECONDS.get(lower_timeframe, 300)
- higher_secs = TIMEFRAME_SECONDS.get(higher_timeframe, 3600)
-
- needed = higher_secs // lower_secs
- if needed <= 1:
- return None
-
- candles = latest_candles(db_path, symbol, lower_timeframe, limit=needed)
- if len(candles) < needed:
- return None
-
- first = candles[0]
- last = candles[-1]
-
- # Determine which higher timeframe bucket the FIRST candle belongs to
- # This is the bucket we're trying to fill
- higher_start = _bucket_start(first["start_ts"], higher_timeframe)
- higher_end = higher_start + higher_secs * 1000
-
- # Check if the span of lower candles fits within this higher bucket
- # The last lower candle must end at or before the higher bucket ends
- last_end = last["end_ts"]
- if last_end > higher_end:
- # The lower candles span across higher boundaries
- # Can't derive a complete higher candle yet
- return None
-
- # Check if we have enough lower candles that start within this higher bucket
- # (i.e., the first lower candle starts at or after the higher bucket start)
- if first["start_ts"] < higher_start:
- # First candle starts before the higher bucket - can't derive
- return None
-
- opens = [c["open"] for c in candles]
- highs = [c["high"] for c in candles]
- lows = [c["low"] for c in candles]
- closes = [c["close"] for c in candles]
- return CandleState(
- symbol=symbol,
- timeframe=higher_timeframe,
- start_ts=higher_start,
- open=opens[0],
- high=max(highs),
- low=min(lows),
- close=closes[-1],
- )
- def _finalize_candle(state: CandleState, db_path: str) -> None:
- upsert_candle(db_path, state.to_row())
- class CandlePoller:
- def __init__(self) -> None:
- self.client = SwissquoteClient()
- self.states_5m: dict[str, CandleState] = {}
- self._last_prune_ts = 0.0
- init_db(DB_PATH)
- def step(self) -> None:
- now_ms = int(time.time() * 1000)
-
- for symbol in METALS_PAIRS:
- quote = self.client.fetch_quote(symbol)
- if not quote:
- continue
-
- start_ts = _bucket_start(quote.timestamp, "5m")
- state = self.states_5m.get(symbol)
-
- if state is None or state.start_ts != start_ts:
- if state is not None:
- _finalize_candle(state, DB_PATH)
-
- self.states_5m[symbol] = CandleState(
- symbol=symbol,
- timeframe="5m",
- start_ts=start_ts,
- open=quote.mid,
- high=quote.mid,
- low=quote.mid,
- close=quote.mid,
- )
- else:
- state.update(quote.mid)
-
- for state in self.states_5m.values():
- upsert_candle(DB_PATH, state.to_row())
-
- higher_timeframes = ["15m", "1h", "4h", "1d"]
- for symbol in METALS_PAIRS:
- for higher_tf in higher_timeframes:
- derived = _derive_higher_timeframe_candles(
- symbol, higher_tf, "5m", DB_PATH
- )
- if derived is not None:
- upsert_candle(DB_PATH, derived.to_row())
-
- if now_ms - self._last_prune_ts >= 3600 * 1000:
- prune_candles_older_than(DB_PATH, METALS_CANDLE_RETENTION_DAYS)
- self._last_prune_ts = float(now_ms)
- def flush(self) -> None:
- for state in self.states_5m.values():
- upsert_candle(DB_PATH, state.to_row())
- def run_forever(self) -> None:
- init_db(DB_PATH)
- while True:
- try:
- self.step()
- except Exception as exc:
- logger.exception("metals poller cycle failed: %s", exc)
- time.sleep(POLL_INTERVAL_SECONDS)
|