| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206 |
- """Canonical type classification pipeline for Atlas entities."""
- from __future__ import annotations
- import json
- import os
- from dataclasses import dataclass
- from datetime import datetime, timezone
- from typing import Optional
- import httpx
- from app.models import AtlasProvenance
- CANONICAL_TYPES = [
- "Person",
- "Organization",
- "Location",
- "CreativeWork",
- "Event",
- "Product",
- "Other",
- ]
- WIKIDATA_CLASS_MAP = {
- "Q5": "Person",
- "Q43229": "Organization",
- "Q17334923": "Location", # human settlement
- "Q515": "Location", # city
- "Q82794": "Location", # geographic region
- "Q16521": "Taxon",
- "Q571": "CreativeWork",
- "Q11424": "CreativeWork", # film
- "Q49848": "CreativeWork", # album
- "Q1656682": "Event",
- "Q191067": "Product",
- }
- OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
- OPENAI_MODEL = os.getenv("ATLAS_OPENAI_MODEL", os.getenv("OPENAI_MODEL", "gpt-4o-mini"))
- GROQ_API_KEY = os.getenv("GROQ_API_KEY")
- GROQ_MODEL = os.getenv(
- "ATLAS_GROQ_MODEL",
- os.getenv("GROQ_MODEL", "meta-llama/llama-4-scout-17b-16e-instruct"),
- )
- @dataclass
- class TypeClassification:
- canonical_type: Optional[str]
- provenance: Optional[AtlasProvenance]
- needs_curation: bool
- async def classify_entity_type(subject: str, resolution: dict, context: Optional[str]) -> TypeClassification:
- label = resolution.get("canonical_label") or subject
- wikidata_hit = await _classify_via_wikidata(label)
- if wikidata_hit is not None:
- return wikidata_hit
- llm_hit = await _classify_via_llm(subject, resolution, context)
- if llm_hit is not None:
- return llm_hit
- return TypeClassification(canonical_type=None, provenance=None, needs_curation=True)
- async def _classify_via_wikidata(label: str) -> Optional[TypeClassification]:
- search_params = {
- "action": "wbsearchentities",
- "search": label,
- "language": "en",
- "limit": 1,
- "format": "json",
- }
- try:
- async with httpx.AsyncClient(timeout=8) as client:
- search_resp = await client.get("https://www.wikidata.org/w/api.php", params=search_params)
- search_resp.raise_for_status()
- search_data = search_resp.json()
- if not search_data.get("search"):
- return None
- entity_id = search_data["search"][0].get("id")
- if not entity_id:
- return None
- data_resp = await client.get(
- f"https://www.wikidata.org/wiki/Special:EntityData/{entity_id}.json",
- params={"flavor": "dump"},
- )
- data_resp.raise_for_status()
- data_payload = data_resp.json()
- entities = data_payload.get("entities", {})
- entity_block = entities.get(entity_id)
- if not entity_block:
- return None
- claims = entity_block.get("claims", {})
- p31 = claims.get("P31", [])
- for claim in p31:
- mainsnak = claim.get("mainsnak", {})
- datavalue = mainsnak.get("datavalue", {})
- value = datavalue.get("value", {})
- wid = value.get("id")
- canonical = WIKIDATA_CLASS_MAP.get(wid)
- if canonical:
- prov = AtlasProvenance(
- source="wikidata",
- retrieval_method="type-classification",
- confidence=0.97,
- retrieved_at=datetime.now(timezone.utc).isoformat(),
- )
- return TypeClassification(canonical_type=canonical, provenance=prov, needs_curation=False)
- except Exception:
- return None
- return None
- async def _classify_via_llm(subject: str, resolution: dict, context: Optional[str]) -> Optional[TypeClassification]:
- provider = None
- if GROQ_API_KEY:
- provider = "groq"
- elif OPENAI_API_KEY:
- provider = "openai"
- if provider is None:
- return None
- prompt = _build_llm_prompt(subject, resolution, context)
- payload = {
- "model": GROQ_MODEL if provider == "groq" else OPENAI_MODEL,
- "messages": [
- {
- "role": "system",
- "content": (
- "You classify named entities into canonical Atlas types. "
- "Valid types: Person, Organization, Location, CreativeWork, Event, Product, Other. "
- "Respond with JSON: {\"type\": <type>, \"confidence\": <0-1>, \"reason\": <short>}"
- ),
- },
- {"role": "user", "content": prompt},
- ],
- "temperature": 0,
- }
- headers = {"Content-Type": "application/json"}
- url = "https://api.groq.com/openai/v1/chat/completions"
- if provider == "groq":
- headers["Authorization"] = f"Bearer {GROQ_API_KEY}"
- else:
- headers["Authorization"] = f"Bearer {OPENAI_API_KEY}"
- url = "https://api.openai.com/v1/chat/completions"
- try:
- async with httpx.AsyncClient(timeout=15) as client:
- resp = await client.post(url, json=payload, headers=headers)
- resp.raise_for_status()
- data = resp.json()
- choice = data.get("choices", [{}])[0]
- message = choice.get("message", {})
- content = message.get("content")
- if not content:
- return None
- parsed = _parse_llm_json(content)
- if not parsed:
- return None
- canonical_type = parsed.get("type")
- confidence = float(parsed.get("confidence", 0))
- if canonical_type not in CANONICAL_TYPES:
- return None
- needs_curation = confidence < 0.6
- prov = AtlasProvenance(
- source=f"{provider}-llm",
- retrieval_method="type-classification",
- confidence=confidence,
- retrieved_at=datetime.now(timezone.utc).isoformat(),
- )
- return TypeClassification(canonical_type=canonical_type, provenance=prov, needs_curation=needs_curation)
- except Exception:
- return None
- return None
- def _build_llm_prompt(subject: str, resolution: dict, context: Optional[str]) -> str:
- raw_type = resolution.get("type") or resolution.get("raw_type") or ""
- candidates = resolution.get("candidates") or []
- candidate_titles = ", ".join(sorted({c.get("title") for c in candidates if c.get("title")}))
- parts = [
- f"Subject: {subject}",
- f"Canonical label: {resolution.get('canonical_label')}",
- f"Raw type hints: {raw_type}",
- f"Candidates: {candidate_titles}",
- ]
- if context:
- parts.append(f"Context: {context}")
- parts.append(f"Return JSON with keys type/confidence/reason. Types allowed: {', '.join(CANONICAL_TYPES)}")
- return "\n".join(parts)
- def _parse_llm_json(text: str) -> Optional[dict]:
- text = text.strip()
- if text.startswith("```") and text.endswith("```"):
- inner = text.strip("`")
- if inner.lower().startswith("json"):
- inner = inner[4:]
- text = inner
- try:
- return json.loads(text)
- except Exception:
- return None
|