analyze_cluster_embedding_merges.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. from __future__ import annotations
  2. """Analyze stored cluster embeddings for possible merges.
  3. This script does not modify the DB. It scans clusters that already have an
  4. `embedding` field, compares clusters within the same topic, and reports likely
  5. merge candidates for one or more cosine-similarity thresholds.
  6. Usage:
  7. ./.venv/bin/python scripts/analyze_cluster_embedding_merges.py --thresholds 0.82 0.85 0.88
  8. ./.venv/bin/python scripts/analyze_cluster_embedding_merges.py --threshold 0.85 --limit 300
  9. """
  10. import argparse
  11. import json
  12. import sys
  13. from collections import defaultdict
  14. from pathlib import Path
  15. from typing import Any
  16. ROOT = Path(__file__).resolve().parents[1]
  17. sys.path.insert(0, str(ROOT))
  18. from news_mcp.config import DB_PATH
  19. from news_mcp.dedup.embedding_support import cosine_similarity
  20. from news_mcp.storage.sqlite_store import SQLiteClusterStore
  21. def _embedding(cluster: dict[str, Any]) -> list[float] | None:
  22. emb = cluster.get("embedding")
  23. if isinstance(emb, list) and emb:
  24. try:
  25. return [float(x) for x in emb]
  26. except Exception:
  27. return None
  28. return None
  29. def _title(cluster: dict[str, Any]) -> str:
  30. return str(cluster.get("headline") or "").strip()
  31. def main() -> None:
  32. parser = argparse.ArgumentParser(description="Analyze news cluster embeddings for merge candidates")
  33. parser.add_argument("--db", type=Path, default=DB_PATH)
  34. parser.add_argument("--threshold", type=float, default=0.85, help="Single threshold to analyze")
  35. parser.add_argument("--thresholds", type=float, nargs="*", default=None, help="Multiple thresholds to analyze")
  36. parser.add_argument("--limit", type=int, default=None, help="Optional maximum number of clusters to scan")
  37. parser.add_argument("--top", type=int, default=50, help="Maximum candidate pairs to print per threshold")
  38. args = parser.parse_args()
  39. thresholds = args.thresholds if args.thresholds else [args.threshold]
  40. thresholds = sorted(set(float(t) for t in thresholds))
  41. store = SQLiteClusterStore(args.db)
  42. with store._conn() as conn: # noqa: SLF001 - maintenance script
  43. rows = conn.execute("SELECT topic, payload FROM clusters ORDER BY updated_at ASC").fetchall()
  44. if args.limit is not None:
  45. rows = rows[: args.limit]
  46. by_topic: dict[str, list[dict[str, Any]]] = defaultdict(list)
  47. for topic, payload_json in rows:
  48. try:
  49. cluster = json.loads(payload_json)
  50. except Exception:
  51. continue
  52. if not _embedding(cluster):
  53. continue
  54. by_topic[topic or cluster.get("topic", "other")].append(cluster)
  55. print({"topics": len(by_topic), "thresholds": thresholds, "clusters_scanned": sum(len(v) for v in by_topic.values())})
  56. for threshold in thresholds:
  57. candidates = []
  58. for topic, clusters in by_topic.items():
  59. for i in range(len(clusters)):
  60. a = clusters[i]
  61. ea = _embedding(a)
  62. if not ea:
  63. continue
  64. for j in range(i + 1, len(clusters)):
  65. b = clusters[j]
  66. eb = _embedding(b)
  67. if not eb:
  68. continue
  69. sim = cosine_similarity(ea, eb)
  70. if sim >= threshold:
  71. candidates.append(
  72. {
  73. "topic": topic,
  74. "similarity": round(sim, 4),
  75. "a": a.get("cluster_id"),
  76. "a_headline": _title(a),
  77. "b": b.get("cluster_id"),
  78. "b_headline": _title(b),
  79. }
  80. )
  81. candidates.sort(key=lambda x: x["similarity"], reverse=True)
  82. print(f"\n=== threshold {threshold:.3f} ===")
  83. print({"candidate_pairs": len(candidates)})
  84. for item in candidates[: args.top]:
  85. print(json.dumps(item, ensure_ascii=False))
  86. if __name__ == "__main__":
  87. main()