merge_cluster_embeddings.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226
  1. from __future__ import annotations
  2. """Merge embedding-similar clusters in SQLite.
  3. This is a maintenance script for the embedding-first clustering rollout.
  4. It supports dry-run mode, reports candidate groups, and when run wet it merges
  5. clusters within the same topic whose embeddings are similar enough.
  6. Usage:
  7. ./.venv/bin/python scripts/merge_cluster_embeddings.py --dry-run --threshold 0.90
  8. ./.venv/bin/python scripts/merge_cluster_embeddings.py --threshold 0.88
  9. """
  10. import argparse
  11. import json
  12. import sys
  13. from collections import defaultdict
  14. from pathlib import Path
  15. from typing import Any
  16. ROOT = Path(__file__).resolve().parents[1]
  17. sys.path.insert(0, str(ROOT))
  18. from news_mcp.config import DB_PATH
  19. from news_mcp.dedup.embedding_support import cosine_similarity
  20. from news_mcp.storage.sqlite_store import SQLiteClusterStore, _normalize_ts
  21. def _embedding(cluster: dict[str, Any]) -> list[float] | None:
  22. emb = cluster.get("embedding")
  23. if isinstance(emb, list) and emb:
  24. try:
  25. return [float(x) for x in emb]
  26. except Exception:
  27. return None
  28. return None
  29. def _avg_vectors(vectors: list[list[float]]) -> list[float] | None:
  30. if not vectors:
  31. return None
  32. size = len(vectors[0])
  33. if any(len(v) != size for v in vectors):
  34. return None
  35. out = [0.0] * size
  36. for v in vectors:
  37. for i, x in enumerate(v):
  38. out[i] += x
  39. n = float(len(vectors))
  40. return [x / n for x in out]
  41. def _uniq_by_url(items: list[dict[str, Any]]) -> list[dict[str, Any]]:
  42. seen = set()
  43. out = []
  44. for item in items:
  45. key = item.get("url") or item.get("title")
  46. if key in seen:
  47. continue
  48. seen.add(key)
  49. out.append(item)
  50. return out
  51. def _merge_payloads(clusters: list[dict[str, Any]]) -> dict[str, Any]:
  52. # Choose the most populated cluster as the base.
  53. base = max(clusters, key=lambda c: len(c.get("articles", []) or []))
  54. merged = dict(base)
  55. all_articles: list[dict[str, Any]] = []
  56. all_sources: list[str] = []
  57. all_entities: list[str] = []
  58. all_keywords: list[str] = []
  59. embeddings: list[list[float]] = []
  60. sent_scores: list[float] = []
  61. first_seen = None
  62. last_updated = None
  63. for c in clusters:
  64. all_articles.extend([a for a in (c.get("articles", []) or []) if isinstance(a, dict)])
  65. all_sources.extend([str(s) for s in (c.get("sources", []) or [])])
  66. all_entities.extend([str(e) for e in (c.get("entities", []) or [])])
  67. all_keywords.extend([str(k) for k in (c.get("keywords", []) or [])])
  68. emb = _embedding(c)
  69. if emb:
  70. embeddings.append(emb)
  71. if c.get("sentimentScore") is not None:
  72. try:
  73. sent_scores.append(float(c["sentimentScore"]))
  74. except Exception:
  75. pass
  76. fs = str(c.get("first_seen") or c.get("timestamp") or "")
  77. lu = str(c.get("last_updated") or c.get("timestamp") or "")
  78. if fs and (first_seen is None or fs < first_seen):
  79. first_seen = fs
  80. if lu and (last_updated is None or lu > last_updated):
  81. last_updated = lu
  82. merged["articles"] = _uniq_by_url(all_articles)
  83. # Normalize article timestamps
  84. for a in merged["articles"]:
  85. if "timestamp" in a and a["timestamp"]:
  86. a["timestamp"] = _normalize_ts(a["timestamp"])
  87. merged["sources"] = list(dict.fromkeys(all_sources))
  88. merged["entities"] = list(dict.fromkeys(e for e in all_entities if e))
  89. merged["keywords"] = list(dict.fromkeys(k for k in all_keywords if k))
  90. # Normalize cluster-level timestamps
  91. for field in ("timestamp", "first_seen", "last_seen", "last_updated"):
  92. val = merged.get(field) or (first_seen if field == "first_seen" else last_updated if field == "last_updated" else "")
  93. if val:
  94. merged[field] = _normalize_ts(val)
  95. merged["importance"] = max(float(c.get("importance", 0.0) or 0.0) for c in clusters)
  96. if sent_scores:
  97. merged["sentimentScore"] = sum(sent_scores) / len(sent_scores)
  98. merged["embedding"] = _avg_vectors(embeddings) or merged.get("embedding")
  99. merged["embedding_model"] = merged.get("embedding_model") or "ollama:nomic-embed-text"
  100. return merged
  101. def main() -> None:
  102. parser = argparse.ArgumentParser(description="Merge embedding-similar news clusters")
  103. parser.add_argument("--db", type=Path, default=DB_PATH)
  104. parser.add_argument("--threshold", type=float, default=0.9)
  105. parser.add_argument("--dry-run", action="store_true")
  106. parser.add_argument("--limit", type=int, default=None, help="Optional maximum number of rows to scan")
  107. parser.add_argument("--top", type=int, default=50, help="Max groups to print in dry-run mode")
  108. args = parser.parse_args()
  109. store = SQLiteClusterStore(args.db)
  110. with store._conn() as conn: # noqa: SLF001 - maintenance script
  111. rows = conn.execute("SELECT cluster_id, topic, payload FROM clusters ORDER BY updated_at ASC").fetchall()
  112. if args.limit is not None:
  113. rows = rows[: args.limit]
  114. by_topic: dict[str, list[dict[str, Any]]] = defaultdict(list)
  115. for cluster_id, topic, payload_json in rows:
  116. try:
  117. cluster = json.loads(payload_json)
  118. except Exception:
  119. continue
  120. emb = _embedding(cluster)
  121. if not emb:
  122. continue
  123. cluster["_topic_key"] = topic or cluster.get("topic", "other")
  124. by_topic[cluster["_topic_key"]].append(cluster)
  125. groups = []
  126. for topic, clusters in by_topic.items():
  127. if len(clusters) < 2:
  128. continue
  129. parent = list(range(len(clusters)))
  130. def find(x: int) -> int:
  131. while parent[x] != x:
  132. parent[x] = parent[parent[x]]
  133. x = parent[x]
  134. return x
  135. def union(a: int, b: int) -> None:
  136. ra, rb = find(a), find(b)
  137. if ra != rb:
  138. parent[rb] = ra
  139. for i in range(len(clusters)):
  140. ei = _embedding(clusters[i])
  141. if not ei:
  142. continue
  143. for j in range(i + 1, len(clusters)):
  144. ej = _embedding(clusters[j])
  145. if not ej:
  146. continue
  147. if cosine_similarity(ei, ej) >= args.threshold:
  148. union(i, j)
  149. buckets: dict[int, list[dict[str, Any]]] = defaultdict(list)
  150. for idx, c in enumerate(clusters):
  151. buckets[find(idx)].append(c)
  152. for bucket in buckets.values():
  153. if len(bucket) < 2:
  154. continue
  155. bucket = sorted(bucket, key=lambda c: str(c.get("updated_at") or c.get("last_updated") or ""), reverse=True)
  156. groups.append((topic, bucket))
  157. print({"threshold": args.threshold, "merge_groups": len(groups), "dry_run": args.dry_run})
  158. if args.dry_run:
  159. for topic, bucket in groups[: args.top]:
  160. print(json.dumps({
  161. "topic": topic,
  162. "size": len(bucket),
  163. "cluster_ids": [c.get("cluster_id") for c in bucket],
  164. "headlines": [c.get("headline") for c in bucket],
  165. }, ensure_ascii=False))
  166. return
  167. with store._conn() as conn: # noqa: SLF001
  168. for topic, bucket in groups:
  169. merged = _merge_payloads(bucket)
  170. cluster_id = merged.get("cluster_id")
  171. payload = json.dumps(merged, ensure_ascii=False)
  172. conn.execute(
  173. "INSERT INTO clusters(cluster_id, topic, payload, updated_at) VALUES(?,?,?,?) "
  174. "ON CONFLICT(cluster_id) DO UPDATE SET topic=excluded.topic, payload=excluded.payload, updated_at=excluded.updated_at",
  175. (cluster_id, topic, payload, merged.get("last_updated") or merged.get("updated_at") or ""),
  176. )
  177. # Delete absorbed duplicates; keep the representative row.
  178. dup_ids = [c.get("cluster_id") for c in bucket if c.get("cluster_id") != cluster_id]
  179. for cid in dup_ids:
  180. conn.execute("DELETE FROM clusters WHERE cluster_id=?", (cid,))
  181. conn.commit()
  182. print({"merged_groups": len(groups), "threshold": args.threshold})
  183. if __name__ == "__main__":
  184. main()