test_atlas_contracts.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. import pytest
  2. import app.atlas as atlas_module
  3. from app.atlas import enrich_entity, resolve_entity
  4. from app.models import AtlasAlias, AtlasClaim, AtlasClaimObject, AtlasEntity, AtlasProvenance
  5. from app.type_classifier import TypeClassification
  6. @pytest.mark.anyio
  7. async def test_resolve_entity_returns_canonical_structure():
  8. entity = await resolve_entity("Trump")
  9. assert entity.atlas_id.startswith("atlas:")
  10. assert len(entity.atlas_id) > 10
  11. assert entity.canonical_label
  12. assert entity.aliases[0].label.lower() == "trump" or entity.aliases[0].label.lower() == "donald trump"
  13. assert entity.claims
  14. assert entity.raw_payload["raw"] == "Trump"
  15. @pytest.mark.anyio
  16. async def test_enrich_entity_returns_dataset_shape():
  17. entity = await resolve_entity("Trump")
  18. result = enrich_entity(entity, constraints={"type": "person"}, depth=2)
  19. assert result.seed_entity.atlas_id == entity.atlas_id
  20. assert result.query_context == {"type": "person"}
  21. assert result.depth == 2
  22. assert result.related_entities == []
  23. def test_internal_models_support_identity_and_provenance():
  24. entity = AtlasEntity(
  25. atlas_id="atlas:abcd1234abcd1234",
  26. canonical_label="Donald Trump",
  27. entity_type="person",
  28. aliases=[AtlasAlias(label="Trump")],
  29. claims=[
  30. AtlasClaim(
  31. claim_id="clm_raw_ident_qid_Q22686",
  32. subject="atlas:abcd1234abcd1234",
  33. predicate="atlas:hasIdentifier",
  34. object=AtlasClaimObject(kind="identifier", id_type="qid", value="Q22686"),
  35. layer="raw",
  36. provenance=AtlasProvenance(source="google-trends", retrieval_method="entity-resolution", confidence=0.93),
  37. )
  38. ],
  39. )
  40. assert entity.atlas_id == "atlas:abcd1234abcd1234"
  41. assert entity.aliases[0].label == "Trump"
  42. assert entity.claims[0].object.value == "Q22686"
  43. assert entity.claims[0].provenance.source == "google-trends"
  44. @pytest.mark.anyio
  45. async def test_resolve_entity_passes_context_to_classifier(monkeypatch):
  46. captured = {}
  47. async def fake_classifier(subject, resolution, context):
  48. captured["context"] = context
  49. return TypeClassification(canonical_type="Person", provenance=None, needs_curation=False)
  50. def fake_trends(subject):
  51. return {
  52. "canonical_label": subject,
  53. "normalized": subject,
  54. "mid": None,
  55. "type": "Person",
  56. "source": "resolver",
  57. "retrieved_at": "2026-04-03T00:00:00Z",
  58. "candidates": [],
  59. "raw": subject,
  60. }
  61. writes = []
  62. async def fake_write(entity):
  63. writes.append(entity)
  64. return {"status": "ok"}
  65. monkeypatch.setattr("app.atlas.classify_entity_type", fake_classifier)
  66. monkeypatch.setattr("app.atlas.resolve_entity_via_trends", fake_trends)
  67. monkeypatch.setattr(atlas_module._storage, "write_entity", fake_write)
  68. entity = await resolve_entity("Sample", context="news paragraph")
  69. assert captured["context"] == "news paragraph"
  70. assert entity.entity_type == "Person"
  71. assert writes and writes[0].canonical_label == "Sample"
  72. @pytest.mark.anyio
  73. async def test_resolve_entity_persists_cached_hits(monkeypatch):
  74. cached_entity = AtlasEntity(atlas_id="atlas:x", canonical_label="Cached Entity")
  75. monkeypatch.setattr("app.atlas._entity_cache.get", lambda token: cached_entity)
  76. writes = []
  77. async def fake_write(entity):
  78. writes.append(entity)
  79. return {"status": "ok"}
  80. monkeypatch.setattr(atlas_module._storage, "write_entity", fake_write)
  81. entity = await resolve_entity("Cached Entity")
  82. assert entity is cached_entity
  83. assert writes and writes[0] is cached_entity
  84. @pytest.mark.anyio
  85. async def test_resolve_entity_marks_needs_curation(monkeypatch):
  86. async def fake_classifier(subject, resolution, context):
  87. return TypeClassification(canonical_type=None, provenance=None, needs_curation=True)
  88. def fake_trends(subject):
  89. return {
  90. "canonical_label": subject,
  91. "normalized": subject,
  92. "mid": None,
  93. "type": "Unknown",
  94. "source": "resolver",
  95. "retrieved_at": "2026-04-03T00:00:00Z",
  96. "candidates": [],
  97. "raw": subject,
  98. }
  99. monkeypatch.setattr("app.atlas.classify_entity_type", fake_classifier)
  100. monkeypatch.setattr("app.atlas.resolve_entity_via_trends", fake_trends)
  101. entity = await resolve_entity("Mysterious")
  102. assert entity.needs_curation is True