"""Canonical type classification pipeline for Atlas entities.""" from __future__ import annotations import json import os from dataclasses import dataclass from datetime import datetime, timezone from typing import Optional import httpx from app.models import AtlasProvenance CANONICAL_TYPES = [ "Person", "Organization", "Location", "CreativeWork", "Event", "Product", "Other", ] WIKIDATA_CLASS_MAP = { "Q5": "Person", "Q43229": "Organization", "Q17334923": "Location", # human settlement "Q515": "Location", # city "Q82794": "Location", # geographic region "Q16521": "Taxon", "Q571": "CreativeWork", "Q11424": "CreativeWork", # film "Q49848": "CreativeWork", # album "Q1656682": "Event", "Q191067": "Product", } OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") OPENAI_MODEL = os.getenv("ATLAS_OPENAI_MODEL", os.getenv("OPENAI_MODEL", "gpt-4o-mini")) GROQ_API_KEY = os.getenv("GROQ_API_KEY") GROQ_MODEL = os.getenv( "ATLAS_GROQ_MODEL", os.getenv("GROQ_MODEL", "meta-llama/llama-4-scout-17b-16e-instruct"), ) @dataclass class TypeClassification: canonical_type: Optional[str] provenance: Optional[AtlasProvenance] needs_curation: bool async def classify_entity_type(subject: str, resolution: dict, context: Optional[str]) -> TypeClassification: label = resolution.get("canonical_label") or subject wikidata_hit = await _classify_via_wikidata(label) if wikidata_hit is not None: return wikidata_hit llm_hit = await _classify_via_llm(subject, resolution, context) if llm_hit is not None: return llm_hit return TypeClassification(canonical_type=None, provenance=None, needs_curation=True) async def _classify_via_wikidata(label: str) -> Optional[TypeClassification]: search_params = { "action": "wbsearchentities", "search": label, "language": "en", "limit": 1, "format": "json", } try: async with httpx.AsyncClient(timeout=8) as client: search_resp = await client.get("https://www.wikidata.org/w/api.php", params=search_params) search_resp.raise_for_status() search_data = search_resp.json() if not search_data.get("search"): return None entity_id = search_data["search"][0].get("id") if not entity_id: return None data_resp = await client.get( f"https://www.wikidata.org/wiki/Special:EntityData/{entity_id}.json", params={"flavor": "dump"}, ) data_resp.raise_for_status() data_payload = data_resp.json() entities = data_payload.get("entities", {}) entity_block = entities.get(entity_id) if not entity_block: return None claims = entity_block.get("claims", {}) p31 = claims.get("P31", []) for claim in p31: mainsnak = claim.get("mainsnak", {}) datavalue = mainsnak.get("datavalue", {}) value = datavalue.get("value", {}) wid = value.get("id") canonical = WIKIDATA_CLASS_MAP.get(wid) if canonical: prov = AtlasProvenance( source="wikidata", retrieval_method="type-classification", confidence=0.97, retrieved_at=datetime.now(timezone.utc).isoformat(), ) return TypeClassification(canonical_type=canonical, provenance=prov, needs_curation=False) except Exception: return None return None async def _classify_via_llm(subject: str, resolution: dict, context: Optional[str]) -> Optional[TypeClassification]: provider = None if GROQ_API_KEY: provider = "groq" elif OPENAI_API_KEY: provider = "openai" if provider is None: return None prompt = _build_llm_prompt(subject, resolution, context) payload = { "model": GROQ_MODEL if provider == "groq" else OPENAI_MODEL, "messages": [ { "role": "system", "content": ( "You classify named entities into canonical Atlas types. " "Valid types: Person, Organization, Location, CreativeWork, Event, Product, Other. " "Respond with JSON: {\"type\": , \"confidence\": <0-1>, \"reason\": }" ), }, {"role": "user", "content": prompt}, ], "temperature": 0, } headers = {"Content-Type": "application/json"} url = "https://api.groq.com/openai/v1/chat/completions" if provider == "groq": headers["Authorization"] = f"Bearer {GROQ_API_KEY}" else: headers["Authorization"] = f"Bearer {OPENAI_API_KEY}" url = "https://api.openai.com/v1/chat/completions" try: async with httpx.AsyncClient(timeout=15) as client: resp = await client.post(url, json=payload, headers=headers) resp.raise_for_status() data = resp.json() choice = data.get("choices", [{}])[0] message = choice.get("message", {}) content = message.get("content") if not content: return None parsed = _parse_llm_json(content) if not parsed: return None canonical_type = parsed.get("type") confidence = float(parsed.get("confidence", 0)) if canonical_type not in CANONICAL_TYPES: return None needs_curation = confidence < 0.6 prov = AtlasProvenance( source=f"{provider}-llm", retrieval_method="type-classification", confidence=confidence, retrieved_at=datetime.now(timezone.utc).isoformat(), ) return TypeClassification(canonical_type=canonical_type, provenance=prov, needs_curation=needs_curation) except Exception: return None return None def _build_llm_prompt(subject: str, resolution: dict, context: Optional[str]) -> str: raw_type = resolution.get("type") or resolution.get("raw_type") or "" candidates = resolution.get("candidates") or [] candidate_titles = ", ".join(sorted({c.get("title") for c in candidates if c.get("title")})) parts = [ f"Subject: {subject}", f"Canonical label: {resolution.get('canonical_label')}", f"Raw type hints: {raw_type}", f"Candidates: {candidate_titles}", ] if context: parts.append(f"Context: {context}") parts.append(f"Return JSON with keys type/confidence/reason. Types allowed: {', '.join(CANONICAL_TYPES)}") return "\n".join(parts) def _parse_llm_json(text: str) -> Optional[dict]: text = text.strip() if text.startswith("```") and text.endswith("```"): inner = text.strip("`") if inner.lower().startswith("json"): inner = inner[4:] text = inner try: return json.loads(text) except Exception: return None