mcp_server_fastmcp.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292
  1. from __future__ import annotations
  2. from fastapi import FastAPI
  3. from mcp.server.fastmcp import FastMCP
  4. from mcp.server.transport_security import TransportSecuritySettings
  5. from trends_mcp.aliases import normalize_entity
  6. from trends_mcp.cache import cache_stats, get_cache, set_cache
  7. from trends_mcp.ledger import entity_history, prune_snapshots, read_recent, store_snapshot, summarize
  8. from trends_mcp.providers.google_trends import GoogleTrendsError, GoogleTrendsProvider
  9. mcp = FastMCP(
  10. "trends-mcp",
  11. transport_security=TransportSecuritySettings(enable_dns_rebinding_protection=False),
  12. )
  13. provider = GoogleTrendsProvider()
  14. CACHE_TTL_SECONDS = 30 * 60
  15. def _trend_label(series: list[int]) -> str:
  16. if not series:
  17. return "flat"
  18. first, last = series[0], series[-1]
  19. if last - first >= 10:
  20. return "rising"
  21. if first - last >= 10:
  22. return "falling"
  23. return "flat"
  24. @mcp.tool(description="Experimental: show attention trend for a keyword or entity over time.")
  25. async def get_interest_over_time(keyword: str, timeframe: str = "7d"):
  26. keyword_norm = normalize_entity(keyword)
  27. cache_key = f"interest:{keyword_norm}:{timeframe}"
  28. cached = get_cache(cache_key)
  29. if cached:
  30. return cached
  31. try:
  32. result = provider.interest_over_time(keyword, timeframe)
  33. except GoogleTrendsError as exc:
  34. return {
  35. "keyword": keyword,
  36. "normalized_keyword": keyword_norm,
  37. "timeframe": timeframe,
  38. "error": str(exc),
  39. }
  40. payload = {
  41. "keyword": keyword,
  42. "normalized_keyword": keyword_norm,
  43. "timeframe": timeframe,
  44. "series": result.series,
  45. "trend": _trend_label(result.series),
  46. "fetched_at": result.fetched_at,
  47. }
  48. set_cache(cache_key, payload, CACHE_TTL_SECONDS)
  49. return payload
  50. @mcp.tool(description="Resolve an entity to Knowledge Graph MID candidates and a best canonical label.")
  51. async def resolve_entity(keyword: str):
  52. cache_key = f"resolve:{normalize_entity(keyword)}"
  53. cached = get_cache(cache_key)
  54. if cached:
  55. return cached
  56. suggestions = provider.suggestions(keyword)
  57. best = suggestions[0] if suggestions else None
  58. payload = {
  59. "keyword": keyword,
  60. "canonical_label": best.get("title") if best else normalize_entity(keyword),
  61. "mid": best.get("mid") if best else None,
  62. "type": best.get("type") if best else None,
  63. "candidates": suggestions,
  64. }
  65. set_cache(cache_key, payload, 24 * 60 * 60)
  66. store_snapshot(
  67. tool="resolve_entity",
  68. keyword=keyword,
  69. normalized_keyword=normalize_entity(keyword),
  70. mid=payload["mid"],
  71. canonical_label=payload["canonical_label"],
  72. payload=payload,
  73. )
  74. return payload
  75. @mcp.tool(description="Get related search queries for an entity.")
  76. async def get_related_queries(keyword: str):
  77. cache_key = f"related:{normalize_entity(keyword)}"
  78. cached = get_cache(cache_key)
  79. if cached:
  80. return cached
  81. related = provider.related_queries(keyword)
  82. out = related.get(keyword) or related.get(normalize_entity(keyword)) or {}
  83. def _rows(df):
  84. if df is None:
  85. return []
  86. try:
  87. return df.reset_index().to_dict(orient="records")
  88. except Exception:
  89. return []
  90. payload = {
  91. "keyword": keyword,
  92. "top": _rows(out.get("top") if isinstance(out, dict) else None),
  93. "rising": _rows(out.get("rising") if isinstance(out, dict) else None),
  94. }
  95. set_cache(cache_key, payload, 24 * 60 * 60)
  96. store_snapshot(
  97. tool="get_related_queries",
  98. keyword=keyword,
  99. normalized_keyword=normalize_entity(keyword),
  100. mid=None,
  101. canonical_label=None,
  102. payload=payload,
  103. )
  104. return payload
  105. @mcp.tool(description="Get related topics for an entity.")
  106. async def get_related_topics(keyword: str):
  107. cache_key = f"topics:{normalize_entity(keyword)}"
  108. cached = get_cache(cache_key)
  109. if cached:
  110. return cached
  111. try:
  112. related = provider.related_topics(keyword)
  113. out = related.get(keyword) or related.get(normalize_entity(keyword)) or {}
  114. except GoogleTrendsError:
  115. # pytrends' related_topics is flaky; fall back to related_queries so the tool stays useful.
  116. related = provider.related_queries(keyword)
  117. out = related.get(keyword) or related.get(normalize_entity(keyword)) or {}
  118. def _rows(df):
  119. if df is None:
  120. return []
  121. try:
  122. return df.reset_index().to_dict(orient="records")
  123. except Exception:
  124. return []
  125. payload = {
  126. "keyword": keyword,
  127. "top": _rows(out.get("top") if isinstance(out, dict) else None),
  128. "rising": _rows(out.get("rising") if isinstance(out, dict) else None),
  129. }
  130. set_cache(cache_key, payload, 24 * 60 * 60)
  131. store_snapshot(
  132. tool="get_related_topics",
  133. keyword=keyword,
  134. normalized_keyword=normalize_entity(keyword),
  135. mid=None,
  136. canonical_label=None,
  137. payload=payload,
  138. )
  139. return payload
  140. @mcp.tool(description="Read the most recent ledger events.")
  141. async def get_ledger_recent(limit: int = 50):
  142. return read_recent(limit=max(1, min(int(limit), 200)))
  143. @mcp.tool(description="Summarize what the ledger is saying.")
  144. async def get_ledger_summary(limit: int = 500):
  145. return summarize(limit=max(1, min(int(limit), 2000)))
  146. @mcp.tool(description="Show the ledger history for one entity or MID.")
  147. async def get_entity_history(entity: str, limit: int = 500):
  148. return entity_history(entity, limit=max(1, min(int(limit), 2000)))
  149. @mcp.tool(description="Prune stored snapshots older than the configured retention window.")
  150. async def prune_history(retention_days: int = 30):
  151. deleted = prune_snapshots(retention_days=max(1, min(int(retention_days), 3650)))
  152. return {"deleted": deleted, "retention_days": retention_days}
  153. @mcp.tool(description="Compare attention between multiple keywords or entities.")
  154. async def compare_interest(keywords: list[str], timeframe: str = "7d"):
  155. if not keywords:
  156. return {"winner": None, "ratios": {}}
  157. normalized = [normalize_entity(k) for k in keywords]
  158. series_map = {}
  159. for original, keyword in zip(keywords, normalized):
  160. cache_key = f"interest:{keyword}:{timeframe}"
  161. cached = get_cache(cache_key)
  162. if cached:
  163. series = cached["series"]
  164. else:
  165. series = provider.interest_over_time(original, timeframe).series
  166. set_cache(
  167. cache_key,
  168. {
  169. "keyword": original,
  170. "normalized_keyword": keyword,
  171. "timeframe": timeframe,
  172. "series": series,
  173. "trend": _trend_label(series),
  174. },
  175. CACHE_TTL_SECONDS,
  176. )
  177. series_map[keyword] = series
  178. scores = {k: sum(v) for k, v in series_map.items()}
  179. winner = max(scores, key=scores.get)
  180. top_score = float(scores[winner]) or 1.0
  181. ratios = {k: round(v / top_score, 3) for k, v in scores.items()}
  182. return {"winner": winner, "ratios": ratios, "timeframe": timeframe}
  183. @mcp.tool(description="Get a compact attention score for a known entity.")
  184. async def get_attention_score(entity: str, timeframe: str = "24h"):
  185. normalized = normalize_entity(entity)
  186. try:
  187. series = provider.interest_over_time(entity, timeframe).series
  188. except GoogleTrendsError as exc:
  189. return {"entity": entity, "normalized_entity": normalized, "error": str(exc), "timeframe": timeframe}
  190. score = round(sum(series) / (len(series) * 100), 3)
  191. baseline = round(series[-1] / max(1, series[0]), 3) if series[0] else float(series[-1])
  192. return {
  193. "entity": entity,
  194. "normalized_entity": normalized,
  195. "score": score,
  196. "relative_to_baseline": baseline,
  197. "timeframe": timeframe,
  198. }
  199. app = FastAPI(title="Trends MCP Server")
  200. app.mount("/mcp", mcp.sse_app())
  201. @app.get("/")
  202. def root():
  203. return {
  204. "status": "ok",
  205. "transport": "fastmcp+sse",
  206. "mount": "/mcp",
  207. "tools": ["resolve_entity", "get_related_queries", "get_related_topics", "get_ledger_recent", "get_ledger_summary", "get_entity_history", "prune_history", "get_interest_over_time", "compare_interest", "get_attention_score"],
  208. }
  209. @mcp.tool(description="Debug Google Trends connectivity, suggestions, and timeframe handling.")
  210. async def debug_google_trends(keyword: str, timeframe: str = "7d"):
  211. keyword_norm = normalize_entity(keyword)
  212. try:
  213. suggestions = provider.suggestions(keyword)
  214. except GoogleTrendsError as exc:
  215. suggestions = {"error": str(exc)}
  216. try:
  217. result = provider.interest_over_time(keyword, timeframe)
  218. payload = {
  219. "keyword": keyword,
  220. "normalized_keyword": keyword_norm,
  221. "timeframe": timeframe,
  222. "suggestions": suggestions,
  223. "series": result.series,
  224. "trend": _trend_label(result.series),
  225. "fetched_at": result.fetched_at,
  226. }
  227. except GoogleTrendsError as exc:
  228. payload = {
  229. "keyword": keyword,
  230. "normalized_keyword": keyword_norm,
  231. "timeframe": timeframe,
  232. "suggestions": suggestions,
  233. "error": str(exc),
  234. }
  235. return payload
  236. @app.get("/health")
  237. def health():
  238. return {"status": "ok", "service": "trends-mcp", "cache": cache_stats()}
  239. def main():
  240. import uvicorn
  241. uvicorn.run("trends_mcp.mcp_server_fastmcp:app", host="0.0.0.0", port=8507, reload=False)
  242. if __name__ == "__main__":
  243. main()