type_classifier.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. """Canonical type classification pipeline for Atlas entities."""
  2. from __future__ import annotations
  3. import json
  4. import os
  5. from dataclasses import dataclass
  6. from datetime import datetime, timezone
  7. from typing import Optional
  8. import httpx
  9. from app.models import AtlasProvenance
  10. CANONICAL_TYPES = [
  11. "Person",
  12. "Organization",
  13. "Location",
  14. "CreativeWork",
  15. "Event",
  16. "Product",
  17. "Other",
  18. ]
  19. WIKIDATA_CLASS_MAP = {
  20. "Q5": "Person",
  21. "Q43229": "Organization",
  22. "Q17334923": "Location", # human settlement
  23. "Q515": "Location", # city
  24. "Q82794": "Location", # geographic region
  25. "Q16521": "Taxon",
  26. "Q571": "CreativeWork",
  27. "Q11424": "CreativeWork", # film
  28. "Q49848": "CreativeWork", # album
  29. "Q1656682": "Event",
  30. "Q191067": "Product",
  31. }
  32. OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
  33. OPENAI_MODEL = os.getenv("ATLAS_OPENAI_MODEL", os.getenv("OPENAI_MODEL", "gpt-4o-mini"))
  34. GROQ_API_KEY = os.getenv("GROQ_API_KEY")
  35. GROQ_MODEL = os.getenv(
  36. "ATLAS_GROQ_MODEL",
  37. os.getenv("GROQ_MODEL", "meta-llama/llama-4-scout-17b-16e-instruct"),
  38. )
  39. @dataclass
  40. class TypeClassification:
  41. canonical_type: Optional[str]
  42. provenance: Optional[AtlasProvenance]
  43. needs_curation: bool
  44. async def classify_entity_type(subject: str, resolution: dict, context: Optional[str]) -> TypeClassification:
  45. label = resolution.get("canonical_label") or subject
  46. wikidata_hit = await _classify_via_wikidata(label)
  47. if wikidata_hit is not None:
  48. return wikidata_hit
  49. llm_hit = await _classify_via_llm(subject, resolution, context)
  50. if llm_hit is not None:
  51. return llm_hit
  52. return TypeClassification(canonical_type=None, provenance=None, needs_curation=True)
  53. async def _classify_via_wikidata(label: str) -> Optional[TypeClassification]:
  54. search_params = {
  55. "action": "wbsearchentities",
  56. "search": label,
  57. "language": "en",
  58. "limit": 1,
  59. "format": "json",
  60. }
  61. try:
  62. async with httpx.AsyncClient(timeout=8) as client:
  63. search_resp = await client.get("https://www.wikidata.org/w/api.php", params=search_params)
  64. search_resp.raise_for_status()
  65. search_data = search_resp.json()
  66. if not search_data.get("search"):
  67. return None
  68. entity_id = search_data["search"][0].get("id")
  69. if not entity_id:
  70. return None
  71. data_resp = await client.get(
  72. f"https://www.wikidata.org/wiki/Special:EntityData/{entity_id}.json",
  73. params={"flavor": "dump"},
  74. )
  75. data_resp.raise_for_status()
  76. data_payload = data_resp.json()
  77. entities = data_payload.get("entities", {})
  78. entity_block = entities.get(entity_id)
  79. if not entity_block:
  80. return None
  81. claims = entity_block.get("claims", {})
  82. p31 = claims.get("P31", [])
  83. for claim in p31:
  84. mainsnak = claim.get("mainsnak", {})
  85. datavalue = mainsnak.get("datavalue", {})
  86. value = datavalue.get("value", {})
  87. wid = value.get("id")
  88. canonical = WIKIDATA_CLASS_MAP.get(wid)
  89. if canonical:
  90. prov = AtlasProvenance(
  91. source="wikidata",
  92. retrieval_method="type-classification",
  93. confidence=0.97,
  94. retrieved_at=datetime.now(timezone.utc).isoformat(),
  95. )
  96. return TypeClassification(canonical_type=canonical, provenance=prov, needs_curation=False)
  97. except Exception:
  98. return None
  99. return None
  100. async def _classify_via_llm(subject: str, resolution: dict, context: Optional[str]) -> Optional[TypeClassification]:
  101. provider = None
  102. if GROQ_API_KEY:
  103. provider = "groq"
  104. elif OPENAI_API_KEY:
  105. provider = "openai"
  106. if provider is None:
  107. return None
  108. prompt = _build_llm_prompt(subject, resolution, context)
  109. payload = {
  110. "model": GROQ_MODEL if provider == "groq" else OPENAI_MODEL,
  111. "messages": [
  112. {
  113. "role": "system",
  114. "content": (
  115. "You classify named entities into canonical Atlas types. "
  116. "Valid types: Person, Organization, Location, CreativeWork, Event, Product, Other. "
  117. "Respond with JSON: {\"type\": <type>, \"confidence\": <0-1>, \"reason\": <short>}"
  118. ),
  119. },
  120. {"role": "user", "content": prompt},
  121. ],
  122. "temperature": 0,
  123. }
  124. headers = {"Content-Type": "application/json"}
  125. url = "https://api.groq.com/openai/v1/chat/completions"
  126. if provider == "groq":
  127. headers["Authorization"] = f"Bearer {GROQ_API_KEY}"
  128. else:
  129. headers["Authorization"] = f"Bearer {OPENAI_API_KEY}"
  130. url = "https://api.openai.com/v1/chat/completions"
  131. try:
  132. async with httpx.AsyncClient(timeout=15) as client:
  133. resp = await client.post(url, json=payload, headers=headers)
  134. resp.raise_for_status()
  135. data = resp.json()
  136. choice = data.get("choices", [{}])[0]
  137. message = choice.get("message", {})
  138. content = message.get("content")
  139. if not content:
  140. return None
  141. parsed = _parse_llm_json(content)
  142. if not parsed:
  143. return None
  144. canonical_type = parsed.get("type")
  145. confidence = float(parsed.get("confidence", 0))
  146. if canonical_type not in CANONICAL_TYPES:
  147. return None
  148. needs_curation = confidence < 0.6
  149. prov = AtlasProvenance(
  150. source=f"{provider}-llm",
  151. retrieval_method="type-classification",
  152. confidence=confidence,
  153. retrieved_at=datetime.now(timezone.utc).isoformat(),
  154. )
  155. return TypeClassification(canonical_type=canonical_type, provenance=prov, needs_curation=needs_curation)
  156. except Exception:
  157. return None
  158. return None
  159. def _build_llm_prompt(subject: str, resolution: dict, context: Optional[str]) -> str:
  160. raw_type = resolution.get("type") or resolution.get("raw_type") or ""
  161. candidates = resolution.get("candidates") or []
  162. candidate_titles = ", ".join(sorted({c.get("title") for c in candidates if c.get("title")}))
  163. parts = [
  164. f"Subject: {subject}",
  165. f"Canonical label: {resolution.get('canonical_label')}",
  166. f"Raw type hints: {raw_type}",
  167. f"Candidates: {candidate_titles}",
  168. ]
  169. if context:
  170. parts.append(f"Context: {context}")
  171. parts.append(f"Return JSON with keys type/confidence/reason. Types allowed: {', '.join(CANONICAL_TYPES)}")
  172. return "\n".join(parts)
  173. def _parse_llm_json(text: str) -> Optional[dict]:
  174. text = text.strip()
  175. if text.startswith("```") and text.endswith("```"):
  176. inner = text.strip("`")
  177. if inner.lower().startswith("json"):
  178. inner = inner[4:]
  179. text = inner
  180. try:
  181. return json.loads(text)
  182. except Exception:
  183. return None