| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189 |
- from __future__ import annotations
- from collections import Counter
- from datetime import datetime, timezone
- from typing import Any
- from news_mcp.entity_normalize import normalize_entity
- from news_mcp.storage.sqlite_store import SQLiteClusterStore
- from news_mcp.trends_resolution import resolve_entity_via_trends
- from news_mcp.trends_related import get_related_topics
- def _now_iso() -> str:
- return datetime.now(timezone.utc).isoformat()
- def _collect_local_related(
- store: SQLiteClusterStore,
- subject_norm: str,
- subject_resolution: dict[str, Any],
- timeframe_hours: float,
- limit: int,
- ) -> list[tuple[str, int]]:
- clusters = store.get_latest_clusters_all_topics(
- ttl_hours=float(timeframe_hours),
- limit=max(limit * 20, 200),
- )
- counter: Counter[str] = Counter()
- subject_terms = {
- subject_norm.strip().lower(),
- str(subject_resolution.get("normalized") or "").strip().lower(),
- str(subject_resolution.get("canonical_label") or "").strip().lower(),
- str(subject_resolution.get("mid") or "").strip().lower(),
- }
- subject_terms = {t for t in subject_terms if t}
- for cluster in clusters:
- # Match clusters by any of the resolved identity terms.
- haystack: list[str] = []
- for ent in cluster.get("entities", []) or []:
- haystack.append(str(ent).strip().lower())
- for res in cluster.get("entityResolutions", []) or []:
- if not isinstance(res, dict):
- continue
- for key in ("normalized", "canonical_label", "mid"):
- val = res.get(key)
- if val:
- haystack.append(str(val).strip().lower())
- haystack_set = set([h for h in haystack if h])
- if not (haystack_set & subject_terms):
- continue
- # Count other entities normalized.
- for ent in cluster.get("entities", []) or []:
- ent_norm = normalize_entity(ent)
- if not ent_norm:
- continue
- ent_key = ent_norm.strip().lower()
- if ent_key in subject_terms:
- continue
- counter[ent_norm] += 1
- return counter.most_common(limit)
- def _collect_trends_related(subject_norm: str, subject_resolution: dict[str, Any], limit: int) -> list[dict[str, Any]]:
- topics = get_related_topics(subject_norm, limit=limit)
- if topics:
- return topics
- # Fallback to autocomplete candidates if related topics are unavailable.
- candidates = subject_resolution.get("candidates") or []
- out = []
- for cand in candidates:
- title = cand.get("title")
- if not title:
- continue
- out.append(
- {
- "canonical_label": title,
- "normalized": normalize_entity(title),
- "mid": cand.get("mid"),
- "type": cand.get("type"),
- }
- )
- if len(out) >= limit:
- return out
- return out
- def related_recent_entities(
- store: SQLiteClusterStore,
- subject: str,
- timeframe_hours: float,
- limit: int,
- include_trends: bool = True,
- ) -> dict[str, Any]:
- subject_norm = normalize_entity(subject)
- if not subject_norm:
- return {
- "subject": {"raw": subject, "normalized": ""},
- "related": [],
- }
- subject_resolution = resolve_entity_via_trends(subject_norm)
- store.record_entity_request(subject_norm, subject_resolution.get("mid"))
- store.upsert_entity_metadata(
- normalized_label=subject_norm,
- canonical_label=subject_resolution.get("canonical_label"),
- mid=subject_resolution.get("mid"),
- sources=[subject_resolution.get("source") or "resolver"],
- )
- local_related = _collect_local_related(
- store=store,
- subject_norm=subject_norm,
- subject_resolution=subject_resolution,
- timeframe_hours=timeframe_hours,
- limit=limit,
- )
- trends_related = _collect_trends_related(subject_norm, subject_resolution, limit) if include_trends else []
- related_map: dict[str, dict[str, Any]] = {}
- def _entry(label: str) -> dict[str, Any]:
- key = label.strip().lower()
- if key not in related_map:
- related_map[key] = {
- "normalized": label,
- "canonical_label": label,
- "mid": None,
- "sources": set(),
- "scores": {},
- }
- return related_map[key]
- for label, count in local_related:
- if not label:
- continue
- entry = _entry(label)
- entry["sources"].add("local")
- entry["scores"]["local_count"] = int(count)
- store.upsert_entity_metadata(
- normalized_label=label,
- canonical_label=label,
- mid=None,
- sources=["local"],
- )
- # Only use enough trends results to fill remaining slots.
- remaining = max(0, limit - len(related_map))
- for idx, cand in enumerate(trends_related[:remaining], start=1):
- label = cand.get("normalized")
- if not label:
- continue
- entry = _entry(label)
- entry["sources"].add("trends")
- entry["canonical_label"] = cand.get("canonical_label") or entry["canonical_label"]
- entry["mid"] = cand.get("mid") or entry["mid"]
- entry["scores"]["trends_rank"] = idx
- store.upsert_entity_metadata(
- normalized_label=label,
- canonical_label=cand.get("canonical_label"),
- mid=cand.get("mid"),
- sources=["trends"],
- )
- results = list(related_map.values())
- for item in results:
- item["sources"] = sorted(item["sources"])
- results.sort(
- key=lambda item: (
- -int(item["scores"].get("local_count", 0)),
- item["scores"].get("trends_rank", 9999),
- item["canonical_label"].lower(),
- )
- )
- return {
- "subject": {
- "raw": subject,
- "normalized": subject_norm,
- "canonical_label": subject_resolution.get("canonical_label") or subject_norm,
- "mid": subject_resolution.get("mid"),
- "resolved_at": subject_resolution.get("resolved_at") or _now_iso(),
- "source": subject_resolution.get("source"),
- },
- "related": results[: max(1, limit)],
- }
|