|
|
@@ -0,0 +1,138 @@
|
|
|
+from __future__ import annotations
|
|
|
+
|
|
|
+"""One-off backfill for news-mcp cluster entity metadata.
|
|
|
+
|
|
|
+This reprocesses every stored cluster so older rows pick up the current
|
|
|
+entity normalization and trends resolution fields used by lookup tools.
|
|
|
+"""
|
|
|
+
|
|
|
+import argparse
|
|
|
+import asyncio
|
|
|
+import json
|
|
|
+import sys
|
|
|
+import logging
|
|
|
+from pathlib import Path
|
|
|
+from typing import Any
|
|
|
+
|
|
|
+
|
|
|
+ROOT = Path(__file__).resolve().parents[1]
|
|
|
+if str(ROOT) not in sys.path:
|
|
|
+ sys.path.insert(0, str(ROOT))
|
|
|
+
|
|
|
+from news_mcp.config import DB_PATH
|
|
|
+from news_mcp.entity_normalize import normalize_entities
|
|
|
+from news_mcp.storage.sqlite_store import SQLiteClusterStore
|
|
|
+from news_mcp.trends_resolution import resolve_entity_via_trends
|
|
|
+
|
|
|
+
|
|
|
+def _compute_entity_resolutions(entities: list[str]) -> list[dict[str, Any]]:
|
|
|
+ return [resolve_entity_via_trends(ent) for ent in entities]
|
|
|
+
|
|
|
+
|
|
|
+def _needs_backfill(cluster: dict) -> bool:
|
|
|
+ raw_entities = cluster.get("entities", []) or []
|
|
|
+ normalized_entities = normalize_entities(raw_entities)
|
|
|
+ resolutions = cluster.get("entityResolutions", []) or []
|
|
|
+
|
|
|
+ # Clearly missing or stale metadata.
|
|
|
+ if not resolutions:
|
|
|
+ return bool(normalized_entities or raw_entities)
|
|
|
+ if normalized_entities != raw_entities:
|
|
|
+ return True
|
|
|
+ if len(resolutions) != len(normalized_entities):
|
|
|
+ return True
|
|
|
+ for res in resolutions:
|
|
|
+ if not isinstance(res, dict):
|
|
|
+ return True
|
|
|
+ if not res.get("normalized") or not res.get("canonical_label"):
|
|
|
+ return True
|
|
|
+ return False
|
|
|
+
|
|
|
+
|
|
|
+async def backfill(
|
|
|
+ db_path: Path,
|
|
|
+ limit: int | None = None,
|
|
|
+ dry_run: bool = False,
|
|
|
+ scan_only: bool = False,
|
|
|
+ progress_every: int = 20,
|
|
|
+) -> dict[str, int]:
|
|
|
+ store = SQLiteClusterStore(db_path)
|
|
|
+
|
|
|
+ with store._conn() as conn: # noqa: SLF001 - intentional one-off maintenance script
|
|
|
+ cur = conn.execute("SELECT cluster_id, topic, payload FROM clusters ORDER BY updated_at ASC")
|
|
|
+ rows = cur.fetchall()
|
|
|
+
|
|
|
+ total = 0
|
|
|
+ updated = 0
|
|
|
+ skipped = 0
|
|
|
+ logger = logging.getLogger("news_mcp.backfill")
|
|
|
+
|
|
|
+ print(f"starting backfill: total_rows={len(rows)} limit={limit or 'all'} dry_run={dry_run} scan_only={scan_only}", flush=True)
|
|
|
+
|
|
|
+ for cluster_id, topic, payload_json in rows:
|
|
|
+ total += 1
|
|
|
+ try:
|
|
|
+ cluster = json.loads(payload_json)
|
|
|
+ except Exception:
|
|
|
+ skipped += 1
|
|
|
+ continue
|
|
|
+
|
|
|
+ if not _needs_backfill(cluster):
|
|
|
+ continue
|
|
|
+
|
|
|
+ raw_entities = cluster.get("entities", []) or []
|
|
|
+ normalized_entities = normalize_entities(raw_entities)
|
|
|
+ entity_resolutions = _compute_entity_resolutions(normalized_entities) if not scan_only else []
|
|
|
+
|
|
|
+ cluster = dict(cluster)
|
|
|
+ cluster["entities"] = normalized_entities
|
|
|
+ if not scan_only:
|
|
|
+ cluster["entityResolutions"] = entity_resolutions
|
|
|
+
|
|
|
+ if not dry_run:
|
|
|
+ store.upsert_clusters([cluster], topic=topic or cluster.get("topic", "other"))
|
|
|
+ updated += 1
|
|
|
+
|
|
|
+ if limit is not None and updated >= limit:
|
|
|
+ break
|
|
|
+
|
|
|
+ if progress_every and total % progress_every == 0:
|
|
|
+ logger.info("backfill progress total=%s updated=%s skipped=%s", total, updated, skipped)
|
|
|
+
|
|
|
+ print(".", end="", flush=True)
|
|
|
+ if progress_every and total % progress_every == 0:
|
|
|
+ print(flush=True)
|
|
|
+
|
|
|
+ if total % max(1, progress_every) != 0:
|
|
|
+ print(flush=True)
|
|
|
+
|
|
|
+ return {"total": total, "updated": updated, "skipped": skipped}
|
|
|
+
|
|
|
+
|
|
|
+def main() -> None:
|
|
|
+ parser = argparse.ArgumentParser(description="Backfill news-mcp entity normalization and resolutions")
|
|
|
+ parser.add_argument("--db", type=Path, default=DB_PATH, help="Path to the news sqlite DB")
|
|
|
+ parser.add_argument("--limit", type=int, default=None, help="Optional cap on clusters processed")
|
|
|
+ parser.add_argument("--dry-run", action="store_true", help="Compute changes without writing them")
|
|
|
+ parser.add_argument("--scan-only", action="store_true", help="Only detect stale rows; skip trends lookups")
|
|
|
+ parser.add_argument("--progress-every", type=int, default=20, help="Emit a progress line every N processed rows")
|
|
|
+ args = parser.parse_args()
|
|
|
+
|
|
|
+ logging.basicConfig(level=logging.INFO, format="%(message)s")
|
|
|
+ result = asyncio.run(
|
|
|
+ backfill(
|
|
|
+ args.db,
|
|
|
+ limit=args.limit,
|
|
|
+ dry_run=args.dry_run,
|
|
|
+ scan_only=args.scan_only,
|
|
|
+ progress_every=max(0, int(args.progress_every)),
|
|
|
+ )
|
|
|
+ )
|
|
|
+ mode = "DRY RUN" if args.dry_run else "DONE"
|
|
|
+ if args.scan_only:
|
|
|
+ mode = f"{mode} SCAN-ONLY"
|
|
|
+ print(f"{mode} {result}")
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ main()
|