| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108 |
- from __future__ import annotations
- """Analyze stored cluster embeddings for possible merges.
- This script does not modify the DB. It scans clusters that already have an
- `embedding` field, compares clusters within the same topic, and reports likely
- merge candidates for one or more cosine-similarity thresholds.
- Usage:
- ./.venv/bin/python scripts/analyze_cluster_embedding_merges.py --thresholds 0.82 0.85 0.88
- ./.venv/bin/python scripts/analyze_cluster_embedding_merges.py --threshold 0.85 --limit 300
- """
- import argparse
- import json
- import sys
- from collections import defaultdict
- from pathlib import Path
- from typing import Any
- ROOT = Path(__file__).resolve().parents[1]
- sys.path.insert(0, str(ROOT))
- from news_mcp.config import DB_PATH
- from news_mcp.dedup.embedding_support import cosine_similarity
- from news_mcp.storage.sqlite_store import SQLiteClusterStore
- def _embedding(cluster: dict[str, Any]) -> list[float] | None:
- emb = cluster.get("embedding")
- if isinstance(emb, list) and emb:
- try:
- return [float(x) for x in emb]
- except Exception:
- return None
- return None
- def _title(cluster: dict[str, Any]) -> str:
- return str(cluster.get("headline") or "").strip()
- def main() -> None:
- parser = argparse.ArgumentParser(description="Analyze news cluster embeddings for merge candidates")
- parser.add_argument("--db", type=Path, default=DB_PATH)
- parser.add_argument("--threshold", type=float, default=0.85, help="Single threshold to analyze")
- parser.add_argument("--thresholds", type=float, nargs="*", default=None, help="Multiple thresholds to analyze")
- parser.add_argument("--limit", type=int, default=None, help="Optional maximum number of clusters to scan")
- parser.add_argument("--top", type=int, default=50, help="Maximum candidate pairs to print per threshold")
- args = parser.parse_args()
- thresholds = args.thresholds if args.thresholds else [args.threshold]
- thresholds = sorted(set(float(t) for t in thresholds))
- store = SQLiteClusterStore(args.db)
- with store._conn() as conn: # noqa: SLF001 - maintenance script
- rows = conn.execute("SELECT topic, payload FROM clusters ORDER BY updated_at ASC").fetchall()
- if args.limit is not None:
- rows = rows[: args.limit]
- by_topic: dict[str, list[dict[str, Any]]] = defaultdict(list)
- for topic, payload_json in rows:
- try:
- cluster = json.loads(payload_json)
- except Exception:
- continue
- if not _embedding(cluster):
- continue
- by_topic[topic or cluster.get("topic", "other")].append(cluster)
- print({"topics": len(by_topic), "thresholds": thresholds, "clusters_scanned": sum(len(v) for v in by_topic.values())})
- for threshold in thresholds:
- candidates = []
- for topic, clusters in by_topic.items():
- for i in range(len(clusters)):
- a = clusters[i]
- ea = _embedding(a)
- if not ea:
- continue
- for j in range(i + 1, len(clusters)):
- b = clusters[j]
- eb = _embedding(b)
- if not eb:
- continue
- sim = cosine_similarity(ea, eb)
- if sim >= threshold:
- candidates.append(
- {
- "topic": topic,
- "similarity": round(sim, 4),
- "a": a.get("cluster_id"),
- "a_headline": _title(a),
- "b": b.get("cluster_id"),
- "b_headline": _title(b),
- }
- )
- candidates.sort(key=lambda x: x["similarity"], reverse=True)
- print(f"\n=== threshold {threshold:.3f} ===")
- print({"candidate_pairs": len(candidates)})
- for item in candidates[: args.top]:
- print(json.dumps(item, ensure_ascii=False))
- if __name__ == "__main__":
- main()
|