from __future__ import annotations """Merge embedding-similar clusters in SQLite. This is a maintenance script for the embedding-first clustering rollout. It supports dry-run mode, reports candidate groups, and when run wet it merges clusters within the same topic whose embeddings are similar enough. Usage: ./.venv/bin/python scripts/merge_cluster_embeddings.py --dry-run --threshold 0.90 ./.venv/bin/python scripts/merge_cluster_embeddings.py --threshold 0.88 """ import argparse import json import sys from collections import defaultdict from pathlib import Path from typing import Any ROOT = Path(__file__).resolve().parents[1] sys.path.insert(0, str(ROOT)) from news_mcp.config import DB_PATH from news_mcp.dedup.embedding_support import cosine_similarity from news_mcp.storage.sqlite_store import SQLiteClusterStore def _embedding(cluster: dict[str, Any]) -> list[float] | None: emb = cluster.get("embedding") if isinstance(emb, list) and emb: try: return [float(x) for x in emb] except Exception: return None return None def _avg_vectors(vectors: list[list[float]]) -> list[float] | None: if not vectors: return None size = len(vectors[0]) if any(len(v) != size for v in vectors): return None out = [0.0] * size for v in vectors: for i, x in enumerate(v): out[i] += x n = float(len(vectors)) return [x / n for x in out] def _uniq_by_url(items: list[dict[str, Any]]) -> list[dict[str, Any]]: seen = set() out = [] for item in items: key = item.get("url") or item.get("title") if key in seen: continue seen.add(key) out.append(item) return out def _merge_payloads(clusters: list[dict[str, Any]]) -> dict[str, Any]: # Choose the most populated cluster as the base. base = max(clusters, key=lambda c: len(c.get("articles", []) or [])) merged = dict(base) all_articles: list[dict[str, Any]] = [] all_sources: list[str] = [] all_entities: list[str] = [] all_keywords: list[str] = [] embeddings: list[list[float]] = [] sent_scores: list[float] = [] first_seen = None last_updated = None for c in clusters: all_articles.extend([a for a in (c.get("articles", []) or []) if isinstance(a, dict)]) all_sources.extend([str(s) for s in (c.get("sources", []) or [])]) all_entities.extend([str(e) for e in (c.get("entities", []) or [])]) all_keywords.extend([str(k) for k in (c.get("keywords", []) or [])]) emb = _embedding(c) if emb: embeddings.append(emb) if c.get("sentimentScore") is not None: try: sent_scores.append(float(c["sentimentScore"])) except Exception: pass fs = str(c.get("first_seen") or c.get("timestamp") or "") lu = str(c.get("last_updated") or c.get("timestamp") or "") if fs and (first_seen is None or fs < first_seen): first_seen = fs if lu and (last_updated is None or lu > last_updated): last_updated = lu merged["articles"] = _uniq_by_url(all_articles) merged["sources"] = list(dict.fromkeys(all_sources)) merged["entities"] = list(dict.fromkeys(e for e in all_entities if e)) merged["keywords"] = list(dict.fromkeys(k for k in all_keywords if k)) merged["first_seen"] = first_seen or merged.get("first_seen") merged["last_updated"] = last_updated or merged.get("last_updated") merged["importance"] = max(float(c.get("importance", 0.0) or 0.0) for c in clusters) if sent_scores: merged["sentimentScore"] = sum(sent_scores) / len(sent_scores) merged["embedding"] = _avg_vectors(embeddings) or merged.get("embedding") merged["embedding_model"] = merged.get("embedding_model") or "ollama:nomic-embed-text" return merged def main() -> None: parser = argparse.ArgumentParser(description="Merge embedding-similar news clusters") parser.add_argument("--db", type=Path, default=DB_PATH) parser.add_argument("--threshold", type=float, default=0.9) parser.add_argument("--dry-run", action="store_true") parser.add_argument("--limit", type=int, default=None, help="Optional maximum number of rows to scan") parser.add_argument("--top", type=int, default=50, help="Max groups to print in dry-run mode") args = parser.parse_args() store = SQLiteClusterStore(args.db) with store._conn() as conn: # noqa: SLF001 - maintenance script rows = conn.execute("SELECT cluster_id, topic, payload FROM clusters ORDER BY updated_at ASC").fetchall() if args.limit is not None: rows = rows[: args.limit] by_topic: dict[str, list[dict[str, Any]]] = defaultdict(list) for cluster_id, topic, payload_json in rows: try: cluster = json.loads(payload_json) except Exception: continue emb = _embedding(cluster) if not emb: continue cluster["_topic_key"] = topic or cluster.get("topic", "other") by_topic[cluster["_topic_key"]].append(cluster) groups = [] for topic, clusters in by_topic.items(): if len(clusters) < 2: continue parent = list(range(len(clusters))) def find(x: int) -> int: while parent[x] != x: parent[x] = parent[parent[x]] x = parent[x] return x def union(a: int, b: int) -> None: ra, rb = find(a), find(b) if ra != rb: parent[rb] = ra for i in range(len(clusters)): ei = _embedding(clusters[i]) if not ei: continue for j in range(i + 1, len(clusters)): ej = _embedding(clusters[j]) if not ej: continue if cosine_similarity(ei, ej) >= args.threshold: union(i, j) buckets: dict[int, list[dict[str, Any]]] = defaultdict(list) for idx, c in enumerate(clusters): buckets[find(idx)].append(c) for bucket in buckets.values(): if len(bucket) < 2: continue bucket = sorted(bucket, key=lambda c: str(c.get("updated_at") or c.get("last_updated") or ""), reverse=True) groups.append((topic, bucket)) print({"threshold": args.threshold, "merge_groups": len(groups), "dry_run": args.dry_run}) if args.dry_run: for topic, bucket in groups[: args.top]: print(json.dumps({ "topic": topic, "size": len(bucket), "cluster_ids": [c.get("cluster_id") for c in bucket], "headlines": [c.get("headline") for c in bucket], }, ensure_ascii=False)) return with store._conn() as conn: # noqa: SLF001 for topic, bucket in groups: merged = _merge_payloads(bucket) cluster_id = merged.get("cluster_id") payload = json.dumps(merged, ensure_ascii=False) conn.execute( "INSERT INTO clusters(cluster_id, topic, payload, updated_at) VALUES(?,?,?,?) " "ON CONFLICT(cluster_id) DO UPDATE SET topic=excluded.topic, payload=excluded.payload, updated_at=excluded.updated_at", (cluster_id, topic, payload, merged.get("last_updated") or merged.get("updated_at") or ""), ) # Delete absorbed duplicates; keep the representative row. dup_ids = [c.get("cluster_id") for c in bucket if c.get("cluster_id") != cluster_id] for cid in dup_ids: conn.execute("DELETE FROM clusters WHERE cluster_id=?", (cid,)) conn.commit() print({"merged_groups": len(groups), "threshold": args.threshold}) if __name__ == "__main__": main()