""" SQLite cache layer for ephemeris-mcp. Provides TTL-based caching to avoid redundant computations. Ephemeris data files themselves are ~50-100 MB and loaded by swisseph directly. This cache stores computed RESULTS (positions, events, etc.). """ from __future__ import annotations import json import sqlite3 import time from pathlib import Path from typing import Optional from .config import DB_PATH class EphemerisCache: """Thread-safe TTL cache backed by SQLite.""" def __init__(self, db_path: Path = DB_PATH): self.db_path = db_path self.db_path.parent.mkdir(parents=True, exist_ok=True) self._local = __import__("threading").local() self._init_db() def _get_conn(self) -> sqlite3.Connection: if not hasattr(self._local, "conn") or self._local.conn is None: self._local.conn = sqlite3.connect(str(self.db_path)) self._local.conn.execute("PRAGMA journal_mode=WAL") self._local.conn.execute("PRAGMA busy_timeout=5000") return self._local.conn def _init_db(self) -> None: conn = self._get_conn() conn.executescript(""" CREATE TABLE IF NOT EXISTS cache ( key TEXT PRIMARY KEY, value TEXT NOT NULL, expires_at REAL NOT NULL ); CREATE INDEX IF NOT EXISTS idx_cache_expires ON cache(expires_at); CREATE TABLE IF NOT EXISTS tle_data ( norad_id INTEGER PRIMARY KEY, name TEXT, line1 TEXT NOT NULL, line2 TEXT NOT NULL, last_fetched REAL NOT NULL ); """) conn.commit() def get(self, key: str) -> Optional[dict]: """Retrieve a cached value. Returns None if expired or missing.""" conn = self._get_conn() row = conn.execute( "SELECT value, expires_at FROM cache WHERE key = ?", (key,) ).fetchone() if row is None: return None if row[1] < time.time(): conn.execute("DELETE FROM cache WHERE key = ?", (key,)) conn.commit() return None return __import__("json").loads(row[0]) def set(self, key: str, value: dict, ttl: float) -> None: """Store a value with TTL in seconds.""" conn = self._get_conn() self.prune() conn.execute( "INSERT OR REPLACE INTO cache (key, value, expires_at) VALUES (?, ?, ?)", (key, json.dumps(value), time.time() + ttl), ) conn.commit() def delete(self, key: str) -> None: """Remove a single cache entry.""" conn = self._get_conn() conn.execute("DELETE FROM cache WHERE key = ?", (key,)) conn.commit() def prune(self) -> int: """Remove all expired entries. Returns count of deleted rows.""" conn = self._get_conn() deleted = conn.execute( "DELETE FROM cache WHERE expires_at < ?", (time.time(),) ).rowcount conn.commit() return deleted # --- TLE-specific methods --- def get_tle(self, norad_id: int) -> Optional[tuple[str, str]]: """Retrieve cached TLE by NORAD ID.""" conn = self._get_conn() row = conn.execute( "SELECT line1, line2 FROM tle_data WHERE norad_id = ?", (norad_id,) ).fetchone() if row is None: return None return (row[0], row[1]) def set_tle(self, norad_id: int, name: str, line1: str, line2: str) -> None: """Cache a TLE entry.""" conn = self._get_conn() conn.execute( "INSERT OR REPLACE INTO tle_data (norad_id, name, line1, line2, last_fetched) " "VALUES (?, ?, ?, ?, ?)", (norad_id, name, line1, line2, time.time()), ) conn.commit() def get_stale_tles(self, max_age_seconds: float = 3600) -> list[int]: """Get NORAD IDs of TLE entries older than max_age_seconds.""" conn = self._get_conn() cutoff = time.time() - max_age_seconds rows = conn.execute( "SELECT norad_id FROM tle_data WHERE last_fetched < ?", (cutoff,) ).fetchall() return [r[0] for r in rows] # Module-level singleton _cache: Optional[EphemerisCache] = None def get_cache() -> EphemerisCache: global _cache if _cache is None: _cache = EphemerisCache() return _cache def cache_key(tool: str, **kwargs) -> str: """Generate a deterministic cache key from tool name and params.""" parts = [tool] for k in sorted(kwargs): v = kwargs[k] if isinstance(v, float): v = f"{v:.6f}" parts.append(f"{k}={v}") return "|".join(parts)