| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146 |
- """
- SQLite cache layer for ephemeris-mcp.
- Provides TTL-based caching to avoid redundant computations.
- Ephemeris data files themselves are ~50-100 MB and loaded by swisseph directly.
- This cache stores computed RESULTS (positions, events, etc.).
- """
- from __future__ import annotations
- import json
- import sqlite3
- import time
- from pathlib import Path
- from typing import Optional
- from .config import DB_PATH
- class EphemerisCache:
- """Thread-safe TTL cache backed by SQLite."""
- def __init__(self, db_path: Path = DB_PATH):
- self.db_path = db_path
- self.db_path.parent.mkdir(parents=True, exist_ok=True)
- self._local = __import__("threading").local()
- self._init_db()
- def _get_conn(self) -> sqlite3.Connection:
- if not hasattr(self._local, "conn") or self._local.conn is None:
- self._local.conn = sqlite3.connect(str(self.db_path))
- self._local.conn.execute("PRAGMA journal_mode=WAL")
- self._local.conn.execute("PRAGMA busy_timeout=5000")
- return self._local.conn
- def _init_db(self) -> None:
- conn = self._get_conn()
- conn.executescript("""
- CREATE TABLE IF NOT EXISTS cache (
- key TEXT PRIMARY KEY,
- value TEXT NOT NULL,
- expires_at REAL NOT NULL
- );
- CREATE INDEX IF NOT EXISTS idx_cache_expires ON cache(expires_at);
- CREATE TABLE IF NOT EXISTS tle_data (
- norad_id INTEGER PRIMARY KEY,
- name TEXT,
- line1 TEXT NOT NULL,
- line2 TEXT NOT NULL,
- last_fetched REAL NOT NULL
- );
- """)
- conn.commit()
- def get(self, key: str) -> Optional[dict]:
- """Retrieve a cached value. Returns None if expired or missing."""
- conn = self._get_conn()
- row = conn.execute(
- "SELECT value, expires_at FROM cache WHERE key = ?", (key,)
- ).fetchone()
- if row is None:
- return None
- if row[1] < time.time():
- conn.execute("DELETE FROM cache WHERE key = ?", (key,))
- conn.commit()
- return None
- return __import__("json").loads(row[0])
- def set(self, key: str, value: dict, ttl: float) -> None:
- """Store a value with TTL in seconds."""
- conn = self._get_conn()
- self.prune()
- conn.execute(
- "INSERT OR REPLACE INTO cache (key, value, expires_at) VALUES (?, ?, ?)",
- (key, json.dumps(value), time.time() + ttl),
- )
- conn.commit()
- def delete(self, key: str) -> None:
- """Remove a single cache entry."""
- conn = self._get_conn()
- conn.execute("DELETE FROM cache WHERE key = ?", (key,))
- conn.commit()
- def prune(self) -> int:
- """Remove all expired entries. Returns count of deleted rows."""
- conn = self._get_conn()
- deleted = conn.execute(
- "DELETE FROM cache WHERE expires_at < ?", (time.time(),)
- ).rowcount
- conn.commit()
- return deleted
- # --- TLE-specific methods ---
- def get_tle(self, norad_id: int) -> Optional[tuple[str, str]]:
- """Retrieve cached TLE by NORAD ID."""
- conn = self._get_conn()
- row = conn.execute(
- "SELECT line1, line2 FROM tle_data WHERE norad_id = ?", (norad_id,)
- ).fetchone()
- if row is None:
- return None
- return (row[0], row[1])
- def set_tle(self, norad_id: int, name: str, line1: str, line2: str) -> None:
- """Cache a TLE entry."""
- conn = self._get_conn()
- conn.execute(
- "INSERT OR REPLACE INTO tle_data (norad_id, name, line1, line2, last_fetched) "
- "VALUES (?, ?, ?, ?, ?)",
- (norad_id, name, line1, line2, time.time()),
- )
- conn.commit()
- def get_stale_tles(self, max_age_seconds: float = 3600) -> list[int]:
- """Get NORAD IDs of TLE entries older than max_age_seconds."""
- conn = self._get_conn()
- cutoff = time.time() - max_age_seconds
- rows = conn.execute(
- "SELECT norad_id FROM tle_data WHERE last_fetched < ?", (cutoff,)
- ).fetchall()
- return [r[0] for r in rows]
- # Module-level singleton
- _cache: Optional[EphemerisCache] = None
- def get_cache() -> EphemerisCache:
- global _cache
- if _cache is None:
- _cache = EphemerisCache()
- return _cache
- def cache_key(tool: str, **kwargs) -> str:
- """Generate a deterministic cache key from tool name and params."""
- parts = [tool]
- for k in sorted(kwargs):
- v = kwargs[k]
- if isinstance(v, float):
- v = f"{v:.6f}"
- parts.append(f"{k}={v}")
- return "|".join(parts)
|