Ver código fonte

news-mcp: add embedding merge script

Lukas Goldschmidt 1 mês atrás
pai
commit
f2700984de
3 arquivos alterados com 232 adições e 0 exclusões
  1. 2 0
      PROJECT.md
  2. 18 0
      README.md
  3. 212 0
      scripts/merge_cluster_embeddings.py

+ 2 - 0
PROJECT.md

@@ -11,6 +11,7 @@ Provide a signal-extraction MCP server that converts RSS into **deduplicated, en
 - optional Ollama embeddings path for clustering (when `NEWS_EMBEDDINGS_ENABLED=true`)
 - optional embeddings backfill script for precomputing cluster vectors in SQLite
 - optional merge-analysis script for threshold experiments before any DB rewrite
+- optional merge pass for destructive consolidation after threshold review
 - Groq enrichment (topic/entities/sentiment/keywords)
 - Tools expose semantic queries over cached clusters
 
@@ -34,3 +35,4 @@ Provide a signal-extraction MCP server that converts RSS into **deduplicated, en
 - Embeddings remain optional: Ollama is tried first when enabled, otherwise the heuristic path stays active
 - Embeddings backfill script exists for older cluster rows before the server restart
 - Merge-analysis script exists to inspect candidate cluster pairs at multiple thresholds
+- Merge pass exists for destructive consolidation once thresholds look sane

+ 18 - 0
README.md

@@ -187,4 +187,22 @@ anything back to the DB:
 
 This prints candidate pairs per threshold so you can decide whether a merge
 script is worth adding next.
