Browse Source

news-mcp: add related_recent_entities with entity metadata

Lukas Goldschmidt 1 month ago
parent
commit
af0049a967
6 changed files with 437 additions and 86 deletions
  1. 3 3
      README.md
  2. 25 83
      news_mcp/mcp_server_fastmcp.py
  3. 189 0
      news_mcp/related_entities.py
  4. 100 0
      news_mcp/storage/sqlite_store.py
  5. 118 0
      news_mcp/trends_related.py
  6. 2 0
      run.sh

+ 3 - 3
README.md

@@ -52,9 +52,9 @@ Health:
 5) `get_news_sentiment(entity, timeframe)`
 - aggregates sentiment around an entity from cached enriched clusters
 
-6) `get_related_entities(subject, timeframe, limit)`
-- entity-only co-occurrence neighborhood: for a given subject entity, returns related entities with aggregated
-  `count`, `avg_importance`, and `sentiment`
+6) `get_related_recent_entities(subject, timeframe, limit, include_trends=true)`
+- merges recent co-occurrence data from cached clusters with Google Trends suggestions and returns
+  related entities (with `mid` when available) plus source/score metadata
 
 ### Entity aliasing
 

+ 25 - 83
news_mcp/mcp_server_fastmcp.py

@@ -25,6 +25,7 @@ from news_mcp.enrichment.llm_enrich import summarize_cluster_llm
 from news_mcp.trends_resolution import resolve_entity_via_trends
 from news_mcp.llm import active_llm_config
 from news_mcp.entity_normalize import normalize_query
