related_entities.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. from __future__ import annotations
  2. from collections import Counter
  3. from datetime import datetime, timezone
  4. from typing import Any
  5. from news_mcp.entity_normalize import normalize_entity
  6. from news_mcp.storage.sqlite_store import SQLiteClusterStore
  7. from news_mcp.trends_resolution import resolve_entity_via_trends
  8. from news_mcp.trends_related import get_related_topics
  9. def _now_iso() -> str:
  10. return datetime.now(timezone.utc).isoformat()
  11. def _collect_local_related(
  12. store: SQLiteClusterStore,
  13. subject_norm: str,
  14. subject_resolution: dict[str, Any],
  15. timeframe_hours: float,
  16. limit: int,
  17. ) -> list[tuple[str, int]]:
  18. clusters = store.get_latest_clusters_all_topics(
  19. ttl_hours=float(timeframe_hours),
  20. limit=max(limit * 20, 200),
  21. )
  22. counter: Counter[str] = Counter()
  23. subject_terms = {
  24. subject_norm.strip().lower(),
  25. str(subject_resolution.get("normalized") or "").strip().lower(),
  26. str(subject_resolution.get("canonical_label") or "").strip().lower(),
  27. str(subject_resolution.get("mid") or "").strip().lower(),
  28. }
  29. subject_terms = {t for t in subject_terms if t}
  30. for cluster in clusters:
  31. # Match clusters by any of the resolved identity terms.
  32. haystack: list[str] = []
  33. for ent in cluster.get("entities", []) or []:
  34. haystack.append(str(ent).strip().lower())
  35. for res in cluster.get("entityResolutions", []) or []:
  36. if not isinstance(res, dict):
  37. continue
  38. for key in ("normalized", "canonical_label", "mid"):
  39. val = res.get(key)
  40. if val:
  41. haystack.append(str(val).strip().lower())
  42. haystack_set = set([h for h in haystack if h])
  43. if not (haystack_set & subject_terms):
  44. continue
  45. # Count other entities normalized.
  46. for ent in cluster.get("entities", []) or []:
  47. ent_norm = normalize_entity(ent)
  48. if not ent_norm:
  49. continue
  50. ent_key = ent_norm.strip().lower()
  51. if ent_key in subject_terms:
  52. continue
  53. counter[ent_norm] += 1
  54. return counter.most_common(limit)
  55. def _collect_trends_related(subject_norm: str, subject_resolution: dict[str, Any], limit: int) -> list[dict[str, Any]]:
  56. topics = get_related_topics(subject_norm, limit=limit)
  57. if topics:
  58. return topics
  59. # Fallback to autocomplete candidates if related topics are unavailable.
  60. candidates = subject_resolution.get("candidates") or []
  61. out = []
  62. for cand in candidates:
  63. title = cand.get("title")
  64. if not title:
  65. continue
  66. out.append(
  67. {
  68. "canonical_label": title,
  69. "normalized": normalize_entity(title),
  70. "mid": cand.get("mid"),
  71. "type": cand.get("type"),
  72. }
  73. )
  74. if len(out) >= limit:
  75. return out
  76. return out
  77. def related_recent_entities(
  78. store: SQLiteClusterStore,
  79. subject: str,
  80. timeframe_hours: float,
  81. limit: int,
  82. include_trends: bool = True,
  83. ) -> dict[str, Any]:
  84. subject_norm = normalize_entity(subject)
  85. if not subject_norm:
  86. return {
  87. "subject": {"raw": subject, "normalized": ""},
  88. "related": [],
  89. }
  90. subject_resolution = resolve_entity_via_trends(subject_norm)
  91. store.record_entity_request(subject_norm, subject_resolution.get("mid"))
  92. store.upsert_entity_metadata(
  93. normalized_label=subject_norm,
  94. canonical_label=subject_resolution.get("canonical_label"),
  95. mid=subject_resolution.get("mid"),
  96. sources=[subject_resolution.get("source") or "resolver"],
  97. )
  98. local_related = _collect_local_related(
  99. store=store,
  100. subject_norm=subject_norm,
  101. subject_resolution=subject_resolution,
  102. timeframe_hours=timeframe_hours,
  103. limit=limit,
  104. )
  105. trends_related = _collect_trends_related(subject_norm, subject_resolution, limit) if include_trends else []
  106. related_map: dict[str, dict[str, Any]] = {}
  107. def _entry(label: str) -> dict[str, Any]:
  108. key = label.strip().lower()
  109. if key not in related_map:
  110. related_map[key] = {
  111. "normalized": label,
  112. "canonical_label": label,
  113. "mid": None,
  114. "sources": set(),
  115. "scores": {},
  116. }
  117. return related_map[key]
  118. for label, count in local_related:
  119. if not label:
  120. continue
  121. entry = _entry(label)
  122. entry["sources"].add("local")
  123. entry["scores"]["local_count"] = int(count)
  124. store.upsert_entity_metadata(
  125. normalized_label=label,
  126. canonical_label=label,
  127. mid=None,
  128. sources=["local"],
  129. )
  130. # Only use enough trends results to fill remaining slots.
  131. remaining = max(0, limit - len(related_map))
  132. for idx, cand in enumerate(trends_related[:remaining], start=1):
  133. label = cand.get("normalized")
  134. if not label:
  135. continue
  136. entry = _entry(label)
  137. entry["sources"].add("trends")
  138. entry["canonical_label"] = cand.get("canonical_label") or entry["canonical_label"]
  139. entry["mid"] = cand.get("mid") or entry["mid"]
  140. entry["scores"]["trends_rank"] = idx
  141. store.upsert_entity_metadata(
  142. normalized_label=label,
  143. canonical_label=cand.get("canonical_label"),
  144. mid=cand.get("mid"),
  145. sources=["trends"],
  146. )
  147. results = list(related_map.values())
  148. for item in results:
  149. item["sources"] = sorted(item["sources"])
  150. results.sort(
  151. key=lambda item: (
  152. -int(item["scores"].get("local_count", 0)),
  153. item["scores"].get("trends_rank", 9999),
  154. item["canonical_label"].lower(),
  155. )
  156. )
  157. return {
  158. "subject": {
  159. "raw": subject,
  160. "normalized": subject_norm,
  161. "canonical_label": subject_resolution.get("canonical_label") or subject_norm,
  162. "mid": subject_resolution.get("mid"),
  163. "resolved_at": subject_resolution.get("resolved_at") or _now_iso(),
  164. "source": subject_resolution.get("source"),
  165. },
  166. "related": results[: max(1, limit)],
  167. }