mcp_server_fastmcp.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294
  1. from __future__ import annotations
  2. from fastapi import FastAPI
  3. from mcp.server.fastmcp import FastMCP
  4. from mcp.server.transport_security import TransportSecuritySettings
  5. from trends_mcp.aliases import normalize_entity
  6. from trends_mcp.cache import cache_stats, get_cache, set_cache
  7. from trends_mcp.ledger import entity_history, prune_snapshots, read_recent, store_snapshot, summarize
  8. from trends_mcp.providers.google_trends import GoogleTrendsError, GoogleTrendsProvider
  9. mcp = FastMCP(
  10. "trends-mcp",
  11. transport_security=TransportSecuritySettings(enable_dns_rebinding_protection=False),
  12. )
  13. provider = GoogleTrendsProvider()
  14. CACHE_TTL_SECONDS = 30 * 60
  15. def _trend_label(series: list[int]) -> str:
  16. if not series:
  17. return "flat"
  18. first, last = series[0], series[-1]
  19. if last - first >= 10:
  20. return "rising"
  21. if first - last >= 10:
  22. return "falling"
  23. return "flat"
  24. @mcp.tool(description="Experimental: show attention trend for a keyword or entity over time.")
  25. async def get_interest_over_time(keyword: str, timeframe: str = "7d"):
  26. keyword_norm = normalize_entity(keyword)
  27. cache_key = f"interest:{keyword_norm}:{timeframe}"
  28. cached = get_cache(cache_key)
  29. if cached:
  30. return cached
  31. try:
  32. result = provider.interest_over_time(keyword, timeframe)
  33. except GoogleTrendsError as exc:
  34. return {
  35. "keyword": keyword,
  36. "normalized_keyword": keyword_norm,
  37. "timeframe": timeframe,
  38. "error": str(exc),
  39. }
  40. payload = {
  41. "keyword": keyword,
  42. "normalized_keyword": keyword_norm,
  43. "timeframe": timeframe,
  44. "series": result.series,
  45. "trend": _trend_label(result.series),
  46. "fetched_at": result.fetched_at,
  47. }
  48. set_cache(cache_key, payload, CACHE_TTL_SECONDS)
  49. return payload
  50. @mcp.tool(description="Resolve an entity to Knowledge Graph MID candidates and a best canonical label.")
  51. async def resolve_entity(keyword: str):
  52. cache_key = f"resolve:{normalize_entity(keyword)}"
  53. cached = get_cache(cache_key)
  54. if cached:
  55. return cached
  56. suggestions = provider.suggestions(keyword)
  57. best = suggestions[0] if suggestions else None
  58. payload = {
  59. "keyword": keyword,
  60. "canonical_label": best.get("title") if best else normalize_entity(keyword),
  61. "mid": best.get("mid") if best else None,
  62. "type": best.get("type") if best else None,
  63. "candidates": suggestions,
  64. }
  65. set_cache(cache_key, payload, 24 * 60 * 60)
  66. store_snapshot(
  67. tool="resolve_entity",
  68. keyword=keyword,
  69. normalized_keyword=normalize_entity(keyword),
  70. mid=payload["mid"],
  71. canonical_label=payload["canonical_label"],
  72. payload=payload,
  73. )
  74. return payload
  75. @mcp.tool(description="Get related search queries for an entity.")
  76. async def get_related_queries(keyword: str):
  77. cache_key = f"related:{normalize_entity(keyword)}"
  78. cached = get_cache(cache_key)
  79. if cached:
  80. return cached
  81. entity_info = await resolve_entity(keyword)
  82. related = provider.related_queries(keyword)
  83. out = related.get(keyword) or related.get(normalize_entity(keyword)) or {}
  84. def _rows(df):
  85. if df is None:
  86. return []
  87. try:
  88. return df.reset_index().to_dict(orient="records")
  89. except Exception:
  90. return []
  91. payload = {
  92. "keyword": keyword,
  93. "top": _rows(out.get("top") if isinstance(out, dict) else None),
  94. "rising": _rows(out.get("rising") if isinstance(out, dict) else None),
  95. }
  96. set_cache(cache_key, payload, 24 * 60 * 60)
  97. store_snapshot(
  98. tool="get_related_queries",
  99. keyword=keyword,
  100. normalized_keyword=normalize_entity(keyword),
  101. mid=entity_info.get("mid"),
  102. canonical_label=entity_info.get("canonical_label"),
  103. payload=payload,
  104. )
  105. return payload
  106. @mcp.tool(description="Get related topics for an entity.")
  107. async def get_related_topics(keyword: str):
  108. cache_key = f"topics:{normalize_entity(keyword)}"
  109. cached = get_cache(cache_key)
  110. if cached:
  111. return cached
  112. entity_info = await resolve_entity(keyword)
  113. try:
  114. related = provider.related_topics(keyword)
  115. out = related.get(keyword) or related.get(normalize_entity(keyword)) or {}
  116. except GoogleTrendsError:
  117. # pytrends' related_topics is flaky; fall back to related_queries so the tool stays useful.
  118. related = provider.related_queries(keyword)
  119. out = related.get(keyword) or related.get(normalize_entity(keyword)) or {}
  120. def _rows(df):
  121. if df is None:
  122. return []
  123. try:
  124. return df.reset_index().to_dict(orient="records")
  125. except Exception:
  126. return []
  127. payload = {
  128. "keyword": keyword,
  129. "top": _rows(out.get("top") if isinstance(out, dict) else None),
  130. "rising": _rows(out.get("rising") if isinstance(out, dict) else None),
  131. }
  132. set_cache(cache_key, payload, 24 * 60 * 60)
  133. store_snapshot(
  134. tool="get_related_topics",
  135. keyword=keyword,
  136. normalized_keyword=normalize_entity(keyword),
  137. mid=entity_info.get("mid"),
  138. canonical_label=entity_info.get("canonical_label"),
  139. payload=payload,
  140. )
  141. return payload
  142. @mcp.tool(description="Read the most recent ledger events.")
  143. async def get_ledger_recent(limit: int = 50):
  144. return read_recent(limit=max(1, min(int(limit), 200)))
  145. @mcp.tool(description="Summarize what the ledger is saying.")
  146. async def get_ledger_summary(limit: int = 500):
  147. return summarize(limit=max(1, min(int(limit), 2000)))
  148. @mcp.tool(description="Show the ledger history for one entity or MID.")
  149. async def get_entity_history(entity: str, limit: int = 500):
  150. return entity_history(entity, limit=max(1, min(int(limit), 2000)))
  151. @mcp.tool(description="Prune stored snapshots older than the configured retention window.")
  152. async def prune_history(retention_days: int = 30):
  153. deleted = prune_snapshots(retention_days=max(1, min(int(retention_days), 3650)))
  154. return {"deleted": deleted, "retention_days": retention_days}
  155. @mcp.tool(description="Compare attention between multiple keywords or entities.")
  156. async def compare_interest(keywords: list[str], timeframe: str = "7d"):
  157. if not keywords:
  158. return {"winner": None, "ratios": {}}
  159. normalized = [normalize_entity(k) for k in keywords]
  160. series_map = {}
  161. for original, keyword in zip(keywords, normalized):
  162. cache_key = f"interest:{keyword}:{timeframe}"
  163. cached = get_cache(cache_key)
  164. if cached:
  165. series = cached["series"]
  166. else:
  167. series = provider.interest_over_time(original, timeframe).series
  168. set_cache(
  169. cache_key,
  170. {
  171. "keyword": original,
  172. "normalized_keyword": keyword,
  173. "timeframe": timeframe,
  174. "series": series,
  175. "trend": _trend_label(series),
  176. },
  177. CACHE_TTL_SECONDS,
  178. )
  179. series_map[keyword] = series
  180. scores = {k: sum(v) for k, v in series_map.items()}
  181. winner = max(scores, key=scores.get)
  182. top_score = float(scores[winner]) or 1.0
  183. ratios = {k: round(v / top_score, 3) for k, v in scores.items()}
  184. return {"winner": winner, "ratios": ratios, "timeframe": timeframe}
  185. @mcp.tool(description="Get a compact attention score for a known entity.")
  186. async def get_attention_score(entity: str, timeframe: str = "24h"):
  187. normalized = normalize_entity(entity)
  188. try:
  189. series = provider.interest_over_time(entity, timeframe).series
  190. except GoogleTrendsError as exc:
  191. return {"entity": entity, "normalized_entity": normalized, "error": str(exc), "timeframe": timeframe}
  192. score = round(sum(series) / (len(series) * 100), 3)
  193. baseline = round(series[-1] / max(1, series[0]), 3) if series[0] else float(series[-1])
  194. return {
  195. "entity": entity,
  196. "normalized_entity": normalized,
  197. "score": score,
  198. "relative_to_baseline": baseline,
  199. "timeframe": timeframe,
  200. }
  201. app = FastAPI(title="Trends MCP Server")
  202. app.mount("/mcp", mcp.sse_app())
  203. @app.get("/")
  204. def root():
  205. return {
  206. "status": "ok",
  207. "transport": "fastmcp+sse",
  208. "mount": "/mcp",
  209. "tools": ["resolve_entity", "get_related_queries", "get_related_topics", "get_ledger_recent", "get_ledger_summary", "get_entity_history", "prune_history", "get_interest_over_time", "compare_interest", "get_attention_score"],
  210. }
  211. @mcp.tool(description="Debug Google Trends connectivity, suggestions, and timeframe handling.")
  212. async def debug_google_trends(keyword: str, timeframe: str = "7d"):
  213. keyword_norm = normalize_entity(keyword)
  214. try:
  215. suggestions = provider.suggestions(keyword)
  216. except GoogleTrendsError as exc:
  217. suggestions = {"error": str(exc)}
  218. try:
  219. result = provider.interest_over_time(keyword, timeframe)
  220. payload = {
  221. "keyword": keyword,
  222. "normalized_keyword": keyword_norm,
  223. "timeframe": timeframe,
  224. "suggestions": suggestions,
  225. "series": result.series,
  226. "trend": _trend_label(result.series),
  227. "fetched_at": result.fetched_at,
  228. }
  229. except GoogleTrendsError as exc:
  230. payload = {
  231. "keyword": keyword,
  232. "normalized_keyword": keyword_norm,
  233. "timeframe": timeframe,
  234. "suggestions": suggestions,
  235. "error": str(exc),
  236. }
  237. return payload
  238. @app.get("/health")
  239. def health():
  240. return {"status": "ok", "service": "trends-mcp", "cache": cache_stats()}
  241. def main():
  242. import uvicorn
  243. uvicorn.run("trends_mcp.mcp_server_fastmcp:app", host="0.0.0.0", port=8507, reload=False)
  244. if __name__ == "__main__":
  245. main()