"""Virtuoso MCP bridge for cached entity lookups.""" from __future__ import annotations import json import os from collections import OrderedDict from typing import Optional from mcp import ClientSession from mcp.client.sse import sse_client from app.models import AtlasAlias, AtlasClaim, AtlasClaimObject, AtlasEntity, AtlasProvenance VIRTUOSO_MCP_SSE_URL = os.getenv("ATLAS_VIRTUOSO_MCP_SSE_URL", "http://192.168.0.249:8501/mcp/sse") VIRTUOSO_MCP_TIMEOUT = float(os.getenv("ATLAS_VIRTUOSO_MCP_TIMEOUT", "10")) VIRTUOSO_MCP_SSE_READ_TIMEOUT = float(os.getenv("ATLAS_VIRTUOSO_MCP_SSE_READ_TIMEOUT", str(60 * 5))) ATLAS_GRAPH_IRI = os.getenv("ATLAS_GRAPH_IRI", "http://world.eu.org/atlas_data#") PREFIX_ATLAS = os.getenv("ATLAS_PREFIX_IRI", "http://world.eu.org/atlas_ontology#") class VirtuosoEntityStore: def __init__(self, max_cache_entries: int = 256): self.max_cache_entries = max_cache_entries self._cache: OrderedDict[str, AtlasEntity] = OrderedDict() def _cache_key(self, token: str) -> str: return str(token or "").strip().lower() def _cache_get(self, token: str) -> Optional[AtlasEntity]: key = self._cache_key(token) if not key: return None hit = self._cache.get(key) if hit is not None: self._cache.move_to_end(key) return hit def _cache_set(self, token: str, entity: AtlasEntity) -> None: key = self._cache_key(token) if not key: return self._cache[key] = entity self._cache.move_to_end(key) while len(self._cache) > self.max_cache_entries: self._cache.popitem(last=False) async def lookup(self, token: str) -> Optional[AtlasEntity]: cached = self._cache_get(token) if cached is not None: return cached entity = await self._lookup_remote(token) if entity is not None: self._cache_set(token, entity) return entity async def _lookup_remote(self, token: str) -> Optional[AtlasEntity]: literal = token.strip().lower() if not literal or not VIRTUOSO_MCP_SSE_URL: return None query = _build_sparql_query(literal) try: async with sse_client( VIRTUOSO_MCP_SSE_URL, timeout=VIRTUOSO_MCP_TIMEOUT, sse_read_timeout=VIRTUOSO_MCP_SSE_READ_TIMEOUT, ) as (read_stream, write_stream): async with ClientSession(read_stream, write_stream) as session: await session.initialize() result = await session.call_tool("sparql_query", {"input": {"query": query}}) if result.isError: return None payload = result.structuredContent or _content_to_json(result.content) if not isinstance(payload, dict): return None bindings = ( payload.get("results", {}) .get("bindings", []) if isinstance(payload.get("results"), dict) else [] ) if not bindings: return None return _entity_from_binding(bindings[0]) except Exception: return None def _content_to_json(content): if not content: return None first = content[0] text = getattr(first, "text", None) if not text: return None try: return json.loads(text) except Exception: return None def _build_sparql_query(literal: str) -> str: esc = literal.replace("\\", "\\\\").replace("\"", "\\\"") return f""" PREFIX atlas: <{PREFIX_ATLAS}> SELECT ?entity ?label ?type ?mid ?desc ?rawWd ?rawTrends WHERE {{ GRAPH <{ATLAS_GRAPH_IRI}> {{ ?entity a atlas:Entity ; atlas:canonicalLabel ?label . OPTIONAL {{ ?entity atlas:canonicalDescription ?desc . }} OPTIONAL {{ ?entity atlas:rawWikidataJson ?rawWd . }} OPTIONAL {{ ?entity atlas:rawTrendsJson ?rawTrends . }} OPTIONAL {{ ?entity atlas:hasCanonicalType ?type . }} ?entity atlas:hasIdentifier ?identifier . ?identifier atlas:identifierValue ?mid ; atlas:identifierType atlas:Mid . }} FILTER(LCASE(STR(?label)) = LCASE("{esc}")) }} LIMIT 1 """ def _entity_from_binding(binding: dict) -> AtlasEntity: label = binding.get("label", {}).get("value", "") entity_uri = binding.get("entity", {}).get("value", "") # ?type is expected to be a class node like atlas:Person entity_type = binding.get("type", {}).get("value", "unknown") if entity_type.startswith(PREFIX_ATLAS): entity_type = entity_type.split("#", 1)[-1] if entity_type.startswith("http://world.eu.org/atlas_ontology#"): entity_type = entity_type.split("#", 1)[-1] mid = binding.get("mid", {}).get("value") desc = binding.get("desc", {}).get("value") raw_wd = binding.get("rawWd", {}).get("value") raw_trends = binding.get("rawTrends", {}).get("value") atlas_id = entity_uri.split("#", 1)[-1].replace("entity_", "atlas:") if "#" in entity_uri else f"atlas:{label.strip().lower().replace(' ', '-') }" base_prov = AtlasProvenance( source="virtuoso-cache", retrieval_method="sparql", confidence=0.95, ) claims: list[AtlasClaim] = [] if mid: claims.append( AtlasClaim( claim_id=f"clm_raw_ident_mid_{mid}", subject=atlas_id, predicate="atlas:hasIdentifier", object=AtlasClaimObject(kind="identifier", id_type="mid", value=mid), layer="raw", provenance=base_prov, ) ) if entity_type and entity_type != "unknown": claims.append( AtlasClaim( claim_id="clm_drv_canonical_type", subject=atlas_id, predicate="atlas:hasCanonicalType", object=AtlasClaimObject(kind="type", value=f"atlas:{entity_type}"), layer="derived", provenance=base_prov, ) ) return AtlasEntity( atlas_id=atlas_id, canonical_label=label or entity_uri, canonical_description=desc, entity_type=entity_type or "unknown", aliases=[AtlasAlias(label=label or entity_uri)], claims=claims, raw_payload={ "source": "virtuoso", "raw": label or entity_uri, "normalized": (label or entity_uri), "wikidata": (json.loads(raw_wd) if raw_wd else {"status": "missing"}), **(json.loads(raw_trends) if raw_trends else {}), }, needs_curation=(entity_type or "unknown") == "unknown", )