|
|
@@ -0,0 +1,108 @@
|
|
|
+from __future__ import annotations
|
|
|
+
|
|
|
+"""Analyze stored cluster embeddings for possible merges.
|
|
|
+
|
|
|
+This script does not modify the DB. It scans clusters that already have an
|
|
|
+`embedding` field, compares clusters within the same topic, and reports likely
|
|
|
+merge candidates for one or more cosine-similarity thresholds.
|
|
|
+
|
|
|
+Usage:
|
|
|
+ ./.venv/bin/python scripts/analyze_cluster_embedding_merges.py --thresholds 0.82 0.85 0.88
|
|
|
+ ./.venv/bin/python scripts/analyze_cluster_embedding_merges.py --threshold 0.85 --limit 300
|
|
|
+"""
|
|
|
+
|
|
|
+import argparse
|
|
|
+import json
|
|
|
+import sys
|
|
|
+from collections import defaultdict
|
|
|
+from pathlib import Path
|
|
|
+from typing import Any
|
|
|
+
|
|
|
+ROOT = Path(__file__).resolve().parents[1]
|
|
|
+sys.path.insert(0, str(ROOT))
|
|
|
+
|
|
|
+from news_mcp.config import DB_PATH
|
|
|
+from news_mcp.dedup.embedding_support import cosine_similarity
|
|
|
+from news_mcp.storage.sqlite_store import SQLiteClusterStore
|
|
|
+
|
|
|
+
|
|
|
+def _embedding(cluster: dict[str, Any]) -> list[float] | None:
|
|
|
+ emb = cluster.get("embedding")
|
|
|
+ if isinstance(emb, list) and emb:
|
|
|
+ try:
|
|
|
+ return [float(x) for x in emb]
|
|
|
+ except Exception:
|
|
|
+ return None
|
|
|
+ return None
|
|
|
+
|
|
|
+
|
|
|
+def _title(cluster: dict[str, Any]) -> str:
|
|
|
+ return str(cluster.get("headline") or "").strip()
|
|
|
+
|
|
|
+
|
|
|
+def main() -> None:
|
|
|
+ parser = argparse.ArgumentParser(description="Analyze news cluster embeddings for merge candidates")
|
|
|
+ parser.add_argument("--db", type=Path, default=DB_PATH)
|
|
|
+ parser.add_argument("--threshold", type=float, default=0.85, help="Single threshold to analyze")
|
|
|
+ parser.add_argument("--thresholds", type=float, nargs="*", default=None, help="Multiple thresholds to analyze")
|
|
|
+ parser.add_argument("--limit", type=int, default=None, help="Optional maximum number of clusters to scan")
|
|
|
+ parser.add_argument("--top", type=int, default=50, help="Maximum candidate pairs to print per threshold")
|
|
|
+ args = parser.parse_args()
|
|
|
+
|
|
|
+ thresholds = args.thresholds if args.thresholds else [args.threshold]
|
|
|
+ thresholds = sorted(set(float(t) for t in thresholds))
|
|
|
+
|
|
|
+ store = SQLiteClusterStore(args.db)
|
|
|
+ with store._conn() as conn: # noqa: SLF001 - maintenance script
|
|
|
+ rows = conn.execute("SELECT topic, payload FROM clusters ORDER BY updated_at ASC").fetchall()
|
|
|
+
|
|
|
+ if args.limit is not None:
|
|
|
+ rows = rows[: args.limit]
|
|
|
+
|
|
|
+ by_topic: dict[str, list[dict[str, Any]]] = defaultdict(list)
|
|
|
+ for topic, payload_json in rows:
|
|
|
+ try:
|
|
|
+ cluster = json.loads(payload_json)
|
|
|
+ except Exception:
|
|
|
+ continue
|
|
|
+ if not _embedding(cluster):
|
|
|
+ continue
|
|
|
+ by_topic[topic or cluster.get("topic", "other")].append(cluster)
|
|
|
+
|
|
|
+ print({"topics": len(by_topic), "thresholds": thresholds, "clusters_scanned": sum(len(v) for v in by_topic.values())})
|
|
|
+
|
|
|
+ for threshold in thresholds:
|
|
|
+ candidates = []
|
|
|
+ for topic, clusters in by_topic.items():
|
|
|
+ for i in range(len(clusters)):
|
|
|
+ a = clusters[i]
|
|
|
+ ea = _embedding(a)
|
|
|
+ if not ea:
|
|
|
+ continue
|
|
|
+ for j in range(i + 1, len(clusters)):
|
|
|
+ b = clusters[j]
|
|
|
+ eb = _embedding(b)
|
|
|
+ if not eb:
|
|
|
+ continue
|
|
|
+ sim = cosine_similarity(ea, eb)
|
|
|
+ if sim >= threshold:
|
|
|
+ candidates.append(
|
|
|
+ {
|
|
|
+ "topic": topic,
|
|
|
+ "similarity": round(sim, 4),
|
|
|
+ "a": a.get("cluster_id"),
|
|
|
+ "a_headline": _title(a),
|
|
|
+ "b": b.get("cluster_id"),
|
|
|
+ "b_headline": _title(b),
|
|
|
+ }
|
|
|
+ )
|
|
|
+
|
|
|
+ candidates.sort(key=lambda x: x["similarity"], reverse=True)
|
|
|
+ print(f"\n=== threshold {threshold:.3f} ===")
|
|
|
+ print({"candidate_pairs": len(candidates)})
|
|
|
+ for item in candidates[: args.top]:
|
|
|
+ print(json.dumps(item, ensure_ascii=False))
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ main()
|