|
|
@@ -0,0 +1,86 @@
|
|
|
+from __future__ import annotations
|
|
|
+
|
|
|
+"""Backfill cluster embeddings into news-mcp's SQLite store.
|
|
|
+
|
|
|
+This precomputes a cluster-level embedding for older rows so the optional
|
|
|
+Ollama-first clustering path has data to work with before live traffic resumes.
|
|
|
+
|
|
|
+Usage:
|
|
|
+ ./.venv/bin/python scripts/backfill_news_embeddings.py --dry-run --limit 200
|
|
|
+ ./.venv/bin/python scripts/backfill_news_embeddings.py --limit 1000
|
|
|
+"""
|
|
|
+
|
|
|
+import argparse
|
|
|
+import json
|
|
|
+import sys
|
|
|
+from pathlib import Path
|
|
|
+
|
|
|
+ROOT = Path(__file__).resolve().parents[1]
|
|
|
+sys.path.insert(0, str(ROOT))
|
|
|
+
|
|
|
+from news_mcp.config import DB_PATH, OLLAMA_BASE_URL, OLLAMA_EMBEDDING_MODEL, NEWS_EMBEDDINGS_ENABLED
|
|
|
+from news_mcp.dedup.embedding_support import ollama_embed
|
|
|
+from news_mcp.storage.sqlite_store import SQLiteClusterStore
|
|
|
+
|
|
|
+
|
|
|
+def _cluster_text(cluster: dict) -> str:
|
|
|
+ parts = [cluster.get("headline", ""), cluster.get("summary", "") or ""]
|
|
|
+ return "\n".join(p for p in parts if p).strip()
|
|
|
+
|
|
|
+
|
|
|
+def main() -> None:
|
|
|
+ parser = argparse.ArgumentParser(description="Backfill embeddings for stored news clusters")
|
|
|
+ parser.add_argument("--db", type=Path, default=DB_PATH)
|
|
|
+ parser.add_argument("--limit", type=int, default=None, help="Optional maximum number of rows to process")
|
|
|
+ parser.add_argument("--dry-run", action="store_true", help="Do not write back changes")
|
|
|
+ args = parser.parse_args()
|
|
|
+
|
|
|
+ if not NEWS_EMBEDDINGS_ENABLED:
|
|
|
+ print("NEWS_EMBEDDINGS_ENABLED is false; nothing to backfill.")
|
|
|
+ return
|
|
|
+
|
|
|
+ store = SQLiteClusterStore(args.db)
|
|
|
+ with store._conn() as conn: # noqa: SLF001 - one-off maintenance script
|
|
|
+ rows = conn.execute("SELECT cluster_id, topic, payload FROM clusters ORDER BY updated_at ASC").fetchall()
|
|
|
+
|
|
|
+ if args.limit is not None:
|
|
|
+ rows = rows[: args.limit]
|
|
|
+
|
|
|
+ total = 0
|
|
|
+ updated = 0
|
|
|
+ skipped = 0
|
|
|
+
|
|
|
+ print(f"starting embeddings backfill: clusters={len(rows)} dry_run={args.dry_run} model={OLLAMA_EMBEDDING_MODEL} url={OLLAMA_BASE_URL}")
|
|
|
+
|
|
|
+ for cluster_id, topic, payload_json in rows:
|
|
|
+ total += 1
|
|
|
+ try:
|
|
|
+ cluster = json.loads(payload_json)
|
|
|
+ except Exception:
|
|
|
+ skipped += 1
|
|
|
+ continue
|
|
|
+
|
|
|
+ if cluster.get("embedding"):
|
|
|
+ continue
|
|
|
+
|
|
|
+ emb = ollama_embed(_cluster_text(cluster))
|
|
|
+ if not emb:
|
|
|
+ skipped += 1
|
|
|
+ continue
|
|
|
+
|
|
|
+ cluster = dict(cluster)
|
|
|
+ cluster["embedding"] = emb
|
|
|
+ cluster["embedding_model"] = f"ollama:{OLLAMA_EMBEDDING_MODEL}"
|
|
|
+
|
|
|
+ if not args.dry_run:
|
|
|
+ store.upsert_clusters([cluster], topic=topic or cluster.get("topic", "other"))
|
|
|
+ updated += 1
|
|
|
+
|
|
|
+ if updated % 25 == 0:
|
|
|
+ print(f"updated={updated} processed={total}")
|
|
|
+
|
|
|
+ print({"total_scanned": total, "updated": updated, "skipped": skipped, "dry_run": args.dry_run})
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ main()
|