Răsfoiți Sursa

Fix entityResolutions drift and harden article dedup

Lukas Goldschmidt 1 lună în urmă
părinte
comite
89e141466f

+ 2 - 0
README.md

@@ -218,4 +218,6 @@ underlying article id / URL path. To clean existing rows:
 ```
 ```
 
 
 The live clustering path also deduplicates article entries when new data comes in.
 The live clustering path also deduplicates article entries when new data comes in.
+
+As of the latest hardening, the server/storage write path also self-heals `payload.articles` by deduplicating before persisting (so historical rows can be fixed via the cleanup script, and future writes won’t reintroduce duplicates). 
 ```
 ```

+ 11 - 0
news_mcp/jobs/poller.py

@@ -7,6 +7,7 @@ from news_mcp.config import CLUSTERS_TTL_HOURS, DB_PATH, NEWS_FEED_URL, NEWS_FEE
 from news_mcp.dedup.cluster import dedup_and_cluster_articles
 from news_mcp.dedup.cluster import dedup_and_cluster_articles
 from news_mcp.enrichment.enrich import enrich_cluster
 from news_mcp.enrichment.enrich import enrich_cluster
 from news_mcp.enrichment.llm_enrich import classify_cluster_groq
 from news_mcp.enrichment.llm_enrich import classify_cluster_groq
+from news_mcp.trends_resolution import resolve_entity_via_trends
 from news_mcp.sources.news_feeds import fetch_news_articles
 from news_mcp.sources.news_feeds import fetch_news_articles
 from news_mcp.storage.sqlite_store import SQLiteClusterStore
 from news_mcp.storage.sqlite_store import SQLiteClusterStore
 
 
@@ -57,6 +58,16 @@ async def refresh_clusters(topic: str | None = None, limit: int = 80) -> None:
                     c2 = dict(c2)
                     c2 = dict(c2)
                     # Keep existing enriched fields.
                     # Keep existing enriched fields.
                     c2["entities"] = existing.get("entities", [])
                     c2["entities"] = existing.get("entities", [])
+
+                    # IMPORTANT: entityResolutions must stay consistent with entities.
+                    # Older rows may have entities but missing/malformed resolutions.
+                    existing_resolutions = existing.get("entityResolutions", None)
+                    if isinstance(existing_resolutions, list) and existing_resolutions:
+                        c2["entityResolutions"] = existing_resolutions
+                    else:
+                        # Recompute resolutions deterministically from the stored entities.
+                        c2["entityResolutions"] = [resolve_entity_via_trends(e) for e in c2["entities"]]
+
                     if existing.get("sentiment"):
                     if existing.get("sentiment"):
                         c2["sentiment"] = existing.get("sentiment")
                         c2["sentiment"] = existing.get("sentiment")
                     if existing.get("sentimentScore") is not None:
                     if existing.get("sentimentScore") is not None:

+ 70 - 0
news_mcp/storage/sqlite_store.py

@@ -6,6 +6,10 @@ from dataclasses import dataclass
 from datetime import datetime, timezone, timedelta
 from datetime import datetime, timezone, timedelta
 from pathlib import Path
 from pathlib import Path
 from typing import Any
 from typing import Any
+from urllib.parse import urlparse
+
+from news_mcp.entity_normalize import normalize_entities
+from news_mcp.trends_resolution import resolve_entity_via_trends
 
 
 
 
 @dataclass
 @dataclass
@@ -16,6 +20,71 @@ class ClusterRow:
     updated_at: datetime
     updated_at: datetime
 
 
 
 
