sqlite_store.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243
  1. from __future__ import annotations
  2. import json
  3. import sqlite3
  4. from dataclasses import dataclass
  5. from datetime import datetime, timezone, timedelta
  6. from pathlib import Path
  7. from typing import Any
  8. from urllib.parse import urlparse
  9. from news_mcp.entity_normalize import normalize_entities
  10. from news_mcp.trends_resolution import resolve_entity_via_trends
  11. @dataclass
  12. class ClusterRow:
  13. cluster_id: str
  14. topic: str
  15. payload: dict
  16. updated_at: datetime
  17. def _article_key(article: dict[str, Any]) -> str:
  18. url = str(article.get("url") or "").strip()
  19. if not url:
  20. return str(article.get("title") or "")
  21. try:
  22. parsed = urlparse(url)
  23. parts = [p for p in parsed.path.split("/") if p]
  24. if parts:
  25. return parts[-1]
  26. except Exception:
  27. pass
  28. return url
  29. def _dedup_articles(articles: list[dict[str, Any]]) -> list[dict[str, Any]]:
  30. seen: set[str] = set()
  31. out: list[dict[str, Any]] = []
  32. for article in articles:
  33. key = _article_key(article)
  34. if key in seen:
  35. continue
  36. seen.add(key)
  37. out.append(article)
  38. return out
  39. def _has_valid_entity_resolutions(resolutions: Any, entities: list[str]) -> bool:
  40. if not isinstance(resolutions, list):
  41. return False
  42. if len(resolutions) != len(entities):
  43. return False
  44. for res in resolutions:
  45. if not isinstance(res, dict):
  46. return False
  47. if not res.get("normalized") or not res.get("canonical_label"):
  48. return False
  49. return True
  50. def sanitize_cluster_payload(cluster: dict[str, Any], *, include_resolutions: bool = True) -> dict[str, Any]:
  51. """Normalize cluster payload so every stored payload is internally consistent."""
  52. out = dict(cluster)
  53. raw_articles = out.get("articles", []) or []
  54. articles = [a for a in raw_articles if isinstance(a, dict)]
  55. out["articles"] = _dedup_articles(articles)
  56. raw_entities = out.get("entities", []) or []
  57. entities = normalize_entities(raw_entities)
  58. out["entities"] = entities
  59. if not include_resolutions:
  60. return out
  61. resolutions = out.get("entityResolutions", None)
  62. if entities:
  63. if not _has_valid_entity_resolutions(resolutions, entities):
  64. out["entityResolutions"] = [resolve_entity_via_trends(e) for e in entities]
  65. else:
  66. # Keep the empty case explicit and stable.
  67. out["entityResolutions"] = []
  68. return out
  69. class SQLiteClusterStore:
  70. def __init__(self, db_path: str | Path):
  71. self.db_path = str(db_path)
  72. self._init_db()
  73. def _conn(self) -> sqlite3.Connection:
  74. return sqlite3.connect(self.db_path)
  75. def _init_db(self) -> None:
  76. Path(self.db_path).parent.mkdir(parents=True, exist_ok=True)
  77. with self._conn() as conn:
  78. conn.execute(
  79. """
  80. CREATE TABLE IF NOT EXISTS clusters (
  81. cluster_id TEXT PRIMARY KEY,
  82. topic TEXT NOT NULL,
  83. payload TEXT NOT NULL,
  84. updated_at TEXT NOT NULL,
  85. summary_payload TEXT,
  86. summary_updated_at TEXT
  87. )
  88. """
  89. )
  90. # If the table already exists without the summary columns,
  91. # add them (SQLite-friendly incremental migrations).
  92. for col_def in [
  93. "summary_payload TEXT",
  94. "summary_updated_at TEXT",
  95. ]:
  96. col = col_def.split()[0]
  97. try:
  98. conn.execute(f"ALTER TABLE clusters ADD COLUMN {col_def}")
  99. except sqlite3.OperationalError:
  100. pass
  101. conn.execute(
  102. "CREATE INDEX IF NOT EXISTS idx_clusters_topic ON clusters(topic)"
  103. )
  104. conn.execute(
  105. """
  106. CREATE TABLE IF NOT EXISTS feed_state (
  107. feed_key TEXT PRIMARY KEY,
  108. last_hash TEXT NOT NULL,
  109. updated_at TEXT NOT NULL
  110. )
  111. """
  112. )
  113. def upsert_clusters(self, clusters: list[dict], topic: str) -> None:
  114. now = datetime.now(timezone.utc)
  115. with self._conn() as conn:
  116. for c in clusters:
  117. c = sanitize_cluster_payload(c)
  118. cluster_id = c["cluster_id"]
  119. payload = json.dumps(c, ensure_ascii=False)
  120. conn.execute(
  121. "INSERT INTO clusters(cluster_id, topic, payload, updated_at) VALUES(?,?,?,?) "
  122. "ON CONFLICT(cluster_id) DO UPDATE SET topic=excluded.topic, payload=excluded.payload, updated_at=excluded.updated_at",
  123. (cluster_id, topic, payload, now.isoformat()),
  124. )
  125. def upsert_cluster_summary(
  126. self,
  127. cluster_id: str,
  128. summary_payload: dict,
  129. ) -> None:
  130. now = datetime.now(timezone.utc).isoformat()
  131. with self._conn() as conn:
  132. conn.execute(
  133. "INSERT INTO clusters(cluster_id, topic, payload, updated_at, summary_payload, summary_updated_at) "
  134. "VALUES(?,?,?,?,?,?) "
  135. "ON CONFLICT(cluster_id) DO UPDATE SET "
  136. "summary_payload=excluded.summary_payload, summary_updated_at=excluded.summary_updated_at",
  137. (
  138. cluster_id,
  139. "", # topic not used for update
  140. json.dumps({}, ensure_ascii=False),
  141. now,
  142. json.dumps(summary_payload, ensure_ascii=False),
  143. now,
  144. ),
  145. )
  146. def get_cluster_summary(self, cluster_id: str, ttl_hours: float) -> dict | None:
  147. cutoff = datetime.now(timezone.utc) - timedelta(hours=ttl_hours)
  148. cutoff_iso = cutoff.isoformat()
  149. with self._conn() as conn:
  150. cur = conn.execute(
  151. "SELECT summary_payload, summary_updated_at FROM clusters "
  152. "WHERE cluster_id=? AND summary_updated_at >= ?",
  153. (cluster_id, cutoff_iso),
  154. )
  155. row = cur.fetchone()
  156. if not row or not row[0]:
  157. return None
  158. return json.loads(row[0])
  159. def get_latest_clusters(self, topic: str, ttl_hours: float, limit: int) -> list[dict]:
  160. cutoff = datetime.now(timezone.utc) - timedelta(hours=ttl_hours)
  161. cutoff_iso = cutoff.isoformat()
  162. with self._conn() as conn:
  163. cur = conn.execute(
  164. "SELECT payload FROM clusters WHERE topic=? AND updated_at >= ? ORDER BY updated_at DESC LIMIT ?",
  165. (topic, cutoff_iso, int(limit)),
  166. )
  167. rows = [json.loads(r[0]) for r in cur.fetchall()]
  168. return rows
  169. def get_latest_clusters_all_topics(self, ttl_hours: float, limit: int) -> list[dict]:
  170. cutoff = datetime.now(timezone.utc) - timedelta(hours=ttl_hours)
  171. cutoff_iso = cutoff.isoformat()
  172. with self._conn() as conn:
  173. cur = conn.execute(
  174. "SELECT payload FROM clusters WHERE updated_at >= ? ORDER BY updated_at DESC LIMIT ?",
  175. (cutoff_iso, int(limit)),
  176. )
  177. return [json.loads(r[0]) for r in cur.fetchall()]
  178. def get_cluster_by_id(self, cluster_id: str) -> dict | None:
  179. with self._conn() as conn:
  180. cur = conn.execute(
  181. "SELECT payload FROM clusters WHERE cluster_id=?",
  182. (cluster_id,),
  183. )
  184. row = cur.fetchone()
  185. return json.loads(row[0]) if row else None
  186. def get_feed_hash(self, feed_key: str) -> str | None:
  187. with self._conn() as conn:
  188. cur = conn.execute(
  189. "SELECT last_hash FROM feed_state WHERE feed_key=?",
  190. (feed_key,),
  191. )
  192. row = cur.fetchone()
  193. return row[0] if row else None
  194. def set_feed_hash(self, feed_key: str, last_hash: str) -> None:
  195. now = datetime.now(timezone.utc).isoformat()
  196. with self._conn() as conn:
  197. conn.execute(
  198. "INSERT INTO feed_state(feed_key, last_hash, updated_at) VALUES(?,?,?) "
  199. "ON CONFLICT(feed_key) DO UPDATE SET last_hash=excluded.last_hash, updated_at=excluded.updated_at",
  200. (feed_key, last_hash, now),
  201. )
  202. def get_feed_state(self, feed_key: str) -> dict | None:
  203. with self._conn() as conn:
  204. cur = conn.execute(
  205. "SELECT last_hash, updated_at FROM feed_state WHERE feed_key=?",
  206. (feed_key,),
  207. )
  208. row = cur.fetchone()
  209. if not row:
  210. return None
  211. return {"last_hash": row[0], "updated_at": row[1]}