| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216 |
- from __future__ import annotations
- """Merge embedding-similar clusters in SQLite.
- This is a maintenance script for the embedding-first clustering rollout.
- It supports dry-run mode, reports candidate groups, and when run wet it merges
- clusters within the same topic whose embeddings are similar enough.
- Usage:
- ./.venv/bin/python scripts/merge_cluster_embeddings.py --dry-run --threshold 0.90
- ./.venv/bin/python scripts/merge_cluster_embeddings.py --threshold 0.88
- """
- import argparse
- import json
- import sys
- from collections import defaultdict
- from pathlib import Path
- from typing import Any
- ROOT = Path(__file__).resolve().parents[1]
- sys.path.insert(0, str(ROOT))
- from news_mcp.config import DB_PATH
- from news_mcp.dedup.embedding_support import cosine_similarity
- from news_mcp.storage.sqlite_store import SQLiteClusterStore
- def _embedding(cluster: dict[str, Any]) -> list[float] | None:
- emb = cluster.get("embedding")
- if isinstance(emb, list) and emb:
- try:
- return [float(x) for x in emb]
- except Exception:
- return None
- return None
- def _avg_vectors(vectors: list[list[float]]) -> list[float] | None:
- if not vectors:
- return None
- size = len(vectors[0])
- if any(len(v) != size for v in vectors):
- return None
- out = [0.0] * size
- for v in vectors:
- for i, x in enumerate(v):
- out[i] += x
- n = float(len(vectors))
- return [x / n for x in out]
- def _uniq_by_url(items: list[dict[str, Any]]) -> list[dict[str, Any]]:
- seen = set()
- out = []
- for item in items:
- key = item.get("url") or item.get("title")
- if key in seen:
- continue
- seen.add(key)
- out.append(item)
- return out
- def _merge_payloads(clusters: list[dict[str, Any]]) -> dict[str, Any]:
- # Choose the most populated cluster as the base.
- base = max(clusters, key=lambda c: len(c.get("articles", []) or []))
- merged = dict(base)
- all_articles: list[dict[str, Any]] = []
- all_sources: list[str] = []
- all_entities: list[str] = []
- all_keywords: list[str] = []
- embeddings: list[list[float]] = []
- sent_scores: list[float] = []
- first_seen = None
- last_updated = None
- for c in clusters:
- all_articles.extend([a for a in (c.get("articles", []) or []) if isinstance(a, dict)])
- all_sources.extend([str(s) for s in (c.get("sources", []) or [])])
- all_entities.extend([str(e) for e in (c.get("entities", []) or [])])
- all_keywords.extend([str(k) for k in (c.get("keywords", []) or [])])
- emb = _embedding(c)
- if emb:
- embeddings.append(emb)
- if c.get("sentimentScore") is not None:
- try:
- sent_scores.append(float(c["sentimentScore"]))
- except Exception:
- pass
- fs = str(c.get("first_seen") or c.get("timestamp") or "")
- lu = str(c.get("last_updated") or c.get("timestamp") or "")
- if fs and (first_seen is None or fs < first_seen):
- first_seen = fs
- if lu and (last_updated is None or lu > last_updated):
- last_updated = lu
- merged["articles"] = _uniq_by_url(all_articles)
- merged["sources"] = list(dict.fromkeys(all_sources))
- merged["entities"] = list(dict.fromkeys(e for e in all_entities if e))
- merged["keywords"] = list(dict.fromkeys(k for k in all_keywords if k))
- merged["first_seen"] = first_seen or merged.get("first_seen")
- merged["last_updated"] = last_updated or merged.get("last_updated")
- merged["importance"] = max(float(c.get("importance", 0.0) or 0.0) for c in clusters)
- if sent_scores:
- merged["sentimentScore"] = sum(sent_scores) / len(sent_scores)
- merged["embedding"] = _avg_vectors(embeddings) or merged.get("embedding")
- merged["embedding_model"] = merged.get("embedding_model") or "ollama:nomic-embed-text"
- return merged
- def main() -> None:
- parser = argparse.ArgumentParser(description="Merge embedding-similar news clusters")
- parser.add_argument("--db", type=Path, default=DB_PATH)
- parser.add_argument("--threshold", type=float, default=0.9)
- parser.add_argument("--dry-run", action="store_true")
- parser.add_argument("--limit", type=int, default=None, help="Optional maximum number of rows to scan")
- parser.add_argument("--top", type=int, default=50, help="Max groups to print in dry-run mode")
- args = parser.parse_args()
- store = SQLiteClusterStore(args.db)
- with store._conn() as conn: # noqa: SLF001 - maintenance script
- rows = conn.execute("SELECT cluster_id, topic, payload FROM clusters ORDER BY updated_at ASC").fetchall()
- if args.limit is not None:
- rows = rows[: args.limit]
- by_topic: dict[str, list[dict[str, Any]]] = defaultdict(list)
- for cluster_id, topic, payload_json in rows:
- try:
- cluster = json.loads(payload_json)
- except Exception:
- continue
- emb = _embedding(cluster)
- if not emb:
- continue
- cluster["_topic_key"] = topic or cluster.get("topic", "other")
- by_topic[cluster["_topic_key"]].append(cluster)
- groups = []
- for topic, clusters in by_topic.items():
- if len(clusters) < 2:
- continue
- parent = list(range(len(clusters)))
- def find(x: int) -> int:
- while parent[x] != x:
- parent[x] = parent[parent[x]]
- x = parent[x]
- return x
- def union(a: int, b: int) -> None:
- ra, rb = find(a), find(b)
- if ra != rb:
- parent[rb] = ra
- for i in range(len(clusters)):
- ei = _embedding(clusters[i])
- if not ei:
- continue
- for j in range(i + 1, len(clusters)):
- ej = _embedding(clusters[j])
- if not ej:
- continue
- if cosine_similarity(ei, ej) >= args.threshold:
- union(i, j)
- buckets: dict[int, list[dict[str, Any]]] = defaultdict(list)
- for idx, c in enumerate(clusters):
- buckets[find(idx)].append(c)
- for bucket in buckets.values():
- if len(bucket) < 2:
- continue
- bucket = sorted(bucket, key=lambda c: str(c.get("updated_at") or c.get("last_updated") or ""), reverse=True)
- groups.append((topic, bucket))
- print({"threshold": args.threshold, "merge_groups": len(groups), "dry_run": args.dry_run})
- if args.dry_run:
- for topic, bucket in groups[: args.top]:
- print(json.dumps({
- "topic": topic,
- "size": len(bucket),
- "cluster_ids": [c.get("cluster_id") for c in bucket],
- "headlines": [c.get("headline") for c in bucket],
- }, ensure_ascii=False))
- return
- with store._conn() as conn: # noqa: SLF001
- for topic, bucket in groups:
- merged = _merge_payloads(bucket)
- cluster_id = merged.get("cluster_id")
- payload = json.dumps(merged, ensure_ascii=False)
- conn.execute(
- "INSERT INTO clusters(cluster_id, topic, payload, updated_at) VALUES(?,?,?,?) "
- "ON CONFLICT(cluster_id) DO UPDATE SET topic=excluded.topic, payload=excluded.payload, updated_at=excluded.updated_at",
- (cluster_id, topic, payload, merged.get("last_updated") or merged.get("updated_at") or ""),
- )
- # Delete absorbed duplicates; keep the representative row.
- dup_ids = [c.get("cluster_id") for c in bucket if c.get("cluster_id") != cluster_id]
- for cid in dup_ids:
- conn.execute("DELETE FROM clusters WHERE cluster_id=?", (cid,))
- conn.commit()
- print({"merged_groups": len(groups), "threshold": args.threshold})
- if __name__ == "__main__":
- main()
|