sqlite_store.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. from __future__ import annotations
  2. import json
  3. import sqlite3
  4. from dataclasses import dataclass
  5. from datetime import datetime, timezone, timedelta
  6. from pathlib import Path
  7. from typing import Any
  8. @dataclass
  9. class ClusterRow:
  10. cluster_id: str
  11. topic: str
  12. payload: dict
  13. updated_at: datetime
  14. class SQLiteClusterStore:
  15. def __init__(self, db_path: str | Path):
  16. self.db_path = str(db_path)
  17. self._init_db()
  18. def _conn(self) -> sqlite3.Connection:
  19. return sqlite3.connect(self.db_path)
  20. def _init_db(self) -> None:
  21. Path(self.db_path).parent.mkdir(parents=True, exist_ok=True)
  22. with self._conn() as conn:
  23. conn.execute(
  24. """
  25. CREATE TABLE IF NOT EXISTS clusters (
  26. cluster_id TEXT PRIMARY KEY,
  27. topic TEXT NOT NULL,
  28. payload TEXT NOT NULL,
  29. updated_at TEXT NOT NULL
  30. )
  31. """
  32. )
  33. conn.execute(
  34. "CREATE INDEX IF NOT EXISTS idx_clusters_topic ON clusters(topic)"
  35. )
  36. def upsert_clusters(self, clusters: list[dict], topic: str) -> None:
  37. now = datetime.now(timezone.utc)
  38. with self._conn() as conn:
  39. for c in clusters:
  40. cluster_id = c["cluster_id"]
  41. payload = json.dumps(c, ensure_ascii=False)
  42. conn.execute(
  43. "INSERT INTO clusters(cluster_id, topic, payload, updated_at) VALUES(?,?,?,?) "
  44. "ON CONFLICT(cluster_id) DO UPDATE SET topic=excluded.topic, payload=excluded.payload, updated_at=excluded.updated_at",
  45. (cluster_id, topic, payload, now.isoformat()),
  46. )
  47. def get_latest_clusters(self, topic: str, ttl_hours: float, limit: int) -> list[dict]:
  48. cutoff = datetime.now(timezone.utc) - timedelta(hours=ttl_hours)
  49. cutoff_iso = cutoff.isoformat()
  50. with self._conn() as conn:
  51. cur = conn.execute(
  52. "SELECT payload FROM clusters WHERE topic=? AND updated_at >= ? ORDER BY updated_at DESC LIMIT ?",
  53. (topic, cutoff_iso, int(limit)),
  54. )
  55. rows = [json.loads(r[0]) for r in cur.fetchall()]
  56. return rows