backfill_news_entities.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. from __future__ import annotations
  2. """One-off backfill for news-mcp cluster entity metadata.
  3. This reprocesses every stored cluster so older rows pick up the current
  4. entity normalization and trends resolution fields used by lookup tools.
  5. """
  6. import argparse
  7. import asyncio
  8. import json
  9. import sys
  10. import logging
  11. from pathlib import Path
  12. ROOT = Path(__file__).resolve().parents[1]
  13. if str(ROOT) not in sys.path:
  14. sys.path.insert(0, str(ROOT))
  15. from news_mcp.config import DB_PATH
  16. from news_mcp.entity_normalize import normalize_entities
  17. from news_mcp.storage.sqlite_store import SQLiteClusterStore, sanitize_cluster_payload
  18. def _backfill_reason(cluster: dict) -> str | None:
  19. """Return a reason string if the cluster needs backfill, else None."""
  20. raw_entities = cluster.get("entities", []) or []
  21. normalized_entities = normalize_entities(raw_entities)
  22. resolutions = cluster.get("entityResolutions", None)
  23. if normalized_entities != raw_entities:
  24. return "entities_not_normalized"
  25. if normalized_entities and not resolutions:
  26. return "missing_entityResolutions"
  27. if resolutions is not None and not isinstance(resolutions, list):
  28. return "invalid_resolution_shape"
  29. if isinstance(resolutions, list) and len(resolutions) != len(normalized_entities):
  30. return "resolution_length_mismatch"
  31. for res in resolutions or []:
  32. if not isinstance(res, dict):
  33. return "invalid_resolution_shape"
  34. if not res.get("normalized") or not res.get("canonical_label"):
  35. return "resolution_missing_fields"
  36. sanitized = sanitize_cluster_payload(cluster)
  37. if sanitized != cluster:
  38. return "payload_sanitized"
  39. return None
  40. async def backfill(
  41. db_path: Path,
  42. limit: int | None = None,
  43. dry_run: bool = False,
  44. scan_only: bool = False,
  45. progress_every: int = 20,
  46. write_batch_size: int = 100,
  47. ) -> dict[str, int]:
  48. store = SQLiteClusterStore(db_path)
  49. with store._conn() as conn: # noqa: SLF001 - intentional one-off maintenance script
  50. cur = conn.execute("SELECT cluster_id, topic, payload FROM clusters ORDER BY updated_at ASC")
  51. rows = cur.fetchall()
  52. total = 0
  53. updated = 0
  54. skipped = 0
  55. reason_counts: dict[str, int] = {}
  56. first_samples: dict[str, list[str]] = {}
  57. sample_limit_per_reason = 5
  58. pending_writes: list[tuple[dict, str]] = []
  59. logger = logging.getLogger("news_mcp.backfill")
  60. print(f"starting backfill: total_rows={len(rows)} limit={limit or 'all'} dry_run={dry_run} scan_only={scan_only}", flush=True)
  61. for cluster_id, topic, payload_json in rows:
  62. total += 1
  63. try:
  64. cluster = json.loads(payload_json)
  65. except Exception:
  66. skipped += 1
  67. continue
  68. reason = _backfill_reason(cluster)
  69. if reason is None:
  70. continue
  71. reason_counts[reason] = reason_counts.get(reason, 0) + 1
  72. # Record a few example cluster_ids per reason (helpful for root cause).
  73. if sample_limit_per_reason and reason in reason_counts and reason not in first_samples:
  74. first_samples[reason] = []
  75. if reason in first_samples and len(first_samples[reason]) < sample_limit_per_reason:
  76. first_samples[reason].append(cluster.get("cluster_id") or cluster_id)
  77. sanitized_cluster = sanitize_cluster_payload(cluster, include_resolutions=not scan_only)
  78. if not dry_run:
  79. pending_writes.append((sanitized_cluster, topic or sanitized_cluster.get("topic", "other")))
  80. if write_batch_size > 0 and len(pending_writes) >= write_batch_size:
  81. topic_groups: dict[str, list[dict]] = {}
  82. for item_cluster, item_topic in pending_writes:
  83. topic_groups.setdefault(item_topic, []).append(item_cluster)
  84. for batch_topic, batch_clusters in topic_groups.items():
  85. store.upsert_clusters(batch_clusters, topic=batch_topic)
  86. pending_writes.clear()
  87. updated += 1
  88. if limit is not None and updated >= limit:
  89. break
  90. if progress_every and total % progress_every == 0:
  91. logger.info("backfill progress total=%s updated=%s skipped=%s", total, updated, skipped)
  92. print(".", end="", flush=True)
  93. if progress_every and total % progress_every == 0:
  94. print(flush=True)
  95. if not dry_run and pending_writes:
  96. topic_groups: dict[str, list[dict]] = {}
  97. for item_cluster, item_topic in pending_writes:
  98. topic_groups.setdefault(item_topic, []).append(item_cluster)
  99. for batch_topic, batch_clusters in topic_groups.items():
  100. store.upsert_clusters(batch_clusters, topic=batch_topic)
  101. if total % max(1, progress_every) != 0:
  102. print(flush=True)
  103. # Print debug summary by reason.
  104. if reason_counts:
  105. print("reason_counts=" + json.dumps(reason_counts, ensure_ascii=False), flush=True)
  106. for reason, samples in first_samples.items():
  107. print(f"sample_{reason}=" + ",".join(samples), flush=True)
  108. return {"total": total, "updated": updated, "skipped": skipped}
  109. def main() -> None:
  110. parser = argparse.ArgumentParser(description="Backfill news-mcp entity normalization and resolutions")
  111. parser.add_argument("--db", type=Path, default=DB_PATH, help="Path to the news sqlite DB")
  112. parser.add_argument("--limit", type=int, default=None, help="Optional cap on clusters processed")
  113. parser.add_argument("--dry-run", action="store_true", help="Compute changes without writing them")
  114. parser.add_argument("--scan-only", action="store_true", help="Only detect stale rows; skip trends lookups")
  115. parser.add_argument("--progress-every", type=int, default=20, help="Emit a progress line every N processed rows")
  116. parser.add_argument("--write-batch-size", type=int, default=100, help="Flush writes every N updated clusters")
  117. args = parser.parse_args()
  118. logging.basicConfig(level=logging.INFO, format="%(message)s")
  119. result = asyncio.run(
  120. backfill(
  121. args.db,
  122. limit=args.limit,
  123. dry_run=args.dry_run,
  124. scan_only=args.scan_only,
  125. progress_every=max(0, int(args.progress_every)),
  126. write_batch_size=max(0, int(args.write_batch_size)),
  127. )
  128. )
  129. mode = "DRY RUN" if args.dry_run else "DONE"
  130. if args.scan_only:
  131. mode = f"{mode} SCAN-ONLY"
  132. print(f"{mode} {result}")
  133. if __name__ == "__main__":
  134. main()