+def _article_key(article: dict[str, Any]) -> str:
+    url = str(article.get("url") or "").strip()
+    if not url:
+        return str(article.get("title") or "")
+    try:
+        parsed = urlparse(url)
+        parts = [p for p in parsed.path.split("/") if p]
+        if parts:
+            return parts[-1]
+    except Exception:
+        pass
+    return url
+
+
+def _dedup_articles(articles: list[dict[str, Any]]) -> list[dict[str, Any]]:
+    seen: set[str] = set()
+    out: list[dict[str, Any]] = []
+    for article in articles:
+        key = _article_key(article)
+        if key in seen:
+            continue
+        seen.add(key)
+        out.append(article)
+    return out
+
+
+def _has_valid_entity_resolutions(resolutions: Any, entities: list[str]) -> bool:
+    if not isinstance(resolutions, list):
+        return False
+    if len(resolutions) != len(entities):
+        return False
+    for res in resolutions:
+        if not isinstance(res, dict):
+            return False
+        if not res.get("normalized") or not res.get("canonical_label"):
+            return False
+    return True
+
+
+def sanitize_cluster_payload(cluster: dict[str, Any], *, include_resolutions: bool = True) -> dict[str, Any]:
+    """Normalize cluster payload so every stored payload is internally consistent."""
+    out = dict(cluster)
+
+    raw_articles = out.get("articles", []) or []
+    articles = [a for a in raw_articles if isinstance(a, dict)]
+    out["articles"] = _dedup_articles(articles)
+
+    raw_entities = out.get("entities", []) or []
+    entities = normalize_entities(raw_entities)
+    out["entities"] = entities
+
+    if not include_resolutions:
+        return out
+
+    resolutions = out.get("entityResolutions", None)
+    if entities:
+        if not _has_valid_entity_resolutions(resolutions, entities):
+            out["entityResolutions"] = [resolve_entity_via_trends(e) for e in entities]
+    else:
+        # Keep the empty case explicit and stable.
+        out["entityResolutions"] = []
+
+    return out
+
+
 class SQLiteClusterStore:
 class SQLiteClusterStore:
     def __init__(self, db_path: str | Path):
     def __init__(self, db_path: str | Path):
         self.db_path = str(db_path)
         self.db_path = str(db_path)
@@ -69,6 +138,7 @@ class SQLiteClusterStore:
         now = datetime.now(timezone.utc)
         now = datetime.now(timezone.utc)
         with self._conn() as conn:
         with self._conn() as conn:
             for c in clusters:
             for c in clusters:
