| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586 |
- from __future__ import annotations
- """Backfill cluster embeddings into news-mcp's SQLite store.
- This precomputes a cluster-level embedding for older rows so the optional
- Ollama-first clustering path has data to work with before live traffic resumes.
- Usage:
- ./.venv/bin/python scripts/backfill_news_embeddings.py --dry-run --limit 200
- ./.venv/bin/python scripts/backfill_news_embeddings.py --limit 1000
- """
- import argparse
- import json
- import sys
- from pathlib import Path
- ROOT = Path(__file__).resolve().parents[1]
- sys.path.insert(0, str(ROOT))
- from news_mcp.config import DB_PATH, OLLAMA_BASE_URL, OLLAMA_EMBEDDING_MODEL, NEWS_EMBEDDINGS_ENABLED
- from news_mcp.dedup.embedding_support import ollama_embed
- from news_mcp.storage.sqlite_store import SQLiteClusterStore
- def _cluster_text(cluster: dict) -> str:
- parts = [cluster.get("headline", ""), cluster.get("summary", "") or ""]
- return "\n".join(p for p in parts if p).strip()
- def main() -> None:
- parser = argparse.ArgumentParser(description="Backfill embeddings for stored news clusters")
- parser.add_argument("--db", type=Path, default=DB_PATH)
- parser.add_argument("--limit", type=int, default=None, help="Optional maximum number of rows to process")
- parser.add_argument("--dry-run", action="store_true", help="Do not write back changes")
- args = parser.parse_args()
- if not NEWS_EMBEDDINGS_ENABLED:
- print("NEWS_EMBEDDINGS_ENABLED is false; nothing to backfill.")
- return
- store = SQLiteClusterStore(args.db)
- with store._conn() as conn: # noqa: SLF001 - one-off maintenance script
- rows = conn.execute("SELECT cluster_id, topic, payload FROM clusters ORDER BY updated_at ASC").fetchall()
- if args.limit is not None:
- rows = rows[: args.limit]
- total = 0
- updated = 0
- skipped = 0
- print(f"starting embeddings backfill: clusters={len(rows)} dry_run={args.dry_run} model={OLLAMA_EMBEDDING_MODEL} url={OLLAMA_BASE_URL}")
- for cluster_id, topic, payload_json in rows:
- total += 1
- try:
- cluster = json.loads(payload_json)
- except Exception:
- skipped += 1
- continue
- if cluster.get("embedding"):
- continue
- emb = ollama_embed(_cluster_text(cluster))
- if not emb:
- skipped += 1
- continue
- cluster = dict(cluster)
- cluster["embedding"] = emb
- cluster["embedding_model"] = f"ollama:{OLLAMA_EMBEDDING_MODEL}"
- if not args.dry_run:
- store.upsert_clusters([cluster], topic=topic or cluster.get("topic", "other"))
- updated += 1
- if updated % 25 == 0:
- print(f"updated={updated} processed={total}")
- print({"total_scanned": total, "updated": updated, "skipped": skipped, "dry_run": args.dry_run})
- if __name__ == "__main__":
- main()
|