from __future__ import annotations """One-off backfill for news-mcp cluster entity metadata. This reprocesses every stored cluster so older rows pick up the current entity normalization and trends resolution fields used by lookup tools. """ import argparse import asyncio import json import sys import logging from pathlib import Path ROOT = Path(__file__).resolve().parents[1] if str(ROOT) not in sys.path: sys.path.insert(0, str(ROOT)) from news_mcp.config import DB_PATH from news_mcp.entity_normalize import normalize_entities from news_mcp.storage.sqlite_store import SQLiteClusterStore, sanitize_cluster_payload def _backfill_reason(cluster: dict) -> str | None: """Return a reason string if the cluster needs backfill, else None.""" raw_entities = cluster.get("entities", []) or [] normalized_entities = normalize_entities(raw_entities) resolutions = cluster.get("entityResolutions", None) if normalized_entities != raw_entities: return "entities_not_normalized" if normalized_entities and not resolutions: return "missing_entityResolutions" if resolutions is not None and not isinstance(resolutions, list): return "invalid_resolution_shape" if isinstance(resolutions, list) and len(resolutions) != len(normalized_entities): return "resolution_length_mismatch" for res in resolutions or []: if not isinstance(res, dict): return "invalid_resolution_shape" if not res.get("normalized") or not res.get("canonical_label"): return "resolution_missing_fields" sanitized = sanitize_cluster_payload(cluster) if sanitized != cluster: return "payload_sanitized" return None async def backfill( db_path: Path, limit: int | None = None, dry_run: bool = False, scan_only: bool = False, progress_every: int = 20, write_batch_size: int = 100, ) -> dict[str, int]: store = SQLiteClusterStore(db_path) with store._conn() as conn: # noqa: SLF001 - intentional one-off maintenance script cur = conn.execute("SELECT cluster_id, topic, payload FROM clusters ORDER BY updated_at ASC") rows = cur.fetchall() total = 0 updated = 0 skipped = 0 reason_counts: dict[str, int] = {} first_samples: dict[str, list[str]] = {} sample_limit_per_reason = 5 pending_writes: list[tuple[dict, str]] = [] logger = logging.getLogger("news_mcp.backfill") print(f"starting backfill: total_rows={len(rows)} limit={limit or 'all'} dry_run={dry_run} scan_only={scan_only}", flush=True) for cluster_id, topic, payload_json in rows: total += 1 try: cluster = json.loads(payload_json) except Exception: skipped += 1 continue reason = _backfill_reason(cluster) if reason is None: continue reason_counts[reason] = reason_counts.get(reason, 0) + 1 # Record a few example cluster_ids per reason (helpful for root cause). if sample_limit_per_reason and reason in reason_counts and reason not in first_samples: first_samples[reason] = [] if reason in first_samples and len(first_samples[reason]) < sample_limit_per_reason: first_samples[reason].append(cluster.get("cluster_id") or cluster_id) sanitized_cluster = sanitize_cluster_payload(cluster, include_resolutions=not scan_only) if not dry_run: pending_writes.append((sanitized_cluster, topic or sanitized_cluster.get("topic", "other"))) if write_batch_size > 0 and len(pending_writes) >= write_batch_size: topic_groups: dict[str, list[dict]] = {} for item_cluster, item_topic in pending_writes: topic_groups.setdefault(item_topic, []).append(item_cluster) for batch_topic, batch_clusters in topic_groups.items(): store.upsert_clusters(batch_clusters, topic=batch_topic) pending_writes.clear() updated += 1 if limit is not None and updated >= limit: break if progress_every and total % progress_every == 0: logger.info("backfill progress total=%s updated=%s skipped=%s", total, updated, skipped) print(".", end="", flush=True) if progress_every and total % progress_every == 0: print(flush=True) if not dry_run and pending_writes: topic_groups: dict[str, list[dict]] = {} for item_cluster, item_topic in pending_writes: topic_groups.setdefault(item_topic, []).append(item_cluster) for batch_topic, batch_clusters in topic_groups.items(): store.upsert_clusters(batch_clusters, topic=batch_topic) if total % max(1, progress_every) != 0: print(flush=True) # Print debug summary by reason. if reason_counts: print("reason_counts=" + json.dumps(reason_counts, ensure_ascii=False), flush=True) for reason, samples in first_samples.items(): print(f"sample_{reason}=" + ",".join(samples), flush=True) return {"total": total, "updated": updated, "skipped": skipped} def main() -> None: parser = argparse.ArgumentParser(description="Backfill news-mcp entity normalization and resolutions") parser.add_argument("--db", type=Path, default=DB_PATH, help="Path to the news sqlite DB") parser.add_argument("--limit", type=int, default=None, help="Optional cap on clusters processed") parser.add_argument("--dry-run", action="store_true", help="Compute changes without writing them") parser.add_argument("--scan-only", action="store_true", help="Only detect stale rows; skip trends lookups") parser.add_argument("--progress-every", type=int, default=20, help="Emit a progress line every N processed rows") parser.add_argument("--write-batch-size", type=int, default=100, help="Flush writes every N updated clusters") args = parser.parse_args() logging.basicConfig(level=logging.INFO, format="%(message)s") result = asyncio.run( backfill( args.db, limit=args.limit, dry_run=args.dry_run, scan_only=args.scan_only, progress_every=max(0, int(args.progress_every)), write_batch_size=max(0, int(args.write_batch_size)), ) ) mode = "DRY RUN" if args.dry_run else "DONE" if args.scan_only: mode = f"{mode} SCAN-ONLY" print(f"{mode} {result}") if __name__ == "__main__": main()