embedding_support.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. from __future__ import annotations
  2. import asyncio
  3. import json
  4. from dataclasses import dataclass
  5. from datetime import datetime, timezone, timedelta
  6. from math import sqrt
  7. from typing import Any
  8. import httpx
  9. from news_mcp.config import (
  10. NEWS_EMBEDDINGS_ENABLED,
  11. OLLAMA_BASE_URL,
  12. OLLAMA_EMBEDDING_MODEL,
  13. _NEEDLE_OLLAMA_MAX_CONCURRENCY,
  14. )
  15. _ollama_semaphore = asyncio.Semaphore(_NEEDLE_OLLAMA_MAX_CONCURRENCY)
  16. @dataclass(frozen=True)
  17. class CandidateRules:
  18. """Cheap, non-embedding filters before we compare vectors."""
  19. require_topic_match: bool = True
  20. require_entity_overlap: int = 1
  21. max_age_hours: int = 72
  22. def cosine_similarity(a: list[float], b: list[float]) -> float:
  23. if not a or not b or len(a) != len(b):
  24. return 0.0
  25. dot = sum(x * y for x, y in zip(a, b))
  26. na = sqrt(sum(x * x for x in a))
  27. nb = sqrt(sum(y * y for y in b))
  28. if na == 0.0 or nb == 0.0:
  29. return 0.0
  30. return dot / (na * nb)
  31. def _to_dt(value: Any) -> datetime | None:
  32. if not value:
  33. return None
  34. if isinstance(value, datetime):
  35. return value
  36. try:
  37. s = str(value).replace("Z", "+00:00")
  38. dt = datetime.fromisoformat(s)
  39. if dt.tzinfo is None:
  40. return dt.replace(tzinfo=timezone.utc)
  41. return dt
  42. except Exception:
  43. return None
  44. def cluster_is_candidate(
  45. article: dict[str, Any],
  46. cluster: dict[str, Any],
  47. *,
  48. rules: CandidateRules | None = None,
  49. article_topic: str | None = None,
  50. ) -> bool:
  51. rules = rules or CandidateRules()
  52. if rules.require_topic_match and article_topic is not None:
  53. if str(article_topic).strip().lower() != str(cluster.get("topic", "")).strip().lower():
  54. return False
  55. # Require some overlap in extracted entities if both sides have them.
  56. article_entities = {
  57. str(e).strip().lower()
  58. for e in (article.get("entities", []) or [])
  59. if str(e).strip()
  60. }
  61. cluster_entities = {
  62. str(e).strip().lower()
  63. for e in (cluster.get("entities", []) or [])
  64. if str(e).strip()
  65. }
  66. if article_entities and cluster_entities:
  67. overlap = len(article_entities & cluster_entities)
  68. if overlap < rules.require_entity_overlap:
  69. return False
  70. # Age gate: keep comparisons within a recent window.
  71. article_dt = _to_dt(article.get("timestamp"))
  72. cluster_dt = _to_dt(cluster.get("last_updated") or cluster.get("timestamp"))
  73. if article_dt and cluster_dt:
  74. age = abs(article_dt - cluster_dt)
  75. if age > timedelta(hours=rules.max_age_hours):
  76. return False
  77. return True
  78. async def ollama_embed(text: str, timeout: float = 20.0) -> list[float] | None:
  79. """Async Ollama embedding call with concurrency limiting.
  80. Returns None on any failure so the caller falls back to heuristic clustering.
  81. """
  82. if not NEWS_EMBEDDINGS_ENABLED:
  83. return None
  84. payload = json.dumps({"model": OLLAMA_EMBEDDING_MODEL, "prompt": text}).encode("utf-8")
  85. url = f"{OLLAMA_BASE_URL.rstrip('/')}/api/embeddings"
  86. async with _ollama_semaphore:
  87. try:
  88. async with httpx.AsyncClient(timeout=timeout) as client:
  89. resp = await client.post(
  90. url,
  91. content=payload,
  92. headers={"Content-Type": "application/json"},
  93. )
  94. resp.raise_for_status()
  95. data = resp.json()
  96. emb = data.get("embedding")
  97. if isinstance(emb, list) and emb:
  98. return [float(x) for x in emb]
  99. except Exception:
  100. return None
  101. return None