backfill_news_entities.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. from __future__ import annotations
  2. """One-off backfill for news-mcp cluster entity metadata.
  3. This reprocesses every stored cluster so older rows pick up the current
  4. entity normalization and trends resolution fields used by lookup tools.
  5. """
  6. import argparse
  7. import asyncio
  8. import json
  9. import sys
  10. import logging
  11. from pathlib import Path
  12. from typing import Any
  13. ROOT = Path(__file__).resolve().parents[1]
  14. if str(ROOT) not in sys.path:
  15. sys.path.insert(0, str(ROOT))
  16. from news_mcp.config import DB_PATH
  17. from news_mcp.entity_normalize import normalize_entities
  18. from news_mcp.storage.sqlite_store import SQLiteClusterStore
  19. from news_mcp.trends_resolution import resolve_entity_via_trends
  20. def _compute_entity_resolutions(entities: list[str]) -> list[dict[str, Any]]:
  21. return [resolve_entity_via_trends(ent) for ent in entities]
  22. def _needs_backfill(cluster: dict) -> bool:
  23. raw_entities = cluster.get("entities", []) or []
  24. normalized_entities = normalize_entities(raw_entities)
  25. resolutions = cluster.get("entityResolutions", []) or []
  26. # Clearly missing or stale metadata.
  27. if not resolutions:
  28. return bool(normalized_entities or raw_entities)
  29. if normalized_entities != raw_entities:
  30. return True
  31. if len(resolutions) != len(normalized_entities):
  32. return True
  33. for res in resolutions:
  34. if not isinstance(res, dict):
  35. return True
  36. if not res.get("normalized") or not res.get("canonical_label"):
  37. return True
  38. return False
  39. async def backfill(
  40. db_path: Path,
  41. limit: int | None = None,
  42. dry_run: bool = False,
  43. scan_only: bool = False,
  44. progress_every: int = 20,
  45. ) -> dict[str, int]:
  46. store = SQLiteClusterStore(db_path)
  47. with store._conn() as conn: # noqa: SLF001 - intentional one-off maintenance script
  48. cur = conn.execute("SELECT cluster_id, topic, payload FROM clusters ORDER BY updated_at ASC")
  49. rows = cur.fetchall()
  50. total = 0
  51. updated = 0
  52. skipped = 0
  53. logger = logging.getLogger("news_mcp.backfill")
  54. print(f"starting backfill: total_rows={len(rows)} limit={limit or 'all'} dry_run={dry_run} scan_only={scan_only}", flush=True)
  55. for cluster_id, topic, payload_json in rows:
  56. total += 1
  57. try:
  58. cluster = json.loads(payload_json)
  59. except Exception:
  60. skipped += 1
  61. continue
  62. if not _needs_backfill(cluster):
  63. continue
  64. raw_entities = cluster.get("entities", []) or []
  65. normalized_entities = normalize_entities(raw_entities)
  66. entity_resolutions = _compute_entity_resolutions(normalized_entities) if not scan_only else []
  67. cluster = dict(cluster)
  68. cluster["entities"] = normalized_entities
  69. if not scan_only:
  70. cluster["entityResolutions"] = entity_resolutions
  71. if not dry_run:
  72. store.upsert_clusters([cluster], topic=topic or cluster.get("topic", "other"))
  73. updated += 1
  74. if limit is not None and updated >= limit:
  75. break
  76. if progress_every and total % progress_every == 0:
  77. logger.info("backfill progress total=%s updated=%s skipped=%s", total, updated, skipped)
  78. print(".", end="", flush=True)
  79. if progress_every and total % progress_every == 0:
  80. print(flush=True)
  81. if total % max(1, progress_every) != 0:
  82. print(flush=True)
  83. return {"total": total, "updated": updated, "skipped": skipped}
  84. def main() -> None:
  85. parser = argparse.ArgumentParser(description="Backfill news-mcp entity normalization and resolutions")
  86. parser.add_argument("--db", type=Path, default=DB_PATH, help="Path to the news sqlite DB")
  87. parser.add_argument("--limit", type=int, default=None, help="Optional cap on clusters processed")
  88. parser.add_argument("--dry-run", action="store_true", help="Compute changes without writing them")
  89. parser.add_argument("--scan-only", action="store_true", help="Only detect stale rows; skip trends lookups")
  90. parser.add_argument("--progress-every", type=int, default=20, help="Emit a progress line every N processed rows")
  91. args = parser.parse_args()
  92. logging.basicConfig(level=logging.INFO, format="%(message)s")
  93. result = asyncio.run(
  94. backfill(
  95. args.db,
  96. limit=args.limit,
  97. dry_run=args.dry_run,
  98. scan_only=args.scan_only,
  99. progress_every=max(0, int(args.progress_every)),
  100. )
  101. )
  102. mode = "DRY RUN" if args.dry_run else "DONE"
  103. if args.scan_only:
  104. mode = f"{mode} SCAN-ONLY"
  105. print(f"{mode} {result}")
  106. if __name__ == "__main__":
  107. main()