atlas.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. """Atlas semantic core for entity resolution and enrichment."""
  2. from __future__ import annotations
  3. from app.cache import EntityCache
  4. from app.entity_normalize import normalize_entity
  5. from app.models import (
  6. AtlasAlias,
  7. AtlasEntity,
  8. AtlasEnrichmentDataset,
  9. AtlasIdentifier,
  10. AtlasProvenance,
  11. )
  12. from app.trends_resolution import resolve_entity_via_trends
  13. from app.type_classifier import TypeClassification, classify_entity_type
  14. from app.storage_service import AtlasStorageService
  15. from app.virtuoso_store import VirtuosoEntityStore
  16. from app.wikidata_lookup import lookup_wikidata
  17. _entity_cache = EntityCache(max_entries=512)
  18. _virtuoso_store = VirtuosoEntityStore(max_cache_entries=256)
  19. _storage = AtlasStorageService()
  20. async def resolve_entity(subject: str, context: str | None = None) -> AtlasEntity:
  21. normalized = normalize_entity(subject)
  22. token = normalized.strip().lower()
  23. cached = _entity_cache.get(token)
  24. if cached is not None:
  25. try:
  26. await _storage.write_entity(cached)
  27. except Exception:
  28. pass
  29. return cached
  30. virt_hit = await _virtuoso_store.lookup(token)
  31. if virt_hit is not None:
  32. # Make the returned raw payload reflect the original caller input
  33. # (so tests and UI/debug output stay stable).
  34. if isinstance(virt_hit.raw_payload, dict):
  35. virt_hit.raw_payload.setdefault("source", "virtuoso")
  36. virt_hit.raw_payload["raw"] = subject
  37. virt_hit.raw_payload["normalized"] = normalized
  38. _entity_cache.store(virt_hit, extra_tokens=[subject, normalized])
  39. try:
  40. await _storage.write_entity(virt_hit)
  41. except Exception:
  42. pass
  43. return virt_hit
  44. resolution = resolve_entity_via_trends(subject)
  45. classification = await classify_entity_type(subject, resolution, context)
  46. wikidata = await lookup_wikidata(subject)
  47. entity = _entity_from_resolution(subject, resolution, classification, wikidata)
  48. _entity_cache.store(entity, extra_tokens=[subject, normalized])
  49. try:
  50. await _storage.write_entity(entity)
  51. except Exception:
  52. pass
  53. return entity
  54. def _entity_from_resolution(subject: str, resolution: dict, classification: TypeClassification, wikidata: dict | None = None) -> AtlasEntity:
  55. canonical_label = (
  56. resolution.get("canonical_label")
  57. or resolution.get("normalized")
  58. or subject.strip()
  59. )
  60. atlas_id = resolution.get("mid")
  61. if atlas_id:
  62. atlas_id = f"atlas:mid:{atlas_id.strip()}"
  63. else:
  64. slug = canonical_label.strip().lower().replace(" ", "-") or "entity"
  65. atlas_id = f"atlas:{slug}"
  66. identifiers = []
  67. mid = resolution.get("mid")
  68. if mid:
  69. identifiers.append(
  70. AtlasIdentifier(value=mid, source="google", identifier_type="mid")
  71. )
  72. if wikidata and wikidata.get("qid"):
  73. identifiers.append(
  74. AtlasIdentifier(value=wikidata["qid"], source="wikidata", identifier_type="qid")
  75. )
  76. provenance = [
  77. AtlasProvenance(
  78. source=resolution.get("source") or "resolver",
  79. retrieval_method="trends-resolution",
  80. confidence=0.9 if resolution.get("mid") else 0.3,
  81. retrieved_at=resolution.get("resolved_at"),
  82. )
  83. ]
  84. if classification.provenance:
  85. provenance.append(classification.provenance)
  86. if wikidata and wikidata.get("qid"):
  87. provenance.append(
  88. AtlasProvenance(
  89. source="wikidata",
  90. retrieval_method="wbsearchentities + entitydata",
  91. confidence=0.99,
  92. retrieved_at=wikidata.get("retrieved_at"),
  93. )
  94. )
  95. canonical_type = (
  96. classification.canonical_type
  97. or resolution.get("type")
  98. or "unknown"
  99. )
  100. payload = dict(resolution)
  101. if wikidata:
  102. payload["wikidata"] = {
  103. "status": "ok",
  104. "qid": wikidata.get("qid"),
  105. "label": wikidata.get("label"),
  106. "description": wikidata.get("description"),
  107. "retrieved_at": wikidata.get("retrieved_at"),
  108. }
  109. else:
  110. payload["wikidata"] = {"status": "missing"}
  111. return AtlasEntity(
  112. atlas_id=atlas_id,
  113. canonical_label=canonical_label,
  114. canonical_description=(wikidata or {}).get("description"),
  115. entity_type=canonical_type,
  116. aliases=[AtlasAlias(label=subject.strip() or canonical_label)],
  117. identifiers=identifiers,
  118. provenance=provenance,
  119. raw_payload=payload,
  120. needs_curation=classification.needs_curation,
  121. )
  122. def enrich_entity(entity: AtlasEntity, constraints=None, depth: int = 1) -> AtlasEnrichmentDataset:
  123. return AtlasEnrichmentDataset(
  124. seed_entity=entity,
  125. related_entities=[],
  126. query_context=constraints or {},
  127. depth=depth,
  128. )