from __future__ import annotations import asyncio import os import json from dataclasses import dataclass, field from typing import Any, Callable from urllib.parse import urlencode import httpx PROPERTY_CACHE: dict[str, str] = {} WIKIDATA_USER_AGENT = os.getenv( "ATLAS_WIKIDATA_USER_AGENT", "Atlas/1.0 (contact: lukas.goldschmidt+atlas@googlemail.com)", ) # The Wikidata reconciliation endpoint has moved more than once, so keep it configurable. WIKIDATA_QUICK_RESOLVE_BASE_URL = os.getenv( "ATLAS_WIKIDATA_QUICK_RESOLVE_URL", "https://wikidata.reconci.link/en/api", ) OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL", "http://192.168.0.200:11434") OLLAMA_EMBEDDING_MODEL = os.getenv("OLLAMA_EMBEDDING_MODEL", "nomic-embed-text") @dataclass class WikidataOptions: search: str = "" language: str = "en" strictlanguage: bool = True type: str = "item" limit: int = 7 searchAction: str = "wbsearchentities" getAction: str = "wbgetentities" apiHost: str = "www.wikidata.org" apiPath: str = "/w/api.php" def _is_null(value: Any) -> bool: return value is None def _build_url(opts: WikidataOptions, params: dict[str, Any]) -> str: query = urlencode(params) return f"https://{opts.apiHost}{opts.apiPath}?{query}" def _client_kwargs() -> dict[str, Any]: return { "timeout": 20, "headers": {"Accept": "application/json", "User-Agent": WIKIDATA_USER_AGENT}, "follow_redirects": True, } def _ollama_client_kwargs() -> dict[str, Any]: return { "timeout": 20, "base_url": OLLAMA_BASE_URL, } class WikidataSearch: def __init__(self, options: dict[str, Any] | None = None, *, client: httpx.AsyncClient | None = None): self.defaultOptions = WikidataOptions() self.options = WikidataOptions(**{k: v for k, v in (options or {}).items() if hasattr(self.defaultOptions, k)}) self._client = client def set(self, key: str, value: Any) -> None: if hasattr(self.options, key): setattr(self.options, key, value) def validateOptions(self) -> bool: if len(self.options.search) == 0: return False if self.options.limit > 50 or self.options.limit < 1: return False return True def clearPropertyCache(self) -> None: PROPERTY_CACHE.clear() async def embed_text(self, text: str) -> list[float] | None: client = self._client or httpx.AsyncClient(**_ollama_client_kwargs()) close_client = self._client is None try: resp = await client.post( "/api/embeddings", json={"model": OLLAMA_EMBEDDING_MODEL, "prompt": text}, ) resp.raise_for_status() data = resp.json() embedding = data.get("embedding") return embedding if isinstance(embedding, list) else None finally: if close_client: await client.aclose() async def search(self) -> dict[str, Any]: if not self.validateOptions(): return {"results": [], "error": "Bad options"} params = { "action": self.options.searchAction, "language": self.options.language, "search": self.options.search, "type": self.options.type, "limit": self.options.limit, "format": "json", } url = _build_url(self.options, params) client = self._client or httpx.AsyncClient(**_client_kwargs()) close_client = self._client is None try: resp = await client.get(url) resp.raise_for_status() data = resp.json() results = [] for item in data.get("search", []): trimmed = {} if item.get("url"): trimmed["url"] = item["url"] if item.get("id"): trimmed["id"] = item["id"] if item.get("label"): trimmed["label"] = item["label"] if item.get("description"): trimmed["description"] = item["description"] if {"url", "id", "label"}.issubset(trimmed): results.append(trimmed) return {"results": results} finally: if close_client: await client.aclose() async def quick_resolve(self, query: str, *, limit: int = 1) -> dict[str, Any]: """Use wikidata.reconci.link quick resolve endpoint. Returns a payload shaped like: {"results": [{"id": "Q..", "label": "..", "description": "..", "type": ".."}, ...]} """ endpoint = WIKIDATA_QUICK_RESOLVE_BASE_URL params = { "queries": json.dumps({"q0": {"query": query, "limit": limit}}), } client = self._client or httpx.AsyncClient(**_client_kwargs()) close_client = self._client is None try: resp = await client.get(endpoint, params=params) resp.raise_for_status() data = resp.json() finally: if close_client: await client.aclose() results = [] for row in (data.get("q0", {}) or {}).get("result", []) or []: # type is an array of {id,name}; pick the first. t0 = (row.get("type") or []) type_id = t0[0].get("id") if t0 else None results.append( { "id": row.get("id"), "label": row.get("name"), "description": row.get("description"), "type": type_id, } ) return {"results": results} async def candidate_embeddings(self, candidates: list[dict[str, Any]]) -> list[dict[str, Any]]: out = [] for cand in candidates: text_parts = [cand.get("label") or "", cand.get("description") or "", " ".join(cand.get("aliases") or [])] text = " | ".join(part for part in text_parts if part) embedding = await self.embed_text(text) out.append({**cand, "embedding": embedding}) return out def search_sync(self) -> dict[str, Any]: return asyncio.run(self.search()) async def get_entities(self, entities: list[str], resolve_properties: bool = True) -> dict[str, Any]: if not isinstance(entities, list): return {"error": "Bad |entities| parameter. Must be an array of strings"} if len(entities) == 0: return {"entities": []} if len(entities) > 50: entities = entities[:50] params = { "action": self.options.getAction, "languages": self.options.language, "redirects": "yes", "props": "claims|descriptions|labels", "normalize": "true", "ids": "|".join(entities), "format": "json", } url = _build_url(self.options, params) client = self._client or httpx.AsyncClient(**_client_kwargs()) close_client = self._client is None try: resp = await client.get(url) resp.raise_for_status() data = resp.json() return self._parse_entities(data, resolve_properties) finally: if close_client: await client.aclose() def get_entities_sync(self, entities: list[str], resolve_properties: bool = True) -> dict[str, Any]: return asyncio.run(self.get_entities(entities, resolve_properties)) def _parse_entities(self, data: dict[str, Any], resolve_properties: bool) -> dict[str, Any]: out_entities = [] combined_property_list: set[str] = set() for entity in data.get("entities", {}).values(): description = entity.get("descriptions", {}).get(self.options.language, {}).get("value", "") label = entity.get("labels", {}).get(self.options.language, {}).get("value", "") claims = [] for claim_group in entity.get("claims", {}).values(): for claim in claim_group: snak = claim.get("mainsnak", {}) if snak.get("snaktype") != "value": continue prop = snak.get("property", "") prop_type = snak.get("datatype", "") val = "" dv = snak.get("datavalue", {}).get("value") if not dv: continue if prop_type == "wikibase-item": val = f"Q{dv.get('numeric-id')}" elif prop_type in {"string", "url", "external-id"}: val = dv elif prop_type == "time": val = dv.get("time", "") elif prop_type == "globe-coordinate": val = f"{dv.get('longitude')},{dv.get('latitude')}" elif prop_type == "quantity": val = dv.get("amount", "") if dv.get("unit") and dv.get("unit") != "1": val = f"{val}{dv.get('unit')}" else: continue if prop and val and prop_type: if resolve_properties: prop_cached = prop in PROPERTY_CACHE if prop_cached: prop = PROPERTY_CACHE[prop] else: combined_property_list.add(prop) value_cached = True if prop_type == "wikibase-item": value_cached = val in PROPERTY_CACHE if value_cached: val = PROPERTY_CACHE[val] else: combined_property_list.add(val) claims.append({"property": prop, "value": val, "type": prop_type, "propertyCached": prop_cached, "valueCached": value_cached}) else: claims.append({"property": prop, "value": val, "type": prop_type}) if description and label and claims: out_entities.append({"label": label, "description": description, "claims": claims}) if not resolve_properties: return {"entities": out_entities} if combined_property_list: self._resolve_properties(list(combined_property_list)) for ent in out_entities: for claim in ent["claims"]: prop_cached = claim.pop("propertyCached", False) val_cached = claim.pop("valueCached", False) if not prop_cached and claim["property"] in PROPERTY_CACHE: claim["property"] = PROPERTY_CACHE[claim["property"]] if not val_cached and claim["value"] in PROPERTY_CACHE: claim["value"] = PROPERTY_CACHE[claim["value"]] claim["type"] = "string" return {"entities": out_entities} def _resolve_properties(self, property_list: list[str]) -> None: # Placeholder for batch property label resolution, kept synchronous for now. # Use wbgetentities in batches of 50. for prop_id in property_list: PROPERTY_CACHE.setdefault(prop_id, prop_id) async def get_entity_data(self, qid: str) -> dict[str, Any]: client = self._client or httpx.AsyncClient(**_client_kwargs()) close_client = self._client is None try: resp = await client.get( f"https://www.wikidata.org/wiki/Special:EntityData/{qid}.json", params={"flavor": "dump"}, ) resp.raise_for_status() return resp.json() finally: if close_client: await client.aclose()