virtuoso_store.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. """Virtuoso MCP bridge for cached entity lookups."""
  2. from __future__ import annotations
  3. import json
  4. import os
  5. from collections import OrderedDict
  6. from typing import Optional
  7. from mcp import ClientSession
  8. from mcp.client.sse import sse_client
  9. from app.models import AtlasAlias, AtlasClaim, AtlasClaimObject, AtlasEntity, AtlasProvenance
  10. VIRTUOSO_MCP_SSE_URL = os.getenv("ATLAS_VIRTUOSO_MCP_SSE_URL", "http://192.168.0.249:8501/mcp/sse")
  11. VIRTUOSO_MCP_TIMEOUT = float(os.getenv("ATLAS_VIRTUOSO_MCP_TIMEOUT", "10"))
  12. VIRTUOSO_MCP_SSE_READ_TIMEOUT = float(os.getenv("ATLAS_VIRTUOSO_MCP_SSE_READ_TIMEOUT", str(60 * 5)))
  13. ATLAS_GRAPH_IRI = os.getenv("ATLAS_GRAPH_IRI", "http://world.eu.org/atlas_data#")
  14. PREFIX_ATLAS = os.getenv("ATLAS_PREFIX_IRI", "http://world.eu.org/atlas_ontology#")
  15. class VirtuosoEntityStore:
  16. def __init__(self, max_cache_entries: int = 256):
  17. self.max_cache_entries = max_cache_entries
  18. self._cache: OrderedDict[str, AtlasEntity] = OrderedDict()
  19. def _cache_key(self, token: str) -> str:
  20. return str(token or "").strip().lower()
  21. def _cache_get(self, token: str) -> Optional[AtlasEntity]:
  22. key = self._cache_key(token)
  23. if not key:
  24. return None
  25. hit = self._cache.get(key)
  26. if hit is not None:
  27. self._cache.move_to_end(key)
  28. return hit
  29. def _cache_set(self, token: str, entity: AtlasEntity) -> None:
  30. key = self._cache_key(token)
  31. if not key:
  32. return
  33. self._cache[key] = entity
  34. self._cache.move_to_end(key)
  35. while len(self._cache) > self.max_cache_entries:
  36. self._cache.popitem(last=False)
  37. async def lookup(self, token: str) -> Optional[AtlasEntity]:
  38. cached = self._cache_get(token)
  39. if cached is not None:
  40. return cached
  41. entity = await self._lookup_remote(token)
  42. if entity is not None:
  43. self._cache_set(token, entity)
  44. return entity
  45. async def _lookup_remote(self, token: str) -> Optional[AtlasEntity]:
  46. literal = token.strip().lower()
  47. if not literal or not VIRTUOSO_MCP_SSE_URL:
  48. return None
  49. query = _build_sparql_query(literal)
  50. try:
  51. async with sse_client(
  52. VIRTUOSO_MCP_SSE_URL,
  53. timeout=VIRTUOSO_MCP_TIMEOUT,
  54. sse_read_timeout=VIRTUOSO_MCP_SSE_READ_TIMEOUT,
  55. ) as (read_stream, write_stream):
  56. async with ClientSession(read_stream, write_stream) as session:
  57. await session.initialize()
  58. result = await session.call_tool("sparql_query", {"input": {"query": query}})
  59. if result.isError:
  60. return None
  61. payload = result.structuredContent or _content_to_json(result.content)
  62. if not isinstance(payload, dict):
  63. return None
  64. bindings = (
  65. payload.get("results", {})
  66. .get("bindings", [])
  67. if isinstance(payload.get("results"), dict)
  68. else []
  69. )
  70. if not bindings:
  71. return None
  72. return _entity_from_binding(bindings[0])
  73. except Exception:
  74. return None
  75. def _content_to_json(content):
  76. if not content:
  77. return None
  78. first = content[0]
  79. text = getattr(first, "text", None)
  80. if not text:
  81. return None
  82. try:
  83. return json.loads(text)
  84. except Exception:
  85. return None
  86. def _build_sparql_query(literal: str) -> str:
  87. esc = literal.replace("\\", "\\\\").replace("\"", "\\\"")
  88. return f"""
  89. PREFIX atlas: <{PREFIX_ATLAS}>
  90. SELECT ?entity ?label ?type ?mid ?desc ?rawWd ?rawTrends WHERE {{
  91. GRAPH <{ATLAS_GRAPH_IRI}> {{
  92. ?entity a atlas:Entity ;
  93. atlas:canonicalLabel ?label .
  94. OPTIONAL {{ ?entity atlas:canonicalDescription ?desc . }}
  95. OPTIONAL {{ ?entity atlas:rawWikidataJson ?rawWd . }}
  96. OPTIONAL {{ ?entity atlas:rawTrendsJson ?rawTrends . }}
  97. OPTIONAL {{
  98. ?entity atlas:hasCanonicalType ?type .
  99. }}
  100. ?entity atlas:hasIdentifier ?identifier .
  101. ?identifier atlas:identifierValue ?mid ;
  102. atlas:identifierType atlas:Mid .
  103. }}
  104. FILTER(LCASE(STR(?label)) = LCASE("{esc}"))
  105. }}
  106. LIMIT 1
  107. """
  108. def _entity_from_binding(binding: dict) -> AtlasEntity:
  109. label = binding.get("label", {}).get("value", "")
  110. entity_uri = binding.get("entity", {}).get("value", "")
  111. # ?type is expected to be a class node like atlas:Person
  112. entity_type = binding.get("type", {}).get("value", "unknown")
  113. if entity_type.startswith(PREFIX_ATLAS):
  114. entity_type = entity_type.split("#", 1)[-1]
  115. if entity_type.startswith("http://world.eu.org/atlas_ontology#"):
  116. entity_type = entity_type.split("#", 1)[-1]
  117. mid = binding.get("mid", {}).get("value")
  118. desc = binding.get("desc", {}).get("value")
  119. raw_wd = binding.get("rawWd", {}).get("value")
  120. raw_trends = binding.get("rawTrends", {}).get("value")
  121. atlas_id = entity_uri.split("#", 1)[-1].replace("entity_", "atlas:") if "#" in entity_uri else f"atlas:{label.strip().lower().replace(' ', '-') }"
  122. base_prov = AtlasProvenance(
  123. source="virtuoso-cache",
  124. retrieval_method="sparql",
  125. confidence=0.95,
  126. )
  127. claims: list[AtlasClaim] = []
  128. if mid:
  129. claims.append(
  130. AtlasClaim(
  131. claim_id=f"clm_raw_ident_mid_{mid}",
  132. subject=atlas_id,
  133. predicate="atlas:hasIdentifier",
  134. object=AtlasClaimObject(kind="identifier", id_type="mid", value=mid),
  135. layer="raw",
  136. provenance=base_prov,
  137. )
  138. )
  139. if entity_type and entity_type != "unknown":
  140. claims.append(
  141. AtlasClaim(
  142. claim_id="clm_drv_canonical_type",
  143. subject=atlas_id,
  144. predicate="atlas:hasCanonicalType",
  145. object=AtlasClaimObject(kind="type", value=f"atlas:{entity_type}"),
  146. layer="derived",
  147. provenance=base_prov,
  148. )
  149. )
  150. return AtlasEntity(
  151. atlas_id=atlas_id,
  152. canonical_label=label or entity_uri,
  153. canonical_description=desc,
  154. entity_type=entity_type or "unknown",
  155. aliases=[AtlasAlias(label=label or entity_uri)],
  156. claims=claims,
  157. raw_payload={
  158. "source": "virtuoso",
  159. "raw": label or entity_uri,
  160. "normalized": (label or entity_uri),
  161. "wikidata": (json.loads(raw_wd) if raw_wd else {"status": "missing"}),
  162. **(json.loads(raw_trends) if raw_trends else {}),
  163. },
  164. needs_curation=(entity_type or "unknown") == "unknown",
  165. )