embedding_support.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. from __future__ import annotations
  2. from dataclasses import dataclass
  3. from datetime import datetime, timezone, timedelta
  4. import json
  5. import urllib.request
  6. from math import sqrt
  7. from typing import Any
  8. from news_mcp.config import NEWS_EMBEDDINGS_ENABLED, OLLAMA_BASE_URL, OLLAMA_EMBEDDING_MODEL
  9. @dataclass(frozen=True)
  10. class CandidateRules:
  11. """Cheap, non-embedding filters before we compare vectors."""
  12. require_topic_match: bool = True
  13. require_entity_overlap: int = 1
  14. max_age_hours: int = 72
  15. def cosine_similarity(a: list[float], b: list[float]) -> float:
  16. if not a or not b or len(a) != len(b):
  17. return 0.0
  18. dot = sum(x * y for x, y in zip(a, b))
  19. na = sqrt(sum(x * x for x in a))
  20. nb = sqrt(sum(y * y for y in b))
  21. if na == 0.0 or nb == 0.0:
  22. return 0.0
  23. return dot / (na * nb)
  24. def _to_dt(value: Any) -> datetime | None:
  25. if not value:
  26. return None
  27. if isinstance(value, datetime):
  28. return value
  29. try:
  30. s = str(value).replace("Z", "+00:00")
  31. dt = datetime.fromisoformat(s)
  32. if dt.tzinfo is None:
  33. return dt.replace(tzinfo=timezone.utc)
  34. return dt
  35. except Exception:
  36. return None
  37. def cluster_is_candidate(
  38. article: dict[str, Any],
  39. cluster: dict[str, Any],
  40. *,
  41. rules: CandidateRules | None = None,
  42. article_topic: str | None = None,
  43. ) -> bool:
  44. rules = rules or CandidateRules()
  45. if rules.require_topic_match and article_topic is not None:
  46. if str(article_topic).strip().lower() != str(cluster.get("topic", "")).strip().lower():
  47. return False
  48. # Require some overlap in extracted entities if both sides have them.
  49. article_entities = {
  50. str(e).strip().lower()
  51. for e in (article.get("entities", []) or [])
  52. if str(e).strip()
  53. }
  54. cluster_entities = {
  55. str(e).strip().lower()
  56. for e in (cluster.get("entities", []) or [])
  57. if str(e).strip()
  58. }
  59. if article_entities and cluster_entities:
  60. overlap = len(article_entities & cluster_entities)
  61. if overlap < rules.require_entity_overlap:
  62. return False
  63. # Age gate: keep comparisons within a recent window.
  64. article_dt = _to_dt(article.get("timestamp"))
  65. cluster_dt = _to_dt(cluster.get("last_updated") or cluster.get("timestamp"))
  66. if article_dt and cluster_dt:
  67. age = abs(article_dt - cluster_dt)
  68. if age > timedelta(hours=rules.max_age_hours):
  69. return False
  70. return True
  71. def ollama_embed(text: str, timeout: float = 20.0) -> list[float] | None:
  72. """Best-effort Ollama embedding call; returns None on any failure.
  73. Embeddings are intentionally optional. The caller should fall back to the
  74. heuristic path when this returns None.
  75. """
  76. if not NEWS_EMBEDDINGS_ENABLED:
  77. return None
  78. payload = json.dumps({"model": OLLAMA_EMBEDDING_MODEL, "prompt": text}).encode("utf-8")
  79. req = urllib.request.Request(
  80. f"{OLLAMA_BASE_URL.rstrip('/')}/api/embeddings",
  81. data=payload,
  82. headers={"Content-Type": "application/json"},
  83. method="POST",
  84. )
  85. try:
  86. with urllib.request.urlopen(req, timeout=timeout) as resp:
  87. data = json.loads(resp.read().decode("utf-8"))
  88. emb = data.get("embedding")
  89. if isinstance(emb, list) and emb:
  90. return [float(x) for x in emb]
  91. except Exception:
  92. return None
  93. return None