|
@@ -0,0 +1,212 @@
|
|
|
|
|
+from __future__ import annotations
|
|
|
|
|
+
|
|
|
|
|
+"""Merge embedding-similar clusters in SQLite.
|
|
|
|
|
+
|
|
|
|
|
+This is a maintenance script for the embedding-first clustering rollout.
|
|
|
|
|
+It supports dry-run mode, reports candidate groups, and when run wet it merges
|
|
|
|
|
+clusters within the same topic whose embeddings are similar enough.
|
|
|
|
|
+
|
|
|
|
|
+Usage:
|
|
|
|
|
+ ./.venv/bin/python scripts/merge_cluster_embeddings.py --dry-run --threshold 0.90
|
|
|
|
|
+ ./.venv/bin/python scripts/merge_cluster_embeddings.py --threshold 0.88
|
|
|
|
|
+"""
|
|
|
|
|
+
|
|
|
|
|
+import argparse
|
|
|
|
|
+import json
|
|
|
|
|
+import sys
|
|
|
|
|
+from collections import defaultdict
|
|
|
|
|
+from pathlib import Path
|
|
|
|
|
+from typing import Any
|
|
|
|
|
+
|
|
|
|
|
+ROOT = Path(__file__).resolve().parents[1]
|
|
|
|
|
+sys.path.insert(0, str(ROOT))
|
|
|
|
|
+
|
|
|
|
|
+from news_mcp.config import DB_PATH
|
|
|
|
|
+from news_mcp.dedup.embedding_support import cosine_similarity
|
|
|
|
|
+from news_mcp.storage.sqlite_store import SQLiteClusterStore
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def _embedding(cluster: dict[str, Any]) -> list[float] | None:
|
|
|
|
|
+ emb = cluster.get("embedding")
|
|
|
|
|
+ if isinstance(emb, list) and emb:
|
|
|
|
|
+ try:
|
|
|
|
|
+ return [float(x) for x in emb]
|
|
|
|
|
+ except Exception:
|
|
|
|
|
+ return None
|
|
|
|
|
+ return None
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def _avg_vectors(vectors: list[list[float]]) -> list[float] | None:
|
|
|
|
|
+ if not vectors:
|
|
|
|
|
+ return None
|
|
|
|
|
+ size = len(vectors[0])
|
|
|
|
|
+ if any(len(v) != size for v in vectors):
|
|
|
|
|
+ return None
|
|
|
|
|
+ out = [0.0] * size
|
|
|
|
|
+ for v in vectors:
|
|
|
|
|
+ for i, x in enumerate(v):
|
|
|
|
|
+ out[i] += x
|
|
|
|
|
+ n = float(len(vectors))
|
|
|
|
|
+ return [x / n for x in out]
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def _uniq_by_url(items: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
|
|
|
+ seen = set()
|
|
|
|
|
+ out = []
|
|
|
|
|
+ for item in items:
|
|
|
|
|
+ key = item.get("url") or item.get("title")
|
|
|
|
|
+ if key in seen:
|
|
|
|
|
+ continue
|
|
|
|
|
+ seen.add(key)
|
|
|
|
|
+ out.append(item)
|
|
|
|
|
+ return out
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def _merge_payloads(clusters: list[dict[str, Any]]) -> dict[str, Any]:
|
|
|
|
|
+ # Choose the most populated cluster as the base.
|
|
|
|
|
+ base = max(clusters, key=lambda c: len(c.get("articles", []) or []))
|
|
|
|
|
+ merged = dict(base)
|
|
|
|
|
+
|
|
|
|
|
+ all_articles: list[dict[str, Any]] = []
|
|
|
|
|
+ all_sources: list[str] = []
|
|
|
|
|
+ all_entities: list[str] = []
|
|
|
|
|
+ all_keywords: list[str] = []
|
|
|
|
|
+ embeddings: list[list[float]] = []
|
|
|
|
|
+ sent_scores: list[float] = []
|
|
|
|
|
+
|
|
|
|
|
+ first_seen = None
|
|
|
|
|
+ last_updated = None
|
|
|
|
|
+
|
|
|
|
|
+ for c in clusters:
|
|
|
|
|
+ all_articles.extend([a for a in (c.get("articles", []) or []) if isinstance(a, dict)])
|
|
|
|
|
+ all_sources.extend([str(s) for s in (c.get("sources", []) or [])])
|
|
|
|
|
+ all_entities.extend([str(e) for e in (c.get("entities", []) or [])])
|
|
|
|
|
+ all_keywords.extend([str(k) for k in (c.get("keywords", []) or [])])
|
|
|
|
|
+
|
|
|
|
|
+ emb = _embedding(c)
|
|
|
|
|
+ if emb:
|
|
|
|
|
+ embeddings.append(emb)
|
|
|
|
|
+
|
|
|
|
|
+ if c.get("sentimentScore") is not None:
|
|
|
|
|
+ try:
|
|
|
|
|
+ sent_scores.append(float(c["sentimentScore"]))
|
|
|
|
|
+ except Exception:
|
|
|
|
|
+ pass
|
|
|
|
|
+
|
|
|
|
|
+ fs = str(c.get("first_seen") or c.get("timestamp") or "")
|
|
|
|
|
+ lu = str(c.get("last_updated") or c.get("timestamp") or "")
|
|
|
|
|
+ if fs and (first_seen is None or fs < first_seen):
|
|
|
|
|
+ first_seen = fs
|
|
|
|
|
+ if lu and (last_updated is None or lu > last_updated):
|
|
|
|
|
+ last_updated = lu
|
|
|
|
|
+
|
|
|
|
|
+ merged["articles"] = _uniq_by_url(all_articles)
|
|
|
|
|
+ merged["sources"] = list(dict.fromkeys(all_sources))
|
|
|
|
|
+ merged["entities"] = list(dict.fromkeys(e for e in all_entities if e))
|
|
|
|
|
+ merged["keywords"] = list(dict.fromkeys(k for k in all_keywords if k))
|
|
|
|
|
+ merged["first_seen"] = first_seen or merged.get("first_seen")
|
|
|
|
|
+ merged["last_updated"] = last_updated or merged.get("last_updated")
|
|
|
|
|
+ merged["importance"] = max(float(c.get("importance", 0.0) or 0.0) for c in clusters)
|
|
|
|
|
+ if sent_scores:
|
|
|
|
|
+ merged["sentimentScore"] = sum(sent_scores) / len(sent_scores)
|
|
|
|
|
+ merged["embedding"] = _avg_vectors(embeddings) or merged.get("embedding")
|
|
|
|
|
+ merged["embedding_model"] = merged.get("embedding_model") or "ollama:nomic-embed-text"
|
|
|
|
|
+ return merged
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def main() -> None:
|
|
|
|
|
+ parser = argparse.ArgumentParser(description="Merge embedding-similar news clusters")
|
|
|
|
|
+ parser.add_argument("--db", type=Path, default=DB_PATH)
|
|
|
|
|
+ parser.add_argument("--threshold", type=float, default=0.9)
|
|
|
|
|
+ parser.add_argument("--dry-run", action="store_true")
|
|
|
|
|
+ parser.add_argument("--limit", type=int, default=None, help="Optional maximum number of rows to scan")
|
|
|
|
|
+ parser.add_argument("--top", type=int, default=50, help="Max groups to print in dry-run mode")
|
|
|
|
|
+ args = parser.parse_args()
|
|
|
|
|
+
|
|
|
|
|
+ store = SQLiteClusterStore(args.db)
|
|
|
|
|
+ with store._conn() as conn: # noqa: SLF001 - maintenance script
|
|
|
|
|
+ rows = conn.execute("SELECT cluster_id, topic, payload FROM clusters ORDER BY updated_at ASC").fetchall()
|
|
|
|
|
+
|
|
|
|
|
+ if args.limit is not None:
|
|
|
|
|
+ rows = rows[: args.limit]
|
|
|
|
|
+
|
|
|
|
|
+ by_topic: dict[str, list[dict[str, Any]]] = defaultdict(list)
|
|
|
|
|
+ for cluster_id, topic, payload_json in rows:
|
|
|
|
|
+ try:
|
|
|
|
|
+ cluster = json.loads(payload_json)
|
|
|
|
|
+ except Exception:
|
|
|
|
|
+ continue
|
|
|
|
|
+ emb = _embedding(cluster)
|
|
|
|
|
+ if not emb:
|
|
|
|
|
+ continue
|
|
|
|
|
+ cluster["_topic_key"] = topic or cluster.get("topic", "other")
|
|
|
|
|
+ by_topic[cluster["_topic_key"]].append(cluster)
|
|
|
|
|
+
|
|
|
|
|
+ groups = []
|
|
|
|
|
+ for topic, clusters in by_topic.items():
|
|
|
|
|
+ if len(clusters) < 2:
|
|
|
|
|
+ continue
|
|
|
|
|
+
|
|
|
|
|
+ parent = list(range(len(clusters)))
|
|
|
|
|
+
|
|
|
|
|
+ def find(x: int) -> int:
|
|
|
|
|
+ while parent[x] != x:
|
|
|
|
|
+ parent[x] = parent[parent[x]]
|
|
|
|
|
+ x = parent[x]
|
|
|
|
|
+ return x
|
|
|
|
|
+
|
|
|
|
|
+ def union(a: int, b: int) -> None:
|
|
|
|
|
+ ra, rb = find(a), find(b)
|
|
|
|
|
+ if ra != rb:
|
|
|
|
|
+ parent[rb] = ra
|
|
|
|
|
+
|
|
|
|
|
+ for i in range(len(clusters)):
|
|
|
|
|
+ ei = _embedding(clusters[i])
|
|
|
|
|
+ if not ei:
|
|
|
|
|
+ continue
|
|
|
|
|
+ for j in range(i + 1, len(clusters)):
|
|
|
|
|
+ ej = _embedding(clusters[j])
|
|
|
|
|
+ if not ej:
|
|
|
|
|
+ continue
|
|
|
|
|
+ if cosine_similarity(ei, ej) >= args.threshold:
|
|
|
|
|
+ union(i, j)
|
|
|
|
|
+
|
|
|
|
|
+ buckets: dict[int, list[dict[str, Any]]] = defaultdict(list)
|
|
|
|
|
+ for idx, c in enumerate(clusters):
|
|
|
|
|
+ buckets[find(idx)].append(c)
|
|
|
|
|
+
|
|
|
|
|
+ for bucket in buckets.values():
|
|
|
|
|
+ if len(bucket) < 2:
|
|
|
|
|
+ continue
|
|
|
|
|
+ bucket = sorted(bucket, key=lambda c: str(c.get("updated_at") or c.get("last_updated") or ""), reverse=True)
|
|
|
|
|
+ groups.append((topic, bucket))
|
|
|
|
|
+
|
|
|
|
|
+ print({"threshold": args.threshold, "merge_groups": len(groups), "dry_run": args.dry_run})
|
|
|
|
|
+
|
|
|
|
|
+ if args.dry_run:
|
|
|
|
|
+ for topic, bucket in groups[: args.top]:
|
|
|
|
|
+ print(json.dumps({
|
|
|
|
|
+ "topic": topic,
|
|
|
|
|
+ "size": len(bucket),
|
|
|
|
|
+ "cluster_ids": [c.get("cluster_id") for c in bucket],
|
|
|
|
|
+ "headlines": [c.get("headline") for c in bucket],
|
|
|
|
|
+ }, ensure_ascii=False))
|
|
|
|
|
+ return
|
|
|
|
|
+
|
|
|
|
|
+ with store._conn() as conn: # noqa: SLF001
|
|
|
|
|
+ for topic, bucket in groups:
|
|
|
|
|
+ rep = bucket[0]
|
|
|
|
|
+ rep_id = rep.get("cluster_id")
|
|
|
|
|
+ merged = _merge_payloads(bucket)
|
|
|
|
|
+ store.upsert_clusters([merged], topic=topic)
|
|
|
|
|
+ # Delete absorbed duplicates; keep the representative row.
|
|
|
|
|
+ dup_ids = [c.get("cluster_id") for c in bucket[1:] if c.get("cluster_id")]
|
|
|
|
|
+ for cid in dup_ids:
|
|
|
|
|
+ conn.execute("DELETE FROM clusters WHERE cluster_id=?", (cid,))
|
|
|
|
|
+ conn.commit()
|
|
|
|
|
+
|
|
|
|
|
+ print({"merged_groups": len(groups), "threshold": args.threshold})
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+if __name__ == "__main__":
|
|
|
|
|
+ main()
|