+from news_mcp.related_entities import related_recent_entities
 
 
 mcp = FastMCP(
@@ -209,6 +210,22 @@ async def get_events_for_entity(entity: str, limit: int = 10, timeframe: str = "
     return out
 
 
+@mcp.tool(description="Return entities most commonly associated with the subject in recent clusters, optionally blended with Google Trends suggestions.")
+async def get_related_recent_entities(subject: str, timeframe: str = "72h", limit: int = 10, include_trends: bool = True):
+    limit = max(1, min(int(limit), 25))
+    hours = _parse_timeframe_to_hours(timeframe)
+    include_trends_bool = str(include_trends).strip().lower() not in {"false", "0", "no"}
+    store = SQLiteClusterStore(DB_PATH)
+    result = related_recent_entities(
+        store=store,
+        subject=subject,
+        timeframe_hours=hours,
+        limit=limit,
+        include_trends=include_trends_bool,
+    )
+    return result
+
+
 @mcp.tool(description="Investigate one cluster in depth and return a concise LLM-written explanation plus key facts.")
 async def get_event_summary(event_id: str, include_articles: bool = False):
     store = SQLiteClusterStore(DB_PATH)
@@ -463,88 +480,6 @@ def _parse_timeframe_to_hours(timeframe: str) -> int:
         return 24
 
 
-@mcp.tool(
-    description="Investigate which entities tend to appear alongside a subject entity in recent clusters, based on co-occurrence."
-)
-async def get_related_entities(subject: str, timeframe: str = "24h", limit: int = 10):
-    store = SQLiteClusterStore(DB_PATH)
-    limit = max(1, min(int(limit), 30))
-
-    subj = normalize_query(subject).strip().lower()
-    if not subj:
-        return []
-
-    resolved = resolve_entity_via_trends(subj)
-    query_terms = {
-        subj,
-        str(resolved.get("normalized") or "").strip().lower(),
-        str(resolved.get("canonical_label") or "").strip().lower(),
-        str(resolved.get("mid") or "").strip().lower(),
-    }
-    query_terms = {q for q in query_terms if q}
-
-    hours = _parse_timeframe_to_hours(timeframe)
-    clusters = store.get_latest_clusters_all_topics(ttl_hours=hours, limit=500)
-
-    # Aggregate related metrics per entity.
-    rel_count = Counter()
-    rel_imp_sum = Counter()
-    rel_sent_sum = Counter()
-    rel_sent_n = Counter()
-
-    for c in clusters:
-        haystack = _cluster_entity_haystack(c)
-        if not any(term in item for item in haystack for term in query_terms):
-            continue
-
-        ents = [str(e).strip().lower() for e in (c.get("entities", []) or []) if str(e).strip()]
-        # remove generic/meta-ish short tokens conservatively
-        ents = [e for e in ents if len(e) >= 4]
-        for e in ents:
-            if e in query_terms:
-                continue
-            rel_count[e] += 1
-            try:
-                rel_imp_sum[e] += float(c.get("importance", 0.0) or 0.0)
-            except Exception:
-                pass
-
-            # sentiment aggregation based on sentimentScore if available.
-            s = c.get("sentimentScore")
-            if s is not None:
-                try:
-                    rel_sent_sum[e] += float(s)
-                    rel_sent_n[e] += 1
-                except Exception:
-                    pass
-
-    # Sort by count, then avg importance.
-    items = []
-    for ent, cnt in rel_count.most_common():
-        avg_imp = rel_imp_sum[ent] / max(1, cnt)
-        avg_score = rel_sent_sum[ent] / max(1, rel_sent_n[ent]) if rel_sent_n[ent] else 0.0
-        if avg_score >= 0.15:
-            sentiment = "positive"
-        elif avg_score <= -0.15:
-            sentiment = "negative"
-        else:
-            sentiment = "neutral"
-
-        items.append(
-            {
-                "entity": ent,
-                "count": cnt,
-                "avg_importance": round(avg_imp, 3),
-                "sentiment": sentiment,
-                "score": round(avg_score, 3),
-            }
-        )
-        if len(items) >= limit:
-            break
-
-    return items
-
-
 app = FastAPI(title="News MCP Server")
 
 logger = logging.getLogger("news_mcp.startup")
@@ -594,7 +529,14 @@ def root():
         "status": "ok",
         "transport": "fastmcp+sse",
         "mount": "/mcp",
-        "tools": ["get_latest_events", "get_events_for_entity", "get_event_summary", "detect_emerging_topics"],
+        "tools": [
+            "get_latest_events",
+            "get_events_for_entity",
+            "get_event_summary",
+            "detect_emerging_topics",
+            "get_news_sentiment",
+            "get_related_recent_entities",
+        ],
         "refresh": {
             "enabled": NEWS_BACKGROUND_REFRESH_ENABLED,
             "interval_seconds": NEWS_REFRESH_INTERVAL_SECONDS,

+ 189 - 0
news_mcp/related_entities.py

@@ -0,0 +1,189 @@
+from __future__ import annotations
+
+from collections import Counter
+from datetime import datetime, timezone
+from typing import Any
+
+from news_mcp.entity_normalize import normalize_entity
+from news_mcp.storage.sqlite_store import SQLiteClusterStore
+from news_mcp.trends_resolution import resolve_entity_via_trends
+from news_mcp.trends_related import get_related_topics
+
+
+def _now_iso() -> str:
+    return datetime.now(timezone.utc).isoformat()
+
+
+def _collect_local_related(
+    store: SQLiteClusterStore,
+    subject_norm: str,
+    subject_resolution: dict[str, Any],
+    timeframe_hours: float,
+    limit: int,
+) -> list[tuple[str, int]]:
+    clusters = store.get_latest_clusters_all_topics(
+        ttl_hours=float(timeframe_hours),
+        limit=max(limit * 20, 200),
+    )
+    counter: Counter[str] = Counter()
+    subject_terms = {
+        subject_norm.strip().lower(),
+        str(subject_resolution.get("normalized") or "").strip().lower(),
+        str(subject_resolution.get("canonical_label") or "").strip().lower(),
+        str(subject_resolution.get("mid") or "").strip().lower(),
+    }
+    subject_terms = {t for t in subject_terms if t}
+    for cluster in clusters:
+        # Match clusters by any of the resolved identity terms.
+        haystack: list[str] = []
+        for ent in cluster.get("entities", []) or []:
+            haystack.append(str(ent).strip().lower())
+        for res in cluster.get("entityResolutions", []) or []:
+            if not isinstance(res, dict):
+                continue
+            for key in ("normalized", "canonical_label", "mid"):
+                val = res.get(key)
+                if val:
+                    haystack.append(str(val).strip().lower())
+
+        haystack_set = set([h for h in haystack if h])
+        if not (haystack_set & subject_terms):
+            continue
+
+        # Count other entities normalized.
+        for ent in cluster.get("entities", []) or []:
+            ent_norm = normalize_entity(ent)
+            if not ent_norm:
+                continue
+            ent_key = ent_norm.strip().lower()
+            if ent_key in subject_terms:
+                continue
+            counter[ent_norm] += 1
+    return counter.most_common(limit)
+
+
+def _collect_trends_related(subject_norm: str, subject_resolution: dict[str, Any], limit: int) -> list[dict[str, Any]]:
+    topics = get_related_topics(subject_norm, limit=limit)
+    if topics:
+        return topics
+
+    # Fallback to autocomplete candidates if related topics are unavailable.
+    candidates = subject_resolution.get("candidates") or []
+    out = []
+    for cand in candidates:
+        title = cand.get("title")
+        if not title:
+            continue
+        out.append(
+            {
+                "canonical_label": title,
+                "normalized": normalize_entity(title),
+                "mid": cand.get("mid"),
+                "type": cand.get("type"),
+            }
+        )
+        if len(out) >= limit:
+            return out
+    return out
+
+
+def related_recent_entities(
+    store: SQLiteClusterStore,
+    subject: str,
+    timeframe_hours: float,
+    limit: int,
+    include_trends: bool = True,
+) -> dict[str, Any]:
+    subject_norm = normalize_entity(subject)
+    if not subject_norm:
+        return {
+            "subject": {"raw": subject, "normalized": ""},
+            "related": [],
+        }
+
+    subject_resolution = resolve_entity_via_trends(subject_norm)
+    store.record_entity_request(subject_norm, subject_resolution.get("mid"))
+    store.upsert_entity_metadata(
+        normalized_label=subject_norm,
+        canonical_label=subject_resolution.get("canonical_label"),
+        mid=subject_resolution.get("mid"),
+        sources=[subject_resolution.get("source") or "resolver"],
+    )
+
+    local_related = _collect_local_related(
+        store=store,
+        subject_norm=subject_norm,
+        subject_resolution=subject_resolution,
+        timeframe_hours=timeframe_hours,
+        limit=limit,
+    )
+    trends_related = _collect_trends_related(subject_norm, subject_resolution, limit) if include_trends else []
+
+    related_map: dict[str, dict[str, Any]] = {}
+
+    def _entry(label: str) -> dict[str, Any]:
+        key = label.strip().lower()
+        if key not in related_map:
+            related_map[key] = {
+                "normalized": label,
+                "canonical_label": label,
+                "mid": None,
+                "sources": set(),
+                "scores": {},
+            }
+        return related_map[key]
+
+    for label, count in local_related:
+        if not label:
+            continue
+        entry = _entry(label)
+        entry["sources"].add("local")
+        entry["scores"]["local_count"] = int(count)
+        store.upsert_entity_metadata(
+            normalized_label=label,
+            canonical_label=label,
+            mid=None,
+            sources=["local"],
+        )
+
+    # Only use enough trends results to fill remaining slots.
+    remaining = max(0, limit - len(related_map))
+    for idx, cand in enumerate(trends_related[:remaining], start=1):
+        label = cand.get("normalized")
+        if not label:
+            continue
+        entry = _entry(label)
+        entry["sources"].add("trends")
+        entry["canonical_label"] = cand.get("canonical_label") or entry["canonical_label"]
+        entry["mid"] = cand.get("mid") or entry["mid"]
+        entry["scores"]["trends_rank"] = idx
+        store.upsert_entity_metadata(
+            normalized_label=label,
+            canonical_label=cand.get("canonical_label"),
+            mid=cand.get("mid"),
+            sources=["trends"],
+        )
+
+    results = list(related_map.values())
+    for item in results:
+        item["sources"] = sorted(item["sources"])
+
+    results.sort(
+        key=lambda item: (
+            -int(item["scores"].get("local_count", 0)),
+            item["scores"].get("trends_rank", 9999),
+            item["canonical_label"].lower(),
+        )
+    )
+
+    return {
+        "subject": {
+            "raw": subject,
+            "normalized": subject_norm,
+            "canonical_label": subject_resolution.get("canonical_label") or subject_norm,
+            "mid": subject_resolution.get("mid"),
+            "resolved_at": subject_resolution.get("resolved_at") or _now_iso(),
+            "source": subject_resolution.get("source"),
+        },
+        "related": results[: max(1, limit)],
+    }

+ 100 - 0
news_mcp/storage/sqlite_store.py

@@ -134,6 +134,30 @@ class SQLiteClusterStore:
                 "CREATE INDEX IF NOT EXISTS idx_clusters_updated_at ON clusters(updated_at)"
             )
 
+            try:
+                cur = conn.execute("PRAGMA table_info(entity_metadata)")
+                cols = [row[1] for row in cur.fetchall()]
+                if cols and "entity_id" not in cols:
+                    conn.execute("DROP TABLE entity_metadata")
+            except sqlite3.OperationalError:
+                pass
+            conn.execute(
+                """
+                CREATE TABLE IF NOT EXISTS entity_metadata (
+                  entity_id TEXT PRIMARY KEY,
+                  normalized_label TEXT NOT NULL,
+                  canonical_label TEXT,
+                  mid TEXT,
+                  sources_json TEXT,
+                  updated_at TEXT,
+                  last_requested_at TEXT
+                )
+                """
+            )
+            conn.execute(
+                "CREATE UNIQUE INDEX IF NOT EXISTS idx_entity_metadata_mid ON entity_metadata(mid) WHERE mid IS NOT NULL"
+            )
+
             conn.execute(
                 """
                 CREATE TABLE IF NOT EXISTS feed_state (
@@ -347,6 +371,82 @@ class SQLiteClusterStore:
                 (key, value),
             )
 
+    def upsert_entity_metadata(
+        self,
+        normalized_label: str,
+        canonical_label: str | None = None,
+        mid: str | None = None,
+        sources: list[str] | None = None,
+    ) -> None:
+        normalized_label = str(normalized_label or "").strip()
+        if not normalized_label:
+            return
+        canonical_label = str(canonical_label).strip() if canonical_label else None
+        mid = str(mid).strip() if mid else None
+        entity_id = mid if mid else f"local:{normalized_label}"
+        sources = sorted({s for s in (sources or []) if s})
+        sources_json = json.dumps(sources, ensure_ascii=False)
+        now = datetime.now(timezone.utc).isoformat()
+        with self._conn() as conn:
+            conn.execute(
+                """
+                INSERT INTO entity_metadata(entity_id, normalized_label, canonical_label, mid, sources_json, updated_at)
+                VALUES(?,?,?,?,?,?)
+                ON CONFLICT(entity_id) DO UPDATE SET
+                  canonical_label=excluded.canonical_label,
+                  mid=excluded.mid,
+                  sources_json=excluded.sources_json,
+                  updated_at=excluded.updated_at
+                """,
+                (entity_id, normalized_label, canonical_label, mid, sources_json, now),
+            )
+
+    def get_entity_metadata(self, normalized_label: str) -> dict[str, Any] | None:
+        normalized_label = str(normalized_label or "").strip()
+        if not normalized_label:
+            return None
+        with self._conn() as conn:
+            cur = conn.execute(
+                "SELECT entity_id, canonical_label, mid, sources_json, updated_at, last_requested_at FROM entity_metadata WHERE normalized_label=?",
+                (normalized_label,),
+            )
+            row = cur.fetchone()
+            if not row:
+                return None
+            sources = []
+            if row[2]:
+                try:
+                    sources = json.loads(row[2])
+                except Exception:
+                    sources = []
+            return {
+                "entity_id": row[0],
+                "normalized_label": normalized_label,
+                "canonical_label": row[1],
+                "mid": row[2],
+                "sources": sources,
+                "updated_at": row[4],
+                "last_requested_at": row[5],
+            }
+
+    def record_entity_request(self, normalized_label: str, mid: str | None = None) -> None:
+        normalized_label = str(normalized_label or "").strip()
+        if not normalized_label:
+            return
+        mid = str(mid).strip() if mid else None
+        entity_id = mid if mid else f"local:{normalized_label}"
+        now = datetime.now(timezone.utc).isoformat()
+        with self._conn() as conn:
+            conn.execute(
+                """
+                INSERT INTO entity_metadata(entity_id, normalized_label, canonical_label, mid, sources_json, updated_at, last_requested_at)
+                VALUES(?,?,?,?,?,?,?)
+                ON CONFLICT(entity_id) DO UPDATE SET
+                  last_requested_at=excluded.last_requested_at
+                """,
+                (entity_id, normalized_label, None, mid, json.dumps([], ensure_ascii=False), now, now),
+            )
+
     def prune_clusters(self, retention_days: float) -> int:
         retention_days = float(retention_days)
         if retention_days <= 0:

+ 118 - 0
news_mcp/trends_related.py

@@ -0,0 +1,118 @@
+from __future__ import annotations
+
+import json
+from functools import lru_cache
+from typing import Any
+
+import httpx
+
+from news_mcp.entity_normalize import normalize_entity
+
+
+class GoogleTrendsRelatedError(RuntimeError):
+    pass
+
+
+class GoogleTrendsRelatedProvider:
+    _EXPLORE_URL = "https://trends.google.com/trends/api/explore"
+    _RELATED_URL = "https://trends.google.com/trends/api/widgetdata/relatedsearches/"
+
+    def __init__(self, *, hl: str = "en-US", tz: int = 120, timeout: float = 10.0):
+        self.hl = hl
+        self.tz = tz
+        self.timeout = timeout
+        self._headers = {
+            "User-Agent": (
+                "Mozilla/5.0 (X11; Linux x86_64) "
+                "AppleWebKit/537.36 (KHTML, like Gecko) "
+                "Chrome/135.0.0.0 Safari/537.36"
+            ),
+            "Accept": "application/json,text/javascript,*/*;q=0.1",
+        }
+
+    def _request(self, url: str, params: dict[str, Any]) -> dict[str, Any]:
+        response = httpx.get(
+            url,
+            params=params,
+            headers=self._headers,
+            timeout=self.timeout,
+            follow_redirects=True,
+        )
+        response.raise_for_status()
+        text = response.text.strip()
+        if text.startswith(")]}',"):
+            text = text[5:]
+        return json.loads(text)
+
+    def _fetch_widget(self, keyword: str, time_window: str) -> dict[str, Any] | None:
+        req_payload = {
+            "comparisonItem": [
+                {
+                    "keyword": keyword,
+                    "geo": "",
+                    "time": time_window,
+                }
+            ],
+            "category": 0,
+            "property": "",
+        }
+        params = {
+            "hl": self.hl,
+            "tz": str(self.tz),
+            "req": json.dumps(req_payload, separators=(",", ":")),
+            "property": "",
+        }
+        data = self._request(self._EXPLORE_URL, params)
+        widgets = (data.get("widgets") or []) if isinstance(data, dict) else []
+        for widget in widgets:
+            if widget.get("id") == "RELATED_QUERIES":
+                return widget
+        return None
+
+    def related_topics(self, keyword: str, *, time_window: str = "now 7-d", limit: int = 10) -> list[dict[str, Any]]:
+        widget = self._fetch_widget(keyword, time_window)
+        if not widget:
+            return []
+        request_payload = widget.get("request") or {}
+        token = widget.get("token")
+        if not request_payload or not token:
+            return []
+        params = {
+            "hl": self.hl,
+            "tz": str(self.tz),
+            "req": json.dumps(request_payload, separators=(",", ":")),
+            "token": token,
+        }
+        data = self._request(self._RELATED_URL, params)
+        ranked = []
+        ranked_lists = data.get("default", {}).get("rankedList", []) if isinstance(data, dict) else []
+        for ranked_list in ranked_lists:
+            for item in ranked_list.get("rankedKeyword", []):
+                topic = item.get("topic") or {}
+                title = topic.get("title") or item.get("query")
+                if not title:
+                    continue
+                ranked.append(
+                    {
+                        "canonical_label": title,
+                        "normalized": normalize_entity(title),
+                        "mid": topic.get("mid"),
+                        "type": topic.get("type"),
+                        "value": item.get("value"),
+                    }
+                )
+                if len(ranked) >= limit:
+                    return ranked
+        return ranked
+
+
+@lru_cache(maxsize=256)
+def get_related_topics(keyword: str, *, time_window: str = "now 7-d", limit: int = 10) -> list[dict[str, Any]]:
+    normalized = normalize_entity(keyword)
+    if not normalized:
+        return []
+    provider = GoogleTrendsRelatedProvider()
+    try:
+        return provider.related_topics(normalized, time_window=time_window, limit=limit)
+    except Exception:
+        return []

+ 2 - 0
run.sh

@@ -22,6 +22,8 @@ if [ -z "$UVICORN_BIN" ]; then
   fi
 fi
 
+export PYTHONPATH="$(pwd):${PYTHONPATH:-}"
+
 nohup "$UVICORN_BIN" "$APP_MODULE" --host 0.0.0.0 --port "$PORT" > "$LOGFILE" 2>&1 &
 echo $! > "$PIDFILE"
 echo "Uvicorn started on port $PORT (PID $(cat "$PIDFILE"))"