+
+## Embedding merge pass (optional, destructive)
+
+After inspecting the analysis output, you can merge clusters above a chosen
+threshold. Start with dry-run:
+
+```bash
+./.venv/bin/python scripts/merge_cluster_embeddings.py --dry-run --threshold 0.90
+```
+
+If the groupings look right, run wet:
+
+```bash
+./.venv/bin/python scripts/merge_cluster_embeddings.py --threshold 0.90
+```
+
+This merges embedding-similar clusters within the same topic and removes the
+absorbed duplicates from SQLite.
 ```

+ 212 - 0
scripts/merge_cluster_embeddings.py

@@ -0,0 +1,212 @@
+from __future__ import annotations
+
+"""Merge embedding-similar clusters in SQLite.
+
+This is a maintenance script for the embedding-first clustering rollout.
+It supports dry-run mode, reports candidate groups, and when run wet it merges
+clusters within the same topic whose embeddings are similar enough.
+
+Usage:
+  ./.venv/bin/python scripts/merge_cluster_embeddings.py --dry-run --threshold 0.90
+  ./.venv/bin/python scripts/merge_cluster_embeddings.py --threshold 0.88
+"""
+
+import argparse
+import json
+import sys
+from collections import defaultdict
+from pathlib import Path
+from typing import Any
+
+ROOT = Path(__file__).resolve().parents[1]
+sys.path.insert(0, str(ROOT))
+
+from news_mcp.config import DB_PATH
+from news_mcp.dedup.embedding_support import cosine_similarity
+from news_mcp.storage.sqlite_store import SQLiteClusterStore
+
+
+def _embedding(cluster: dict[str, Any]) -> list[float] | None:
+    emb = cluster.get("embedding")
+    if isinstance(emb, list) and emb:
+        try:
+            return [float(x) for x in emb]
+        except Exception:
+            return None
+    return None
+
+
+def _avg_vectors(vectors: list[list[float]]) -> list[float] | None:
+    if not vectors:
+        return None
+    size = len(vectors[0])
+    if any(len(v) != size for v in vectors):
+        return None
+    out = [0.0] * size
+    for v in vectors:
+        for i, x in enumerate(v):
+            out[i] += x
+    n = float(len(vectors))
+    return [x / n for x in out]
+
+
+def _uniq_by_url(items: list[dict[str, Any]]) -> list[dict[str, Any]]:
+    seen = set()
+    out = []
+    for item in items:
+        key = item.get("url") or item.get("title")
+        if key in seen:
+            continue
+        seen.add(key)
+        out.append(item)
+    return out
+
+
+def _merge_payloads(clusters: list[dict[str, Any]]) -> dict[str, Any]:
+    # Choose the most populated cluster as the base.
+    base = max(clusters, key=lambda c: len(c.get("articles", []) or []))
+    merged = dict(base)
+
+    all_articles: list[dict[str, Any]] = []
+    all_sources: list[str] = []
+    all_entities: list[str] = []
+    all_keywords: list[str] = []
+    embeddings: list[list[float]] = []
+    sent_scores: list[float] = []
+
+    first_seen = None
+    last_updated = None
+
+    for c in clusters:
+        all_articles.extend([a for a in (c.get("articles", []) or []) if isinstance(a, dict)])
+        all_sources.extend([str(s) for s in (c.get("sources", []) or [])])
+        all_entities.extend([str(e) for e in (c.get("entities", []) or [])])
+        all_keywords.extend([str(k) for k in (c.get("keywords", []) or [])])
+
+        emb = _embedding(c)
+        if emb:
+            embeddings.append(emb)
+
+        if c.get("sentimentScore") is not None:
+            try:
+                sent_scores.append(float(c["sentimentScore"]))
+            except Exception:
+                pass
+
+        fs = str(c.get("first_seen") or c.get("timestamp") or "")
+        lu = str(c.get("last_updated") or c.get("timestamp") or "")
+        if fs and (first_seen is None or fs < first_seen):
+            first_seen = fs
+        if lu and (last_updated is None or lu > last_updated):
+            last_updated = lu
+
+    merged["articles"] = _uniq_by_url(all_articles)
+    merged["sources"] = list(dict.fromkeys(all_sources))
+    merged["entities"] = list(dict.fromkeys(e for e in all_entities if e))
+    merged["keywords"] = list(dict.fromkeys(k for k in all_keywords if k))
+    merged["first_seen"] = first_seen or merged.get("first_seen")
+    merged["last_updated"] = last_updated or merged.get("last_updated")
+    merged["importance"] = max(float(c.get("importance", 0.0) or 0.0) for c in clusters)
+    if sent_scores:
+        merged["sentimentScore"] = sum(sent_scores) / len(sent_scores)
+    merged["embedding"] = _avg_vectors(embeddings) or merged.get("embedding")
+    merged["embedding_model"] = merged.get("embedding_model") or "ollama:nomic-embed-text"
+    return merged
+
+
+def main() -> None:
+    parser = argparse.ArgumentParser(description="Merge embedding-similar news clusters")
+    parser.add_argument("--db", type=Path, default=DB_PATH)
+    parser.add_argument("--threshold", type=float, default=0.9)
+    parser.add_argument("--dry-run", action="store_true")
+    parser.add_argument("--limit", type=int, default=None, help="Optional maximum number of rows to scan")
+    parser.add_argument("--top", type=int, default=50, help="Max groups to print in dry-run mode")
+    args = parser.parse_args()
+
+    store = SQLiteClusterStore(args.db)
+    with store._conn() as conn:  # noqa: SLF001 - maintenance script
+        rows = conn.execute("SELECT cluster_id, topic, payload FROM clusters ORDER BY updated_at ASC").fetchall()
+
+    if args.limit is not None:
+        rows = rows[: args.limit]
+
+    by_topic: dict[str, list[dict[str, Any]]] = defaultdict(list)
+    for cluster_id, topic, payload_json in rows:
+        try:
+            cluster = json.loads(payload_json)
+        except Exception:
+            continue
+        emb = _embedding(cluster)
+        if not emb:
+            continue
+        cluster["_topic_key"] = topic or cluster.get("topic", "other")
+        by_topic[cluster["_topic_key"]].append(cluster)
+
+    groups = []
+    for topic, clusters in by_topic.items():
+        if len(clusters) < 2:
+            continue
+
+        parent = list(range(len(clusters)))
+
+        def find(x: int) -> int:
+            while parent[x] != x:
+                parent[x] = parent[parent[x]]
+                x = parent[x]
+            return x
+
+        def union(a: int, b: int) -> None:
+            ra, rb = find(a), find(b)
+            if ra != rb:
+                parent[rb] = ra
+
+        for i in range(len(clusters)):
+            ei = _embedding(clusters[i])
+            if not ei:
+                continue
+            for j in range(i + 1, len(clusters)):
+                ej = _embedding(clusters[j])
+                if not ej:
+                    continue
+                if cosine_similarity(ei, ej) >= args.threshold:
+                    union(i, j)
+
+        buckets: dict[int, list[dict[str, Any]]] = defaultdict(list)
+        for idx, c in enumerate(clusters):
+            buckets[find(idx)].append(c)
+
+        for bucket in buckets.values():
+            if len(bucket) < 2:
+                continue
+            bucket = sorted(bucket, key=lambda c: str(c.get("updated_at") or c.get("last_updated") or ""), reverse=True)
+            groups.append((topic, bucket))
+
+    print({"threshold": args.threshold, "merge_groups": len(groups), "dry_run": args.dry_run})
+
+    if args.dry_run:
+        for topic, bucket in groups[: args.top]:
+            print(json.dumps({
+                "topic": topic,
+                "size": len(bucket),
+                "cluster_ids": [c.get("cluster_id") for c in bucket],
+                "headlines": [c.get("headline") for c in bucket],
+            }, ensure_ascii=False))
+        return
+
+    with store._conn() as conn:  # noqa: SLF001
+        for topic, bucket in groups:
+            rep = bucket[0]
+            rep_id = rep.get("cluster_id")
+            merged = _merge_payloads(bucket)
+            store.upsert_clusters([merged], topic=topic)
+            # Delete absorbed duplicates; keep the representative row.
+            dup_ids = [c.get("cluster_id") for c in bucket[1:] if c.get("cluster_id")]
+            for cid in dup_ids:
+                conn.execute("DELETE FROM clusters WHERE cluster_id=?", (cid,))
+        conn.commit()
+
+    print({"merged_groups": len(groups), "threshold": args.threshold})
+
+
+if __name__ == "__main__":
+    main()