main.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. """
  2. Crypto MCP Server — FastAPI entry point.
  3. MCP endpoints:
  4. GET /tools → list available MCP tools
  5. POST /tools/{name} → call a tool
  6. Internal:
  7. GET /health → server health + cache stats
  8. """
  9. import sys
  10. import os
  11. sys.path.insert(0, os.path.dirname(__file__))
  12. from fastapi import FastAPI, Request
  13. from fastapi.responses import JSONResponse
  14. from fastapi.middleware.cors import CORSMiddleware
  15. import services
  16. from mcp_tools import MCP_TOOLS
  17. from errors import CryptoMCPError
  18. from cache import get_cache_stats
  19. MCP_SPEC = {
  20. "protocolVersion": "2024-11-05",
  21. "serverInfo": {"name": "crypto-mcp", "version": "1.0.0"},
  22. }
  23. _sessions: dict[str, dict] = {}
  24. app = FastAPI(
  25. title="Crypto MCP Server",
  26. description="Agent-friendly crypto market data + technical indicators",
  27. version="1.0.0",
  28. )
  29. @app.get("/")
  30. async def root():
  31. return {"jsonrpc": "2.0", "result": {"tools": MCP_TOOLS}, "id": None}
  32. @app.get("/mcp")
  33. async def mcp_root():
  34. return {"jsonrpc": "2.0", "result": {"tools": MCP_TOOLS}, "id": None}
  35. @app.post("/mcp")
  36. async def mcp_rpc(request: Request):
  37. try:
  38. payload = await request.json()
  39. except Exception:
  40. return _rpc_error(None, -32700, "Parse error")
  41. if payload.get("jsonrpc") != "2.0":
  42. return _rpc_error(payload.get("id"), -32600, "Invalid Request")
  43. method = payload.get("method")
  44. params = payload.get("params", {}) or {}
  45. req_id = payload.get("id")
  46. try:
  47. if method == "initialize":
  48. session_id = params.get("sessionId") or _new_session_id()
  49. _sessions.setdefault(session_id, {"initialized": True})
  50. return _rpc_result(req_id, {**MCP_SPEC, "sessionId": session_id, "capabilities": {"tools": {"listChanged": False}}})
  51. if method in ("tools/list", "listTools"):
  52. return _rpc_result(req_id, {"tools": MCP_TOOLS})
  53. if method in ("tools/call", "callTool"):
  54. name = params.get("name") or params.get("toolName")
  55. arguments = params.get("arguments") or params.get("params") or {}
  56. if not name:
  57. return _rpc_error(req_id, -32602, "Missing tool name")
  58. result = await _call_tool(name, arguments)
  59. return _rpc_result(req_id, result)
  60. if method == "ping":
  61. return _rpc_result(req_id, {"ok": True})
  62. return _rpc_error(req_id, -32601, f"Method not found: {method}")
  63. except CryptoMCPError as exc:
  64. return _rpc_error(req_id, 400, exc.to_dict())
  65. def _rpc_result(req_id, result):
  66. return {"jsonrpc": "2.0", "id": req_id, "result": result}
  67. def _rpc_error(req_id, code, message):
  68. return {"jsonrpc": "2.0", "id": req_id, "error": {"code": code, "message": message}}
  69. def _new_session_id() -> str:
  70. import uuid
  71. return uuid.uuid4().hex
  72. async def _call_tool(tool_name: str, body: dict):
  73. match tool_name:
  74. case "get_price":
  75. return await services.get_price(_require(body, "symbol"))
  76. case "get_ohlcv":
  77. return await services.get_ohlcv(_require(body, "symbol"), body.get("timeframe", "1h"), int(body.get("limit", 100)))
  78. case "get_indicator":
  79. return await services.get_indicator(_require(body, "symbol"), _require(body, "indicator"), body.get("timeframe", "1h"), body.get("params", {}))
  80. case "get_market_snapshot":
  81. return await services.get_market_snapshot(_require(body, "symbol"))
  82. case "get_top_movers":
  83. return await services.get_top_movers(int(body.get("limit", 10)))
  84. case _:
  85. return {"error": "TOOL_NOT_FOUND", "detail": f"No tool named '{tool_name}'"}
  86. app.add_middleware(
  87. CORSMiddleware,
  88. allow_origins=["*"],
  89. allow_methods=["*"],
  90. allow_headers=["*"],
  91. )
  92. # ---------------------------------------------------------------------------
  93. # Global error handler
  94. # ---------------------------------------------------------------------------
  95. @app.exception_handler(CryptoMCPError)
  96. async def crypto_error_handler(request: Request, exc: CryptoMCPError):
  97. return JSONResponse(status_code=400, content=exc.to_dict())
  98. @app.exception_handler(Exception)
  99. async def generic_error_handler(request: Request, exc: Exception):
  100. return JSONResponse(
  101. status_code=500,
  102. content={"error": "INTERNAL_ERROR", "detail": str(exc)},
  103. )
  104. # ---------------------------------------------------------------------------
  105. # Health
  106. # ---------------------------------------------------------------------------
  107. @app.get("/health")
  108. async def health():
  109. return {"status": "ok", "cache": get_cache_stats()}
  110. # ---------------------------------------------------------------------------
  111. # MCP Tool Registry
  112. # ---------------------------------------------------------------------------
  113. @app.get("/tools")
  114. async def list_tools():
  115. """Return all available MCP tool definitions."""
  116. return {"tools": MCP_TOOLS}
  117. # ---------------------------------------------------------------------------
  118. # MCP Tool Dispatch
  119. # ---------------------------------------------------------------------------
  120. @app.post("/tools/{tool_name}")
  121. async def call_tool(tool_name: str, request: Request):
  122. """
  123. Dispatch a tool call by name.
  124. Body: tool parameters as JSON object.
  125. """
  126. try:
  127. body = await request.json()
  128. except Exception:
  129. body = {}
  130. match tool_name:
  131. case "get_price":
  132. symbol = _require(body, "symbol")
  133. return await services.get_price(symbol)
  134. case "get_ohlcv":
  135. symbol = _require(body, "symbol")
  136. timeframe = body.get("timeframe", "1h")
  137. limit = int(body.get("limit", 100))
  138. return await services.get_ohlcv(symbol, timeframe, limit)
  139. case "get_indicator":
  140. symbol = _require(body, "symbol")
  141. indicator = _require(body, "indicator")
  142. timeframe = body.get("timeframe", "1h")
  143. params = body.get("params", {})
  144. return await services.get_indicator(symbol, indicator, timeframe, params)
  145. case "get_market_snapshot":
  146. symbol = _require(body, "symbol")
  147. return await services.get_market_snapshot(symbol)
  148. case "get_top_movers":
  149. limit = int(body.get("limit", 10))
  150. return await services.get_top_movers(limit)
  151. case _:
  152. return JSONResponse(
  153. status_code=404,
  154. content={"error": "TOOL_NOT_FOUND", "detail": f"No tool named '{tool_name}'"},
  155. )
  156. # ---------------------------------------------------------------------------
  157. # Helper
  158. # ---------------------------------------------------------------------------
  159. def _require(body: dict, key: str) -> str:
  160. from errors import InvalidParamsError
  161. val = body.get(key)
  162. if not val:
  163. raise InvalidParamsError(f"Missing required parameter: '{key}'")
  164. return str(val)
  165. # ---------------------------------------------------------------------------
  166. # Dev runner
  167. # ---------------------------------------------------------------------------
  168. if __name__ == "__main__":
  169. import uvicorn
  170. uvicorn.run("main:app", host="0.0.0.0", port=8505, reload=True)