mcp_server_fastmcp.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316
  1. from __future__ import annotations
  2. from fastapi import FastAPI
  3. from mcp.server.fastmcp import FastMCP
  4. from mcp.server.transport_security import TransportSecuritySettings
  5. from news_mcp.config import CLUSTERS_TTL_HOURS, DEFAULT_TOPICS, DB_PATH
  6. from news_mcp.config import NEWS_REFRESH_INTERVAL_SECONDS, NEWS_BACKGROUND_REFRESH_ENABLED, NEWS_BACKGROUND_REFRESH_ON_START
  7. from news_mcp.jobs.poller import refresh_clusters
  8. from news_mcp.storage.sqlite_store import SQLiteClusterStore
  9. from news_mcp.enrichment.llm_enrich import summarize_cluster_groq
  10. from news_mcp.llm import active_llm_config
  11. from collections import Counter
  12. import logging
  13. mcp = FastMCP(
  14. "news-mcp",
  15. transport_security=TransportSecuritySettings(enable_dns_rebinding_protection=False),
  16. )
  17. @mcp.tool(description="What is happening right now? Return the latest deduplicated news clusters for a topic.")
  18. async def get_latest_events(topic: str = "crypto", limit: int = 5):
  19. limit = max(1, min(int(limit), 20))
  20. # In v1, `topic` is a coarse category. If the caller passes an entity name
  21. # (e.g. "trump"/"iran"), gracefully fall back to `other`.
  22. topic_norm = str(topic).strip().lower()
  23. allowed = {t.lower() for t in DEFAULT_TOPICS}
  24. if topic_norm not in allowed:
  25. topic_norm = "other"
  26. store = SQLiteClusterStore(DB_PATH)
  27. # Cache-first: only refresh if we currently have no fresh clusters for this topic.
  28. clusters = store.get_latest_clusters(topic=topic_norm, ttl_hours=CLUSTERS_TTL_HOURS, limit=limit)
  29. if not clusters:
  30. await refresh_clusters(topic=topic_norm, limit=200)
  31. clusters = store.get_latest_clusters(topic=topic_norm, ttl_hours=CLUSTERS_TTL_HOURS, limit=limit)
  32. # Ensure the response is compact and agent-friendly.
  33. clusters_sorted = sorted(clusters, key=lambda x: float(x.get("importance", 0.0)), reverse=True)
  34. out = []
  35. for c in clusters_sorted:
  36. out.append(
  37. {
  38. "cluster_id": c.get("cluster_id"),
  39. "headline": c.get("headline"),
  40. "summary": c.get("summary"),
  41. "entities": c.get("entities", []),
  42. "sentiment": c.get("sentiment", "neutral"),
  43. "importance": c.get("importance", 0.0),
  44. "sources": c.get("sources", []),
  45. "timestamp": c.get("timestamp"),
  46. }
  47. )
  48. return out
  49. @mcp.tool(description="What's happening with X? Filter latest clusters by extracted entity substring (case-insensitive).")
  50. async def get_events_for_entity(entity: str, limit: int = 10):
  51. limit = max(1, min(int(limit), 30))
  52. query = str(entity).strip().lower()
  53. if not query:
  54. return []
  55. # Cache-first: search recent clusters across all topics.
  56. store = SQLiteClusterStore(DB_PATH)
  57. clusters = store.get_latest_clusters_all_topics(ttl_hours=CLUSTERS_TTL_HOURS, limit=limit * 5)
  58. hits = []
  59. for c in clusters:
  60. ents = c.get("entities") or []
  61. if any(query in str(e).lower() for e in ents):
  62. hits.append(c)
  63. if len(hits) >= limit:
  64. break
  65. # Compress to tool response shape.
  66. out = []
  67. for c in hits:
  68. out.append(
  69. {
  70. "cluster_id": c.get("cluster_id"),
  71. "headline": c.get("headline"),
  72. "summary": c.get("summary"),
  73. "entities": c.get("entities", []),
  74. "sentiment": c.get("sentiment", "neutral"),
  75. "importance": c.get("importance", 0.0),
  76. "sources": c.get("sources", []),
  77. "timestamp": c.get("timestamp"),
  78. }
  79. )
  80. return out
  81. @mcp.tool(description="Explain an event clearly by cluster_id (Groq summary).")
  82. async def get_event_summary(event_id: str):
  83. store = SQLiteClusterStore(DB_PATH)
  84. # Summary cache: reuse if present within TTL.
  85. cached_summary = store.get_cluster_summary(
  86. cluster_id=event_id,
  87. ttl_hours=CLUSTERS_TTL_HOURS,
  88. )
  89. if cached_summary:
  90. return {
  91. "event_id": event_id,
  92. "headline": cached_summary.get("headline"),
  93. "mergedSummary": cached_summary.get("mergedSummary"),
  94. "keyFacts": cached_summary.get("keyFacts", []),
  95. "sources": cached_summary.get("sources", []),
  96. }
  97. cluster = store.get_cluster_by_id(event_id)
  98. if not cluster:
  99. return {
  100. "event_id": event_id,
  101. "error": "NOT_FOUND",
  102. }
  103. summary = await summarize_cluster_groq(cluster)
  104. store.upsert_cluster_summary(event_id, summary)
  105. return {
  106. "event_id": event_id,
  107. "headline": summary.get("headline"),
  108. "mergedSummary": summary.get("mergedSummary"),
  109. "keyFacts": summary.get("keyFacts", []),
  110. "sources": summary.get("sources", []),
  111. }
  112. @mcp.tool(description="Detect emerging topics/entities from recent cached news clusters.")
  113. async def detect_emerging_topics(limit: int = 10):
  114. limit = max(1, min(int(limit), 20))
  115. store = SQLiteClusterStore(DB_PATH)
  116. clusters = store.get_latest_clusters_all_topics(ttl_hours=CLUSTERS_TTL_HOURS, limit=200)
  117. from collections import Counter
  118. import re
  119. entity_counts = Counter()
  120. phrase_counts = Counter()
  121. topic_counts = Counter()
  122. for c in clusters:
  123. topic_counts[c.get("topic", "other")] += 1
  124. for ent in c.get("entities", []) or []:
  125. key = str(ent).strip().lower()
  126. if key:
  127. entity_counts[key] += 1
  128. text = f"{c.get('headline','')} {c.get('summary','')}"
  129. words = [w for w in re.findall(r"[A-Za-z][A-Za-z0-9\-]{2,}", text.lower())]
  130. for i in range(len(words) - 1):
  131. phrase = f"{words[i]} {words[i+1]}"
  132. if len(phrase) > 6:
  133. phrase_counts[phrase] += 1
  134. emerging = []
  135. for ent, count in entity_counts.most_common(limit):
  136. emerging.append({
  137. "topic": ent,
  138. "trend_score": min(0.99, round(0.25 + 0.15 * count, 2)),
  139. "related_entities": [ent],
  140. "signal_type": "entity",
  141. "count": count,
  142. })
  143. for phrase, count in phrase_counts.most_common(limit * 2):
  144. if any(item["topic"] == phrase for item in emerging):
  145. continue
  146. emerging.append({
  147. "topic": phrase.title(),
  148. "trend_score": min(0.99, round(0.20 + 0.10 * count, 2)),
  149. "related_entities": [],
  150. "signal_type": "phrase",
  151. "count": count,
  152. })
  153. if len(emerging) >= limit:
  154. break
  155. return emerging[:limit]
  156. @mcp.tool(description="What's the overall sentiment around an entity within a timeframe?")
  157. async def get_news_sentiment(entity: str, timeframe: str = "24h"):
  158. store = SQLiteClusterStore(DB_PATH)
  159. ent = str(entity).strip().lower()
  160. if not ent:
  161. return {
  162. "entity": entity,
  163. "sentiment": "neutral",
  164. "score": 0.0,
  165. "cluster_count": 0,
  166. }
  167. # timeframe: accept '24h' or '24'
  168. tf = str(timeframe).strip().lower()
  169. try:
  170. hours = int(tf[:-1]) if tf.endswith("h") else int(tf)
  171. except Exception:
  172. hours = 24
  173. hours = max(1, min(int(hours), 168))
  174. clusters = store.get_latest_clusters_all_topics(ttl_hours=hours, limit=500)
  175. matched = []
  176. for c in clusters:
  177. ents = c.get("entities") or []
  178. if any(ent in str(e).lower() for e in ents):
  179. matched.append(c)
  180. if not matched:
  181. return {
  182. "entity": entity,
  183. "sentiment": "neutral",
  184. "score": 0.0,
  185. "cluster_count": 0,
  186. }
  187. scores = []
  188. labels = []
  189. for c in matched:
  190. s = c.get("sentimentScore")
  191. if s is not None:
  192. try:
  193. scores.append(float(s))
  194. except Exception:
  195. pass
  196. lbl = c.get("sentiment")
  197. if lbl:
  198. labels.append(str(lbl).lower())
  199. avg_score = sum(scores) / len(scores) if scores else 0.0
  200. # Majority vote on sentiment label, fall back to sign of avg score.
  201. if labels:
  202. majority = Counter(labels).most_common(1)[0][0]
  203. if majority in {"positive", "negative", "neutral"}:
  204. sentiment = majority
  205. else:
  206. sentiment = "positive" if avg_score > 0 else "negative" if avg_score < 0 else "neutral"
  207. else:
  208. sentiment = "positive" if avg_score > 0 else "negative" if avg_score < 0 else "neutral"
  209. return {
  210. "entity": entity,
  211. "sentiment": sentiment,
  212. "score": round(avg_score, 3),
  213. "cluster_count": len(matched),
  214. }
  215. app = FastAPI(title="News MCP Server")
  216. logger = logging.getLogger("news_mcp.startup")
  217. app.mount("/mcp", mcp.sse_app())
  218. _background_task_started = False
  219. @app.on_event("startup")
  220. async def _start_background_refresh():
  221. global _background_task_started
  222. if _background_task_started:
  223. return
  224. if not NEWS_BACKGROUND_REFRESH_ENABLED:
  225. return
  226. _background_task_started = True
  227. logger.info("news-mcp llm config: %s", active_llm_config())
  228. async def _loop():
  229. if not NEWS_BACKGROUND_REFRESH_ON_START:
  230. await asyncio.sleep(float(NEWS_REFRESH_INTERVAL_SECONDS))
  231. while True:
  232. try:
  233. # Refresh all topics by passing topic=None
  234. await refresh_clusters(topic=None, limit=200)
  235. except Exception:
  236. # Avoid crashing the server on network errors.
  237. pass
  238. await asyncio.sleep(float(NEWS_REFRESH_INTERVAL_SECONDS))
  239. import asyncio
  240. asyncio.create_task(_loop())
  241. @app.get("/")
  242. def root():
  243. return {
  244. "status": "ok",
  245. "transport": "fastmcp+sse",
  246. "mount": "/mcp",
  247. "tools": ["get_latest_events", "get_events_for_entity", "get_event_summary", "detect_emerging_topics"],
  248. "refresh": {
  249. "enabled": NEWS_BACKGROUND_REFRESH_ENABLED,
  250. "interval_seconds": NEWS_REFRESH_INTERVAL_SECONDS,
  251. },
  252. }
  253. @app.get("/health")
  254. def health():
  255. store = SQLiteClusterStore(DB_PATH)
  256. return {
  257. "status": "ok",
  258. "ttl_hours": CLUSTERS_TTL_HOURS,
  259. "db": str(DB_PATH),
  260. "refresh": store.get_feed_state("breakingthenews"),
  261. }