backfill_news_embeddings.py 2.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. from __future__ import annotations
  2. """Backfill cluster embeddings into news-mcp's SQLite store.
  3. This precomputes a cluster-level embedding for older rows so the optional
  4. Ollama-first clustering path has data to work with before live traffic resumes.
  5. Usage:
  6. ./.venv/bin/python scripts/backfill_news_embeddings.py --dry-run --limit 200
  7. ./.venv/bin/python scripts/backfill_news_embeddings.py --limit 1000
  8. """
  9. import argparse
  10. import json
  11. import sys
  12. from pathlib import Path
  13. ROOT = Path(__file__).resolve().parents[1]
  14. sys.path.insert(0, str(ROOT))
  15. from news_mcp.config import DB_PATH, OLLAMA_BASE_URL, OLLAMA_EMBEDDING_MODEL, NEWS_EMBEDDINGS_ENABLED
  16. from news_mcp.dedup.embedding_support import ollama_embed
  17. from news_mcp.storage.sqlite_store import SQLiteClusterStore
  18. def _cluster_text(cluster: dict) -> str:
  19. parts = [cluster.get("headline", ""), cluster.get("summary", "") or ""]
  20. return "\n".join(p for p in parts if p).strip()
  21. def main() -> None:
  22. parser = argparse.ArgumentParser(description="Backfill embeddings for stored news clusters")
  23. parser.add_argument("--db", type=Path, default=DB_PATH)
  24. parser.add_argument("--limit", type=int, default=None, help="Optional maximum number of rows to process")
  25. parser.add_argument("--dry-run", action="store_true", help="Do not write back changes")
  26. args = parser.parse_args()
  27. if not NEWS_EMBEDDINGS_ENABLED:
  28. print("NEWS_EMBEDDINGS_ENABLED is false; nothing to backfill.")
  29. return
  30. store = SQLiteClusterStore(args.db)
  31. with store._conn() as conn: # noqa: SLF001 - one-off maintenance script
  32. rows = conn.execute("SELECT cluster_id, topic, payload FROM clusters ORDER BY updated_at ASC").fetchall()
  33. if args.limit is not None:
  34. rows = rows[: args.limit]
  35. total = 0
  36. updated = 0
  37. skipped = 0
  38. print(f"starting embeddings backfill: clusters={len(rows)} dry_run={args.dry_run} model={OLLAMA_EMBEDDING_MODEL} url={OLLAMA_BASE_URL}")
  39. for cluster_id, topic, payload_json in rows:
  40. total += 1
  41. try:
  42. cluster = json.loads(payload_json)
  43. except Exception:
  44. skipped += 1
  45. continue
  46. if cluster.get("embedding"):
  47. continue
  48. emb = ollama_embed(_cluster_text(cluster))
  49. if not emb:
  50. skipped += 1
  51. continue
  52. cluster = dict(cluster)
  53. cluster["embedding"] = emb
  54. cluster["embedding_model"] = f"ollama:{OLLAMA_EMBEDDING_MODEL}"
  55. if not args.dry_run:
  56. store.upsert_clusters([cluster], topic=topic or cluster.get("topic", "other"))
  57. updated += 1
  58. if updated % 25 == 0:
  59. print(f"updated={updated} processed={total}")
  60. print({"total_scanned": total, "updated": updated, "skipped": skipped, "dry_run": args.dry_run})
  61. if __name__ == "__main__":
  62. main()