virtuoso_store.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. """Virtuoso MCP bridge for cached entity lookups."""
  2. from __future__ import annotations
  3. import json
  4. import os
  5. from collections import OrderedDict
  6. from typing import Optional
  7. from mcp import ClientSession
  8. from mcp.client.sse import sse_client
  9. from app.models import AtlasAlias, AtlasEntity, AtlasIdentifier, AtlasProvenance
  10. VIRTUOSO_MCP_SSE_URL = os.getenv("ATLAS_VIRTUOSO_MCP_SSE_URL", "http://192.168.0.249:8501/mcp/sse")
  11. VIRTUOSO_MCP_TIMEOUT = float(os.getenv("ATLAS_VIRTUOSO_MCP_TIMEOUT", "10"))
  12. VIRTUOSO_MCP_SSE_READ_TIMEOUT = float(os.getenv("ATLAS_VIRTUOSO_MCP_SSE_READ_TIMEOUT", str(60 * 5)))
  13. ATLAS_GRAPH_IRI = os.getenv("ATLAS_GRAPH_IRI", "http://world.eu.org/atlas_data#")
  14. PREFIX_ATLAS = os.getenv("ATLAS_PREFIX_IRI", "http://world.eu.org/atlas_ontology#")
  15. class VirtuosoEntityStore:
  16. def __init__(self, max_cache_entries: int = 256):
  17. self.max_cache_entries = max_cache_entries
  18. self._cache: OrderedDict[str, AtlasEntity] = OrderedDict()
  19. def _cache_key(self, token: str) -> str:
  20. return str(token or "").strip().lower()
  21. def _cache_get(self, token: str) -> Optional[AtlasEntity]:
  22. key = self._cache_key(token)
  23. if not key:
  24. return None
  25. hit = self._cache.get(key)
  26. if hit is not None:
  27. self._cache.move_to_end(key)
  28. return hit
  29. def _cache_set(self, token: str, entity: AtlasEntity) -> None:
  30. key = self._cache_key(token)
  31. if not key:
  32. return
  33. self._cache[key] = entity
  34. self._cache.move_to_end(key)
  35. while len(self._cache) > self.max_cache_entries:
  36. self._cache.popitem(last=False)
  37. async def lookup(self, token: str) -> Optional[AtlasEntity]:
  38. cached = self._cache_get(token)
  39. if cached is not None:
  40. return cached
  41. entity = await self._lookup_remote(token)
  42. if entity is not None:
  43. self._cache_set(token, entity)
  44. return entity
  45. async def _lookup_remote(self, token: str) -> Optional[AtlasEntity]:
  46. literal = token.strip().lower()
  47. if not literal or not VIRTUOSO_MCP_SSE_URL:
  48. return None
  49. query = _build_sparql_query(literal)
  50. try:
  51. async with sse_client(
  52. VIRTUOSO_MCP_SSE_URL,
  53. timeout=VIRTUOSO_MCP_TIMEOUT,
  54. sse_read_timeout=VIRTUOSO_MCP_SSE_READ_TIMEOUT,
  55. ) as (read_stream, write_stream):
  56. async with ClientSession(read_stream, write_stream) as session:
  57. await session.initialize()
  58. result = await session.call_tool("sparql_query", {"query": query})
  59. if result.isError:
  60. return None
  61. payload = result.structuredContent or _content_to_json(result.content)
  62. if not isinstance(payload, dict):
  63. return None
  64. bindings = (
  65. payload.get("results", {})
  66. .get("bindings", [])
  67. if isinstance(payload.get("results"), dict)
  68. else []
  69. )
  70. if not bindings:
  71. return None
  72. return _entity_from_binding(bindings[0])
  73. except Exception:
  74. return None
  75. def _content_to_json(content):
  76. if not content:
  77. return None
  78. first = content[0]
  79. text = getattr(first, "text", None)
  80. if not text:
  81. return None
  82. try:
  83. return json.loads(text)
  84. except Exception:
  85. return None
  86. def _build_sparql_query(literal: str) -> str:
  87. esc = literal.replace("\\", "\\\\").replace("\"", "\\\"")
  88. return f"""
  89. PREFIX atlas: <{PREFIX_ATLAS}>
  90. SELECT ?entity ?label ?type ?mid WHERE {{
  91. GRAPH <{ATLAS_GRAPH_IRI}> {{
  92. ?entity atlas:canonicalLabel ?label .
  93. OPTIONAL {{ ?entity atlas:entityType ?type. }}
  94. OPTIONAL {{
  95. ?entity atlas:hasExternalIdentifier ?identifier .
  96. ?identifier atlas:identifierType "mid" .
  97. ?identifier atlas:identifierValue ?mid .
  98. }}
  99. }}
  100. FILTER(LCASE(?label) = \"{esc}\")
  101. }}
  102. LIMIT 1
  103. """
  104. def _entity_from_binding(binding: dict) -> AtlasEntity:
  105. label = binding.get("label", {}).get("value", "")
  106. entity_uri = binding.get("entity", {}).get("value", "")
  107. entity_type = binding.get("type", {}).get("value", "unknown")
  108. mid = binding.get("mid", {}).get("value")
  109. identifiers = []
  110. if mid:
  111. identifiers.append(AtlasIdentifier(value=mid, source="virtuoso", identifier_type="mid"))
  112. provenance = [
  113. AtlasProvenance(
  114. source="virtuoso-cache",
  115. retrieval_method="sparql",
  116. confidence=0.95,
  117. )
  118. ]
  119. return AtlasEntity(
  120. atlas_id=entity_uri or f"atlas:{label.strip().lower().replace(' ', '-')}",
  121. canonical_label=label or entity_uri,
  122. entity_type=entity_type or "unknown",
  123. aliases=[AtlasAlias(label=label or entity_uri)],
  124. identifiers=identifiers,
  125. provenance=provenance,
  126. raw_payload={"source": "virtuoso", "binding": binding},
  127. )