|
|
@@ -12,7 +12,6 @@ import json
|
|
|
import sys
|
|
|
import logging
|
|
|
from pathlib import Path
|
|
|
-from typing import Any
|
|
|
|
|
|
|
|
|
ROOT = Path(__file__).resolve().parents[1]
|
|
|
@@ -21,32 +20,38 @@ if str(ROOT) not in sys.path:
|
|
|
|
|
|
from news_mcp.config import DB_PATH
|
|
|
from news_mcp.entity_normalize import normalize_entities
|
|
|
-from news_mcp.storage.sqlite_store import SQLiteClusterStore
|
|
|
-from news_mcp.trends_resolution import resolve_entity_via_trends
|
|
|
+from news_mcp.storage.sqlite_store import SQLiteClusterStore, sanitize_cluster_payload
|
|
|
|
|
|
|
|
|
-def _compute_entity_resolutions(entities: list[str]) -> list[dict[str, Any]]:
|
|
|
- return [resolve_entity_via_trends(ent) for ent in entities]
|
|
|
-
|
|
|
-
|
|
|
-def _needs_backfill(cluster: dict) -> bool:
|
|
|
+def _backfill_reason(cluster: dict) -> str | None:
|
|
|
+ """Return a reason string if the cluster needs backfill, else None."""
|
|
|
raw_entities = cluster.get("entities", []) or []
|
|
|
normalized_entities = normalize_entities(raw_entities)
|
|
|
- resolutions = cluster.get("entityResolutions", []) or []
|
|
|
+ resolutions = cluster.get("entityResolutions", None)
|
|
|
|
|
|
- # Clearly missing or stale metadata.
|
|
|
- if not resolutions:
|
|
|
- return bool(normalized_entities or raw_entities)
|
|
|
if normalized_entities != raw_entities:
|
|
|
- return True
|
|
|
- if len(resolutions) != len(normalized_entities):
|
|
|
- return True
|
|
|
- for res in resolutions:
|
|
|
+ return "entities_not_normalized"
|
|
|
+
|
|
|
+ if normalized_entities and not resolutions:
|
|
|
+ return "missing_entityResolutions"
|
|
|
+
|
|
|
+ if resolutions is not None and not isinstance(resolutions, list):
|
|
|
+ return "invalid_resolution_shape"
|
|
|
+
|
|
|
+ if isinstance(resolutions, list) and len(resolutions) != len(normalized_entities):
|
|
|
+ return "resolution_length_mismatch"
|
|
|
+
|
|
|
+ for res in resolutions or []:
|
|
|
if not isinstance(res, dict):
|
|
|
- return True
|
|
|
+ return "invalid_resolution_shape"
|
|
|
if not res.get("normalized") or not res.get("canonical_label"):
|
|
|
- return True
|
|
|
- return False
|
|
|
+ return "resolution_missing_fields"
|
|
|
+
|
|
|
+ sanitized = sanitize_cluster_payload(cluster)
|
|
|
+ if sanitized != cluster:
|
|
|
+ return "payload_sanitized"
|
|
|
+
|
|
|
+ return None
|
|
|
|
|
|
|
|
|
async def backfill(
|
|
|
@@ -55,6 +60,7 @@ async def backfill(
|
|
|
dry_run: bool = False,
|
|
|
scan_only: bool = False,
|
|
|
progress_every: int = 20,
|
|
|
+ write_batch_size: int = 100,
|
|
|
) -> dict[str, int]:
|
|
|
store = SQLiteClusterStore(db_path)
|
|
|
|
|
|
@@ -65,6 +71,10 @@ async def backfill(
|
|
|
total = 0
|
|
|
updated = 0
|
|
|
skipped = 0
|
|
|
+ reason_counts: dict[str, int] = {}
|
|
|
+ first_samples: dict[str, list[str]] = {}
|
|
|
+ sample_limit_per_reason = 5
|
|
|
+ pending_writes: list[tuple[dict, str]] = []
|
|
|
logger = logging.getLogger("news_mcp.backfill")
|
|
|
|
|
|
print(f"starting backfill: total_rows={len(rows)} limit={limit or 'all'} dry_run={dry_run} scan_only={scan_only}", flush=True)
|
|
|
@@ -77,20 +87,28 @@ async def backfill(
|
|
|
skipped += 1
|
|
|
continue
|
|
|
|
|
|
- if not _needs_backfill(cluster):
|
|
|
+ reason = _backfill_reason(cluster)
|
|
|
+ if reason is None:
|
|
|
continue
|
|
|
+ reason_counts[reason] = reason_counts.get(reason, 0) + 1
|
|
|
|
|
|
- raw_entities = cluster.get("entities", []) or []
|
|
|
- normalized_entities = normalize_entities(raw_entities)
|
|
|
- entity_resolutions = _compute_entity_resolutions(normalized_entities) if not scan_only else []
|
|
|
+ # Record a few example cluster_ids per reason (helpful for root cause).
|
|
|
+ if sample_limit_per_reason and reason in reason_counts and reason not in first_samples:
|
|
|
+ first_samples[reason] = []
|
|
|
+ if reason in first_samples and len(first_samples[reason]) < sample_limit_per_reason:
|
|
|
+ first_samples[reason].append(cluster.get("cluster_id") or cluster_id)
|
|
|
|
|
|
- cluster = dict(cluster)
|
|
|
- cluster["entities"] = normalized_entities
|
|
|
- if not scan_only:
|
|
|
- cluster["entityResolutions"] = entity_resolutions
|
|
|
+ sanitized_cluster = sanitize_cluster_payload(cluster, include_resolutions=not scan_only)
|
|
|
|
|
|
if not dry_run:
|
|
|
- store.upsert_clusters([cluster], topic=topic or cluster.get("topic", "other"))
|
|
|
+ pending_writes.append((sanitized_cluster, topic or sanitized_cluster.get("topic", "other")))
|
|
|
+ if write_batch_size > 0 and len(pending_writes) >= write_batch_size:
|
|
|
+ topic_groups: dict[str, list[dict]] = {}
|
|
|
+ for item_cluster, item_topic in pending_writes:
|
|
|
+ topic_groups.setdefault(item_topic, []).append(item_cluster)
|
|
|
+ for batch_topic, batch_clusters in topic_groups.items():
|
|
|
+ store.upsert_clusters(batch_clusters, topic=batch_topic)
|
|
|
+ pending_writes.clear()
|
|
|
updated += 1
|
|
|
|
|
|
if limit is not None and updated >= limit:
|
|
|
@@ -103,9 +121,21 @@ async def backfill(
|
|
|
if progress_every and total % progress_every == 0:
|
|
|
print(flush=True)
|
|
|
|
|
|
+ if not dry_run and pending_writes:
|
|
|
+ topic_groups: dict[str, list[dict]] = {}
|
|
|
+ for item_cluster, item_topic in pending_writes:
|
|
|
+ topic_groups.setdefault(item_topic, []).append(item_cluster)
|
|
|
+ for batch_topic, batch_clusters in topic_groups.items():
|
|
|
+ store.upsert_clusters(batch_clusters, topic=batch_topic)
|
|
|
+
|
|
|
if total % max(1, progress_every) != 0:
|
|
|
print(flush=True)
|
|
|
|
|
|
+ # Print debug summary by reason.
|
|
|
+ if reason_counts:
|
|
|
+ print("reason_counts=" + json.dumps(reason_counts, ensure_ascii=False), flush=True)
|
|
|
+ for reason, samples in first_samples.items():
|
|
|
+ print(f"sample_{reason}=" + ",".join(samples), flush=True)
|
|
|
return {"total": total, "updated": updated, "skipped": skipped}
|
|
|
|
|
|
|
|
|
@@ -116,6 +146,7 @@ def main() -> None:
|
|
|
parser.add_argument("--dry-run", action="store_true", help="Compute changes without writing them")
|
|
|
parser.add_argument("--scan-only", action="store_true", help="Only detect stale rows; skip trends lookups")
|
|
|
parser.add_argument("--progress-every", type=int, default=20, help="Emit a progress line every N processed rows")
|
|
|
+ parser.add_argument("--write-batch-size", type=int, default=100, help="Flush writes every N updated clusters")
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format="%(message)s")
|
|
|
@@ -126,6 +157,7 @@ def main() -> None:
|
|
|
dry_run=args.dry_run,
|
|
|
scan_only=args.scan_only,
|
|
|
progress_every=max(0, int(args.progress_every)),
|
|
|
+ write_batch_size=max(0, int(args.write_batch_size)),
|
|
|
)
|
|
|
)
|
|
|
mode = "DRY RUN" if args.dry_run else "DONE"
|