mcp_server_fastmcp.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375
  1. from __future__ import annotations
  2. from fastapi import FastAPI
  3. from mcp.server.fastmcp import FastMCP
  4. from mcp.server.transport_security import TransportSecuritySettings
  5. from news_mcp.config import CLUSTERS_TTL_HOURS, DEFAULT_TOPICS, DB_PATH
  6. from news_mcp.config import NEWS_REFRESH_INTERVAL_SECONDS, NEWS_BACKGROUND_REFRESH_ENABLED, NEWS_BACKGROUND_REFRESH_ON_START
  7. from news_mcp.jobs.poller import refresh_clusters
  8. from news_mcp.storage.sqlite_store import SQLiteClusterStore
  9. from news_mcp.enrichment.llm_enrich import summarize_cluster_groq
  10. from news_mcp.trends_resolution import resolve_entity_via_trends
  11. from news_mcp.llm import active_llm_config
  12. from news_mcp.entity_normalize import normalize_query
  13. from collections import Counter
  14. import logging
  15. mcp = FastMCP(
  16. "news-mcp",
  17. transport_security=TransportSecuritySettings(enable_dns_rebinding_protection=False),
  18. )
  19. def _cluster_entity_haystack(cluster: dict) -> list[str]:
  20. """Collect the normalized entity clues attached to a cluster."""
  21. values: list[str] = []
  22. for ent in cluster.get("entities", []) or []:
  23. values.append(str(ent).strip().lower())
  24. for res in cluster.get("entityResolutions", []) or []:
  25. if not isinstance(res, dict):
  26. continue
  27. for key in ("normalized", "canonical_label", "mid"):
  28. val = res.get(key)
  29. if val:
  30. values.append(str(val).strip().lower())
  31. return [v for v in values if v]
  32. @mcp.tool(description="What is happening right now? Return the latest deduplicated news clusters for a topic.")
  33. async def get_latest_events(topic: str = "crypto", limit: int = 5):
  34. limit = max(1, min(int(limit), 20))
  35. # If the caller passes an entity-like value, resolve it and use the canonical
  36. # entity as the query lens. Otherwise keep the original topic path.
  37. topic_norm = normalize_query(topic).lower()
  38. resolved = resolve_entity_via_trends(topic_norm)
  39. allowed = {t.lower() for t in DEFAULT_TOPICS}
  40. is_topic = topic_norm in allowed
  41. query_terms = {
  42. topic_norm,
  43. str(resolved.get("normalized") or "").strip().lower(),
  44. str(resolved.get("canonical_label") or "").strip().lower(),
  45. str(resolved.get("mid") or "").strip().lower(),
  46. }
  47. query_terms = {q for q in query_terms if q}
  48. store = SQLiteClusterStore(DB_PATH)
  49. if is_topic:
  50. # Cache-first: only refresh if we currently have no fresh clusters for this topic.
  51. clusters = store.get_latest_clusters(topic=topic_norm, ttl_hours=CLUSTERS_TTL_HOURS, limit=limit)
  52. if not clusters:
  53. await refresh_clusters(topic=topic_norm, limit=200)
  54. clusters = store.get_latest_clusters(topic=topic_norm, ttl_hours=CLUSTERS_TTL_HOURS, limit=limit)
  55. else:
  56. # Entity-aware mode: search recent clusters across all topics and match by
  57. # raw entity, canonical label, or MID.
  58. clusters = store.get_latest_clusters_all_topics(ttl_hours=CLUSTERS_TTL_HOURS, limit=limit * 8)
  59. filtered = []
  60. for c in clusters:
  61. haystack = _cluster_entity_haystack(c)
  62. if any(any(term in item for item in haystack) for term in query_terms):
  63. filtered.append(c)
  64. if len(filtered) >= limit:
  65. break
  66. clusters = filtered
  67. # Ensure the response is compact and agent-friendly.
  68. clusters_sorted = sorted(clusters, key=lambda x: float(x.get("importance", 0.0)), reverse=True)
  69. out = []
  70. for c in clusters_sorted:
  71. out.append(
  72. {
  73. "cluster_id": c.get("cluster_id"),
  74. "headline": c.get("headline"),
  75. "summary": c.get("summary"),
  76. "entities": c.get("entities", []),
  77. "sentiment": c.get("sentiment", "neutral"),
  78. "importance": c.get("importance", 0.0),
  79. "sources": c.get("sources", []),
  80. "timestamp": c.get("timestamp"),
  81. }
  82. )
  83. return out
  84. @mcp.tool(description="What's happening with X? Filter latest clusters by extracted entity substring (case-insensitive).")
  85. async def get_events_for_entity(entity: str, limit: int = 10):
  86. limit = max(1, min(int(limit), 30))
  87. query = normalize_query(entity).strip().lower()
  88. if not query:
  89. return []
  90. resolved = resolve_entity_via_trends(query)
  91. query_terms = {
  92. query,
  93. str(resolved.get("normalized") or "").strip().lower(),
  94. str(resolved.get("canonical_label") or "").strip().lower(),
  95. str(resolved.get("mid") or "").strip().lower(),
  96. }
  97. query_terms = {q for q in query_terms if q}
  98. # Cache-first: search recent clusters across all topics.
  99. store = SQLiteClusterStore(DB_PATH)
  100. def _match_clusters(clusters: list[dict]) -> list[dict]:
  101. hits: list[dict] = []
  102. for c in clusters:
  103. haystack = _cluster_entity_haystack(c)
  104. if any(any(term in item for item in haystack) for term in query_terms):
  105. hits.append(c)
  106. if len(hits) >= limit:
  107. break
  108. return hits
  109. clusters = store.get_latest_clusters_all_topics(ttl_hours=CLUSTERS_TTL_HOURS, limit=limit * 5)
  110. hits = _match_clusters(clusters)
  111. # If the recent slice misses, broaden the search window before giving up.
  112. if not hits:
  113. clusters = store.get_latest_clusters_all_topics(ttl_hours=24 * 7, limit=500)
  114. hits = _match_clusters(clusters)
  115. # Compress to tool response shape.
  116. out = []
  117. for c in hits:
  118. out.append(
  119. {
  120. "cluster_id": c.get("cluster_id"),
  121. "headline": c.get("headline"),
  122. "summary": c.get("summary"),
  123. "entities": c.get("entities", []),
  124. "sentiment": c.get("sentiment", "neutral"),
  125. "importance": c.get("importance", 0.0),
  126. "sources": c.get("sources", []),
  127. "timestamp": c.get("timestamp"),
  128. }
  129. )
  130. return out
  131. @mcp.tool(description="Explain an event clearly by cluster_id (Groq summary).")
  132. async def get_event_summary(event_id: str):
  133. store = SQLiteClusterStore(DB_PATH)
  134. # Summary cache: reuse if present within TTL.
  135. cached_summary = store.get_cluster_summary(
  136. cluster_id=event_id,
  137. ttl_hours=CLUSTERS_TTL_HOURS,
  138. )
  139. if cached_summary:
  140. return {
  141. "event_id": event_id,
  142. "headline": cached_summary.get("headline"),
  143. "mergedSummary": cached_summary.get("mergedSummary"),
  144. "keyFacts": cached_summary.get("keyFacts", []),
  145. "sources": cached_summary.get("sources", []),
  146. }
  147. cluster = store.get_cluster_by_id(event_id)
  148. if not cluster:
  149. return {
  150. "event_id": event_id,
  151. "error": "NOT_FOUND",
  152. }
  153. summary = await summarize_cluster_groq(cluster)
  154. store.upsert_cluster_summary(event_id, summary)
  155. return {
  156. "event_id": event_id,
  157. "headline": summary.get("headline"),
  158. "mergedSummary": summary.get("mergedSummary"),
  159. "keyFacts": summary.get("keyFacts", []),
  160. "sources": summary.get("sources", []),
  161. }
  162. @mcp.tool(description="Detect emerging topics/entities from recent cached news clusters.")
  163. async def detect_emerging_topics(limit: int = 10):
  164. limit = max(1, min(int(limit), 20))
  165. store = SQLiteClusterStore(DB_PATH)
  166. clusters = store.get_latest_clusters_all_topics(ttl_hours=CLUSTERS_TTL_HOURS, limit=200)
  167. from collections import Counter
  168. import re
  169. entity_counts = Counter()
  170. phrase_counts = Counter()
  171. topic_counts = Counter()
  172. for c in clusters:
  173. topic_counts[c.get("topic", "other")] += 1
  174. for ent in c.get("entities", []) or []:
  175. key = str(ent).strip().lower()
  176. if key:
  177. entity_counts[key] += 1
  178. text = f"{c.get('headline','')} {c.get('summary','')}"
  179. words = [w for w in re.findall(r"[A-Za-z][A-Za-z0-9\-]{2,}", text.lower())]
  180. for i in range(len(words) - 1):
  181. phrase = f"{words[i]} {words[i+1]}"
  182. if len(phrase) > 6:
  183. phrase_counts[phrase] += 1
  184. emerging = []
  185. for ent, count in entity_counts.most_common(limit):
  186. emerging.append({
  187. "topic": ent,
  188. "trend_score": min(0.99, round(0.25 + 0.15 * count, 2)),
  189. "related_entities": [ent],
  190. "signal_type": "entity",
  191. "count": count,
  192. })
  193. for phrase, count in phrase_counts.most_common(limit * 2):
  194. if any(item["topic"] == phrase for item in emerging):
  195. continue
  196. emerging.append({
  197. "topic": phrase.title(),
  198. "trend_score": min(0.99, round(0.20 + 0.10 * count, 2)),
  199. "related_entities": [],
  200. "signal_type": "phrase",
  201. "count": count,
  202. })
  203. if len(emerging) >= limit:
  204. break
  205. return emerging[:limit]
  206. @mcp.tool(description="What's the overall sentiment around an entity within a timeframe?")
  207. async def get_news_sentiment(entity: str, timeframe: str = "24h"):
  208. store = SQLiteClusterStore(DB_PATH)
  209. ent = normalize_query(entity).strip().lower()
  210. resolved = resolve_entity_via_trends(ent)
  211. query_terms = {
  212. ent,
  213. str(resolved.get("normalized") or "").strip().lower(),
  214. str(resolved.get("canonical_label") or "").strip().lower(),
  215. str(resolved.get("mid") or "").strip().lower(),
  216. }
  217. query_terms = {q for q in query_terms if q}
  218. if not ent:
  219. return {
  220. "entity": entity,
  221. "sentiment": "neutral",
  222. "score": 0.0,
  223. "cluster_count": 0,
  224. }
  225. # timeframe: accept '24h' or '24'
  226. tf = str(timeframe).strip().lower()
  227. try:
  228. hours = int(tf[:-1]) if tf.endswith("h") else int(tf)
  229. except Exception:
  230. hours = 24
  231. hours = max(1, min(int(hours), 168))
  232. clusters = store.get_latest_clusters_all_topics(ttl_hours=hours, limit=500)
  233. matched = []
  234. for c in clusters:
  235. haystack = _cluster_entity_haystack(c)
  236. if any(any(term in item for item in haystack) for term in query_terms):
  237. matched.append(c)
  238. if not matched:
  239. return {
  240. "entity": entity,
  241. "sentiment": "neutral",
  242. "score": 0.0,
  243. "cluster_count": 0,
  244. }
  245. scores = []
  246. for c in matched:
  247. s = c.get("sentimentScore")
  248. if s is not None:
  249. try:
  250. scores.append(float(s))
  251. except Exception:
  252. pass
  253. avg_score = sum(scores) / len(scores) if scores else 0.0
  254. # Keep the label aligned with the numeric score.
  255. # Small magnitudes are treated as neutral to avoid noisy label flips.
  256. if avg_score >= 0.15:
  257. sentiment = "positive"
  258. elif avg_score <= -0.15:
  259. sentiment = "negative"
  260. else:
  261. sentiment = "neutral"
  262. return {
  263. "entity": entity,
  264. "sentiment": sentiment,
  265. "score": round(avg_score, 3),
  266. "cluster_count": len(matched),
  267. }
  268. app = FastAPI(title="News MCP Server")
  269. logger = logging.getLogger("news_mcp.startup")
  270. app.mount("/mcp", mcp.sse_app())
  271. _background_task_started = False
  272. @app.on_event("startup")
  273. async def _start_background_refresh():
  274. global _background_task_started
  275. if _background_task_started:
  276. return
  277. if not NEWS_BACKGROUND_REFRESH_ENABLED:
  278. return
  279. _background_task_started = True
  280. logger.info("news-mcp llm config: %s", active_llm_config())
  281. async def _loop():
  282. if not NEWS_BACKGROUND_REFRESH_ON_START:
  283. await asyncio.sleep(float(NEWS_REFRESH_INTERVAL_SECONDS))
  284. while True:
  285. try:
  286. # Refresh all topics by passing topic=None
  287. await refresh_clusters(topic=None, limit=200)
  288. except Exception:
  289. # Avoid crashing the server on network errors.
  290. pass
  291. await asyncio.sleep(float(NEWS_REFRESH_INTERVAL_SECONDS))
  292. import asyncio
  293. asyncio.create_task(_loop())
  294. @app.get("/")
  295. def root():
  296. return {
  297. "status": "ok",
  298. "transport": "fastmcp+sse",
  299. "mount": "/mcp",
  300. "tools": ["get_latest_events", "get_events_for_entity", "get_event_summary", "detect_emerging_topics"],
  301. "refresh": {
  302. "enabled": NEWS_BACKGROUND_REFRESH_ENABLED,
  303. "interval_seconds": NEWS_REFRESH_INTERVAL_SECONDS,
  304. },
  305. }
  306. @app.get("/health")
  307. def health():
  308. store = SQLiteClusterStore(DB_PATH)
  309. return {
  310. "status": "ok",
  311. "ttl_hours": CLUSTERS_TTL_HOURS,
  312. "db": str(DB_PATH),
  313. "refresh": store.get_feed_state("breakingthenews"),
  314. }