mcp_server.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267
  1. """FastMCP transport for Atlas tools."""
  2. from __future__ import annotations
  3. from pathlib import Path
  4. import json
  5. from mcp.server.fastmcp import FastMCP
  6. from mcp.server.transport_security import TransportSecuritySettings
  7. from .atlas import enrich_entity, resolve_entity
  8. from .claims import build_claim_sets
  9. from .models import AtlasClaim, AtlasClaimObject, AtlasProvenance
  10. from .storage_service import AtlasStorageService
  11. from .type_classifier import CANONICAL_TYPES, classify_entity_type
  12. from .triple_export import entity_to_turtle
  13. mcp = FastMCP(
  14. "atlas",
  15. transport_security=TransportSecuritySettings(
  16. enable_dns_rebinding_protection=False
  17. ),
  18. )
  19. def _extract_bindings(result_payload):
  20. if isinstance(result_payload, list) and result_payload:
  21. text = getattr(result_payload[0], "text", None)
  22. if text:
  23. try:
  24. result_payload = json.loads(text)
  25. except Exception:
  26. return []
  27. if isinstance(result_payload, dict):
  28. return result_payload.get("results", {}).get("bindings", [])
  29. return []
  30. def _curie(value: str | None) -> str | None:
  31. if not value:
  32. return value
  33. if value.startswith("http://world.eu.org/atlas_ontology#"):
  34. return f"atlas:{value.split('#', 1)[-1]}"
  35. return value
  36. def _layer_value(value: str | None) -> str:
  37. v = (value or "").strip().lower()
  38. if v.endswith("#raw") or v.endswith(":raw") or v == "raw":
  39. return "raw"
  40. if v.endswith("#derived") or v.endswith(":derived") or v == "derived":
  41. return "derived"
  42. return "raw"
  43. def _id_type_value(value: str | None) -> str | None:
  44. if not value:
  45. return None
  46. tail = value.rsplit("#", 1)[-1].rsplit(":", 1)[-1]
  47. low = tail.lower()
  48. if low == "mid":
  49. return "mid"
  50. if low in {"wikidataqid", "qid"}:
  51. return "qid"
  52. return low
  53. def _human_claims(entity: AtlasEntity, raw_claims: list[dict], derived_claims: list[dict]) -> list[str]:
  54. out: list[str] = []
  55. seen: set[tuple[str, str]] = set()
  56. for claim in raw_claims + derived_claims:
  57. pred = claim.get("predicate")
  58. obj = claim.get("object", {}) or {}
  59. if pred == "atlas:hasIdentifier":
  60. ident_type = (obj.get("id_type") or "identifier").split(":")[-1].replace("-", " ").title()
  61. value = obj.get("value")
  62. if value:
  63. key = (ident_type, value)
  64. if key in seen:
  65. continue
  66. seen.add(key)
  67. out.append(f"{ident_type}: {value}")
  68. elif pred in {"atlas:hasLatitude", "atlas:hasLongitude"}:
  69. label = "Latitude" if pred.endswith("Latitude") else "Longitude"
  70. value = obj.get("value")
  71. if value:
  72. key = (label, value)
  73. if key in seen:
  74. continue
  75. seen.add(key)
  76. out.append(f"{label}: {value}")
  77. elif pred == "atlas:hasBirthDate":
  78. value = obj.get("value")
  79. if value:
  80. key = ("Birth date", value)
  81. if key in seen:
  82. continue
  83. seen.add(key)
  84. out.append(f"Birth date: {value}")
  85. elif pred == "atlas:hasCountry":
  86. value = obj.get("value")
  87. if value:
  88. key = ("Country", value)
  89. if key in seen:
  90. continue
  91. seen.add(key)
  92. out.append(f"Country: {value}")
  93. elif pred == "atlas:hasCanonicalType":
  94. value = obj.get("value") or ""
  95. if value.startswith("atlas:"):
  96. value = value.split(":", 1)[-1]
  97. else:
  98. # skip noisy canonical-type claims that don't reference atlas namespace
  99. continue
  100. if value not in CANONICAL_TYPES:
  101. continue
  102. key = ("Type", value)
  103. if key in seen:
  104. continue
  105. seen.add(key)
  106. out.append(f"Type: {value}")
  107. elif pred == "atlas:hasAlias":
  108. out.append(f"Alias: {obj.get('value')}")
  109. else:
  110. value = obj.get("value")
  111. if value:
  112. out.append(f"{pred.split(':')[-1]}: {value}")
  113. # keep output compact and readable
  114. return out[:50]
  115. async def _load_persisted_entity_state(entity_id: str) -> dict:
  116. svc = AtlasStorageService()
  117. payload = await svc.read_entity_claims(entity_id)
  118. if payload.get("status") != "ok":
  119. return {"claims": [], "canonical_label": None, "canonical_type": None}
  120. bindings = _extract_bindings(payload.get("result"))
  121. claims: list[AtlasClaim] = []
  122. seen = set()
  123. canonical_label = None
  124. canonical_type = None
  125. for b in bindings:
  126. if canonical_label is None:
  127. canonical_label = b.get("label", {}).get("value")
  128. if canonical_type is None:
  129. canon_val = b.get("canonType", {}).get("value")
  130. if canon_val:
  131. canonical_type = canon_val.rsplit("#", 1)[-1].rsplit(":", 1)[-1]
  132. claim_uri = b.get("claim", {}).get("value")
  133. claim_id = (claim_uri or "").rsplit("#", 1)[-1] or "clm_unknown"
  134. if claim_id in seen:
  135. continue
  136. seen.add(claim_id)
  137. pred = _curie(b.get("pred", {}).get("value"))
  138. obj_iri = b.get("objIri", {}).get("value")
  139. obj_lit = b.get("objLit", {}).get("value")
  140. id_val = b.get("idVal", {}).get("value")
  141. id_type = b.get("idType", {}).get("value")
  142. layer = _layer_value(b.get("layer", {}).get("value"))
  143. status = b.get("status", {}).get("value") or "active"
  144. prov = None
  145. src = b.get("src", {}).get("value")
  146. if src:
  147. prov = AtlasProvenance(
  148. source=src,
  149. retrieval_method=b.get("method", {}).get("value") or "unknown",
  150. confidence=float(b.get("conf", {}).get("value") or 0.0),
  151. retrieved_at=b.get("ts", {}).get("value"),
  152. )
  153. obj = AtlasClaimObject(kind="literal", value=obj_lit or obj_iri or "")
  154. if id_val:
  155. obj = AtlasClaimObject(kind="identifier", value=id_val, id_type=_id_type_value(id_type))
  156. elif obj_iri:
  157. obj = AtlasClaimObject(kind="identifier", value=obj_iri.rsplit("#", 1)[-1])
  158. claims.append(
  159. AtlasClaim(
  160. claim_id=claim_id,
  161. subject=entity_id,
  162. predicate=pred,
  163. object=obj,
  164. layer=layer,
  165. status=status,
  166. provenance=prov,
  167. )
  168. )
  169. return {"claims": claims, "canonical_label": canonical_label, "canonical_type": canonical_type}
  170. @mcp.tool(name="resolve_entity", description="Resolve a subject string to a canonical Atlas entity. Use payloads=true to include raw payload snapshots; use debug=true for full claim/provenance detail.")
  171. async def resolve_entity_tool(subject: str, context: str | None = None, debug: bool = False, debug_path: str | None = None, payloads: bool = False):
  172. entity = await resolve_entity(subject, context)
  173. persisted = await _load_persisted_entity_state(entity.atlas_id)
  174. if persisted.get("claims"):
  175. entity.claims = persisted["claims"]
  176. if persisted.get("canonical_label"):
  177. entity.canonical_label = persisted["canonical_label"]
  178. if persisted.get("canonical_type"):
  179. entity.entity_type = persisted["canonical_type"]
  180. entity.needs_curation = False
  181. if entity.entity_type not in CANONICAL_TYPES:
  182. try:
  183. resolution = {
  184. "canonical_label": entity.canonical_label,
  185. "type": entity.entity_type,
  186. "candidates": entity.raw_payload.get("wikidata", {}).get("candidates", []) if isinstance(entity.raw_payload, dict) else [],
  187. }
  188. classification = await classify_entity_type(entity.atlas_id, resolution, context)
  189. if classification.canonical_type:
  190. entity.entity_type = classification.canonical_type
  191. entity.needs_curation = classification.needs_curation
  192. except Exception:
  193. pass
  194. raw_claims, derived_claims = build_claim_sets(entity)
  195. clean_claims = _human_claims(entity, raw_claims, derived_claims)
  196. result = {
  197. "atlas_id": entity.atlas_id,
  198. "canonical_label": entity.canonical_label,
  199. "canonical_description": entity.canonical_description,
  200. "entity_type": entity.entity_type,
  201. "needs_curation": entity.needs_curation,
  202. "aliases": [alias.label for alias in entity.aliases],
  203. "active_claims": clean_claims,
  204. }
  205. if payloads:
  206. result["g_trends_payload"] = {k: v for k, v in entity.raw_payload.items() if k not in {"wikidata", "wikidata_entity_json"}}
  207. result["wikidata_payload"] = (
  208. entity.raw_payload.get("wikidata")
  209. if entity.raw_payload.get("wikidata") is not None
  210. else {"wikidata_status": "missing"}
  211. )
  212. if debug:
  213. turtle = entity_to_turtle(entity)
  214. result["raw_claims"] = raw_claims
  215. result["derived_claims"] = derived_claims
  216. result["source_payloads"] = {
  217. "g_trends_payload": {k: v for k, v in entity.raw_payload.items() if k not in {"wikidata", "wikidata_entity_json"}},
  218. "wikidata_payload": entity.raw_payload.get("wikidata") if entity.raw_payload.get("wikidata") is not None else {"wikidata_status": "missing"},
  219. }
  220. result["turtle"] = turtle
  221. if debug_path:
  222. path = Path(debug_path)
  223. path.parent.mkdir(parents=True, exist_ok=True)
  224. path.write_text(turtle, encoding="utf-8")
  225. result["turtle_path"] = str(path)
  226. return result
  227. @mcp.tool(name="enrich_entity", description="Enrich a canonical Atlas entity into a related-entity dataset.")
  228. async def enrich_entity_tool(subject: str, depth: int = 1, context: str | None = None):
  229. entity = await resolve_entity(subject, context)
  230. result = enrich_entity(entity, depth=depth)
  231. return {
  232. "seed": {
  233. "atlas_id": result.seed_entity.atlas_id,
  234. "canonical_label": result.seed_entity.canonical_label,
  235. },
  236. "related_entities": [
  237. {
  238. "atlas_id": item.atlas_id,
  239. "canonical_label": item.canonical_label,
  240. "entity_type": item.entity_type,
  241. }
  242. for item in result.related_entities
  243. ],
  244. "query_context": result.query_context,
  245. "depth": result.depth,
  246. }