embedding_support.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. from __future__ import annotations
  2. from dataclasses import dataclass
  3. from datetime import datetime, timezone, timedelta
  4. from math import sqrt
  5. from typing import Any
  6. @dataclass(frozen=True)
  7. class CandidateRules:
  8. """Cheap, non-embedding filters before we compare vectors."""
  9. require_topic_match: bool = True
  10. require_entity_overlap: int = 1
  11. max_age_hours: int = 72
  12. def cosine_similarity(a: list[float], b: list[float]) -> float:
  13. if not a or not b or len(a) != len(b):
  14. return 0.0
  15. dot = sum(x * y for x, y in zip(a, b))
  16. na = sqrt(sum(x * x for x in a))
  17. nb = sqrt(sum(y * y for y in b))
  18. if na == 0.0 or nb == 0.0:
  19. return 0.0
  20. return dot / (na * nb)
  21. def _to_dt(value: Any) -> datetime | None:
  22. if not value:
  23. return None
  24. if isinstance(value, datetime):
  25. return value
  26. try:
  27. s = str(value).replace("Z", "+00:00")
  28. dt = datetime.fromisoformat(s)
  29. if dt.tzinfo is None:
  30. return dt.replace(tzinfo=timezone.utc)
  31. return dt
  32. except Exception:
  33. return None
  34. def cluster_is_candidate(
  35. article: dict[str, Any],
  36. cluster: dict[str, Any],
  37. *,
  38. rules: CandidateRules | None = None,
  39. article_topic: str | None = None,
  40. ) -> bool:
  41. rules = rules or CandidateRules()
  42. if rules.require_topic_match and article_topic is not None:
  43. if str(article_topic).strip().lower() != str(cluster.get("topic", "")).strip().lower():
  44. return False
  45. # Require some overlap in extracted entities if both sides have them.
  46. article_entities = {
  47. str(e).strip().lower()
  48. for e in (article.get("entities", []) or [])
  49. if str(e).strip()
  50. }
  51. cluster_entities = {
  52. str(e).strip().lower()
  53. for e in (cluster.get("entities", []) or [])
  54. if str(e).strip()
  55. }
  56. if article_entities and cluster_entities:
  57. overlap = len(article_entities & cluster_entities)
  58. if overlap < rules.require_entity_overlap:
  59. return False
  60. # Age gate: keep comparisons within a recent window.
  61. article_dt = _to_dt(article.get("timestamp"))
  62. cluster_dt = _to_dt(cluster.get("last_updated") or cluster.get("timestamp"))
  63. if article_dt and cluster_dt:
  64. age = abs(article_dt - cluster_dt)
  65. if age > timedelta(hours=rules.max_age_hours):
  66. return False
  67. return True