related_entities.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211
  1. from __future__ import annotations
  2. from collections import Counter
  3. from datetime import datetime, timezone
  4. from typing import Any
  5. from news_mcp.entity_normalize import normalize_entity
  6. from news_mcp.storage.sqlite_store import SQLiteClusterStore
  7. from news_mcp.trends_resolution import resolve_entity_via_trends
  8. from news_mcp.trends_related import get_related_topics
  9. def _now_iso() -> str:
  10. return datetime.now(timezone.utc).isoformat()
  11. def _collect_local_related(
  12. store: SQLiteClusterStore,
  13. subject_norm: str,
  14. subject_resolution: dict[str, Any],
  15. timeframe_hours: float,
  16. limit: int,
  17. ) -> list[tuple[str, int]]:
  18. clusters = store.get_latest_clusters_all_topics(
  19. ttl_hours=float(timeframe_hours),
  20. limit=max(limit * 20, 200),
  21. )
  22. counter: Counter[str] = Counter()
  23. subject_terms = {
  24. subject_norm.strip().lower(),
  25. str(subject_resolution.get("normalized") or "").strip().lower(),
  26. str(subject_resolution.get("canonical_label") or "").strip().lower(),
  27. str(subject_resolution.get("mid") or "").strip().lower(),
  28. }
  29. subject_terms = {t for t in subject_terms if t}
  30. for cluster in clusters:
  31. # Match clusters by any of the resolved identity terms.
  32. haystack: list[str] = []
  33. for ent in cluster.get("entities", []) or []:
  34. haystack.append(str(ent).strip().lower())
  35. for res in cluster.get("entityResolutions", []) or []:
  36. if not isinstance(res, dict):
  37. continue
  38. for key in ("normalized", "canonical_label", "mid"):
  39. val = res.get(key)
  40. if val:
  41. haystack.append(str(val).strip().lower())
  42. haystack_set = set([h for h in haystack if h])
  43. if not (haystack_set & subject_terms):
  44. continue
  45. # Collect entities already present in this cluster (by normalized form)
  46. # so we can skip keywords that are already counted as entities.
  47. ents_in_cluster = {str(e).strip().lower() for e in (cluster.get("entities", []) or []) if str(e).strip()}
  48. # Count other entities (existing behavior).
  49. for ent in cluster.get("entities", []) or []:
  50. ent_norm = normalize_entity(ent)
  51. if not ent_norm:
  52. continue
  53. ent_key = ent_norm.strip().lower()
  54. if ent_key in subject_terms:
  55. continue
  56. counter[ent_norm] += 1
  57. # Count keywords that are NOT already entities in this cluster.
  58. # Keywords are LLM-curated thematic descriptors — they capture
  59. # subject-matter signals that may not be named entities.
  60. for kw in cluster.get("keywords", []) or []:
  61. kw_norm = str(kw).strip()
  62. if not kw_norm:
  63. continue
  64. kw_key = kw_norm.lower()
  65. # Skip if this keyword is already an entity in this cluster
  66. # (entity signal is higher quality — has MID, canonical_label).
  67. if kw_key in ents_in_cluster:
  68. continue
  69. # Skip if it matches the subject itself
  70. if kw_key in subject_terms:
  71. continue
  72. counter[kw_norm] += 1
  73. return counter.most_common(limit)
  74. def _collect_trends_related(subject_norm: str, subject_resolution: dict[str, Any], limit: int) -> list[dict[str, Any]]:
  75. topics = get_related_topics(subject_norm, limit=limit)
  76. if topics:
  77. return topics
  78. # Fallback to autocomplete candidates if related topics are unavailable.
  79. candidates = subject_resolution.get("candidates") or []
  80. out = []
  81. for cand in candidates:
  82. title = cand.get("title")
  83. if not title:
  84. continue
  85. out.append(
  86. {
  87. "canonical_label": title,
  88. "normalized": normalize_entity(title),
  89. "mid": cand.get("mid"),
  90. "type": cand.get("type"),
  91. }
  92. )
  93. if len(out) >= limit:
  94. return out
  95. return out
  96. def related_recent_entities(
  97. store: SQLiteClusterStore,
  98. subject: str,
  99. timeframe_hours: float,
  100. limit: int,
  101. include_trends: bool = True,
  102. ) -> dict[str, Any]:
  103. subject_norm = normalize_entity(subject)
  104. if not subject_norm:
  105. return {
  106. "subject": {"raw": subject, "normalized": ""},
  107. "related": [],
  108. }
  109. subject_resolution = resolve_entity_via_trends(subject_norm)
  110. store.record_entity_request(subject_norm, subject_resolution.get("mid"))
  111. store.upsert_entity_metadata(
  112. normalized_label=subject_norm,
  113. canonical_label=subject_resolution.get("canonical_label"),
  114. mid=subject_resolution.get("mid"),
  115. sources=[subject_resolution.get("source") or "resolver"],
  116. )
  117. local_related = _collect_local_related(
  118. store=store,
  119. subject_norm=subject_norm,
  120. subject_resolution=subject_resolution,
  121. timeframe_hours=timeframe_hours,
  122. limit=limit,
  123. )
  124. trends_related = _collect_trends_related(subject_norm, subject_resolution, limit) if include_trends else []
  125. related_map: dict[str, dict[str, Any]] = {}
  126. def _entry(label: str) -> dict[str, Any]:
  127. key = label.strip().lower()
  128. if key not in related_map:
  129. related_map[key] = {
  130. "normalized": label,
  131. "canonical_label": label,
  132. "mid": None,
  133. "sources": set(),
  134. "scores": {},
  135. }
  136. return related_map[key]
  137. for label, count in local_related:
  138. if not label:
  139. continue
  140. entry = _entry(label)
  141. entry["sources"].add("local")
  142. entry["scores"]["local_count"] = int(count)
  143. store.upsert_entity_metadata(
  144. normalized_label=label,
  145. canonical_label=label,
  146. mid=None,
  147. sources=["local"],
  148. )
  149. # Only use enough trends results to fill remaining slots.
  150. remaining = max(0, limit - len(related_map))
  151. for idx, cand in enumerate(trends_related[:remaining], start=1):
  152. label = cand.get("normalized")
  153. if not label:
  154. continue
  155. entry = _entry(label)
  156. entry["sources"].add("trends")
  157. entry["canonical_label"] = cand.get("canonical_label") or entry["canonical_label"]
  158. entry["mid"] = cand.get("mid") or entry["mid"]
  159. entry["scores"]["trends_rank"] = idx
  160. store.upsert_entity_metadata(
  161. normalized_label=label,
  162. canonical_label=cand.get("canonical_label"),
  163. mid=cand.get("mid"),
  164. sources=["trends"],
  165. )
  166. results = list(related_map.values())
  167. for item in results:
  168. item["sources"] = sorted(item["sources"])
  169. results.sort(
  170. key=lambda item: (
  171. -int(item["scores"].get("local_count", 0)),
  172. item["scores"].get("trends_rank", 9999),
  173. item["canonical_label"].lower(),
  174. )
  175. )
  176. return {
  177. "subject": {
  178. "raw": subject,
  179. "normalized": subject_norm,
  180. "canonical_label": subject_resolution.get("canonical_label") or subject_norm,
  181. "mid": subject_resolution.get("mid"),
  182. "resolved_at": subject_resolution.get("resolved_at") or _now_iso(),
  183. "source": subject_resolution.get("source"),
  184. },
  185. "related": results[: max(1, limit)],
  186. }