type_classifier.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223
  1. """Canonical type classification pipeline for Atlas entities."""
  2. from __future__ import annotations
  3. import json
  4. import os
  5. from dataclasses import dataclass
  6. from datetime import datetime, timezone
  7. from typing import Optional
  8. import httpx
  9. from app.models import AtlasProvenance
  10. from .wikidata_type_reasoner import infer_atlas_type_from_p31
  11. CANONICAL_TYPES = [
  12. "Person",
  13. "Organization",
  14. "Location",
  15. "CreativeWork",
  16. "Event",
  17. "Product",
  18. "Other",
  19. ]
  20. WIKIDATA_CLASS_MAP = {
  21. "Q5": "Person",
  22. "Q43229": "Organization",
  23. "Q17334923": "Location", # human settlement
  24. "Q515": "Location", # city
  25. "Q82794": "Location", # geographic region
  26. "Q16521": "Taxon",
  27. "Q571": "CreativeWork",
  28. "Q11424": "CreativeWork", # film
  29. "Q49848": "CreativeWork", # album
  30. "Q1656682": "Event",
  31. "Q191067": "Product",
  32. }
  33. OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
  34. OPENAI_MODEL = os.getenv("ATLAS_OPENAI_MODEL", os.getenv("OPENAI_MODEL", "gpt-4o-mini"))
  35. GROQ_API_KEY = os.getenv("GROQ_API_KEY")
  36. GROQ_MODEL = os.getenv(
  37. "ATLAS_GROQ_MODEL",
  38. os.getenv("GROQ_MODEL", "meta-llama/llama-4-scout-17b-16e-instruct"),
  39. )
  40. TYPE_CLASSIFIER_PROVIDER = os.getenv("ATLAS_TYPE_CLASSIFIER_PROVIDER", "auto")
  41. TYPE_CLASSIFIER_MODEL = os.getenv("ATLAS_TYPE_CLASSIFIER_MODEL")
  42. @dataclass
  43. class TypeClassification:
  44. canonical_type: Optional[str]
  45. provenance: Optional[AtlasProvenance]
  46. needs_curation: bool
  47. async def classify_entity_type(subject: str, resolution: dict, context: Optional[str]) -> TypeClassification:
  48. label = resolution.get("canonical_label") or subject
  49. wikidata_hit = await _classify_via_wikidata(label)
  50. if wikidata_hit is not None:
  51. return wikidata_hit
  52. llm_hit = await _classify_via_llm(subject, resolution, context)
  53. if llm_hit is not None:
  54. return llm_hit
  55. return TypeClassification(canonical_type=None, provenance=None, needs_curation=True)
  56. async def _classify_via_wikidata(label: str) -> Optional[TypeClassification]:
  57. search_params = {
  58. "action": "wbsearchentities",
  59. "search": label,
  60. "language": "en",
  61. "limit": 1,
  62. "format": "json",
  63. }
  64. try:
  65. async with httpx.AsyncClient(timeout=8) as client:
  66. search_resp = await client.get("https://www.wikidata.org/w/api.php", params=search_params)
  67. search_resp.raise_for_status()
  68. search_data = search_resp.json()
  69. if not search_data.get("search"):
  70. return None
  71. entity_id = search_data["search"][0].get("id")
  72. if not entity_id:
  73. return None
  74. data_resp = await client.get(
  75. f"https://www.wikidata.org/wiki/Special:EntityData/{entity_id}.json",
  76. params={"flavor": "dump"},
  77. )
  78. data_resp.raise_for_status()
  79. data_payload = data_resp.json()
  80. entities = data_payload.get("entities", {})
  81. entity_block = entities.get(entity_id)
  82. if not entity_block:
  83. return None
  84. claims = entity_block.get("claims", {})
  85. p31 = claims.get("P31", [])
  86. qids: list[str] = []
  87. for claim in p31:
  88. mainsnak = claim.get("mainsnak", {})
  89. datavalue = mainsnak.get("datavalue", {})
  90. value = datavalue.get("value", {})
  91. wid = value.get("id")
  92. if wid:
  93. qids.append(wid)
  94. canonical = infer_atlas_type_from_p31(tuple(qids)) if qids else None
  95. if not canonical:
  96. for wid in qids:
  97. canonical = WIKIDATA_CLASS_MAP.get(wid)
  98. if canonical:
  99. break
  100. if canonical:
  101. prov = AtlasProvenance(
  102. source="wikidata",
  103. retrieval_method="type-classification",
  104. confidence=0.97,
  105. retrieved_at=datetime.now(timezone.utc).isoformat(),
  106. )
  107. return TypeClassification(canonical_type=canonical, provenance=prov, needs_curation=False)
  108. except Exception:
  109. return None
  110. return None
  111. async def _classify_via_llm(subject: str, resolution: dict, context: Optional[str]) -> Optional[TypeClassification]:
  112. provider = None
  113. if TYPE_CLASSIFIER_PROVIDER == "groq" and GROQ_API_KEY:
  114. provider = "groq"
  115. elif TYPE_CLASSIFIER_PROVIDER == "openai" and OPENAI_API_KEY:
  116. provider = "openai"
  117. elif GROQ_API_KEY:
  118. provider = "groq"
  119. elif OPENAI_API_KEY:
  120. provider = "openai"
  121. if provider is None:
  122. return None
  123. prompt = _build_llm_prompt(subject, resolution, context)
  124. payload = {
  125. "model": TYPE_CLASSIFIER_MODEL or (GROQ_MODEL if provider == "groq" else OPENAI_MODEL),
  126. "messages": [
  127. {
  128. "role": "system",
  129. "content": (
  130. "You classify named entities into canonical Atlas types. "
  131. "Valid types: Person, Organization, Location, CreativeWork, Event, Product, Other. "
  132. "Respond with JSON: {\"type\": <type>, \"confidence\": <0-1>, \"reason\": <short>}"
  133. ),
  134. },
  135. {"role": "user", "content": prompt},
  136. ],
  137. "temperature": 0,
  138. }
  139. headers = {"Content-Type": "application/json"}
  140. url = "https://api.groq.com/openai/v1/chat/completions"
  141. if provider == "groq":
  142. headers["Authorization"] = f"Bearer {GROQ_API_KEY}"
  143. else:
  144. headers["Authorization"] = f"Bearer {OPENAI_API_KEY}"
  145. url = "https://api.openai.com/v1/chat/completions"
  146. try:
  147. async with httpx.AsyncClient(timeout=15) as client:
  148. resp = await client.post(url, json=payload, headers=headers)
  149. resp.raise_for_status()
  150. data = resp.json()
  151. choice = data.get("choices", [{}])[0]
  152. message = choice.get("message", {})
  153. content = message.get("content")
  154. if not content:
  155. return None
  156. parsed = _parse_llm_json(content)
  157. if not parsed:
  158. return None
  159. canonical_type = parsed.get("type")
  160. confidence = float(parsed.get("confidence", 0))
  161. if canonical_type not in CANONICAL_TYPES:
  162. return None
  163. needs_curation = confidence < 0.6
  164. prov = AtlasProvenance(
  165. source=f"{provider}-llm",
  166. retrieval_method="type-classification",
  167. confidence=confidence,
  168. retrieved_at=datetime.now(timezone.utc).isoformat(),
  169. provider=provider,
  170. model=TYPE_CLASSIFIER_MODEL or (GROQ_MODEL if provider == "groq" else OPENAI_MODEL),
  171. )
  172. return TypeClassification(canonical_type=canonical_type, provenance=prov, needs_curation=needs_curation)
  173. except Exception:
  174. return None
  175. return None
  176. def _build_llm_prompt(subject: str, resolution: dict, context: Optional[str]) -> str:
  177. raw_type = resolution.get("type") or resolution.get("raw_type") or ""
  178. candidates = resolution.get("candidates") or []
  179. candidate_titles = ", ".join(sorted({c.get("title") for c in candidates if c.get("title")}))
  180. parts = [
  181. f"Subject: {subject}",
  182. f"Canonical label: {resolution.get('canonical_label')}",
  183. f"Raw type hints: {raw_type}",
  184. f"Candidates: {candidate_titles}",
  185. ]
  186. if context:
  187. parts.append(f"Context: {context}")
  188. parts.append(f"Return JSON with keys type/confidence/reason. Types allowed: {', '.join(CANONICAL_TYPES)}")
  189. return "\n".join(parts)
  190. def _parse_llm_json(text: str) -> Optional[dict]:
  191. text = text.strip()
  192. if text.startswith("```") and text.endswith("```"):
  193. inner = text.strip("`")
  194. if inner.lower().startswith("json"):
  195. inner = inner[4:]
  196. text = inner
  197. try:
  198. return json.loads(text)
  199. except Exception:
  200. return None