from __future__ import annotations """Analyze stored cluster embeddings for possible merges. This script does not modify the DB. It scans clusters that already have an `embedding` field, compares clusters within the same topic, and reports likely merge candidates for one or more cosine-similarity thresholds. Usage: ./.venv/bin/python scripts/analyze_cluster_embedding_merges.py --thresholds 0.82 0.85 0.88 ./.venv/bin/python scripts/analyze_cluster_embedding_merges.py --threshold 0.85 --limit 300 """ import argparse import json import sys from collections import defaultdict from pathlib import Path from typing import Any ROOT = Path(__file__).resolve().parents[1] sys.path.insert(0, str(ROOT)) from news_mcp.config import DB_PATH from news_mcp.dedup.embedding_support import cosine_similarity from news_mcp.storage.sqlite_store import SQLiteClusterStore def _embedding(cluster: dict[str, Any]) -> list[float] | None: emb = cluster.get("embedding") if isinstance(emb, list) and emb: try: return [float(x) for x in emb] except Exception: return None return None def _title(cluster: dict[str, Any]) -> str: return str(cluster.get("headline") or "").strip() def main() -> None: parser = argparse.ArgumentParser(description="Analyze news cluster embeddings for merge candidates") parser.add_argument("--db", type=Path, default=DB_PATH) parser.add_argument("--threshold", type=float, default=0.85, help="Single threshold to analyze") parser.add_argument("--thresholds", type=float, nargs="*", default=None, help="Multiple thresholds to analyze") parser.add_argument("--limit", type=int, default=None, help="Optional maximum number of clusters to scan") parser.add_argument("--top", type=int, default=50, help="Maximum candidate pairs to print per threshold") args = parser.parse_args() thresholds = args.thresholds if args.thresholds else [args.threshold] thresholds = sorted(set(float(t) for t in thresholds)) store = SQLiteClusterStore(args.db) with store._conn() as conn: # noqa: SLF001 - maintenance script rows = conn.execute("SELECT topic, payload FROM clusters ORDER BY updated_at ASC").fetchall() if args.limit is not None: rows = rows[: args.limit] by_topic: dict[str, list[dict[str, Any]]] = defaultdict(list) for topic, payload_json in rows: try: cluster = json.loads(payload_json) except Exception: continue if not _embedding(cluster): continue by_topic[topic or cluster.get("topic", "other")].append(cluster) print({"topics": len(by_topic), "thresholds": thresholds, "clusters_scanned": sum(len(v) for v in by_topic.values())}) for threshold in thresholds: candidates = [] for topic, clusters in by_topic.items(): for i in range(len(clusters)): a = clusters[i] ea = _embedding(a) if not ea: continue for j in range(i + 1, len(clusters)): b = clusters[j] eb = _embedding(b) if not eb: continue sim = cosine_similarity(ea, eb) if sim >= threshold: candidates.append( { "topic": topic, "similarity": round(sim, 4), "a": a.get("cluster_id"), "a_headline": _title(a), "b": b.get("cluster_id"), "b_headline": _title(b), } ) candidates.sort(key=lambda x: x["similarity"], reverse=True) print(f"\n=== threshold {threshold:.3f} ===") print({"candidate_pairs": len(candidates)}) for item in candidates[: args.top]: print(json.dumps(item, ensure_ascii=False)) if __name__ == "__main__": main()