"""Atlas persistence/read service via virtuoso-mcp (MCP transport). We intentionally use the MCP SSE transport ("/mcp/sse") to match the standard across our MCP servers and avoid legacy direct "/rpc" calls. """ from __future__ import annotations import json import logging import os import time from typing import Any, Awaitable, Callable from mcp import ClientSession from mcp.client.sse import sse_client from app.models import AtlasEntity from app.triple_export import entity_to_turtle logger = logging.getLogger(__name__) ATLAS_GRAPH_IRI = os.getenv("ATLAS_GRAPH_IRI", "http://world.eu.org/atlas_data#") VIRTUOSO_MCP_SSE_URL = os.getenv("ATLAS_VIRTUOSO_MCP_SSE_URL", "http://192.168.0.249:8501/mcp/sse") VIRTUOSO_MCP_TIMEOUT = float(os.getenv("ATLAS_VIRTUOSO_MCP_TIMEOUT", "10")) VIRTUOSO_MCP_SSE_READ_TIMEOUT = float(os.getenv("ATLAS_VIRTUOSO_MCP_SSE_READ_TIMEOUT", str(60 * 5))) CallToolFn = Callable[[str, dict[str, Any]], Awaitable[dict[str, Any]]] def _safe_fragment(value: str) -> str: value = (value or "").strip().lower() out = [] for ch in value: if ch.isalnum() or ch in ["_", "-"]: out.append(ch) else: out.append("_") frag = "".join(out).strip("_") return frag or "entity" def entity_iri(entity_id: str) -> str: return f"http://world.eu.org/atlas_data#entity_{_safe_fragment(entity_id)}" class AtlasStorageService: def __init__(self, call_tool: CallToolFn | None = None): # Tests can inject a fake transport; production uses the MCP session client. self._call_tool_override = call_tool self._tool_cache: dict[str, tuple[float, dict[str, Any]]] = {} self._tool_cache_ttl_seconds = float(os.getenv("ATLAS_VIRTUOSO_CALL_CACHE_TTL", "30")) def _cache_key(self, tool_name: str, payload: dict[str, Any]) -> str: # Stable keying keeps equivalent tool calls from duplicating work. return f"{tool_name}:{json.dumps(payload, sort_keys=True, separators=(',', ':'))}" def _cache_get(self, key: str) -> dict[str, Any] | None: item = self._tool_cache.get(key) if not item: return None expires_at, value = item if expires_at < time.time(): self._tool_cache.pop(key, None) return None return value def _cache_set(self, key: str, value: dict[str, Any]) -> None: self._tool_cache[key] = (time.time() + self._tool_cache_ttl_seconds, value) async def _call_tool(self, tool_name: str, payload: dict[str, Any], *, cache_result: bool = True) -> dict[str, Any]: # Cache read-heavy calls, but let write paths pass through untouched. cache_key = self._cache_key(tool_name, payload) if cache_result: cached = self._cache_get(cache_key) if cached is not None: return cached if self._call_tool_override: result = await self._call_tool_override(tool_name, payload) if cache_result: self._cache_set(cache_key, result) return result try: async with sse_client( VIRTUOSO_MCP_SSE_URL, timeout=VIRTUOSO_MCP_TIMEOUT, sse_read_timeout=VIRTUOSO_MCP_SSE_READ_TIMEOUT, ) as (read_stream, write_stream): async with ClientSession(read_stream, write_stream) as session: await session.initialize() result = await session.call_tool(tool_name, {"input": payload}) if result.isError: raise RuntimeError(f"Tool {tool_name} failed: {result.error}") data = result.structuredContent if result.structuredContent is not None else result.content if cache_result and isinstance(data, dict): self._cache_set(cache_key, data) return data except Exception as exc: raise RuntimeError(f"Virtuoso MCP call failed for {tool_name}: {exc}") async def write_entity(self, entity: AtlasEntity) -> dict[str, Any]: # Turn an Atlas entity into Turtle, then hand it to Virtuoso in one insert. ttl = entity_to_turtle(entity) try: result = await self._call_tool( "batch_insert", { "ttl": ttl, "graph": ATLAS_GRAPH_IRI, }, cache_result=False, ) return { "status": "ok", "graph": ATLAS_GRAPH_IRI, "entity_id": entity.atlas_id, "result": result, } except Exception as exc: logger.warning( "Atlas persistence failed for %s into %s: %s", entity.atlas_id, ATLAS_GRAPH_IRI, exc, ) return { "status": "unfinished", "message": "Persistence path not fully available yet", "error": str(exc), "entity_id": entity.atlas_id, } async def read_entity_claims(self, entity_id: str, include_superseded: bool = False) -> dict[str, Any]: # Pull the entity's claim graph, with active claims by default. iri = entity_iri(entity_id) status_filter = "" if include_superseded else 'FILTER(?status = "active")' query = f""" PREFIX atlas: SELECT ?entity ?label ?canonType ?claim ?pred ?objIri ?objLit ?idVal ?idType ?layer ?status ?prov ?src ?method ?conf ?ts WHERE {{ VALUES ?entity {{ <{iri}> }} ?entity a atlas:Entity ; atlas:canonicalLabel ?label ; atlas:hasCanonicalType ?canonType ; atlas:hasClaim ?claim . ?claim atlas:claimSubjectIri ?entity ; atlas:claimPredicate ?pred ; atlas:claimLayer ?layer ; atlas:claimStatus ?status . OPTIONAL {{ ?claim atlas:claimObjectIri ?objIri . }} OPTIONAL {{ ?claim atlas:claimObjectLiteral ?objLit . }} OPTIONAL {{ ?objIri atlas:identifierValue ?idVal . }} OPTIONAL {{ ?objIri atlas:identifierType ?idType . }} OPTIONAL {{ ?claim atlas:hasProvenance ?prov . ?prov atlas:provenanceSource ?src . OPTIONAL {{ ?prov atlas:retrievalMethod ?method . }} OPTIONAL {{ ?prov atlas:confidence ?conf . }} OPTIONAL {{ ?prov atlas:retrievedAt ?ts . }} }} {status_filter} }} ORDER BY ?claim """ try: result = await self._call_tool("sparql_query", {"query": query}) return { "status": "ok", "entity_id": entity_id, "query": query, "result": result, } except Exception as exc: return { "status": "unfinished", "message": "Read path not fully available yet", "error": str(exc), "entity_id": entity_id, "query": query, } async def sparql_update(self, query: str) -> dict[str, Any]: # Write raw SPARQL when a higher-level helper would just get in the way. return await self._call_tool("sparql_update", {"query": query}, cache_result=False) async def supersede_claims(self, claim_iris: list[str]) -> None: if not claim_iris: return values = " ".join(f"<{uri}>" for uri in claim_iris) query = f""" PREFIX atlas: WITH <{ATLAS_GRAPH_IRI}> DELETE {{ ?claim atlas:claimStatus ?old . }} INSERT {{ ?claim atlas:claimStatus "superseded" }} WHERE {{ VALUES ?claim {{ {values} }} OPTIONAL {{ ?claim atlas:claimStatus ?old . }} }} """ await self.sparql_update(query) async def replace_entity_core(self, entity_id: str, *, canonical_label: str, canonical_description: str | None, canonical_type: str | None) -> None: # Replace the entity's canonical fields without disturbing its claims. iri = entity_iri(entity_id) desc_insert = f' <{iri}> atlas:canonicalDescription "{canonical_description.replace("\\", "\\\\").replace("\"", "\\\"")}" .\n' if canonical_description else "" type_insert = f" <{iri}> atlas:hasCanonicalType atlas:{canonical_type} .\n" if canonical_type else "" query = f""" PREFIX atlas: WITH <{ATLAS_GRAPH_IRI}> DELETE {{ <{iri}> atlas:canonicalLabel ?oldLabel . <{iri}> atlas:canonicalDescription ?oldDesc . <{iri}> atlas:hasCanonicalType ?oldType . }} INSERT {{ <{iri}> atlas:canonicalLabel "{canonical_label.replace("\\", "\\\\").replace("\"", "\\\"")}" . {desc_insert}{type_insert}}} WHERE {{ OPTIONAL {{ <{iri}> atlas:canonicalLabel ?oldLabel . }} OPTIONAL {{ <{iri}> atlas:canonicalDescription ?oldDesc . }} OPTIONAL {{ <{iri}> atlas:hasCanonicalType ?oldType . }} }} """ await self.sparql_update(query)