| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138 |
- from __future__ import annotations
- """One-off backfill for news-mcp cluster entity metadata.
- This reprocesses every stored cluster so older rows pick up the current
- entity normalization and trends resolution fields used by lookup tools.
- """
- import argparse
- import asyncio
- import json
- import sys
- import logging
- from pathlib import Path
- from typing import Any
- ROOT = Path(__file__).resolve().parents[1]
- if str(ROOT) not in sys.path:
- sys.path.insert(0, str(ROOT))
- from news_mcp.config import DB_PATH
- from news_mcp.entity_normalize import normalize_entities
- from news_mcp.storage.sqlite_store import SQLiteClusterStore
- from news_mcp.trends_resolution import resolve_entity_via_trends
- def _compute_entity_resolutions(entities: list[str]) -> list[dict[str, Any]]:
- return [resolve_entity_via_trends(ent) for ent in entities]
- def _needs_backfill(cluster: dict) -> bool:
- raw_entities = cluster.get("entities", []) or []
- normalized_entities = normalize_entities(raw_entities)
- resolutions = cluster.get("entityResolutions", []) or []
- # Clearly missing or stale metadata.
- if not resolutions:
- return bool(normalized_entities or raw_entities)
- if normalized_entities != raw_entities:
- return True
- if len(resolutions) != len(normalized_entities):
- return True
- for res in resolutions:
- if not isinstance(res, dict):
- return True
- if not res.get("normalized") or not res.get("canonical_label"):
- return True
- return False
- async def backfill(
- db_path: Path,
- limit: int | None = None,
- dry_run: bool = False,
- scan_only: bool = False,
- progress_every: int = 20,
- ) -> dict[str, int]:
- store = SQLiteClusterStore(db_path)
- with store._conn() as conn: # noqa: SLF001 - intentional one-off maintenance script
- cur = conn.execute("SELECT cluster_id, topic, payload FROM clusters ORDER BY updated_at ASC")
- rows = cur.fetchall()
- total = 0
- updated = 0
- skipped = 0
- logger = logging.getLogger("news_mcp.backfill")
- print(f"starting backfill: total_rows={len(rows)} limit={limit or 'all'} dry_run={dry_run} scan_only={scan_only}", flush=True)
- for cluster_id, topic, payload_json in rows:
- total += 1
- try:
- cluster = json.loads(payload_json)
- except Exception:
- skipped += 1
- continue
- if not _needs_backfill(cluster):
- continue
- raw_entities = cluster.get("entities", []) or []
- normalized_entities = normalize_entities(raw_entities)
- entity_resolutions = _compute_entity_resolutions(normalized_entities) if not scan_only else []
- cluster = dict(cluster)
- cluster["entities"] = normalized_entities
- if not scan_only:
- cluster["entityResolutions"] = entity_resolutions
- if not dry_run:
- store.upsert_clusters([cluster], topic=topic or cluster.get("topic", "other"))
- updated += 1
- if limit is not None and updated >= limit:
- break
- if progress_every and total % progress_every == 0:
- logger.info("backfill progress total=%s updated=%s skipped=%s", total, updated, skipped)
- print(".", end="", flush=True)
- if progress_every and total % progress_every == 0:
- print(flush=True)
- if total % max(1, progress_every) != 0:
- print(flush=True)
- return {"total": total, "updated": updated, "skipped": skipped}
- def main() -> None:
- parser = argparse.ArgumentParser(description="Backfill news-mcp entity normalization and resolutions")
- parser.add_argument("--db", type=Path, default=DB_PATH, help="Path to the news sqlite DB")
- parser.add_argument("--limit", type=int, default=None, help="Optional cap on clusters processed")
- parser.add_argument("--dry-run", action="store_true", help="Compute changes without writing them")
- parser.add_argument("--scan-only", action="store_true", help="Only detect stale rows; skip trends lookups")
- parser.add_argument("--progress-every", type=int, default=20, help="Emit a progress line every N processed rows")
- args = parser.parse_args()
- logging.basicConfig(level=logging.INFO, format="%(message)s")
- result = asyncio.run(
- backfill(
- args.db,
- limit=args.limit,
- dry_run=args.dry_run,
- scan_only=args.scan_only,
- progress_every=max(0, int(args.progress_every)),
- )
- )
- mode = "DRY RUN" if args.dry_run else "DONE"
- if args.scan_only:
- mode = f"{mode} SCAN-ONLY"
- print(f"{mode} {result}")
- if __name__ == "__main__":
- main()
|