storage.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. from __future__ import annotations
  2. import sqlite3
  3. from datetime import datetime, timezone, timedelta
  4. from pathlib import Path
  5. from typing import Any
  6. SCHEMA = """
  7. CREATE TABLE IF NOT EXISTS candles (
  8. id INTEGER PRIMARY KEY AUTOINCREMENT,
  9. symbol TEXT NOT NULL,
  10. timeframe TEXT NOT NULL,
  11. open REAL NOT NULL,
  12. high REAL NOT NULL,
  13. low REAL NOT NULL,
  14. close REAL NOT NULL,
  15. start_ts INTEGER NOT NULL,
  16. end_ts INTEGER NOT NULL,
  17. UNIQUE(symbol, timeframe, start_ts)
  18. );
  19. """
  20. def connect(db_path: str | Path) -> sqlite3.Connection:
  21. conn = sqlite3.connect(str(db_path))
  22. conn.row_factory = sqlite3.Row
  23. return conn
  24. def init_db(db_path: str | Path) -> None:
  25. path = Path(db_path)
  26. path.parent.mkdir(parents=True, exist_ok=True)
  27. with connect(path) as conn:
  28. conn.executescript(SCHEMA)
  29. conn.execute(
  30. "CREATE UNIQUE INDEX IF NOT EXISTS idx_candles_symbol_timeframe_start_ts ON candles(symbol, timeframe, start_ts)"
  31. )
  32. conn.commit()
  33. def upsert_candle(db_path: str | Path, candle: dict[str, Any]) -> None:
  34. with connect(db_path) as conn:
  35. conn.execute(
  36. """
  37. INSERT INTO candles(symbol, timeframe, open, high, low, close, start_ts, end_ts)
  38. VALUES(?, ?, ?, ?, ?, ?, ?, ?)
  39. ON CONFLICT(symbol, timeframe, start_ts) DO UPDATE SET
  40. open=excluded.open,
  41. high=excluded.high,
  42. low=excluded.low,
  43. close=excluded.close,
  44. end_ts=excluded.end_ts
  45. """,
  46. (
  47. candle["symbol"],
  48. candle["timeframe"],
  49. candle["open"],
  50. candle["high"],
  51. candle["low"],
  52. candle["close"],
  53. candle["start_ts"],
  54. candle["end_ts"],
  55. ),
  56. )
  57. conn.commit()
  58. def latest_candles(db_path: str | Path, symbol: str, timeframe: str, limit: int = 100) -> list[dict[str, Any]]:
  59. with connect(db_path) as conn:
  60. rows = conn.execute(
  61. """
  62. SELECT symbol, timeframe, open, high, low, close, start_ts, end_ts
  63. FROM candles
  64. WHERE symbol = ? AND timeframe = ?
  65. ORDER BY start_ts DESC
  66. LIMIT ?
  67. """,
  68. (symbol, timeframe, limit),
  69. ).fetchall()
  70. return [dict(row) for row in reversed(rows)]
  71. def last_candle(db_path: str | Path, symbol: str, timeframe: str) -> dict[str, Any] | None:
  72. with connect(db_path) as conn:
  73. row = conn.execute(
  74. """
  75. SELECT symbol, timeframe, open, high, low, close, start_ts, end_ts
  76. FROM candles
  77. WHERE symbol = ? AND timeframe = ?
  78. ORDER BY start_ts DESC
  79. LIMIT 1
  80. """,
  81. (symbol, timeframe),
  82. ).fetchone()
  83. return dict(row) if row else None
  84. def stats(db_path: str | Path) -> dict[str, Any]:
  85. with connect(db_path) as conn:
  86. candles = conn.execute("SELECT COUNT(*) AS n FROM candles").fetchone()["n"]
  87. return {"candles": candles}
  88. def prune_candles_older_than(db_path: str | Path, days: int) -> int:
  89. if days <= 0:
  90. return 0
  91. cutoff = int((datetime.now(timezone.utc) - timedelta(days=days)).timestamp() * 1000)
  92. with connect(db_path) as conn:
  93. cursor = conn.execute("DELETE FROM candles WHERE end_ts < ?", (cutoff,))
  94. conn.commit()
  95. return int(cursor.rowcount or 0)