from __future__ import annotations from collections import Counter from datetime import datetime, timezone from typing import Any from news_mcp.entity_normalize import normalize_entity from news_mcp.storage.sqlite_store import SQLiteClusterStore from news_mcp.trends_resolution import resolve_entity_via_trends from news_mcp.trends_related import get_related_topics def _now_iso() -> str: return datetime.now(timezone.utc).isoformat() def _collect_local_related( store: SQLiteClusterStore, subject_norm: str, subject_resolution: dict[str, Any], timeframe_hours: float, limit: int, ) -> list[tuple[str, int]]: clusters = store.get_latest_clusters_all_topics( ttl_hours=float(timeframe_hours), limit=max(limit * 20, 200), ) counter: Counter[str] = Counter() subject_terms = { subject_norm.strip().lower(), str(subject_resolution.get("normalized") or "").strip().lower(), str(subject_resolution.get("canonical_label") or "").strip().lower(), str(subject_resolution.get("mid") or "").strip().lower(), } subject_terms = {t for t in subject_terms if t} for cluster in clusters: # Match clusters by any of the resolved identity terms. haystack: list[str] = [] for ent in cluster.get("entities", []) or []: haystack.append(str(ent).strip().lower()) for res in cluster.get("entityResolutions", []) or []: if not isinstance(res, dict): continue for key in ("normalized", "canonical_label", "mid"): val = res.get(key) if val: haystack.append(str(val).strip().lower()) haystack_set = set([h for h in haystack if h]) if not (haystack_set & subject_terms): continue # Count other entities normalized. for ent in cluster.get("entities", []) or []: ent_norm = normalize_entity(ent) if not ent_norm: continue ent_key = ent_norm.strip().lower() if ent_key in subject_terms: continue counter[ent_norm] += 1 return counter.most_common(limit) def _collect_trends_related(subject_norm: str, subject_resolution: dict[str, Any], limit: int) -> list[dict[str, Any]]: topics = get_related_topics(subject_norm, limit=limit) if topics: return topics # Fallback to autocomplete candidates if related topics are unavailable. candidates = subject_resolution.get("candidates") or [] out = [] for cand in candidates: title = cand.get("title") if not title: continue out.append( { "canonical_label": title, "normalized": normalize_entity(title), "mid": cand.get("mid"), "type": cand.get("type"), } ) if len(out) >= limit: return out return out def related_recent_entities( store: SQLiteClusterStore, subject: str, timeframe_hours: float, limit: int, include_trends: bool = True, ) -> dict[str, Any]: subject_norm = normalize_entity(subject) if not subject_norm: return { "subject": {"raw": subject, "normalized": ""}, "related": [], } subject_resolution = resolve_entity_via_trends(subject_norm) store.record_entity_request(subject_norm, subject_resolution.get("mid")) store.upsert_entity_metadata( normalized_label=subject_norm, canonical_label=subject_resolution.get("canonical_label"), mid=subject_resolution.get("mid"), sources=[subject_resolution.get("source") or "resolver"], ) local_related = _collect_local_related( store=store, subject_norm=subject_norm, subject_resolution=subject_resolution, timeframe_hours=timeframe_hours, limit=limit, ) trends_related = _collect_trends_related(subject_norm, subject_resolution, limit) if include_trends else [] related_map: dict[str, dict[str, Any]] = {} def _entry(label: str) -> dict[str, Any]: key = label.strip().lower() if key not in related_map: related_map[key] = { "normalized": label, "canonical_label": label, "mid": None, "sources": set(), "scores": {}, } return related_map[key] for label, count in local_related: if not label: continue entry = _entry(label) entry["sources"].add("local") entry["scores"]["local_count"] = int(count) store.upsert_entity_metadata( normalized_label=label, canonical_label=label, mid=None, sources=["local"], ) # Only use enough trends results to fill remaining slots. remaining = max(0, limit - len(related_map)) for idx, cand in enumerate(trends_related[:remaining], start=1): label = cand.get("normalized") if not label: continue entry = _entry(label) entry["sources"].add("trends") entry["canonical_label"] = cand.get("canonical_label") or entry["canonical_label"] entry["mid"] = cand.get("mid") or entry["mid"] entry["scores"]["trends_rank"] = idx store.upsert_entity_metadata( normalized_label=label, canonical_label=cand.get("canonical_label"), mid=cand.get("mid"), sources=["trends"], ) results = list(related_map.values()) for item in results: item["sources"] = sorted(item["sources"]) results.sort( key=lambda item: ( -int(item["scores"].get("local_count", 0)), item["scores"].get("trends_rank", 9999), item["canonical_label"].lower(), ) ) return { "subject": { "raw": subject, "normalized": subject_norm, "canonical_label": subject_resolution.get("canonical_label") or subject_norm, "mid": subject_resolution.get("mid"), "resolved_at": subject_resolution.get("resolved_at") or _now_iso(), "source": subject_resolution.get("source"), }, "related": results[: max(1, limit)], }