merge_cluster_embeddings.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. from __future__ import annotations
  2. """Merge embedding-similar clusters in SQLite.
  3. This is a maintenance script for the embedding-first clustering rollout.
  4. It supports dry-run mode, reports candidate groups, and when run wet it merges
  5. clusters within the same topic whose embeddings are similar enough.
  6. Usage:
  7. ./.venv/bin/python scripts/merge_cluster_embeddings.py --dry-run --threshold 0.90
  8. ./.venv/bin/python scripts/merge_cluster_embeddings.py --threshold 0.88
  9. """
  10. import argparse
  11. import json
  12. import sys
  13. from collections import defaultdict
  14. from pathlib import Path
  15. from typing import Any
  16. ROOT = Path(__file__).resolve().parents[1]
  17. sys.path.insert(0, str(ROOT))
  18. from news_mcp.config import DB_PATH
  19. from news_mcp.dedup.embedding_support import cosine_similarity
  20. from news_mcp.storage.sqlite_store import SQLiteClusterStore
  21. def _embedding(cluster: dict[str, Any]) -> list[float] | None:
  22. emb = cluster.get("embedding")
  23. if isinstance(emb, list) and emb:
  24. try:
  25. return [float(x) for x in emb]
  26. except Exception:
  27. return None
  28. return None
  29. def _avg_vectors(vectors: list[list[float]]) -> list[float] | None:
  30. if not vectors:
  31. return None
  32. size = len(vectors[0])
  33. if any(len(v) != size for v in vectors):
  34. return None
  35. out = [0.0] * size
  36. for v in vectors:
  37. for i, x in enumerate(v):
  38. out[i] += x
  39. n = float(len(vectors))
  40. return [x / n for x in out]
  41. def _uniq_by_url(items: list[dict[str, Any]]) -> list[dict[str, Any]]:
  42. seen = set()
  43. out = []
  44. for item in items:
  45. key = item.get("url") or item.get("title")
  46. if key in seen:
  47. continue
  48. seen.add(key)
  49. out.append(item)
  50. return out
  51. def _merge_payloads(clusters: list[dict[str, Any]]) -> dict[str, Any]:
  52. # Choose the most populated cluster as the base.
  53. base = max(clusters, key=lambda c: len(c.get("articles", []) or []))
  54. merged = dict(base)
  55. all_articles: list[dict[str, Any]] = []
  56. all_sources: list[str] = []
  57. all_entities: list[str] = []
  58. all_keywords: list[str] = []
  59. embeddings: list[list[float]] = []
  60. sent_scores: list[float] = []
  61. first_seen = None
  62. last_updated = None
  63. for c in clusters:
  64. all_articles.extend([a for a in (c.get("articles", []) or []) if isinstance(a, dict)])
  65. all_sources.extend([str(s) for s in (c.get("sources", []) or [])])
  66. all_entities.extend([str(e) for e in (c.get("entities", []) or [])])
  67. all_keywords.extend([str(k) for k in (c.get("keywords", []) or [])])
  68. emb = _embedding(c)
  69. if emb:
  70. embeddings.append(emb)
  71. if c.get("sentimentScore") is not None:
  72. try:
  73. sent_scores.append(float(c["sentimentScore"]))
  74. except Exception:
  75. pass
  76. fs = str(c.get("first_seen") or c.get("timestamp") or "")
  77. lu = str(c.get("last_updated") or c.get("timestamp") or "")
  78. if fs and (first_seen is None or fs < first_seen):
  79. first_seen = fs
  80. if lu and (last_updated is None or lu > last_updated):
  81. last_updated = lu
  82. merged["articles"] = _uniq_by_url(all_articles)
  83. merged["sources"] = list(dict.fromkeys(all_sources))
  84. merged["entities"] = list(dict.fromkeys(e for e in all_entities if e))
  85. merged["keywords"] = list(dict.fromkeys(k for k in all_keywords if k))
  86. merged["first_seen"] = first_seen or merged.get("first_seen")
  87. merged["last_updated"] = last_updated or merged.get("last_updated")
  88. merged["importance"] = max(float(c.get("importance", 0.0) or 0.0) for c in clusters)
  89. if sent_scores:
  90. merged["sentimentScore"] = sum(sent_scores) / len(sent_scores)
  91. merged["embedding"] = _avg_vectors(embeddings) or merged.get("embedding")
  92. merged["embedding_model"] = merged.get("embedding_model") or "ollama:nomic-embed-text"
  93. return merged
  94. def main() -> None:
  95. parser = argparse.ArgumentParser(description="Merge embedding-similar news clusters")
  96. parser.add_argument("--db", type=Path, default=DB_PATH)
  97. parser.add_argument("--threshold", type=float, default=0.9)
  98. parser.add_argument("--dry-run", action="store_true")
  99. parser.add_argument("--limit", type=int, default=None, help="Optional maximum number of rows to scan")
  100. parser.add_argument("--top", type=int, default=50, help="Max groups to print in dry-run mode")
  101. args = parser.parse_args()
  102. store = SQLiteClusterStore(args.db)
  103. with store._conn() as conn: # noqa: SLF001 - maintenance script
  104. rows = conn.execute("SELECT cluster_id, topic, payload FROM clusters ORDER BY updated_at ASC").fetchall()
  105. if args.limit is not None:
  106. rows = rows[: args.limit]
  107. by_topic: dict[str, list[dict[str, Any]]] = defaultdict(list)
  108. for cluster_id, topic, payload_json in rows:
  109. try:
  110. cluster = json.loads(payload_json)
  111. except Exception:
  112. continue
  113. emb = _embedding(cluster)
  114. if not emb:
  115. continue
  116. cluster["_topic_key"] = topic or cluster.get("topic", "other")
  117. by_topic[cluster["_topic_key"]].append(cluster)
  118. groups = []
  119. for topic, clusters in by_topic.items():
  120. if len(clusters) < 2:
  121. continue
  122. parent = list(range(len(clusters)))
  123. def find(x: int) -> int:
  124. while parent[x] != x:
  125. parent[x] = parent[parent[x]]
  126. x = parent[x]
  127. return x
  128. def union(a: int, b: int) -> None:
  129. ra, rb = find(a), find(b)
  130. if ra != rb:
  131. parent[rb] = ra
  132. for i in range(len(clusters)):
  133. ei = _embedding(clusters[i])
  134. if not ei:
  135. continue
  136. for j in range(i + 1, len(clusters)):
  137. ej = _embedding(clusters[j])
  138. if not ej:
  139. continue
  140. if cosine_similarity(ei, ej) >= args.threshold:
  141. union(i, j)
  142. buckets: dict[int, list[dict[str, Any]]] = defaultdict(list)
  143. for idx, c in enumerate(clusters):
  144. buckets[find(idx)].append(c)
  145. for bucket in buckets.values():
  146. if len(bucket) < 2:
  147. continue
  148. bucket = sorted(bucket, key=lambda c: str(c.get("updated_at") or c.get("last_updated") or ""), reverse=True)
  149. groups.append((topic, bucket))
  150. print({"threshold": args.threshold, "merge_groups": len(groups), "dry_run": args.dry_run})
  151. if args.dry_run:
  152. for topic, bucket in groups[: args.top]:
  153. print(json.dumps({
  154. "topic": topic,
  155. "size": len(bucket),
  156. "cluster_ids": [c.get("cluster_id") for c in bucket],
  157. "headlines": [c.get("headline") for c in bucket],
  158. }, ensure_ascii=False))
  159. return
  160. with store._conn() as conn: # noqa: SLF001
  161. for topic, bucket in groups:
  162. rep = bucket[0]
  163. rep_id = rep.get("cluster_id")
  164. merged = _merge_payloads(bucket)
  165. store.upsert_clusters([merged], topic=topic)
  166. # Delete absorbed duplicates; keep the representative row.
  167. dup_ids = [c.get("cluster_id") for c in bucket[1:] if c.get("cluster_id")]
  168. for cid in dup_ids:
  169. conn.execute("DELETE FROM clusters WHERE cluster_id=?", (cid,))
  170. conn.commit()
  171. print({"merged_groups": len(groups), "threshold": args.threshold})
  172. if __name__ == "__main__":
  173. main()