test_storage_service.py 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. import pytest
  2. from app.models import AtlasAlias, AtlasEntity, AtlasIdentifier, AtlasProvenance
  3. from app.storage_service import AtlasStorageService, entity_iri
  4. @pytest.mark.anyio
  5. async def test_write_entity_uses_batch_insert():
  6. calls = []
  7. async def fake_call(tool, payload):
  8. calls.append((tool, payload))
  9. return {"ok": True}
  10. svc = AtlasStorageService(call_tool=fake_call)
  11. entity = AtlasEntity(
  12. atlas_id="atlas:mid:/m/0cqt90",
  13. canonical_label="Donald Trump",
  14. canonical_description="45th and 47th U.S. President",
  15. entity_type="Person",
  16. aliases=[AtlasAlias(label="Donald Trump")],
  17. identifiers=[AtlasIdentifier(value="/m/0cqt90", source="google", identifier_type="mid")],
  18. provenance=[AtlasProvenance(source="google", retrieval_method="trends-resolution", confidence=0.9)],
  19. )
  20. result = await svc.write_entity(entity)
  21. assert result["status"] == "ok"
  22. assert calls[0][0] == "batch_insert"
  23. assert "ttl" in calls[0][1]
  24. @pytest.mark.anyio
  25. async def test_read_entity_claims_uses_sparql_query():
  26. calls = []
  27. async def fake_call(tool, payload):
  28. calls.append((tool, payload))
  29. return {"results": {"bindings": []}}
  30. svc = AtlasStorageService(call_tool=fake_call)
  31. result = await svc.read_entity_claims("atlas:mid:/m/0cqt90")
  32. assert result["status"] == "ok"
  33. assert calls[0][0] == "sparql_query"
  34. assert entity_iri("atlas:mid:/m/0cqt90") in calls[0][1]["query"]
  35. assert 'FILTER(?status = "active")' in calls[0][1]["query"]
  36. @pytest.mark.anyio
  37. async def test_read_entity_claims_include_superseded_removes_filter():
  38. calls = []
  39. async def fake_call(tool, payload):
  40. calls.append((tool, payload))
  41. return {"results": {"bindings": []}}
  42. svc = AtlasStorageService(call_tool=fake_call)
  43. result = await svc.read_entity_claims("atlas:mid:/m/0cqt90", include_superseded=True)
  44. assert result["status"] == "ok"
  45. assert calls[0][0] == "sparql_query"
  46. assert 'FILTER(?status = "active")' not in calls[0][1]["query"]
  47. @pytest.mark.anyio
  48. async def test_write_entity_unfinished_on_failure():
  49. async def fake_call(tool, payload):
  50. raise RuntimeError("backend down")
  51. svc = AtlasStorageService(call_tool=fake_call)
  52. entity = AtlasEntity(atlas_id="atlas:x", canonical_label="X")
  53. result = await svc.write_entity(entity)
  54. assert result["status"] == "unfinished"
  55. assert "backend down" in result["error"]