storage_service.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231
  1. """Atlas persistence/read service via virtuoso-mcp (MCP transport).
  2. We intentionally use the MCP SSE transport ("/mcp/sse") to match the standard across
  3. our MCP servers and avoid legacy direct "/rpc" calls.
  4. """
  5. from __future__ import annotations
  6. import json
  7. import logging
  8. import os
  9. import time
  10. from typing import Any, Awaitable, Callable
  11. from mcp import ClientSession
  12. from mcp.client.sse import sse_client
  13. from app.models import AtlasEntity
  14. from app.triple_export import entity_to_turtle
  15. logger = logging.getLogger(__name__)
  16. ATLAS_GRAPH_IRI = os.getenv("ATLAS_GRAPH_IRI", "http://world.eu.org/atlas_data#")
  17. VIRTUOSO_MCP_SSE_URL = os.getenv("ATLAS_VIRTUOSO_MCP_SSE_URL", "http://192.168.0.249:8501/mcp/sse")
  18. VIRTUOSO_MCP_TIMEOUT = float(os.getenv("ATLAS_VIRTUOSO_MCP_TIMEOUT", "10"))
  19. VIRTUOSO_MCP_SSE_READ_TIMEOUT = float(os.getenv("ATLAS_VIRTUOSO_MCP_SSE_READ_TIMEOUT", str(60 * 5)))
  20. CallToolFn = Callable[[str, dict[str, Any]], Awaitable[dict[str, Any]]]
  21. def _safe_fragment(value: str) -> str:
  22. value = (value or "").strip().lower()
  23. out = []
  24. for ch in value:
  25. if ch.isalnum() or ch in ["_", "-"]:
  26. out.append(ch)
  27. else:
  28. out.append("_")
  29. frag = "".join(out).strip("_")
  30. return frag or "entity"
  31. def entity_iri(entity_id: str) -> str:
  32. return f"http://world.eu.org/atlas_data#entity_{_safe_fragment(entity_id)}"
  33. class AtlasStorageService:
  34. def __init__(self, call_tool: CallToolFn | None = None):
  35. # Tests can inject a fake transport; production uses the MCP session client.
  36. self._call_tool_override = call_tool
  37. self._tool_cache: dict[str, tuple[float, dict[str, Any]]] = {}
  38. self._tool_cache_ttl_seconds = float(os.getenv("ATLAS_VIRTUOSO_CALL_CACHE_TTL", "30"))
  39. def _cache_key(self, tool_name: str, payload: dict[str, Any]) -> str:
  40. # Stable keying keeps equivalent tool calls from duplicating work.
  41. return f"{tool_name}:{json.dumps(payload, sort_keys=True, separators=(',', ':'))}"
  42. def _cache_get(self, key: str) -> dict[str, Any] | None:
  43. item = self._tool_cache.get(key)
  44. if not item:
  45. return None
  46. expires_at, value = item
  47. if expires_at < time.time():
  48. self._tool_cache.pop(key, None)
  49. return None
  50. return value
  51. def _cache_set(self, key: str, value: dict[str, Any]) -> None:
  52. self._tool_cache[key] = (time.time() + self._tool_cache_ttl_seconds, value)
  53. async def _call_tool(self, tool_name: str, payload: dict[str, Any], *, cache_result: bool = True) -> dict[str, Any]:
  54. # Cache read-heavy calls, but let write paths pass through untouched.
  55. cache_key = self._cache_key(tool_name, payload)
  56. if cache_result:
  57. cached = self._cache_get(cache_key)
  58. if cached is not None:
  59. return cached
  60. if self._call_tool_override:
  61. result = await self._call_tool_override(tool_name, payload)
  62. if cache_result:
  63. self._cache_set(cache_key, result)
  64. return result
  65. try:
  66. async with sse_client(
  67. VIRTUOSO_MCP_SSE_URL,
  68. timeout=VIRTUOSO_MCP_TIMEOUT,
  69. sse_read_timeout=VIRTUOSO_MCP_SSE_READ_TIMEOUT,
  70. ) as (read_stream, write_stream):
  71. async with ClientSession(read_stream, write_stream) as session:
  72. await session.initialize()
  73. result = await session.call_tool(tool_name, {"input": payload})
  74. if result.isError:
  75. raise RuntimeError(f"Tool {tool_name} failed: {result.error}")
  76. data = result.structuredContent if result.structuredContent is not None else result.content
  77. if cache_result and isinstance(data, dict):
  78. self._cache_set(cache_key, data)
  79. return data
  80. except Exception as exc:
  81. raise RuntimeError(f"Virtuoso MCP call failed for {tool_name}: {exc}")
  82. async def write_entity(self, entity: AtlasEntity) -> dict[str, Any]:
  83. # Turn an Atlas entity into Turtle, then hand it to Virtuoso in one insert.
  84. ttl = entity_to_turtle(entity)
  85. try:
  86. result = await self._call_tool(
  87. "batch_insert",
  88. {
  89. "ttl": ttl,
  90. "graph": ATLAS_GRAPH_IRI,
  91. },
  92. cache_result=False,
  93. )
  94. return {
  95. "status": "ok",
  96. "graph": ATLAS_GRAPH_IRI,
  97. "entity_id": entity.atlas_id,
  98. "result": result,
  99. }
  100. except Exception as exc:
  101. logger.warning(
  102. "Atlas persistence failed for %s into %s: %s",
  103. entity.atlas_id,
  104. ATLAS_GRAPH_IRI,
  105. exc,
  106. )
  107. return {
  108. "status": "unfinished",
  109. "message": "Persistence path not fully available yet",
  110. "error": str(exc),
  111. "entity_id": entity.atlas_id,
  112. }
  113. async def read_entity_claims(self, entity_id: str, include_superseded: bool = False) -> dict[str, Any]:
  114. # Pull the entity's claim graph, with active claims by default.
  115. iri = entity_iri(entity_id)
  116. status_filter = "" if include_superseded else 'FILTER(?status = "active")'
  117. query = f"""
  118. PREFIX atlas: <http://world.eu.org/atlas_ontology#>
  119. SELECT ?entity ?label ?canonType ?claim ?pred ?objIri ?objLit ?idVal ?idType ?layer ?status ?prov ?src ?method ?conf ?ts
  120. WHERE {{
  121. VALUES ?entity {{ <{iri}> }}
  122. ?entity a atlas:Entity ;
  123. atlas:canonicalLabel ?label ;
  124. atlas:hasCanonicalType ?canonType ;
  125. atlas:hasClaim ?claim .
  126. ?claim atlas:claimSubjectIri ?entity ;
  127. atlas:claimPredicate ?pred ;
  128. atlas:claimLayer ?layer ;
  129. atlas:claimStatus ?status .
  130. OPTIONAL {{ ?claim atlas:claimObjectIri ?objIri . }}
  131. OPTIONAL {{ ?claim atlas:claimObjectLiteral ?objLit . }}
  132. OPTIONAL {{ ?objIri atlas:identifierValue ?idVal . }}
  133. OPTIONAL {{ ?objIri atlas:identifierType ?idType . }}
  134. OPTIONAL {{
  135. ?claim atlas:hasProvenance ?prov .
  136. ?prov atlas:provenanceSource ?src .
  137. OPTIONAL {{ ?prov atlas:retrievalMethod ?method . }}
  138. OPTIONAL {{ ?prov atlas:confidence ?conf . }}
  139. OPTIONAL {{ ?prov atlas:retrievedAt ?ts . }}
  140. }}
  141. {status_filter}
  142. }}
  143. ORDER BY ?claim
  144. """
  145. try:
  146. result = await self._call_tool("sparql_query", {"query": query})
  147. return {
  148. "status": "ok",
  149. "entity_id": entity_id,
  150. "query": query,
  151. "result": result,
  152. }
  153. except Exception as exc:
  154. return {
  155. "status": "unfinished",
  156. "message": "Read path not fully available yet",
  157. "error": str(exc),
  158. "entity_id": entity_id,
  159. "query": query,
  160. }
  161. async def sparql_update(self, query: str) -> dict[str, Any]:
  162. # Write raw SPARQL when a higher-level helper would just get in the way.
  163. return await self._call_tool("sparql_update", {"query": query}, cache_result=False)
  164. async def supersede_claims(self, claim_iris: list[str]) -> None:
  165. if not claim_iris:
  166. return
  167. values = " ".join(f"<{uri}>" for uri in claim_iris)
  168. query = f"""
  169. PREFIX atlas: <http://world.eu.org/atlas_ontology#>
  170. WITH <{ATLAS_GRAPH_IRI}>
  171. DELETE {{ ?claim atlas:claimStatus ?old . }}
  172. INSERT {{ ?claim atlas:claimStatus "superseded" }}
  173. WHERE {{
  174. VALUES ?claim {{ {values} }}
  175. OPTIONAL {{ ?claim atlas:claimStatus ?old . }}
  176. }}
  177. """
  178. await self.sparql_update(query)
  179. async def replace_entity_core(self, entity_id: str, *, canonical_label: str, canonical_description: str | None, canonical_type: str | None) -> None:
  180. # Replace the entity's canonical fields without disturbing its claims.
  181. iri = entity_iri(entity_id)
  182. desc_insert = f' <{iri}> atlas:canonicalDescription "{canonical_description.replace("\\", "\\\\").replace("\"", "\\\"")}" .\n' if canonical_description else ""
  183. type_insert = f" <{iri}> atlas:hasCanonicalType atlas:{canonical_type} .\n" if canonical_type else ""
  184. query = f"""
  185. PREFIX atlas: <http://world.eu.org/atlas_ontology#>
  186. WITH <{ATLAS_GRAPH_IRI}>
  187. DELETE {{
  188. <{iri}> atlas:canonicalLabel ?oldLabel .
  189. <{iri}> atlas:canonicalDescription ?oldDesc .
  190. <{iri}> atlas:hasCanonicalType ?oldType .
  191. }}
  192. INSERT {{
  193. <{iri}> atlas:canonicalLabel "{canonical_label.replace("\\", "\\\\").replace("\"", "\\\"")}" .
  194. {desc_insert}{type_insert}}}
  195. WHERE {{
  196. OPTIONAL {{ <{iri}> atlas:canonicalLabel ?oldLabel . }}
  197. OPTIONAL {{ <{iri}> atlas:canonicalDescription ?oldDesc . }}
  198. OPTIONAL {{ <{iri}> atlas:hasCanonicalType ?oldType . }}
  199. }}
  200. """
  201. await self.sparql_update(query)