+                c = sanitize_cluster_payload(c)
                 cluster_id = c["cluster_id"]
                 cluster_id = c["cluster_id"]
                 payload = json.dumps(c, ensure_ascii=False)
                 payload = json.dumps(c, ensure_ascii=False)
                 conn.execute(
                 conn.execute(

+ 60 - 28
scripts/backfill_news_entities.py

@@ -12,7 +12,6 @@ import json
 import sys
 import sys
 import logging
 import logging
 from pathlib import Path
 from pathlib import Path
-from typing import Any
 
 
 
 
 ROOT = Path(__file__).resolve().parents[1]
 ROOT = Path(__file__).resolve().parents[1]
@@ -21,32 +20,38 @@ if str(ROOT) not in sys.path:
 
 
 from news_mcp.config import DB_PATH
 from news_mcp.config import DB_PATH
 from news_mcp.entity_normalize import normalize_entities
 from news_mcp.entity_normalize import normalize_entities
-from news_mcp.storage.sqlite_store import SQLiteClusterStore
-from news_mcp.trends_resolution import resolve_entity_via_trends
+from news_mcp.storage.sqlite_store import SQLiteClusterStore, sanitize_cluster_payload
 
 
 
 
-def _compute_entity_resolutions(entities: list[str]) -> list[dict[str, Any]]:
-    return [resolve_entity_via_trends(ent) for ent in entities]
-
-
-def _needs_backfill(cluster: dict) -> bool:
+def _backfill_reason(cluster: dict) -> str | None:
+    """Return a reason string if the cluster needs backfill, else None."""
     raw_entities = cluster.get("entities", []) or []
     raw_entities = cluster.get("entities", []) or []
     normalized_entities = normalize_entities(raw_entities)
     normalized_entities = normalize_entities(raw_entities)
-    resolutions = cluster.get("entityResolutions", []) or []
+    resolutions = cluster.get("entityResolutions", None)
 
 
-    # Clearly missing or stale metadata.
-    if not resolutions:
-        return bool(normalized_entities or raw_entities)
     if normalized_entities != raw_entities:
     if normalized_entities != raw_entities:
-        return True
-    if len(resolutions) != len(normalized_entities):
-        return True
-    for res in resolutions:
+        return "entities_not_normalized"
+
+    if normalized_entities and not resolutions:
+        return "missing_entityResolutions"
+
+    if resolutions is not None and not isinstance(resolutions, list):
+        return "invalid_resolution_shape"
+
+    if isinstance(resolutions, list) and len(resolutions) != len(normalized_entities):
+        return "resolution_length_mismatch"
+
+    for res in resolutions or []:
         if not isinstance(res, dict):
         if not isinstance(res, dict):
-            return True
+            return "invalid_resolution_shape"
         if not res.get("normalized") or not res.get("canonical_label"):
         if not res.get("normalized") or not res.get("canonical_label"):
-            return True
-    return False
+            return "resolution_missing_fields"
+
+    sanitized = sanitize_cluster_payload(cluster)
+    if sanitized != cluster:
+        return "payload_sanitized"
+
+    return None
 
 
 
 
 async def backfill(
 async def backfill(
@@ -55,6 +60,7 @@ async def backfill(
     dry_run: bool = False,
     dry_run: bool = False,
     scan_only: bool = False,
     scan_only: bool = False,
     progress_every: int = 20,
     progress_every: int = 20,
+    write_batch_size: int = 100,
 ) -> dict[str, int]:
 ) -> dict[str, int]:
     store = SQLiteClusterStore(db_path)
     store = SQLiteClusterStore(db_path)
 
 
@@ -65,6 +71,10 @@ async def backfill(
     total = 0
     total = 0
     updated = 0
     updated = 0
     skipped = 0
     skipped = 0
+    reason_counts: dict[str, int] = {}
+    first_samples: dict[str, list[str]] = {}
+    sample_limit_per_reason = 5
+    pending_writes: list[tuple[dict, str]] = []
     logger = logging.getLogger("news_mcp.backfill")
     logger = logging.getLogger("news_mcp.backfill")
 
 
     print(f"starting backfill: total_rows={len(rows)} limit={limit or 'all'} dry_run={dry_run} scan_only={scan_only}", flush=True)
     print(f"starting backfill: total_rows={len(rows)} limit={limit or 'all'} dry_run={dry_run} scan_only={scan_only}", flush=True)
@@ -77,20 +87,28 @@ async def backfill(
             skipped += 1
             skipped += 1
             continue
             continue
 
 
-        if not _needs_backfill(cluster):
+        reason = _backfill_reason(cluster)
+        if reason is None:
             continue
             continue
+        reason_counts[reason] = reason_counts.get(reason, 0) + 1
 
 
-        raw_entities = cluster.get("entities", []) or []
-        normalized_entities = normalize_entities(raw_entities)
-        entity_resolutions = _compute_entity_resolutions(normalized_entities) if not scan_only else []
+        # Record a few example cluster_ids per reason (helpful for root cause).
+        if sample_limit_per_reason and reason in reason_counts and reason not in first_samples:
+            first_samples[reason] = []
+        if reason in first_samples and len(first_samples[reason]) < sample_limit_per_reason:
+            first_samples[reason].append(cluster.get("cluster_id") or cluster_id)
 
 
-        cluster = dict(cluster)
-        cluster["entities"] = normalized_entities
-        if not scan_only:
-            cluster["entityResolutions"] = entity_resolutions
+        sanitized_cluster = sanitize_cluster_payload(cluster, include_resolutions=not scan_only)
 
 
         if not dry_run:
         if not dry_run:
-            store.upsert_clusters([cluster], topic=topic or cluster.get("topic", "other"))
+            pending_writes.append((sanitized_cluster, topic or sanitized_cluster.get("topic", "other")))
+            if write_batch_size > 0 and len(pending_writes) >= write_batch_size:
+                topic_groups: dict[str, list[dict]] = {}
+                for item_cluster, item_topic in pending_writes:
+                    topic_groups.setdefault(item_topic, []).append(item_cluster)
+                for batch_topic, batch_clusters in topic_groups.items():
+                    store.upsert_clusters(batch_clusters, topic=batch_topic)
+                pending_writes.clear()
         updated += 1
         updated += 1
 
 
         if limit is not None and updated >= limit:
         if limit is not None and updated >= limit:
@@ -103,9 +121,21 @@ async def backfill(
         if progress_every and total % progress_every == 0:
         if progress_every and total % progress_every == 0:
             print(flush=True)
             print(flush=True)
 
 
+    if not dry_run and pending_writes:
+        topic_groups: dict[str, list[dict]] = {}
+        for item_cluster, item_topic in pending_writes:
+            topic_groups.setdefault(item_topic, []).append(item_cluster)
+        for batch_topic, batch_clusters in topic_groups.items():
+            store.upsert_clusters(batch_clusters, topic=batch_topic)
+
     if total % max(1, progress_every) != 0:
     if total % max(1, progress_every) != 0:
         print(flush=True)
         print(flush=True)
 
 
+    # Print debug summary by reason.
+    if reason_counts:
+        print("reason_counts=" + json.dumps(reason_counts, ensure_ascii=False), flush=True)
+        for reason, samples in first_samples.items():
+            print(f"sample_{reason}=" + ",".join(samples), flush=True)
     return {"total": total, "updated": updated, "skipped": skipped}
     return {"total": total, "updated": updated, "skipped": skipped}
 
 
 
 
@@ -116,6 +146,7 @@ def main() -> None:
     parser.add_argument("--dry-run", action="store_true", help="Compute changes without writing them")
     parser.add_argument("--dry-run", action="store_true", help="Compute changes without writing them")
     parser.add_argument("--scan-only", action="store_true", help="Only detect stale rows; skip trends lookups")
     parser.add_argument("--scan-only", action="store_true", help="Only detect stale rows; skip trends lookups")
     parser.add_argument("--progress-every", type=int, default=20, help="Emit a progress line every N processed rows")
     parser.add_argument("--progress-every", type=int, default=20, help="Emit a progress line every N processed rows")
+    parser.add_argument("--write-batch-size", type=int, default=100, help="Flush writes every N updated clusters")
     args = parser.parse_args()
     args = parser.parse_args()
 
 
     logging.basicConfig(level=logging.INFO, format="%(message)s")
     logging.basicConfig(level=logging.INFO, format="%(message)s")
@@ -126,6 +157,7 @@ def main() -> None:
             dry_run=args.dry_run,
             dry_run=args.dry_run,
             scan_only=args.scan_only,
             scan_only=args.scan_only,
             progress_every=max(0, int(args.progress_every)),
             progress_every=max(0, int(args.progress_every)),
+            write_batch_size=max(0, int(args.write_batch_size)),
         )
         )
     )
     )
     mode = "DRY RUN" if args.dry_run else "DONE"
     mode = "DRY RUN" if args.dry_run else "DONE"

+ 6 - 34
scripts/dedup_articles_in_clusters.py

@@ -14,40 +14,12 @@ import argparse
 import json
 import json
 import sys
 import sys
 from pathlib import Path
 from pathlib import Path
-from typing import Any
-from urllib.parse import urlparse
 
 
 ROOT = Path(__file__).resolve().parents[1]
 ROOT = Path(__file__).resolve().parents[1]
 sys.path.insert(0, str(ROOT))
 sys.path.insert(0, str(ROOT))
 
 
 from news_mcp.config import DB_PATH
 from news_mcp.config import DB_PATH
-from news_mcp.storage.sqlite_store import SQLiteClusterStore
-
-
-def _article_key(article: dict[str, Any]) -> str:
-    url = str(article.get("url") or "").strip()
-    if not url:
-        return str(article.get("title") or "")
-    try:
-        parsed = urlparse(url)
-        parts = [p for p in parsed.path.split("/") if p]
-        if parts:
-            return parts[-1]
-    except Exception:
-        pass
-    return url
-
-
-def _dedup_articles(articles: list[dict[str, Any]]) -> list[dict[str, Any]]:
-    seen = set()
-    out = []
-    for article in articles:
-        key = _article_key(article)
-        if key in seen:
-            continue
-        seen.add(key)
-        out.append(article)
-    return out
+from news_mcp.storage.sqlite_store import SQLiteClusterStore, sanitize_cluster_payload
 
 
 
 
 def main() -> None:
 def main() -> None:
@@ -76,13 +48,13 @@ def main() -> None:
         except Exception:
         except Exception:
             continue
             continue
 
 
-        articles = cluster.get("articles", []) or []
-        deduped = _dedup_articles([a for a in articles if isinstance(a, dict)])
-        if len(deduped) == len(articles):
+        sanitized = sanitize_cluster_payload(cluster)
+        original_articles = cluster.get("articles", []) or []
+        deduped = sanitized.get("articles", []) or []
+        if deduped == original_articles:
             continue
             continue
 
 
-        cluster = dict(cluster)
-        cluster["articles"] = deduped
+        cluster = sanitized
         if not args.dry_run:
         if not args.dry_run:
             store.upsert_clusters([cluster], topic=topic or cluster.get("topic", "other"))
             store.upsert_clusters([cluster], topic=topic or cluster.get("topic", "other"))
         updated += 1
         updated += 1