resolve.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444
  1. from __future__ import annotations
  2. import hashlib
  3. import os
  4. import logging
  5. from dataclasses import dataclass
  6. from typing import Any
  7. import time
  8. import uuid
  9. import datetime
  10. import math
  11. from .atlas_model import Entity, Identifier
  12. from .atlas_store import load_entity_by_subject, save_entity_minimal
  13. from .wikidata import WikidataSearch
  14. ATLAS = "http://world.eu.org/atlas_ontology#"
  15. DEFAULT_ENDPOINT = os.getenv("ATLAS_VIRTUOSO_MCP_SSE_URL", "http://192.168.0.249:8501/mcp/sse")
  16. DEFAULT_UPDATE_ENDPOINT = os.getenv("ATLAS_VIRTUOSO_MCP_SSE_URL", DEFAULT_ENDPOINT)
  17. DEBUG_LOGS = os.getenv("ATLAS_DEBUG_LOGS", "false").lower() in {"1", "true", "yes", "on"}
  18. logger = logging.getLogger(__name__)
  19. def _hash_id(subject: str) -> str:
  20. return hashlib.sha1(subject.strip().lower().encode("utf-8")).hexdigest()[:16]
  21. def _entity_iri(atlas_id: str) -> str:
  22. return f"atlas_data:entity_{atlas_id}"
  23. async def _wikidata_lookup(subject: str, language: str = "en", limit: int = 1) -> list[dict[str, Any]]:
  24. search = WikidataSearch({"search": subject, "limit": limit, "language": language})
  25. result = await search.quick_resolve(subject, limit=limit)
  26. return result.get("results", []) or []
  27. def _candidate_text(subject: str, wd: dict[str, Any], hints: dict[str, Any] | None = None) -> str:
  28. hints = hints or {}
  29. aliases = hints.get("aliases") or []
  30. parts = [subject, wd.get("label") or "", wd.get("description") or "", " ".join(str(a) for a in aliases)]
  31. return " | ".join(part for part in parts if part)
  32. def _cosine_similarity(a: list[float] | None, b: list[float] | None) -> float:
  33. if not a or not b or len(a) != len(b):
  34. return 0.0
  35. dot = sum(x * y for x, y in zip(a, b))
  36. norm_a = math.sqrt(sum(x * x for x in a))
  37. norm_b = math.sqrt(sum(y * y for y in b))
  38. if not norm_a or not norm_b:
  39. return 0.0
  40. return dot / (norm_a * norm_b)
  41. def _infer_atlas_type(label: str | None, description: str | None) -> str:
  42. text = f"{label or ''} {description or ''}".lower()
  43. if any(k in text for k in ["president", "person", "singer", "composer", "human", "actor", "writer"]):
  44. return "atlas:Person"
  45. if any(k in text for k in ["city", "town", "village", "country", "state", "location", "place"]):
  46. return "atlas:Location"
  47. if any(k in text for k in ["company", "organization", "organisation", "institution", "foundation", "band"]):
  48. return "atlas:Organization"
  49. return "atlas:Other"
  50. def _score_wikidata_candidate(
  51. subject: str,
  52. wd: dict[str, Any],
  53. *,
  54. context: dict[str, Any] | None = None,
  55. hints: dict[str, Any] | None = None,
  56. use_embeddings: bool = False,
  57. subject_embedding: list[float] | None = None,
  58. candidate_embedding: list[float] | None = None,
  59. ) -> tuple[float, dict[str, float]]:
  60. context = context or {}
  61. hints = hints or {}
  62. score = 0.0
  63. breakdown: dict[str, float] = {}
  64. subject_norm = subject.strip().lower()
  65. label = (wd.get("label") or "").strip()
  66. description = (wd.get("description") or "").strip()
  67. label_norm = label.lower()
  68. description_norm = description.lower()
  69. if label_norm == subject_norm:
  70. score += 0.75
  71. breakdown["exact_label"] = 0.75
  72. elif subject_norm and subject_norm in label_norm:
  73. score += 0.45
  74. breakdown["partial_label"] = 0.45
  75. for alias in hints.get("aliases") or []:
  76. alias_norm = str(alias).strip().lower()
  77. if alias_norm and alias_norm == label_norm:
  78. score += 0.15
  79. breakdown["alias_match"] = 0.15
  80. break
  81. expected_type = (hints.get("expected_type") or "").strip().lower()
  82. inferred_type = _infer_atlas_type(label, description).lower()
  83. if expected_type and expected_type in inferred_type:
  84. score += 0.1
  85. breakdown["expected_type"] = 0.1
  86. realm = (context.get("realm") or "").strip().lower()
  87. if realm and realm in description_norm:
  88. score += 0.1
  89. breakdown["realm"] = 0.1
  90. if wd.get("id"):
  91. score += 0.05
  92. breakdown["has_qid"] = 0.05
  93. if use_embeddings:
  94. sim = _cosine_similarity(subject_embedding, candidate_embedding)
  95. if sim > 0:
  96. emb_score = max(0.0, min(0.25, sim * 0.25))
  97. score += emb_score
  98. breakdown["embedding_similarity"] = emb_score
  99. score = min(score, 0.99)
  100. return score, breakdown
  101. def _entity_from_wikidata(subject: str, wd: dict[str, Any]) -> Entity:
  102. atlas_id = _hash_id(subject)
  103. label = wd.get("label") or subject
  104. description = wd.get("description")
  105. qid = wd.get("id")
  106. entity_type = _infer_atlas_type(label, description)
  107. ent = Entity(
  108. id=atlas_id,
  109. label=label,
  110. description=description,
  111. type=entity_type,
  112. aliases=[subject] if subject.lower() != label.lower() else [],
  113. identifiers=[Identifier(scheme="wikidata-qid", value=qid)] if qid else [],
  114. needs_curation=True,
  115. )
  116. return ent
  117. def _flatten_exception_details(exc: BaseException) -> list[str]:
  118. parts = [f"{type(exc).__name__}: {exc}"]
  119. nested = getattr(exc, "exceptions", None)
  120. if nested:
  121. for sub in nested:
  122. parts.extend(_flatten_exception_details(sub))
  123. return parts
  124. async def _persist_entity(entity: Entity) -> None:
  125. await save_entity_minimal(entity, DEFAULT_UPDATE_ENDPOINT)
  126. async def _load_entity(subject: str) -> dict[str, Any] | None:
  127. return await load_entity_by_subject(subject, DEFAULT_ENDPOINT)
  128. def _required_confidence(mode: str, constraints: dict[str, Any]) -> float:
  129. requested = constraints.get("min_confidence")
  130. if requested is not None:
  131. return float(requested)
  132. if mode == "quick":
  133. return 0.55
  134. if mode in {"ranked", "hybrid", "llm_select"}:
  135. return 0.85
  136. if mode == "interactive":
  137. return 0.0
  138. return 0.5
  139. def _is_ambiguous_subject(subject: str, wd_candidates: list[dict[str, Any]]) -> bool:
  140. if len(wd_candidates) < 2:
  141. return False
  142. subject_norm = subject.strip().lower()
  143. labels = [(cand.get("label") or "").strip().lower() for cand in wd_candidates]
  144. exact_matches = sum(1 for label in labels if label == subject_norm)
  145. return exact_matches >= 2 or (exact_matches == 1 and any(label == subject_norm for label in labels[1:]))
  146. def _cache_can_satisfy(stored: dict[str, Any], mode: str, constraints: dict[str, Any]) -> bool:
  147. stored_confidence = float(stored.get("confidence") or 0.0)
  148. return stored_confidence >= _required_confidence(mode, constraints)
  149. def _debug_decision(
  150. *,
  151. mode: str,
  152. top_confidence: float,
  153. auto_accept_threshold: float,
  154. interactive_below_threshold: bool,
  155. required_confidence: float,
  156. used_cache: bool,
  157. cache_confidence: float | None = None,
  158. ) -> dict[str, Any]:
  159. return {
  160. "mode": mode,
  161. "top_confidence": top_confidence,
  162. "auto_accept_threshold": auto_accept_threshold,
  163. "interactive_below_threshold": interactive_below_threshold,
  164. "required_confidence": required_confidence,
  165. "used_cache": used_cache,
  166. "cache_confidence": cache_confidence,
  167. "decision": (
  168. "cache_hit"
  169. if used_cache
  170. else "resolved"
  171. if top_confidence >= auto_accept_threshold
  172. else "ambiguous_below_threshold"
  173. ),
  174. }
  175. @dataclass
  176. class ResolveService:
  177. load_entity_fn: Any = _load_entity
  178. wikidata_lookup_fn: Any = _wikidata_lookup
  179. persist_entity_fn: Any = _persist_entity
  180. async def resolve(self, *, subject: str, context: dict[str, Any] | None = None,
  181. constraints: dict[str, Any] | None = None,
  182. hints: dict[str, Any] | None = None,
  183. debug: dict[str, Any] | None = None,
  184. strategy: dict[str, Any] | None = None) -> dict[str, Any]:
  185. context = context or {}
  186. constraints = constraints or {}
  187. hints = hints or {}
  188. debug = debug or {}
  189. strategy = strategy or {}
  190. language = (context.get("language") or "en").strip() or "en"
  191. mode = (strategy.get("mode") or "quick").strip().lower() or "quick"
  192. use_embeddings = bool(strategy.get("use_embeddings", False))
  193. max_candidates = int(constraints.get("max_candidates") or 5)
  194. auto_accept_threshold = float(strategy.get("auto_accept_threshold") or 0.85)
  195. interactive_below_threshold = bool(strategy.get("interactive_below_threshold", True))
  196. required_confidence = _required_confidence(mode, constraints)
  197. try:
  198. request_id = str(uuid.uuid4())
  199. ts = datetime.datetime.now(datetime.timezone.utc).isoformat()
  200. start = time.time()
  201. subject = (subject or "").strip()
  202. if not subject:
  203. return {
  204. "status": "not_found",
  205. "entity": None,
  206. "confidence": 0.0,
  207. "candidates": [],
  208. "ambiguity": None,
  209. "resolution_path": [],
  210. "meta": {"request_id": request_id, "timestamp": ts, "duration_ms": 0},
  211. "error": None,
  212. }
  213. if DEBUG_LOGS:
  214. logger.info("resolve start subject=%s", subject)
  215. stored = await self.load_entity_fn(subject)
  216. if stored:
  217. if DEBUG_LOGS:
  218. logger.info("store hit subject=%s atlas_id=%s", subject, stored.get("atlas_id"))
  219. stored_confidence = float(stored.get("confidence") or (0.9 if not stored.get("needs_curation", False) else 0.6))
  220. if _cache_can_satisfy(stored, mode, constraints):
  221. return {
  222. "status": "resolved",
  223. "entity": {
  224. "id": stored.get("atlas_id"),
  225. "label": stored.get("label"),
  226. "type": stored.get("type"),
  227. "description": stored.get("description"),
  228. "source": None,
  229. "uri": None,
  230. "attributes": {},
  231. },
  232. "confidence": stored_confidence,
  233. "candidates": [],
  234. "ambiguity": None,
  235. "resolution_path": [
  236. {"phase": "cache", "action": "store_hit", "source": "triple_store"}
  237. ],
  238. "meta": {
  239. "request_id": request_id,
  240. "timestamp": ts,
  241. "duration_ms": int((time.time() - start) * 1000),
  242. **({"debug": _debug_decision(mode=mode, top_confidence=stored_confidence, auto_accept_threshold=auto_accept_threshold, interactive_below_threshold=interactive_below_threshold, required_confidence=required_confidence, used_cache=True, cache_confidence=stored_confidence)} if debug.get("include_explanations") else {}),
  243. },
  244. "error": None,
  245. }
  246. if DEBUG_LOGS:
  247. logger.info("cache confidence too low subject=%s mode=%s confidence=%.3f required=%.3f", subject, mode, stored_confidence, required_confidence)
  248. wd_candidates = await self.wikidata_lookup_fn(
  249. subject,
  250. language,
  251. 1 if mode == "quick" else max(1, min(max_candidates, 10)),
  252. )
  253. if not wd_candidates:
  254. if DEBUG_LOGS:
  255. logger.info("wikidata miss subject=%s mode=%s", subject, mode)
  256. return {
  257. "status": "not_found",
  258. "entity": {
  259. "id": None,
  260. "label": None,
  261. "type": None,
  262. "description": None,
  263. "source": None,
  264. "uri": None,
  265. "attributes": {},
  266. },
  267. "confidence": 0.0,
  268. "candidates": [],
  269. "ambiguity": None,
  270. "resolution_path": [
  271. {"phase": "query", "action": "wikidata_quick_resolve", "source": "remote"}
  272. ],
  273. "meta": {
  274. "request_id": request_id,
  275. "timestamp": ts,
  276. "duration_ms": int((time.time() - start) * 1000),
  277. **({"debug": _debug_decision(mode=mode, top_confidence=0.0, auto_accept_threshold=auto_accept_threshold, interactive_below_threshold=interactive_below_threshold, required_confidence=required_confidence, used_cache=False)} if debug.get("include_explanations") else {}),
  278. },
  279. "error": None,
  280. }
  281. ranked_candidates = []
  282. subject_embedding = None
  283. embedder = None
  284. if use_embeddings:
  285. embedder = WikidataSearch()
  286. subject_embedding = await embedder.embed_text(_candidate_text(subject, {"label": subject, "description": "", "aliases": []}, hints))
  287. for wd in wd_candidates:
  288. candidate_embedding = None
  289. if use_embeddings and embedder is not None:
  290. candidate_embedding = await embedder.embed_text(_candidate_text(subject, wd, hints))
  291. confidence, breakdown = _score_wikidata_candidate(
  292. subject,
  293. wd,
  294. context=context,
  295. hints=hints,
  296. use_embeddings=use_embeddings,
  297. subject_embedding=subject_embedding,
  298. candidate_embedding=candidate_embedding,
  299. )
  300. ranked_candidates.append({**wd, "confidence": confidence, "score_breakdown": breakdown})
  301. ranked_candidates.sort(key=lambda item: ((item.get("confidence") or 0.0), item.get("label") or ""), reverse=True)
  302. wd = ranked_candidates[0]
  303. entity = _entity_from_wikidata(subject, wd)
  304. if mode == "quick":
  305. wd["confidence"] = min(wd.get("confidence", 0.0), 0.6)
  306. if DEBUG_LOGS:
  307. logger.info(
  308. "wikidata hit subject=%s qid=%s atlas_id=%s type=%s",
  309. subject,
  310. wd.get("id"),
  311. entity.id,
  312. entity.type,
  313. )
  314. await self.persist_entity_fn(entity)
  315. resolution_path = [
  316. {"phase": "query", "action": "wikidata_quick_resolve", "source": "remote"},
  317. {"phase": "ranking", "action": f"mode_{mode}", "source": "resolver"},
  318. ]
  319. if use_embeddings:
  320. resolution_path.append(
  321. {
  322. "phase": "ranking",
  323. "action": "embedding_similarity",
  324. "source": "ollama",
  325. "note": "embedding similarity used to score candidate order",
  326. }
  327. )
  328. status = "ambiguous"
  329. ambiguity = {"reason": "pre-maintenance", "dimension": 0.5}
  330. if mode == "quick":
  331. status = "ambiguous"
  332. elif (wd.get("confidence") or 0.0) >= auto_accept_threshold:
  333. status = "resolved"
  334. ambiguity = None
  335. elif interactive_below_threshold:
  336. status = "ambiguous"
  337. return {
  338. "status": status,
  339. "entity": {
  340. "id": entity.id,
  341. "label": entity.label,
  342. "type": entity.type,
  343. "description": entity.description,
  344. "source": "wikidata",
  345. "uri": None,
  346. "attributes": {
  347. "wikidata_id": wd.get("id"),
  348. "alias": subject,
  349. },
  350. },
  351. "confidence": wd.get("confidence", 0.6),
  352. "candidates": [
  353. {
  354. "id": cand.get("id"),
  355. "label": cand.get("label"),
  356. "type": cand.get("type"),
  357. "source": "wikidata",
  358. "confidence": cand.get("confidence", 0.0),
  359. "score_breakdown": cand.get("score_breakdown", {}) if debug.get("include_explanations") else {},
  360. }
  361. for cand in ranked_candidates
  362. ] if mode in {"ranked", "interactive", "hybrid", "llm_select"} else [],
  363. "ambiguity": ambiguity,
  364. "resolution_path": resolution_path + [{"phase": "persistence", "action": "store_save_minimal", "source": "triple_store"}],
  365. "meta": {
  366. "request_id": request_id,
  367. "timestamp": ts,
  368. "duration_ms": int((time.time() - start) * 1000),
  369. **({"debug": _debug_decision(mode=mode, top_confidence=wd.get("confidence", 0.0), auto_accept_threshold=auto_accept_threshold, interactive_below_threshold=interactive_below_threshold, required_confidence=required_confidence, used_cache=False)} if debug.get("include_explanations") else {}),
  370. },
  371. "error": None,
  372. }
  373. except Exception as exc:
  374. detail = " | ".join(_flatten_exception_details(exc))
  375. return {
  376. "status": "error",
  377. "entity": None,
  378. "confidence": 0.0,
  379. "candidates": [],
  380. "ambiguity": None,
  381. "resolution_path": [],
  382. "meta": {
  383. "request_id": str(uuid.uuid4()),
  384. "timestamp": datetime.datetime.now(datetime.timezone.utc).isoformat(),
  385. "duration_ms": 0,
  386. },
  387. "error": {"code": "RESOLVE_FAILED", "message": detail},
  388. }