|
@@ -0,0 +1,189 @@
|
|
|
|
|
+from __future__ import annotations
|
|
|
|
|
+
|
|
|
|
|
+from collections import Counter
|
|
|
|
|
+from datetime import datetime, timezone
|
|
|
|
|
+from typing import Any
|
|
|
|
|
+
|
|
|
|
|
+from news_mcp.entity_normalize import normalize_entity
|
|
|
|
|
+from news_mcp.storage.sqlite_store import SQLiteClusterStore
|
|
|
|
|
+from news_mcp.trends_resolution import resolve_entity_via_trends
|
|
|
|
|
+from news_mcp.trends_related import get_related_topics
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def _now_iso() -> str:
|
|
|
|
|
+ return datetime.now(timezone.utc).isoformat()
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def _collect_local_related(
|
|
|
|
|
+ store: SQLiteClusterStore,
|
|
|
|
|
+ subject_norm: str,
|
|
|
|
|
+ subject_resolution: dict[str, Any],
|
|
|
|
|
+ timeframe_hours: float,
|
|
|
|
|
+ limit: int,
|
|
|
|
|
+) -> list[tuple[str, int]]:
|
|
|
|
|
+ clusters = store.get_latest_clusters_all_topics(
|
|
|
|
|
+ ttl_hours=float(timeframe_hours),
|
|
|
|
|
+ limit=max(limit * 20, 200),
|
|
|
|
|
+ )
|
|
|
|
|
+ counter: Counter[str] = Counter()
|
|
|
|
|
+ subject_terms = {
|
|
|
|
|
+ subject_norm.strip().lower(),
|
|
|
|
|
+ str(subject_resolution.get("normalized") or "").strip().lower(),
|
|
|
|
|
+ str(subject_resolution.get("canonical_label") or "").strip().lower(),
|
|
|
|
|
+ str(subject_resolution.get("mid") or "").strip().lower(),
|
|
|
|
|
+ }
|
|
|
|
|
+ subject_terms = {t for t in subject_terms if t}
|
|
|
|
|
+ for cluster in clusters:
|
|
|
|
|
+ # Match clusters by any of the resolved identity terms.
|
|
|
|
|
+ haystack: list[str] = []
|
|
|
|
|
+ for ent in cluster.get("entities", []) or []:
|
|
|
|
|
+ haystack.append(str(ent).strip().lower())
|
|
|
|
|
+ for res in cluster.get("entityResolutions", []) or []:
|
|
|
|
|
+ if not isinstance(res, dict):
|
|
|
|
|
+ continue
|
|
|
|
|
+ for key in ("normalized", "canonical_label", "mid"):
|
|
|
|
|
+ val = res.get(key)
|
|
|
|
|
+ if val:
|
|
|
|
|
+ haystack.append(str(val).strip().lower())
|
|
|
|
|
+
|
|
|
|
|
+ haystack_set = set([h for h in haystack if h])
|
|
|
|
|
+ if not (haystack_set & subject_terms):
|
|
|
|
|
+ continue
|
|
|
|
|
+
|
|
|
|
|
+ # Count other entities normalized.
|
|
|
|
|
+ for ent in cluster.get("entities", []) or []:
|
|
|
|
|
+ ent_norm = normalize_entity(ent)
|
|
|
|
|
+ if not ent_norm:
|
|
|
|
|
+ continue
|
|
|
|
|
+ ent_key = ent_norm.strip().lower()
|
|
|
|
|
+ if ent_key in subject_terms:
|
|
|
|
|
+ continue
|
|
|
|
|
+ counter[ent_norm] += 1
|
|
|
|
|
+ return counter.most_common(limit)
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def _collect_trends_related(subject_norm: str, subject_resolution: dict[str, Any], limit: int) -> list[dict[str, Any]]:
|
|
|
|
|
+ topics = get_related_topics(subject_norm, limit=limit)
|
|
|
|
|
+ if topics:
|
|
|
|
|
+ return topics
|
|
|
|
|
+
|
|
|
|
|
+ # Fallback to autocomplete candidates if related topics are unavailable.
|
|
|
|
|
+ candidates = subject_resolution.get("candidates") or []
|
|
|
|
|
+ out = []
|
|
|
|
|
+ for cand in candidates:
|
|
|
|
|
+ title = cand.get("title")
|
|
|
|
|
+ if not title:
|
|
|
|
|
+ continue
|
|
|
|
|
+ out.append(
|
|
|
|
|
+ {
|
|
|
|
|
+ "canonical_label": title,
|
|
|
|
|
+ "normalized": normalize_entity(title),
|
|
|
|
|
+ "mid": cand.get("mid"),
|
|
|
|
|
+ "type": cand.get("type"),
|
|
|
|
|
+ }
|
|
|
|
|
+ )
|
|
|
|
|
+ if len(out) >= limit:
|
|
|
|
|
+ return out
|
|
|
|
|
+ return out
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def related_recent_entities(
|
|
|
|
|
+ store: SQLiteClusterStore,
|
|
|
|
|
+ subject: str,
|
|
|
|
|
+ timeframe_hours: float,
|
|
|
|
|
+ limit: int,
|
|
|
|
|
+ include_trends: bool = True,
|
|
|
|
|
+) -> dict[str, Any]:
|
|
|
|
|
+ subject_norm = normalize_entity(subject)
|
|
|
|
|
+ if not subject_norm:
|
|
|
|
|
+ return {
|
|
|
|
|
+ "subject": {"raw": subject, "normalized": ""},
|
|
|
|
|
+ "related": [],
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ subject_resolution = resolve_entity_via_trends(subject_norm)
|
|
|
|
|
+ store.record_entity_request(subject_norm, subject_resolution.get("mid"))
|
|
|
|
|
+ store.upsert_entity_metadata(
|
|
|
|
|
+ normalized_label=subject_norm,
|
|
|
|
|
+ canonical_label=subject_resolution.get("canonical_label"),
|
|
|
|
|
+ mid=subject_resolution.get("mid"),
|
|
|
|
|
+ sources=[subject_resolution.get("source") or "resolver"],
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ local_related = _collect_local_related(
|
|
|
|
|
+ store=store,
|
|
|
|
|
+ subject_norm=subject_norm,
|
|
|
|
|
+ subject_resolution=subject_resolution,
|
|
|
|
|
+ timeframe_hours=timeframe_hours,
|
|
|
|
|
+ limit=limit,
|
|
|
|
|
+ )
|
|
|
|
|
+ trends_related = _collect_trends_related(subject_norm, subject_resolution, limit) if include_trends else []
|
|
|
|
|
+
|
|
|
|
|
+ related_map: dict[str, dict[str, Any]] = {}
|
|
|
|
|
+
|
|
|
|
|
+ def _entry(label: str) -> dict[str, Any]:
|
|
|
|
|
+ key = label.strip().lower()
|
|
|
|
|
+ if key not in related_map:
|
|
|
|
|
+ related_map[key] = {
|
|
|
|
|
+ "normalized": label,
|
|
|
|
|
+ "canonical_label": label,
|
|
|
|
|
+ "mid": None,
|
|
|
|
|
+ "sources": set(),
|
|
|
|
|
+ "scores": {},
|
|
|
|
|
+ }
|
|
|
|
|
+ return related_map[key]
|
|
|
|
|
+
|
|
|
|
|
+ for label, count in local_related:
|
|
|
|
|
+ if not label:
|
|
|
|
|
+ continue
|
|
|
|
|
+ entry = _entry(label)
|
|
|
|
|
+ entry["sources"].add("local")
|
|
|
|
|
+ entry["scores"]["local_count"] = int(count)
|
|
|
|
|
+ store.upsert_entity_metadata(
|
|
|
|
|
+ normalized_label=label,
|
|
|
|
|
+ canonical_label=label,
|
|
|
|
|
+ mid=None,
|
|
|
|
|
+ sources=["local"],
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ # Only use enough trends results to fill remaining slots.
|
|
|
|
|
+ remaining = max(0, limit - len(related_map))
|
|
|
|
|
+ for idx, cand in enumerate(trends_related[:remaining], start=1):
|
|
|
|
|
+ label = cand.get("normalized")
|
|
|
|
|
+ if not label:
|
|
|
|
|
+ continue
|
|
|
|
|
+ entry = _entry(label)
|
|
|
|
|
+ entry["sources"].add("trends")
|
|
|
|
|
+ entry["canonical_label"] = cand.get("canonical_label") or entry["canonical_label"]
|
|
|
|
|
+ entry["mid"] = cand.get("mid") or entry["mid"]
|
|
|
|
|
+ entry["scores"]["trends_rank"] = idx
|
|
|
|
|
+ store.upsert_entity_metadata(
|
|
|
|
|
+ normalized_label=label,
|
|
|
|
|
+ canonical_label=cand.get("canonical_label"),
|
|
|
|
|
+ mid=cand.get("mid"),
|
|
|
|
|
+ sources=["trends"],
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ results = list(related_map.values())
|
|
|
|
|
+ for item in results:
|
|
|
|
|
+ item["sources"] = sorted(item["sources"])
|
|
|
|
|
+
|
|
|
|
|
+ results.sort(
|
|
|
|
|
+ key=lambda item: (
|
|
|
|
|
+ -int(item["scores"].get("local_count", 0)),
|
|
|
|
|
+ item["scores"].get("trends_rank", 9999),
|
|
|
|
|
+ item["canonical_label"].lower(),
|
|
|
|
|
+ )
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ return {
|
|
|
|
|
+ "subject": {
|
|
|
|
|
+ "raw": subject,
|
|
|
|
|
+ "normalized": subject_norm,
|
|
|
|
|
+ "canonical_label": subject_resolution.get("canonical_label") or subject_norm,
|
|
|
|
|
+ "mid": subject_resolution.get("mid"),
|
|
|
|
|
+ "resolved_at": subject_resolution.get("resolved_at") or _now_iso(),
|
|
|
|
|
+ "source": subject_resolution.get("source"),
|
|
|
|
|
+ },
|
|
|
|
|
+ "related": results[: max(1, limit)],
|
|
|
|
|
+ }
|