| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170 |
- from __future__ import annotations
- """One-off backfill for news-mcp cluster entity metadata.
- This reprocesses every stored cluster so older rows pick up the current
- entity normalization and trends resolution fields used by lookup tools.
- """
- import argparse
- import asyncio
- import json
- import sys
- import logging
- from pathlib import Path
- ROOT = Path(__file__).resolve().parents[1]
- if str(ROOT) not in sys.path:
- sys.path.insert(0, str(ROOT))
- from news_mcp.config import DB_PATH
- from news_mcp.entity_normalize import normalize_entities
- from news_mcp.storage.sqlite_store import SQLiteClusterStore, sanitize_cluster_payload
- def _backfill_reason(cluster: dict) -> str | None:
- """Return a reason string if the cluster needs backfill, else None."""
- raw_entities = cluster.get("entities", []) or []
- normalized_entities = normalize_entities(raw_entities)
- resolutions = cluster.get("entityResolutions", None)
- if normalized_entities != raw_entities:
- return "entities_not_normalized"
- if normalized_entities and not resolutions:
- return "missing_entityResolutions"
- if resolutions is not None and not isinstance(resolutions, list):
- return "invalid_resolution_shape"
- if isinstance(resolutions, list) and len(resolutions) != len(normalized_entities):
- return "resolution_length_mismatch"
- for res in resolutions or []:
- if not isinstance(res, dict):
- return "invalid_resolution_shape"
- if not res.get("normalized") or not res.get("canonical_label"):
- return "resolution_missing_fields"
- sanitized = sanitize_cluster_payload(cluster)
- if sanitized != cluster:
- return "payload_sanitized"
- return None
- async def backfill(
- db_path: Path,
- limit: int | None = None,
- dry_run: bool = False,
- scan_only: bool = False,
- progress_every: int = 20,
- write_batch_size: int = 100,
- ) -> dict[str, int]:
- store = SQLiteClusterStore(db_path)
- with store._conn() as conn: # noqa: SLF001 - intentional one-off maintenance script
- cur = conn.execute("SELECT cluster_id, topic, payload FROM clusters ORDER BY updated_at ASC")
- rows = cur.fetchall()
- total = 0
- updated = 0
- skipped = 0
- reason_counts: dict[str, int] = {}
- first_samples: dict[str, list[str]] = {}
- sample_limit_per_reason = 5
- pending_writes: list[tuple[dict, str]] = []
- logger = logging.getLogger("news_mcp.backfill")
- print(f"starting backfill: total_rows={len(rows)} limit={limit or 'all'} dry_run={dry_run} scan_only={scan_only}", flush=True)
- for cluster_id, topic, payload_json in rows:
- total += 1
- try:
- cluster = json.loads(payload_json)
- except Exception:
- skipped += 1
- continue
- reason = _backfill_reason(cluster)
- if reason is None:
- continue
- reason_counts[reason] = reason_counts.get(reason, 0) + 1
- # Record a few example cluster_ids per reason (helpful for root cause).
- if sample_limit_per_reason and reason in reason_counts and reason not in first_samples:
- first_samples[reason] = []
- if reason in first_samples and len(first_samples[reason]) < sample_limit_per_reason:
- first_samples[reason].append(cluster.get("cluster_id") or cluster_id)
- sanitized_cluster = sanitize_cluster_payload(cluster, include_resolutions=not scan_only)
- if not dry_run:
- pending_writes.append((sanitized_cluster, topic or sanitized_cluster.get("topic", "other")))
- if write_batch_size > 0 and len(pending_writes) >= write_batch_size:
- topic_groups: dict[str, list[dict]] = {}
- for item_cluster, item_topic in pending_writes:
- topic_groups.setdefault(item_topic, []).append(item_cluster)
- for batch_topic, batch_clusters in topic_groups.items():
- store.upsert_clusters(batch_clusters, topic=batch_topic)
- pending_writes.clear()
- updated += 1
- if limit is not None and updated >= limit:
- break
- if progress_every and total % progress_every == 0:
- logger.info("backfill progress total=%s updated=%s skipped=%s", total, updated, skipped)
- print(".", end="", flush=True)
- if progress_every and total % progress_every == 0:
- print(flush=True)
- if not dry_run and pending_writes:
- topic_groups: dict[str, list[dict]] = {}
- for item_cluster, item_topic in pending_writes:
- topic_groups.setdefault(item_topic, []).append(item_cluster)
- for batch_topic, batch_clusters in topic_groups.items():
- store.upsert_clusters(batch_clusters, topic=batch_topic)
- if total % max(1, progress_every) != 0:
- print(flush=True)
- # Print debug summary by reason.
- if reason_counts:
- print("reason_counts=" + json.dumps(reason_counts, ensure_ascii=False), flush=True)
- for reason, samples in first_samples.items():
- print(f"sample_{reason}=" + ",".join(samples), flush=True)
- return {"total": total, "updated": updated, "skipped": skipped}
- def main() -> None:
- parser = argparse.ArgumentParser(description="Backfill news-mcp entity normalization and resolutions")
- parser.add_argument("--db", type=Path, default=DB_PATH, help="Path to the news sqlite DB")
- parser.add_argument("--limit", type=int, default=None, help="Optional cap on clusters processed")
- parser.add_argument("--dry-run", action="store_true", help="Compute changes without writing them")
- parser.add_argument("--scan-only", action="store_true", help="Only detect stale rows; skip trends lookups")
- parser.add_argument("--progress-every", type=int, default=20, help="Emit a progress line every N processed rows")
- parser.add_argument("--write-batch-size", type=int, default=100, help="Flush writes every N updated clusters")
- args = parser.parse_args()
- logging.basicConfig(level=logging.INFO, format="%(message)s")
- result = asyncio.run(
- backfill(
- args.db,
- limit=args.limit,
- dry_run=args.dry_run,
- scan_only=args.scan_only,
- progress_every=max(0, int(args.progress_every)),
- write_batch_size=max(0, int(args.write_batch_size)),
- )
- )
- mode = "DRY RUN" if args.dry_run else "DONE"
- if args.scan_only:
- mode = f"{mode} SCAN-ONLY"
- print(f"{mode} {result}")
- if __name__ == "__main__":
- main()
|