mcp_server_fastmcp.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241
  1. from __future__ import annotations
  2. from fastapi import FastAPI
  3. from mcp.server.fastmcp import FastMCP
  4. from mcp.server.transport_security import TransportSecuritySettings
  5. from news_mcp.config import CLUSTERS_TTL_HOURS, DEFAULT_TOPICS, DB_PATH
  6. from news_mcp.config import NEWS_REFRESH_INTERVAL_SECONDS, NEWS_BACKGROUND_REFRESH_ENABLED, NEWS_BACKGROUND_REFRESH_ON_START
  7. from news_mcp.jobs.poller import refresh_clusters
  8. from news_mcp.storage.sqlite_store import SQLiteClusterStore
  9. from news_mcp.enrichment.groq_enrich import summarize_cluster_groq
  10. mcp = FastMCP(
  11. "news-mcp",
  12. transport_security=TransportSecuritySettings(enable_dns_rebinding_protection=False),
  13. )
  14. @mcp.tool(description="What is happening right now? Return the latest deduplicated news clusters for a topic.")
  15. async def get_latest_events(topic: str = "crypto", limit: int = 5):
  16. limit = max(1, min(int(limit), 20))
  17. # In v1, `topic` is a coarse category. If the caller passes an entity name
  18. # (e.g. "trump"/"iran"), gracefully fall back to `other`.
  19. topic_norm = str(topic).strip().lower()
  20. allowed = {t.lower() for t in DEFAULT_TOPICS}
  21. if topic_norm not in allowed:
  22. topic_norm = "other"
  23. store = SQLiteClusterStore(DB_PATH)
  24. # Cache-first: only refresh if we currently have no fresh clusters for this topic.
  25. clusters = store.get_latest_clusters(topic=topic_norm, ttl_hours=CLUSTERS_TTL_HOURS, limit=limit)
  26. if not clusters:
  27. await refresh_clusters(topic=topic_norm, limit=200)
  28. clusters = store.get_latest_clusters(topic=topic_norm, ttl_hours=CLUSTERS_TTL_HOURS, limit=limit)
  29. # Ensure the response is compact and agent-friendly.
  30. clusters_sorted = sorted(clusters, key=lambda x: float(x.get("importance", 0.0)), reverse=True)
  31. out = []
  32. for c in clusters_sorted:
  33. out.append(
  34. {
  35. "cluster_id": c.get("cluster_id"),
  36. "headline": c.get("headline"),
  37. "summary": c.get("summary"),
  38. "entities": c.get("entities", []),
  39. "sentiment": c.get("sentiment", "neutral"),
  40. "importance": c.get("importance", 0.0),
  41. "sources": c.get("sources", []),
  42. "timestamp": c.get("timestamp"),
  43. }
  44. )
  45. return out
  46. @mcp.tool(description="What's happening with X? Filter latest clusters by extracted entity substring (case-insensitive).")
  47. async def get_events_for_entity(entity: str, limit: int = 10):
  48. limit = max(1, min(int(limit), 30))
  49. query = str(entity).strip().lower()
  50. if not query:
  51. return []
  52. # Cache-first: search recent clusters across all topics.
  53. store = SQLiteClusterStore(DB_PATH)
  54. clusters = store.get_latest_clusters_all_topics(ttl_hours=CLUSTERS_TTL_HOURS, limit=limit * 5)
  55. hits = []
  56. for c in clusters:
  57. ents = c.get("entities") or []
  58. if any(query in str(e).lower() for e in ents):
  59. hits.append(c)
  60. if len(hits) >= limit:
  61. break
  62. # Compress to tool response shape.
  63. out = []
  64. for c in hits:
  65. out.append(
  66. {
  67. "cluster_id": c.get("cluster_id"),
  68. "headline": c.get("headline"),
  69. "summary": c.get("summary"),
  70. "entities": c.get("entities", []),
  71. "sentiment": c.get("sentiment", "neutral"),
  72. "importance": c.get("importance", 0.0),
  73. "sources": c.get("sources", []),
  74. "timestamp": c.get("timestamp"),
  75. }
  76. )
  77. return out
  78. @mcp.tool(description="Explain an event clearly by cluster_id (Groq summary).")
  79. async def get_event_summary(event_id: str):
  80. store = SQLiteClusterStore(DB_PATH)
  81. # Summary cache: reuse if present within TTL.
  82. cached_summary = store.get_cluster_summary(
  83. cluster_id=event_id,
  84. ttl_hours=CLUSTERS_TTL_HOURS,
  85. )
  86. if cached_summary:
  87. return {
  88. "event_id": event_id,
  89. "headline": cached_summary.get("headline"),
  90. "mergedSummary": cached_summary.get("mergedSummary"),
  91. "keyFacts": cached_summary.get("keyFacts", []),
  92. "sources": cached_summary.get("sources", []),
  93. }
  94. cluster = store.get_cluster_by_id(event_id)
  95. if not cluster:
  96. return {
  97. "event_id": event_id,
  98. "error": "NOT_FOUND",
  99. }
  100. summary = await summarize_cluster_groq(cluster)
  101. store.upsert_cluster_summary(event_id, summary)
  102. return {
  103. "event_id": event_id,
  104. "headline": summary.get("headline"),
  105. "mergedSummary": summary.get("mergedSummary"),
  106. "keyFacts": summary.get("keyFacts", []),
  107. "sources": summary.get("sources", []),
  108. }
  109. @mcp.tool(description="Detect emerging topics/entities from recent cached news clusters.")
  110. async def detect_emerging_topics(limit: int = 10):
  111. limit = max(1, min(int(limit), 20))
  112. store = SQLiteClusterStore(DB_PATH)
  113. clusters = store.get_latest_clusters_all_topics(ttl_hours=CLUSTERS_TTL_HOURS, limit=200)
  114. from collections import Counter
  115. import re
  116. entity_counts = Counter()
  117. phrase_counts = Counter()
  118. topic_counts = Counter()
  119. for c in clusters:
  120. topic_counts[c.get("topic", "other")] += 1
  121. for ent in c.get("entities", []) or []:
  122. key = str(ent).strip().lower()
  123. if key:
  124. entity_counts[key] += 1
  125. text = f"{c.get('headline','')} {c.get('summary','')}"
  126. words = [w for w in re.findall(r"[A-Za-z][A-Za-z0-9\-]{2,}", text.lower())]
  127. for i in range(len(words) - 1):
  128. phrase = f"{words[i]} {words[i+1]}"
  129. if len(phrase) > 6:
  130. phrase_counts[phrase] += 1
  131. emerging = []
  132. for ent, count in entity_counts.most_common(limit):
  133. emerging.append({
  134. "topic": ent,
  135. "trend_score": min(0.99, round(0.25 + 0.15 * count, 2)),
  136. "related_entities": [ent],
  137. "signal_type": "entity",
  138. "count": count,
  139. })
  140. for phrase, count in phrase_counts.most_common(limit * 2):
  141. if any(item["topic"] == phrase for item in emerging):
  142. continue
  143. emerging.append({
  144. "topic": phrase.title(),
  145. "trend_score": min(0.99, round(0.20 + 0.10 * count, 2)),
  146. "related_entities": [],
  147. "signal_type": "phrase",
  148. "count": count,
  149. })
  150. if len(emerging) >= limit:
  151. break
  152. return emerging[:limit]
  153. app = FastAPI(title="News MCP Server")
  154. app.mount("/mcp", mcp.sse_app())
  155. _background_task_started = False
  156. @app.on_event("startup")
  157. async def _start_background_refresh():
  158. global _background_task_started
  159. if _background_task_started:
  160. return
  161. if not NEWS_BACKGROUND_REFRESH_ENABLED:
  162. return
  163. _background_task_started = True
  164. async def _loop():
  165. if not NEWS_BACKGROUND_REFRESH_ON_START:
  166. await asyncio.sleep(float(NEWS_REFRESH_INTERVAL_SECONDS))
  167. while True:
  168. try:
  169. # Refresh all topics by passing topic=None
  170. await refresh_clusters(topic=None, limit=200)
  171. except Exception:
  172. # Avoid crashing the server on network errors.
  173. pass
  174. await asyncio.sleep(float(NEWS_REFRESH_INTERVAL_SECONDS))
  175. import asyncio
  176. asyncio.create_task(_loop())
  177. @app.get("/")
  178. def root():
  179. return {
  180. "status": "ok",
  181. "transport": "fastmcp+sse",
  182. "mount": "/mcp",
  183. "tools": ["get_latest_events", "get_events_for_entity", "get_event_summary", "detect_emerging_topics"],
  184. "refresh": {
  185. "enabled": NEWS_BACKGROUND_REFRESH_ENABLED,
  186. "interval_seconds": NEWS_REFRESH_INTERVAL_SECONDS,
  187. },
  188. }
  189. @app.get("/health")
  190. def health():
  191. store = SQLiteClusterStore(DB_PATH)
  192. return {
  193. "status": "ok",
  194. "ttl_hours": CLUSTERS_TTL_HOURS,
  195. "db": str(DB_PATH),
  196. "refresh": store.get_feed_state("breakingthenews"),
  197. }