| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243 |
- from __future__ import annotations
- import json
- import sqlite3
- from dataclasses import dataclass
- from datetime import datetime, timezone, timedelta
- from pathlib import Path
- from typing import Any
- from urllib.parse import urlparse
- from news_mcp.entity_normalize import normalize_entities
- from news_mcp.trends_resolution import resolve_entity_via_trends
- @dataclass
- class ClusterRow:
- cluster_id: str
- topic: str
- payload: dict
- updated_at: datetime
- def _article_key(article: dict[str, Any]) -> str:
- url = str(article.get("url") or "").strip()
- if not url:
- return str(article.get("title") or "")
- try:
- parsed = urlparse(url)
- parts = [p for p in parsed.path.split("/") if p]
- if parts:
- return parts[-1]
- except Exception:
- pass
- return url
- def _dedup_articles(articles: list[dict[str, Any]]) -> list[dict[str, Any]]:
- seen: set[str] = set()
- out: list[dict[str, Any]] = []
- for article in articles:
- key = _article_key(article)
- if key in seen:
- continue
- seen.add(key)
- out.append(article)
- return out
- def _has_valid_entity_resolutions(resolutions: Any, entities: list[str]) -> bool:
- if not isinstance(resolutions, list):
- return False
- if len(resolutions) != len(entities):
- return False
- for res in resolutions:
- if not isinstance(res, dict):
- return False
- if not res.get("normalized") or not res.get("canonical_label"):
- return False
- return True
- def sanitize_cluster_payload(cluster: dict[str, Any], *, include_resolutions: bool = True) -> dict[str, Any]:
- """Normalize cluster payload so every stored payload is internally consistent."""
- out = dict(cluster)
- raw_articles = out.get("articles", []) or []
- articles = [a for a in raw_articles if isinstance(a, dict)]
- out["articles"] = _dedup_articles(articles)
- raw_entities = out.get("entities", []) or []
- entities = normalize_entities(raw_entities)
- out["entities"] = entities
- if not include_resolutions:
- return out
- resolutions = out.get("entityResolutions", None)
- if entities:
- if not _has_valid_entity_resolutions(resolutions, entities):
- out["entityResolutions"] = [resolve_entity_via_trends(e) for e in entities]
- else:
- # Keep the empty case explicit and stable.
- out["entityResolutions"] = []
- return out
- class SQLiteClusterStore:
- def __init__(self, db_path: str | Path):
- self.db_path = str(db_path)
- self._init_db()
- def _conn(self) -> sqlite3.Connection:
- return sqlite3.connect(self.db_path)
- def _init_db(self) -> None:
- Path(self.db_path).parent.mkdir(parents=True, exist_ok=True)
- with self._conn() as conn:
- conn.execute(
- """
- CREATE TABLE IF NOT EXISTS clusters (
- cluster_id TEXT PRIMARY KEY,
- topic TEXT NOT NULL,
- payload TEXT NOT NULL,
- updated_at TEXT NOT NULL,
- summary_payload TEXT,
- summary_updated_at TEXT
- )
- """
- )
- # If the table already exists without the summary columns,
- # add them (SQLite-friendly incremental migrations).
- for col_def in [
- "summary_payload TEXT",
- "summary_updated_at TEXT",
- ]:
- col = col_def.split()[0]
- try:
- conn.execute(f"ALTER TABLE clusters ADD COLUMN {col_def}")
- except sqlite3.OperationalError:
- pass
- conn.execute(
- "CREATE INDEX IF NOT EXISTS idx_clusters_topic ON clusters(topic)"
- )
- conn.execute(
- """
- CREATE TABLE IF NOT EXISTS feed_state (
- feed_key TEXT PRIMARY KEY,
- last_hash TEXT NOT NULL,
- updated_at TEXT NOT NULL
- )
- """
- )
- def upsert_clusters(self, clusters: list[dict], topic: str) -> None:
- now = datetime.now(timezone.utc)
- with self._conn() as conn:
- for c in clusters:
- c = sanitize_cluster_payload(c)
- cluster_id = c["cluster_id"]
- payload = json.dumps(c, ensure_ascii=False)
- conn.execute(
- "INSERT INTO clusters(cluster_id, topic, payload, updated_at) VALUES(?,?,?,?) "
- "ON CONFLICT(cluster_id) DO UPDATE SET topic=excluded.topic, payload=excluded.payload, updated_at=excluded.updated_at",
- (cluster_id, topic, payload, now.isoformat()),
- )
- def upsert_cluster_summary(
- self,
- cluster_id: str,
- summary_payload: dict,
- ) -> None:
- now = datetime.now(timezone.utc).isoformat()
- with self._conn() as conn:
- conn.execute(
- "INSERT INTO clusters(cluster_id, topic, payload, updated_at, summary_payload, summary_updated_at) "
- "VALUES(?,?,?,?,?,?) "
- "ON CONFLICT(cluster_id) DO UPDATE SET "
- "summary_payload=excluded.summary_payload, summary_updated_at=excluded.summary_updated_at",
- (
- cluster_id,
- "", # topic not used for update
- json.dumps({}, ensure_ascii=False),
- now,
- json.dumps(summary_payload, ensure_ascii=False),
- now,
- ),
- )
- def get_cluster_summary(self, cluster_id: str, ttl_hours: float) -> dict | None:
- cutoff = datetime.now(timezone.utc) - timedelta(hours=ttl_hours)
- cutoff_iso = cutoff.isoformat()
- with self._conn() as conn:
- cur = conn.execute(
- "SELECT summary_payload, summary_updated_at FROM clusters "
- "WHERE cluster_id=? AND summary_updated_at >= ?",
- (cluster_id, cutoff_iso),
- )
- row = cur.fetchone()
- if not row or not row[0]:
- return None
- return json.loads(row[0])
- def get_latest_clusters(self, topic: str, ttl_hours: float, limit: int) -> list[dict]:
- cutoff = datetime.now(timezone.utc) - timedelta(hours=ttl_hours)
- cutoff_iso = cutoff.isoformat()
- with self._conn() as conn:
- cur = conn.execute(
- "SELECT payload FROM clusters WHERE topic=? AND updated_at >= ? ORDER BY updated_at DESC LIMIT ?",
- (topic, cutoff_iso, int(limit)),
- )
- rows = [json.loads(r[0]) for r in cur.fetchall()]
- return rows
- def get_latest_clusters_all_topics(self, ttl_hours: float, limit: int) -> list[dict]:
- cutoff = datetime.now(timezone.utc) - timedelta(hours=ttl_hours)
- cutoff_iso = cutoff.isoformat()
- with self._conn() as conn:
- cur = conn.execute(
- "SELECT payload FROM clusters WHERE updated_at >= ? ORDER BY updated_at DESC LIMIT ?",
- (cutoff_iso, int(limit)),
- )
- return [json.loads(r[0]) for r in cur.fetchall()]
- def get_cluster_by_id(self, cluster_id: str) -> dict | None:
- with self._conn() as conn:
- cur = conn.execute(
- "SELECT payload FROM clusters WHERE cluster_id=?",
- (cluster_id,),
- )
- row = cur.fetchone()
- return json.loads(row[0]) if row else None
- def get_feed_hash(self, feed_key: str) -> str | None:
- with self._conn() as conn:
- cur = conn.execute(
- "SELECT last_hash FROM feed_state WHERE feed_key=?",
- (feed_key,),
- )
- row = cur.fetchone()
- return row[0] if row else None
- def set_feed_hash(self, feed_key: str, last_hash: str) -> None:
- now = datetime.now(timezone.utc).isoformat()
- with self._conn() as conn:
- conn.execute(
- "INSERT INTO feed_state(feed_key, last_hash, updated_at) VALUES(?,?,?) "
- "ON CONFLICT(feed_key) DO UPDATE SET last_hash=excluded.last_hash, updated_at=excluded.updated_at",
- (feed_key, last_hash, now),
- )
- def get_feed_state(self, feed_key: str) -> dict | None:
- with self._conn() as conn:
- cur = conn.execute(
- "SELECT last_hash, updated_at FROM feed_state WHERE feed_key=?",
- (feed_key,),
- )
- row = cur.fetchone()
- if not row:
- return None
- return {"last_hash": row[0], "updated_at": row[1]}
|