| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267 |
- """FastMCP transport for Atlas tools."""
- from __future__ import annotations
- from pathlib import Path
- import json
- from mcp.server.fastmcp import FastMCP
- from mcp.server.transport_security import TransportSecuritySettings
- from .atlas import enrich_entity, resolve_entity
- from .claims import build_claim_sets
- from .models import AtlasClaim, AtlasClaimObject, AtlasProvenance
- from .storage_service import AtlasStorageService
- from .type_classifier import CANONICAL_TYPES, classify_entity_type
- from .triple_export import entity_to_turtle
- mcp = FastMCP(
- "atlas",
- transport_security=TransportSecuritySettings(
- enable_dns_rebinding_protection=False
- ),
- )
- def _extract_bindings(result_payload):
- if isinstance(result_payload, list) and result_payload:
- text = getattr(result_payload[0], "text", None)
- if text:
- try:
- result_payload = json.loads(text)
- except Exception:
- return []
- if isinstance(result_payload, dict):
- return result_payload.get("results", {}).get("bindings", [])
- return []
- def _curie(value: str | None) -> str | None:
- if not value:
- return value
- if value.startswith("http://world.eu.org/atlas_ontology#"):
- return f"atlas:{value.split('#', 1)[-1]}"
- return value
- def _layer_value(value: str | None) -> str:
- v = (value or "").strip().lower()
- if v.endswith("#raw") or v.endswith(":raw") or v == "raw":
- return "raw"
- if v.endswith("#derived") or v.endswith(":derived") or v == "derived":
- return "derived"
- return "raw"
- def _id_type_value(value: str | None) -> str | None:
- if not value:
- return None
- tail = value.rsplit("#", 1)[-1].rsplit(":", 1)[-1]
- low = tail.lower()
- if low == "mid":
- return "mid"
- if low in {"wikidataqid", "qid"}:
- return "qid"
- return low
- def _human_claims(entity: AtlasEntity, raw_claims: list[dict], derived_claims: list[dict]) -> list[str]:
- out: list[str] = []
- seen: set[tuple[str, str]] = set()
- for claim in raw_claims + derived_claims:
- pred = claim.get("predicate")
- obj = claim.get("object", {}) or {}
- if pred == "atlas:hasIdentifier":
- ident_type = (obj.get("id_type") or "identifier").split(":")[-1].replace("-", " ").title()
- value = obj.get("value")
- if value:
- key = (ident_type, value)
- if key in seen:
- continue
- seen.add(key)
- out.append(f"{ident_type}: {value}")
- elif pred in {"atlas:hasLatitude", "atlas:hasLongitude"}:
- label = "Latitude" if pred.endswith("Latitude") else "Longitude"
- value = obj.get("value")
- if value:
- key = (label, value)
- if key in seen:
- continue
- seen.add(key)
- out.append(f"{label}: {value}")
- elif pred == "atlas:hasBirthDate":
- value = obj.get("value")
- if value:
- key = ("Birth date", value)
- if key in seen:
- continue
- seen.add(key)
- out.append(f"Birth date: {value}")
- elif pred == "atlas:hasCountry":
- value = obj.get("value")
- if value:
- key = ("Country", value)
- if key in seen:
- continue
- seen.add(key)
- out.append(f"Country: {value}")
- elif pred == "atlas:hasCanonicalType":
- value = obj.get("value") or ""
- if value.startswith("atlas:"):
- value = value.split(":", 1)[-1]
- else:
- # skip noisy canonical-type claims that don't reference atlas namespace
- continue
- if value not in CANONICAL_TYPES:
- continue
- key = ("Type", value)
- if key in seen:
- continue
- seen.add(key)
- out.append(f"Type: {value}")
- elif pred == "atlas:hasAlias":
- out.append(f"Alias: {obj.get('value')}")
- else:
- value = obj.get("value")
- if value:
- out.append(f"{pred.split(':')[-1]}: {value}")
- # keep output compact and readable
- return out[:50]
- async def _load_persisted_entity_state(entity_id: str) -> dict:
- svc = AtlasStorageService()
- payload = await svc.read_entity_claims(entity_id)
- if payload.get("status") != "ok":
- return {"claims": [], "canonical_label": None, "canonical_type": None}
- bindings = _extract_bindings(payload.get("result"))
- claims: list[AtlasClaim] = []
- seen = set()
- canonical_label = None
- canonical_type = None
- for b in bindings:
- if canonical_label is None:
- canonical_label = b.get("label", {}).get("value")
- if canonical_type is None:
- canon_val = b.get("canonType", {}).get("value")
- if canon_val:
- canonical_type = canon_val.rsplit("#", 1)[-1].rsplit(":", 1)[-1]
- claim_uri = b.get("claim", {}).get("value")
- claim_id = (claim_uri or "").rsplit("#", 1)[-1] or "clm_unknown"
- if claim_id in seen:
- continue
- seen.add(claim_id)
- pred = _curie(b.get("pred", {}).get("value"))
- obj_iri = b.get("objIri", {}).get("value")
- obj_lit = b.get("objLit", {}).get("value")
- id_val = b.get("idVal", {}).get("value")
- id_type = b.get("idType", {}).get("value")
- layer = _layer_value(b.get("layer", {}).get("value"))
- status = b.get("status", {}).get("value") or "active"
- prov = None
- src = b.get("src", {}).get("value")
- if src:
- prov = AtlasProvenance(
- source=src,
- retrieval_method=b.get("method", {}).get("value") or "unknown",
- confidence=float(b.get("conf", {}).get("value") or 0.0),
- retrieved_at=b.get("ts", {}).get("value"),
- )
- obj = AtlasClaimObject(kind="literal", value=obj_lit or obj_iri or "")
- if id_val:
- obj = AtlasClaimObject(kind="identifier", value=id_val, id_type=_id_type_value(id_type))
- elif obj_iri:
- obj = AtlasClaimObject(kind="identifier", value=obj_iri.rsplit("#", 1)[-1])
- claims.append(
- AtlasClaim(
- claim_id=claim_id,
- subject=entity_id,
- predicate=pred,
- object=obj,
- layer=layer,
- status=status,
- provenance=prov,
- )
- )
- return {"claims": claims, "canonical_label": canonical_label, "canonical_type": canonical_type}
- @mcp.tool(name="resolve_entity", description="Resolve a subject string to a canonical Atlas entity. Use payloads=true to include raw payload snapshots; use debug=true for full claim/provenance detail.")
- async def resolve_entity_tool(subject: str, context: str | None = None, debug: bool = False, debug_path: str | None = None, payloads: bool = False):
- entity = await resolve_entity(subject, context)
- persisted = await _load_persisted_entity_state(entity.atlas_id)
- if persisted.get("claims"):
- entity.claims = persisted["claims"]
- if persisted.get("canonical_label"):
- entity.canonical_label = persisted["canonical_label"]
- if persisted.get("canonical_type"):
- entity.entity_type = persisted["canonical_type"]
- entity.needs_curation = False
- if entity.entity_type not in CANONICAL_TYPES:
- try:
- resolution = {
- "canonical_label": entity.canonical_label,
- "type": entity.entity_type,
- "candidates": entity.raw_payload.get("wikidata", {}).get("candidates", []) if isinstance(entity.raw_payload, dict) else [],
- }
- classification = await classify_entity_type(entity.atlas_id, resolution, context)
- if classification.canonical_type:
- entity.entity_type = classification.canonical_type
- entity.needs_curation = classification.needs_curation
- except Exception:
- pass
- raw_claims, derived_claims = build_claim_sets(entity)
- clean_claims = _human_claims(entity, raw_claims, derived_claims)
- result = {
- "atlas_id": entity.atlas_id,
- "canonical_label": entity.canonical_label,
- "canonical_description": entity.canonical_description,
- "entity_type": entity.entity_type,
- "needs_curation": entity.needs_curation,
- "aliases": [alias.label for alias in entity.aliases],
- "active_claims": clean_claims,
- }
- if payloads:
- result["g_trends_payload"] = {k: v for k, v in entity.raw_payload.items() if k not in {"wikidata", "wikidata_entity_json"}}
- result["wikidata_payload"] = (
- entity.raw_payload.get("wikidata")
- if entity.raw_payload.get("wikidata") is not None
- else {"wikidata_status": "missing"}
- )
- if debug:
- turtle = entity_to_turtle(entity)
- result["raw_claims"] = raw_claims
- result["derived_claims"] = derived_claims
- result["source_payloads"] = {
- "g_trends_payload": {k: v for k, v in entity.raw_payload.items() if k not in {"wikidata", "wikidata_entity_json"}},
- "wikidata_payload": entity.raw_payload.get("wikidata") if entity.raw_payload.get("wikidata") is not None else {"wikidata_status": "missing"},
- }
- result["turtle"] = turtle
- if debug_path:
- path = Path(debug_path)
- path.parent.mkdir(parents=True, exist_ok=True)
- path.write_text(turtle, encoding="utf-8")
- result["turtle_path"] = str(path)
- return result
- @mcp.tool(name="enrich_entity", description="Enrich a canonical Atlas entity into a related-entity dataset.")
- async def enrich_entity_tool(subject: str, depth: int = 1, context: str | None = None):
- entity = await resolve_entity(subject, context)
- result = enrich_entity(entity, depth=depth)
- return {
- "seed": {
- "atlas_id": result.seed_entity.atlas_id,
- "canonical_label": result.seed_entity.canonical_label,
- },
- "related_entities": [
- {
- "atlas_id": item.atlas_id,
- "canonical_label": item.canonical_label,
- "entity_type": item.entity_type,
- }
- for item in result.related_entities
- ],
- "query_context": result.query_context,
- "depth": result.depth,
- }
|