| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227 |
- """
- Crypto MCP Server — FastAPI entry point.
- MCP endpoints:
- GET /tools → list available MCP tools
- POST /tools/{name} → call a tool
- Internal:
- GET /health → server health + cache stats
- """
- import sys
- import os
- sys.path.insert(0, os.path.dirname(__file__))
- from fastapi import FastAPI, Request
- from fastapi.responses import JSONResponse
- from fastapi.middleware.cors import CORSMiddleware
- import services
- from mcp_tools import MCP_TOOLS
- from errors import CryptoMCPError
- from cache import get_cache_stats
- MCP_SPEC = {
- "protocolVersion": "2024-11-05",
- "serverInfo": {"name": "crypto-mcp", "version": "1.0.0"},
- }
- _sessions: dict[str, dict] = {}
- app = FastAPI(
- title="Crypto MCP Server",
- description="Agent-friendly crypto market data + technical indicators",
- version="1.0.0",
- )
- @app.get("/")
- async def root():
- return {"jsonrpc": "2.0", "result": {"tools": MCP_TOOLS}, "id": None}
- @app.get("/mcp")
- async def mcp_root():
- return {"jsonrpc": "2.0", "result": {"tools": MCP_TOOLS}, "id": None}
- @app.post("/mcp")
- async def mcp_rpc(request: Request):
- try:
- payload = await request.json()
- except Exception:
- return _rpc_error(None, -32700, "Parse error")
- if payload.get("jsonrpc") != "2.0":
- return _rpc_error(payload.get("id"), -32600, "Invalid Request")
- method = payload.get("method")
- params = payload.get("params", {}) or {}
- req_id = payload.get("id")
- try:
- if method == "initialize":
- session_id = params.get("sessionId") or _new_session_id()
- _sessions.setdefault(session_id, {"initialized": True})
- return _rpc_result(req_id, {**MCP_SPEC, "sessionId": session_id, "capabilities": {"tools": {"listChanged": False}}})
- if method in ("tools/list", "listTools"):
- return _rpc_result(req_id, {"tools": MCP_TOOLS})
- if method in ("tools/call", "callTool"):
- name = params.get("name") or params.get("toolName")
- arguments = params.get("arguments") or params.get("params") or {}
- if not name:
- return _rpc_error(req_id, -32602, "Missing tool name")
- result = await _call_tool(name, arguments)
- return _rpc_result(req_id, result)
- if method == "ping":
- return _rpc_result(req_id, {"ok": True})
- return _rpc_error(req_id, -32601, f"Method not found: {method}")
- except CryptoMCPError as exc:
- return _rpc_error(req_id, 400, exc.to_dict())
- def _rpc_result(req_id, result):
- return {"jsonrpc": "2.0", "id": req_id, "result": result}
- def _rpc_error(req_id, code, message):
- return {"jsonrpc": "2.0", "id": req_id, "error": {"code": code, "message": message}}
- def _new_session_id() -> str:
- import uuid
- return uuid.uuid4().hex
- async def _call_tool(tool_name: str, body: dict):
- match tool_name:
- case "get_price":
- return await services.get_price(_require(body, "symbol"))
- case "get_ohlcv":
- return await services.get_ohlcv(_require(body, "symbol"), body.get("timeframe", "1h"), int(body.get("limit", 100)))
- case "get_indicator":
- return await services.get_indicator(_require(body, "symbol"), _require(body, "indicator"), body.get("timeframe", "1h"), body.get("params", {}))
- case "get_market_snapshot":
- return await services.get_market_snapshot(_require(body, "symbol"))
- case "get_top_movers":
- return await services.get_top_movers(int(body.get("limit", 10)))
- case _:
- return {"error": "TOOL_NOT_FOUND", "detail": f"No tool named '{tool_name}'"}
- app.add_middleware(
- CORSMiddleware,
- allow_origins=["*"],
- allow_methods=["*"],
- allow_headers=["*"],
- )
- # ---------------------------------------------------------------------------
- # Global error handler
- # ---------------------------------------------------------------------------
- @app.exception_handler(CryptoMCPError)
- async def crypto_error_handler(request: Request, exc: CryptoMCPError):
- return JSONResponse(status_code=400, content=exc.to_dict())
- @app.exception_handler(Exception)
- async def generic_error_handler(request: Request, exc: Exception):
- return JSONResponse(
- status_code=500,
- content={"error": "INTERNAL_ERROR", "detail": str(exc)},
- )
- # ---------------------------------------------------------------------------
- # Health
- # ---------------------------------------------------------------------------
- @app.get("/health")
- async def health():
- return {"status": "ok", "cache": get_cache_stats()}
- # ---------------------------------------------------------------------------
- # MCP Tool Registry
- # ---------------------------------------------------------------------------
- @app.get("/tools")
- async def list_tools():
- """Return all available MCP tool definitions."""
- return {"tools": MCP_TOOLS}
- # ---------------------------------------------------------------------------
- # MCP Tool Dispatch
- # ---------------------------------------------------------------------------
- @app.post("/tools/{tool_name}")
- async def call_tool(tool_name: str, request: Request):
- """
- Dispatch a tool call by name.
- Body: tool parameters as JSON object.
- """
- try:
- body = await request.json()
- except Exception:
- body = {}
- match tool_name:
- case "get_price":
- symbol = _require(body, "symbol")
- return await services.get_price(symbol)
- case "get_ohlcv":
- symbol = _require(body, "symbol")
- timeframe = body.get("timeframe", "1h")
- limit = int(body.get("limit", 100))
- return await services.get_ohlcv(symbol, timeframe, limit)
- case "get_indicator":
- symbol = _require(body, "symbol")
- indicator = _require(body, "indicator")
- timeframe = body.get("timeframe", "1h")
- params = body.get("params", {})
- return await services.get_indicator(symbol, indicator, timeframe, params)
- case "get_market_snapshot":
- symbol = _require(body, "symbol")
- return await services.get_market_snapshot(symbol)
- case "get_top_movers":
- limit = int(body.get("limit", 10))
- return await services.get_top_movers(limit)
- case _:
- return JSONResponse(
- status_code=404,
- content={"error": "TOOL_NOT_FOUND", "detail": f"No tool named '{tool_name}'"},
- )
- # ---------------------------------------------------------------------------
- # Helper
- # ---------------------------------------------------------------------------
- def _require(body: dict, key: str) -> str:
- from errors import InvalidParamsError
- val = body.get(key)
- if not val:
- raise InvalidParamsError(f"Missing required parameter: '{key}'")
- return str(val)
- # ---------------------------------------------------------------------------
- # Dev runner
- # ---------------------------------------------------------------------------
- if __name__ == "__main__":
- import uvicorn
- uvicorn.run("main:app", host="0.0.0.0", port=8505, reload=True)
|