storage.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. """
  2. SQLite cache layer for ephemeris-mcp.
  3. Provides TTL-based caching to avoid redundant computations.
  4. Ephemeris data files themselves are ~50-100 MB and loaded by swisseph directly.
  5. This cache stores computed RESULTS (positions, events, etc.).
  6. """
  7. from __future__ import annotations
  8. import json
  9. import sqlite3
  10. import time
  11. from pathlib import Path
  12. from typing import Optional
  13. from .config import DB_PATH
  14. class EphemerisCache:
  15. """Thread-safe TTL cache backed by SQLite."""
  16. def __init__(self, db_path: Path = DB_PATH):
  17. self.db_path = db_path
  18. self.db_path.parent.mkdir(parents=True, exist_ok=True)
  19. self._local = __import__("threading").local()
  20. self._init_db()
  21. def _get_conn(self) -> sqlite3.Connection:
  22. if not hasattr(self._local, "conn") or self._local.conn is None:
  23. self._local.conn = sqlite3.connect(str(self.db_path))
  24. self._local.conn.execute("PRAGMA journal_mode=WAL")
  25. self._local.conn.execute("PRAGMA busy_timeout=5000")
  26. return self._local.conn
  27. def _init_db(self) -> None:
  28. conn = self._get_conn()
  29. conn.executescript("""
  30. CREATE TABLE IF NOT EXISTS cache (
  31. key TEXT PRIMARY KEY,
  32. value TEXT NOT NULL,
  33. expires_at REAL NOT NULL
  34. );
  35. CREATE INDEX IF NOT EXISTS idx_cache_expires ON cache(expires_at);
  36. CREATE TABLE IF NOT EXISTS tle_data (
  37. norad_id INTEGER PRIMARY KEY,
  38. name TEXT,
  39. line1 TEXT NOT NULL,
  40. line2 TEXT NOT NULL,
  41. last_fetched REAL NOT NULL
  42. );
  43. """)
  44. conn.commit()
  45. def get(self, key: str) -> Optional[dict]:
  46. """Retrieve a cached value. Returns None if expired or missing."""
  47. conn = self._get_conn()
  48. row = conn.execute(
  49. "SELECT value, expires_at FROM cache WHERE key = ?", (key,)
  50. ).fetchone()
  51. if row is None:
  52. return None
  53. if row[1] < time.time():
  54. conn.execute("DELETE FROM cache WHERE key = ?", (key,))
  55. conn.commit()
  56. return None
  57. return __import__("json").loads(row[0])
  58. def set(self, key: str, value: dict, ttl: float) -> None:
  59. """Store a value with TTL in seconds."""
  60. conn = self._get_conn()
  61. self.prune()
  62. conn.execute(
  63. "INSERT OR REPLACE INTO cache (key, value, expires_at) VALUES (?, ?, ?)",
  64. (key, json.dumps(value), time.time() + ttl),
  65. )
  66. conn.commit()
  67. def delete(self, key: str) -> None:
  68. """Remove a single cache entry."""
  69. conn = self._get_conn()
  70. conn.execute("DELETE FROM cache WHERE key = ?", (key,))
  71. conn.commit()
  72. def prune(self) -> int:
  73. """Remove all expired entries. Returns count of deleted rows."""
  74. conn = self._get_conn()
  75. deleted = conn.execute(
  76. "DELETE FROM cache WHERE expires_at < ?", (time.time(),)
  77. ).rowcount
  78. conn.commit()
  79. return deleted
  80. # --- TLE-specific methods ---
  81. def get_tle(self, norad_id: int) -> Optional[tuple[str, str]]:
  82. """Retrieve cached TLE by NORAD ID."""
  83. conn = self._get_conn()
  84. row = conn.execute(
  85. "SELECT line1, line2 FROM tle_data WHERE norad_id = ?", (norad_id,)
  86. ).fetchone()
  87. if row is None:
  88. return None
  89. return (row[0], row[1])
  90. def set_tle(self, norad_id: int, name: str, line1: str, line2: str) -> None:
  91. """Cache a TLE entry."""
  92. conn = self._get_conn()
  93. conn.execute(
  94. "INSERT OR REPLACE INTO tle_data (norad_id, name, line1, line2, last_fetched) "
  95. "VALUES (?, ?, ?, ?, ?)",
  96. (norad_id, name, line1, line2, time.time()),
  97. )
  98. conn.commit()
  99. def get_stale_tles(self, max_age_seconds: float = 3600) -> list[int]:
  100. """Get NORAD IDs of TLE entries older than max_age_seconds."""
  101. conn = self._get_conn()
  102. cutoff = time.time() - max_age_seconds
  103. rows = conn.execute(
  104. "SELECT norad_id FROM tle_data WHERE last_fetched < ?", (cutoff,)
  105. ).fetchall()
  106. return [r[0] for r in rows]
  107. # Module-level singleton
  108. _cache: Optional[EphemerisCache] = None
  109. def get_cache() -> EphemerisCache:
  110. global _cache
  111. if _cache is None:
  112. _cache = EphemerisCache()
  113. return _cache
  114. def cache_key(tool: str, **kwargs) -> str:
  115. """Generate a deterministic cache key from tool name and params."""
  116. parts = [tool]
  117. for k in sorted(kwargs):
  118. v = kwargs[k]
  119. if isinstance(v, float):
  120. v = f"{v:.6f}"
  121. parts.append(f"{k}={v}")
  122. return "|".join(parts)