sqlite_store.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517
  1. from __future__ import annotations
  2. import json
  3. import sqlite3
  4. from dataclasses import dataclass
  5. from datetime import datetime, timezone, timedelta
  6. from pathlib import Path
  7. from typing import Any
  8. from urllib.parse import urlparse
  9. from email.utils import parsedate_to_datetime
  10. from news_mcp.entity_normalize import normalize_entities
  11. from news_mcp.trends_resolution import resolve_entity_via_trends
  12. @dataclass
  13. class ClusterRow:
  14. cluster_id: str
  15. topic: str
  16. payload: dict
  17. updated_at: datetime
  18. META_LAST_PRUNE_AT = "last_prune_at"
  19. def _article_key(article: dict[str, Any]) -> str:
  20. url = str(article.get("url") or "").strip()
  21. if not url:
  22. return str(article.get("title") or "")
  23. try:
  24. parsed = urlparse(url)
  25. parts = [p for p in parsed.path.split("/") if p]
  26. if parts:
  27. return parts[-1]
  28. except Exception:
  29. pass
  30. return url
  31. def _dedup_articles(articles: list[dict[str, Any]]) -> list[dict[str, Any]]:
  32. seen: set[str] = set()
  33. out: list[dict[str, Any]] = []
  34. for article in articles:
  35. key = _article_key(article)
  36. if key in seen:
  37. continue
  38. seen.add(key)
  39. out.append(article)
  40. return out
  41. def _has_valid_entity_resolutions(resolutions: Any, entities: list[str]) -> bool:
  42. if not isinstance(resolutions, list):
  43. return False
  44. if len(resolutions) != len(entities):
  45. return False
  46. for res in resolutions:
  47. if not isinstance(res, dict):
  48. return False
  49. if not res.get("normalized") or not res.get("canonical_label"):
  50. return False
  51. return True
  52. def sanitize_cluster_payload(cluster: dict[str, Any], *, include_resolutions: bool = True) -> dict[str, Any]:
  53. """Normalize cluster payload so every stored payload is internally consistent."""
  54. out = dict(cluster)
  55. raw_articles = out.get("articles", []) or []
  56. articles = [a for a in raw_articles if isinstance(a, dict)]
  57. out["articles"] = _dedup_articles(articles)
  58. raw_entities = out.get("entities", []) or []
  59. entities = normalize_entities(raw_entities)
  60. out["entities"] = entities
  61. if not include_resolutions:
  62. return out
  63. resolutions = out.get("entityResolutions", None)
  64. if entities:
  65. if not _has_valid_entity_resolutions(resolutions, entities):
  66. out["entityResolutions"] = [resolve_entity_via_trends(e) for e in entities]
  67. else:
  68. # Keep the empty case explicit and stable.
  69. out["entityResolutions"] = []
  70. return out
  71. class SQLiteClusterStore:
  72. def __init__(self, db_path: str | Path):
  73. self.db_path = str(db_path)
  74. self._init_db()
  75. def _conn(self) -> sqlite3.Connection:
  76. return sqlite3.connect(self.db_path)
  77. def _init_db(self) -> None:
  78. Path(self.db_path).parent.mkdir(parents=True, exist_ok=True)
  79. with self._conn() as conn:
  80. conn.execute("PRAGMA journal_mode=WAL")
  81. conn.execute("PRAGMA synchronous=NORMAL")
  82. conn.execute("PRAGMA busy_timeout=5000")
  83. conn.execute(
  84. """
  85. CREATE TABLE IF NOT EXISTS clusters (
  86. cluster_id TEXT PRIMARY KEY,
  87. topic TEXT NOT NULL,
  88. payload TEXT NOT NULL,
  89. updated_at TEXT NOT NULL,
  90. summary_payload TEXT,
  91. summary_updated_at TEXT
  92. )
  93. """
  94. )
  95. # If the table already exists without the summary columns,
  96. # add them (SQLite-friendly incremental migrations).
  97. for col_def in [
  98. "summary_payload TEXT",
  99. "summary_updated_at TEXT",
  100. ]:
  101. col = col_def.split()[0]
  102. try:
  103. conn.execute(f"ALTER TABLE clusters ADD COLUMN {col_def}")
  104. except sqlite3.OperationalError:
  105. pass
  106. conn.execute(
  107. "CREATE INDEX IF NOT EXISTS idx_clusters_topic ON clusters(topic)"
  108. )
  109. conn.execute(
  110. "CREATE INDEX IF NOT EXISTS idx_clusters_updated_at ON clusters(updated_at)"
  111. )
  112. try:
  113. cur = conn.execute("PRAGMA table_info(entity_metadata)")
  114. cols = [row[1] for row in cur.fetchall()]
  115. if cols and "entity_id" not in cols:
  116. conn.execute("DROP TABLE entity_metadata")
  117. except sqlite3.OperationalError:
  118. pass
  119. conn.execute(
  120. """
  121. CREATE TABLE IF NOT EXISTS entity_metadata (
  122. entity_id TEXT PRIMARY KEY,
  123. normalized_label TEXT NOT NULL,
  124. canonical_label TEXT,
  125. mid TEXT,
  126. sources_json TEXT,
  127. updated_at TEXT,
  128. last_requested_at TEXT
  129. )
  130. """
  131. )
  132. conn.execute(
  133. "CREATE UNIQUE INDEX IF NOT EXISTS idx_entity_metadata_mid ON entity_metadata(mid) WHERE mid IS NOT NULL"
  134. )
  135. conn.execute(
  136. """
  137. CREATE TABLE IF NOT EXISTS feed_state (
  138. feed_key TEXT PRIMARY KEY,
  139. last_hash TEXT NOT NULL,
  140. updated_at TEXT NOT NULL
  141. )
  142. """
  143. )
  144. conn.execute(
  145. """
  146. CREATE TABLE IF NOT EXISTS meta (
  147. key TEXT PRIMARY KEY,
  148. value TEXT NOT NULL
  149. )
  150. """
  151. )
  152. def upsert_clusters(self, clusters: list[dict], topic: str) -> None:
  153. now = datetime.now(timezone.utc)
  154. with self._conn() as conn:
  155. for c in clusters:
  156. c = sanitize_cluster_payload(c)
  157. cluster_id = c["cluster_id"]
  158. payload = json.dumps(c, ensure_ascii=False)
  159. conn.execute(
  160. "INSERT INTO clusters(cluster_id, topic, payload, updated_at) VALUES(?,?,?,?) "
  161. "ON CONFLICT(cluster_id) DO UPDATE SET topic=excluded.topic, payload=excluded.payload, updated_at=excluded.updated_at",
  162. (cluster_id, topic, payload, now.isoformat()),
  163. )
  164. def upsert_cluster_summary(
  165. self,
  166. cluster_id: str,
  167. summary_payload: dict,
  168. ) -> None:
  169. now = datetime.now(timezone.utc).isoformat()
  170. with self._conn() as conn:
  171. conn.execute(
  172. "INSERT INTO clusters(cluster_id, topic, payload, updated_at, summary_payload, summary_updated_at) "
  173. "VALUES(?,?,?,?,?,?) "
  174. "ON CONFLICT(cluster_id) DO UPDATE SET "
  175. "summary_payload=excluded.summary_payload, summary_updated_at=excluded.summary_updated_at",
  176. (
  177. cluster_id,
  178. "", # topic not used for update
  179. json.dumps({}, ensure_ascii=False),
  180. now,
  181. json.dumps(summary_payload, ensure_ascii=False),
  182. now,
  183. ),
  184. )
  185. def get_cluster_summary(self, cluster_id: str, ttl_hours: float) -> dict | None:
  186. cutoff = datetime.now(timezone.utc) - timedelta(hours=ttl_hours)
  187. cutoff_iso = cutoff.isoformat()
  188. with self._conn() as conn:
  189. cur = conn.execute(
  190. "SELECT summary_payload, summary_updated_at FROM clusters "
  191. "WHERE cluster_id=? AND summary_updated_at >= ?",
  192. (cluster_id, cutoff_iso),
  193. )
  194. row = cur.fetchone()
  195. if not row or not row[0]:
  196. return None
  197. return json.loads(row[0])
  198. def get_latest_clusters(self, topic: str, ttl_hours: float, limit: int) -> list[dict]:
  199. """Return newest clusters by *their own* timestamp.
  200. Filtering/sorting by the DB row's `updated_at` can drift away from the
  201. actual event time in `payload.timestamp`.
  202. """
  203. cutoff = datetime.now(timezone.utc) - timedelta(hours=float(ttl_hours))
  204. cutoff_ts = cutoff.timestamp()
  205. def _parse_payload_ts(ts: Any) -> float | None:
  206. if not ts:
  207. return None
  208. if isinstance(ts, (int, float)):
  209. return float(ts)
  210. text = str(ts).strip()
  211. try:
  212. dt = datetime.fromisoformat(text.replace('Z', '+00:00'))
  213. if dt.tzinfo is None:
  214. dt = dt.replace(tzinfo=timezone.utc)
  215. return dt.astimezone(timezone.utc).timestamp()
  216. except Exception:
  217. pass
  218. try:
  219. dt = parsedate_to_datetime(text)
  220. if dt.tzinfo is None:
  221. dt = dt.replace(tzinfo=timezone.utc)
  222. return dt.astimezone(timezone.utc).timestamp()
  223. except Exception:
  224. return None
  225. # Pull a wider candidate set, then filter by payload.timestamp.
  226. with self._conn() as conn:
  227. cur = conn.execute(
  228. "SELECT payload FROM clusters WHERE topic=? LIMIT ?",
  229. (topic, int(max(200, limit) * 10)),
  230. )
  231. candidates = [json.loads(r[0]) for r in cur.fetchall()]
  232. filtered: list[dict] = []
  233. for c in candidates:
  234. ts = _parse_payload_ts(c.get("timestamp"))
  235. if ts is None:
  236. continue
  237. if ts >= cutoff_ts:
  238. filtered.append(c)
  239. filtered.sort(key=lambda c: _parse_payload_ts(c.get("timestamp")) or 0.0, reverse=True)
  240. return filtered[: int(limit)]
  241. def get_latest_clusters_all_topics(self, ttl_hours: float, limit: int) -> list[dict]:
  242. cutoff = datetime.now(timezone.utc) - timedelta(hours=float(ttl_hours))
  243. cutoff_ts = cutoff.timestamp()
  244. def _parse_payload_ts(ts: Any) -> float | None:
  245. if not ts:
  246. return None
  247. if isinstance(ts, (int, float)):
  248. return float(ts)
  249. text = str(ts).strip()
  250. try:
  251. dt = datetime.fromisoformat(text.replace('Z', '+00:00'))
  252. if dt.tzinfo is None:
  253. dt = dt.replace(tzinfo=timezone.utc)
  254. return dt.astimezone(timezone.utc).timestamp()
  255. except Exception:
  256. pass
  257. try:
  258. dt = parsedate_to_datetime(text)
  259. if dt.tzinfo is None:
  260. dt = dt.replace(tzinfo=timezone.utc)
  261. return dt.astimezone(timezone.utc).timestamp()
  262. except Exception:
  263. return None
  264. with self._conn() as conn:
  265. cur = conn.execute(
  266. "SELECT payload FROM clusters LIMIT ?",
  267. (int(max(500, limit) * 10),),
  268. )
  269. candidates = [json.loads(r[0]) for r in cur.fetchall()]
  270. filtered: list[dict] = []
  271. for c in candidates:
  272. ts = _parse_payload_ts(c.get("timestamp"))
  273. if ts is None:
  274. continue
  275. if ts >= cutoff_ts:
  276. filtered.append(c)
  277. filtered.sort(key=lambda c: _parse_payload_ts(c.get("timestamp")) or 0.0, reverse=True)
  278. return filtered[: int(limit)]
  279. def get_cluster_by_id(self, cluster_id: str) -> dict | None:
  280. with self._conn() as conn:
  281. cur = conn.execute(
  282. "SELECT payload FROM clusters WHERE cluster_id=?",
  283. (cluster_id,),
  284. )
  285. row = cur.fetchone()
  286. return json.loads(row[0]) if row else None
  287. def get_feed_hash(self, feed_key: str) -> str | None:
  288. with self._conn() as conn:
  289. cur = conn.execute(
  290. "SELECT last_hash FROM feed_state WHERE feed_key=?",
  291. (feed_key,),
  292. )
  293. row = cur.fetchone()
  294. return row[0] if row else None
  295. def set_feed_hash(self, feed_key: str, last_hash: str) -> None:
  296. now = datetime.now(timezone.utc).isoformat()
  297. with self._conn() as conn:
  298. conn.execute(
  299. "INSERT INTO feed_state(feed_key, last_hash, updated_at) VALUES(?,?,?) "
  300. "ON CONFLICT(feed_key) DO UPDATE SET last_hash=excluded.last_hash, updated_at=excluded.updated_at",
  301. (feed_key, last_hash, now),
  302. )
  303. def get_feed_state(self, feed_key: str) -> dict | None:
  304. with self._conn() as conn:
  305. cur = conn.execute(
  306. "SELECT last_hash, updated_at FROM feed_state WHERE feed_key=?",
  307. (feed_key,),
  308. )
  309. row = cur.fetchone()
  310. if not row:
  311. return None
  312. return {"last_hash": row[0], "updated_at": row[1]}
  313. def get_meta(self, key: str) -> str | None:
  314. with self._conn() as conn:
  315. cur = conn.execute("SELECT value FROM meta WHERE key=?", (key,))
  316. row = cur.fetchone()
  317. return row[0] if row else None
  318. def set_meta(self, key: str, value: str) -> None:
  319. with self._conn() as conn:
  320. conn.execute(
  321. "INSERT INTO meta(key, value) VALUES(?, ?) "
  322. "ON CONFLICT(key) DO UPDATE SET value=excluded.value",
  323. (key, value),
  324. )
  325. def upsert_entity_metadata(
  326. self,
  327. normalized_label: str,
  328. canonical_label: str | None = None,
  329. mid: str | None = None,
  330. sources: list[str] | None = None,
  331. ) -> None:
  332. normalized_label = str(normalized_label or "").strip()
  333. if not normalized_label:
  334. return
  335. canonical_label = str(canonical_label).strip() if canonical_label else None
  336. mid = str(mid).strip() if mid else None
  337. entity_id = mid if mid else f"local:{normalized_label}"
  338. sources = sorted({s for s in (sources or []) if s})
  339. sources_json = json.dumps(sources, ensure_ascii=False)
  340. now = datetime.now(timezone.utc).isoformat()
  341. with self._conn() as conn:
  342. conn.execute(
  343. """
  344. INSERT INTO entity_metadata(entity_id, normalized_label, canonical_label, mid, sources_json, updated_at)
  345. VALUES(?,?,?,?,?,?)
  346. ON CONFLICT(entity_id) DO UPDATE SET
  347. canonical_label=excluded.canonical_label,
  348. mid=excluded.mid,
  349. sources_json=excluded.sources_json,
  350. updated_at=excluded.updated_at
  351. """,
  352. (entity_id, normalized_label, canonical_label, mid, sources_json, now),
  353. )
  354. def get_entity_metadata(self, normalized_label: str) -> dict[str, Any] | None:
  355. normalized_label = str(normalized_label or "").strip()
  356. if not normalized_label:
  357. return None
  358. with self._conn() as conn:
  359. cur = conn.execute(
  360. "SELECT entity_id, canonical_label, mid, sources_json, updated_at, last_requested_at FROM entity_metadata WHERE normalized_label=?",
  361. (normalized_label,),
  362. )
  363. row = cur.fetchone()
  364. if not row:
  365. return None
  366. sources = []
  367. if row[2]:
  368. try:
  369. sources = json.loads(row[2])
  370. except Exception:
  371. sources = []
  372. return {
  373. "entity_id": row[0],
  374. "normalized_label": normalized_label,
  375. "canonical_label": row[1],
  376. "mid": row[2],
  377. "sources": sources,
  378. "updated_at": row[4],
  379. "last_requested_at": row[5],
  380. }
  381. def record_entity_request(self, normalized_label: str, mid: str | None = None) -> None:
  382. normalized_label = str(normalized_label or "").strip()
  383. if not normalized_label:
  384. return
  385. mid = str(mid).strip() if mid else None
  386. entity_id = mid if mid else f"local:{normalized_label}"
  387. now = datetime.now(timezone.utc).isoformat()
  388. with self._conn() as conn:
  389. conn.execute(
  390. """
  391. INSERT INTO entity_metadata(entity_id, normalized_label, canonical_label, mid, sources_json, updated_at, last_requested_at)
  392. VALUES(?,?,?,?,?,?,?)
  393. ON CONFLICT(entity_id) DO UPDATE SET
  394. last_requested_at=excluded.last_requested_at
  395. """,
  396. (entity_id, normalized_label, None, mid, json.dumps([], ensure_ascii=False), now, now),
  397. )
  398. def prune_clusters(self, retention_days: float) -> int:
  399. retention_days = float(retention_days)
  400. if retention_days <= 0:
  401. return 0
  402. cutoff = datetime.now(timezone.utc) - timedelta(days=retention_days)
  403. cutoff_iso = cutoff.isoformat()
  404. pruned_at = datetime.now(timezone.utc).isoformat()
  405. with self._conn() as conn:
  406. cur = conn.execute("DELETE FROM clusters WHERE updated_at < ?", (cutoff_iso,))
  407. deleted = int(cur.rowcount or 0)
  408. conn.execute(
  409. "INSERT INTO meta(key, value) VALUES(?, ?) "
  410. "ON CONFLICT(key) DO UPDATE SET value=excluded.value",
  411. (META_LAST_PRUNE_AT, pruned_at),
  412. )
  413. return deleted
  414. def prune_if_due(self, pruning_enabled: bool, retention_days: float, interval_hours: float = 24.0) -> dict[str, Any]:
  415. retention_days = float(retention_days)
  416. interval_hours = float(interval_hours)
  417. if (not pruning_enabled) or retention_days <= 0:
  418. return {
  419. "enabled": bool(pruning_enabled),
  420. "deleted": 0,
  421. "due": False,
  422. "retention_days": retention_days,
  423. "interval_hours": interval_hours,
  424. "last_prune_at": self.get_meta(META_LAST_PRUNE_AT),
  425. }
  426. last_prune_at = self.get_meta(META_LAST_PRUNE_AT)
  427. now = datetime.now(timezone.utc)
  428. due = True
  429. if last_prune_at:
  430. try:
  431. last_dt = datetime.fromisoformat(last_prune_at)
  432. due = now - last_dt >= timedelta(hours=max(1.0, interval_hours))
  433. except Exception:
  434. due = True
  435. if not due:
  436. return {
  437. "enabled": True,
  438. "deleted": 0,
  439. "due": False,
  440. "retention_days": retention_days,
  441. "interval_hours": interval_hours,
  442. "last_prune_at": last_prune_at,
  443. }
  444. deleted = self.prune_clusters(retention_days)
  445. last_prune_at = self.get_meta(META_LAST_PRUNE_AT)
  446. return {
  447. "enabled": True,
  448. "deleted": deleted,
  449. "due": True,
  450. "retention_days": retention_days,
  451. "interval_hours": interval_hours,
  452. "last_prune_at": last_prune_at,
  453. }
  454. def get_prune_state(self, pruning_enabled: bool, retention_days: float, interval_hours: float = 24.0) -> dict[str, Any]:
  455. return {
  456. "enabled": bool(pruning_enabled),
  457. "retention_days": float(retention_days),
  458. "interval_hours": float(interval_hours),
  459. "last_prune_at": self.get_meta(META_LAST_PRUNE_AT),
  460. }