wikidata.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284
  1. from __future__ import annotations
  2. import asyncio
  3. import os
  4. import json
  5. from dataclasses import dataclass, field
  6. from typing import Any, Callable
  7. from urllib.parse import urlencode
  8. import httpx
  9. PROPERTY_CACHE: dict[str, str] = {}
  10. WIKIDATA_USER_AGENT = os.getenv(
  11. "ATLAS_WIKIDATA_USER_AGENT",
  12. "Atlas/1.0 (contact: lukas.goldschmidt+atlas@googlemail.com)",
  13. )
  14. WIKIDATA_QUICK_RESOLVE_BASE_URL = os.getenv(
  15. "ATLAS_WIKIDATA_QUICK_RESOLVE_URL",
  16. "https://wikidata.reconci.link/en/api",
  17. )
  18. @dataclass
  19. class WikidataOptions:
  20. search: str = ""
  21. language: str = "en"
  22. strictlanguage: bool = True
  23. type: str = "item"
  24. limit: int = 7
  25. searchAction: str = "wbsearchentities"
  26. getAction: str = "wbgetentities"
  27. apiHost: str = "www.wikidata.org"
  28. apiPath: str = "/w/api.php"
  29. def _is_null(value: Any) -> bool:
  30. return value is None
  31. def _build_url(opts: WikidataOptions, params: dict[str, Any]) -> str:
  32. query = urlencode(params)
  33. return f"https://{opts.apiHost}{opts.apiPath}?{query}"
  34. class WikidataSearch:
  35. def __init__(self, options: dict[str, Any] | None = None, *, client: httpx.AsyncClient | None = None):
  36. self.defaultOptions = WikidataOptions()
  37. self.options = WikidataOptions(**{k: v for k, v in (options or {}).items() if hasattr(self.defaultOptions, k)})
  38. self._client = client
  39. def set(self, key: str, value: Any) -> None:
  40. if hasattr(self.options, key):
  41. setattr(self.options, key, value)
  42. def validateOptions(self) -> bool:
  43. if len(self.options.search) == 0:
  44. return False
  45. if self.options.limit > 50 or self.options.limit < 1:
  46. return False
  47. return True
  48. def clearPropertyCache(self) -> None:
  49. PROPERTY_CACHE.clear()
  50. async def search(self) -> dict[str, Any]:
  51. if not self.validateOptions():
  52. return {"results": [], "error": "Bad options"}
  53. params = {
  54. "action": self.options.searchAction,
  55. "language": self.options.language,
  56. "search": self.options.search,
  57. "type": self.options.type,
  58. "limit": self.options.limit,
  59. "format": "json",
  60. }
  61. url = _build_url(self.options, params)
  62. client = self._client or httpx.AsyncClient(
  63. timeout=20,
  64. headers={"Accept": "application/json", "User-Agent": WIKIDATA_USER_AGENT},
  65. )
  66. close_client = self._client is None
  67. try:
  68. resp = await client.get(url)
  69. resp.raise_for_status()
  70. data = resp.json()
  71. results = []
  72. for item in data.get("search", []):
  73. trimmed = {}
  74. if item.get("url"):
  75. trimmed["url"] = item["url"]
  76. if item.get("id"):
  77. trimmed["id"] = item["id"]
  78. if item.get("label"):
  79. trimmed["label"] = item["label"]
  80. if item.get("description"):
  81. trimmed["description"] = item["description"]
  82. if {"url", "id", "label"}.issubset(trimmed):
  83. results.append(trimmed)
  84. return {"results": results}
  85. finally:
  86. if close_client:
  87. await client.aclose()
  88. async def quick_resolve(self, query: str, *, limit: int = 1) -> dict[str, Any]:
  89. """Use wikidata.reconci.link quick resolve endpoint.
  90. Returns a payload shaped like:
  91. {"results": [{"id": "Q..", "label": "..", "description": "..", "type": ".."}, ...]}
  92. """
  93. endpoint = WIKIDATA_QUICK_RESOLVE_BASE_URL
  94. params = {
  95. "queries": json.dumps({"q0": {"query": query, "limit": limit}}),
  96. }
  97. client = self._client or httpx.AsyncClient(
  98. timeout=20,
  99. headers={"Accept": "application/json", "User-Agent": WIKIDATA_USER_AGENT},
  100. )
  101. close_client = self._client is None
  102. try:
  103. resp = await client.get(endpoint, params=params)
  104. resp.raise_for_status()
  105. data = resp.json()
  106. finally:
  107. if close_client:
  108. await client.aclose()
  109. results = []
  110. for row in (data.get("q0", {}) or {}).get("result", []) or []:
  111. # type is an array of {id,name}; pick the first.
  112. t0 = (row.get("type") or [])
  113. type_id = t0[0].get("id") if t0 else None
  114. results.append(
  115. {
  116. "id": row.get("id"),
  117. "label": row.get("name"),
  118. "description": row.get("description"),
  119. "type": type_id,
  120. }
  121. )
  122. return {"results": results}
  123. def search_sync(self) -> dict[str, Any]:
  124. return asyncio.run(self.search())
  125. async def get_entities(self, entities: list[str], resolve_properties: bool = True) -> dict[str, Any]:
  126. if not isinstance(entities, list):
  127. return {"error": "Bad |entities| parameter. Must be an array of strings"}
  128. if len(entities) == 0:
  129. return {"entities": []}
  130. if len(entities) > 50:
  131. entities = entities[:50]
  132. params = {
  133. "action": self.options.getAction,
  134. "languages": self.options.language,
  135. "redirects": "yes",
  136. "props": "claims|descriptions|labels",
  137. "normalize": "true",
  138. "ids": "|".join(entities),
  139. "format": "json",
  140. }
  141. url = _build_url(self.options, params)
  142. client = self._client or httpx.AsyncClient(
  143. timeout=20,
  144. headers={"Accept": "application/json", "User-Agent": WIKIDATA_USER_AGENT},
  145. )
  146. close_client = self._client is None
  147. try:
  148. resp = await client.get(url)
  149. resp.raise_for_status()
  150. data = resp.json()
  151. return self._parse_entities(data, resolve_properties)
  152. finally:
  153. if close_client:
  154. await client.aclose()
  155. def get_entities_sync(self, entities: list[str], resolve_properties: bool = True) -> dict[str, Any]:
  156. return asyncio.run(self.get_entities(entities, resolve_properties))
  157. def _parse_entities(self, data: dict[str, Any], resolve_properties: bool) -> dict[str, Any]:
  158. out_entities = []
  159. combined_property_list: set[str] = set()
  160. for entity in data.get("entities", {}).values():
  161. description = entity.get("descriptions", {}).get(self.options.language, {}).get("value", "")
  162. label = entity.get("labels", {}).get(self.options.language, {}).get("value", "")
  163. claims = []
  164. for claim_group in entity.get("claims", {}).values():
  165. for claim in claim_group:
  166. snak = claim.get("mainsnak", {})
  167. if snak.get("snaktype") != "value":
  168. continue
  169. prop = snak.get("property", "")
  170. prop_type = snak.get("datatype", "")
  171. val = ""
  172. dv = snak.get("datavalue", {}).get("value")
  173. if not dv:
  174. continue
  175. if prop_type == "wikibase-item":
  176. val = f"Q{dv.get('numeric-id')}"
  177. elif prop_type in {"string", "url", "external-id"}:
  178. val = dv
  179. elif prop_type == "time":
  180. val = dv.get("time", "")
  181. elif prop_type == "globe-coordinate":
  182. val = f"{dv.get('longitude')},{dv.get('latitude')}"
  183. elif prop_type == "quantity":
  184. val = dv.get("amount", "")
  185. if dv.get("unit") and dv.get("unit") != "1":
  186. val = f"{val}{dv.get('unit')}"
  187. else:
  188. continue
  189. if prop and val and prop_type:
  190. if resolve_properties:
  191. prop_cached = prop in PROPERTY_CACHE
  192. if prop_cached:
  193. prop = PROPERTY_CACHE[prop]
  194. else:
  195. combined_property_list.add(prop)
  196. value_cached = True
  197. if prop_type == "wikibase-item":
  198. value_cached = val in PROPERTY_CACHE
  199. if value_cached:
  200. val = PROPERTY_CACHE[val]
  201. else:
  202. combined_property_list.add(val)
  203. claims.append({"property": prop, "value": val, "type": prop_type, "propertyCached": prop_cached, "valueCached": value_cached})
  204. else:
  205. claims.append({"property": prop, "value": val, "type": prop_type})
  206. if description and label and claims:
  207. out_entities.append({"label": label, "description": description, "claims": claims})
  208. if not resolve_properties:
  209. return {"entities": out_entities}
  210. if combined_property_list:
  211. self._resolve_properties(list(combined_property_list))
  212. for ent in out_entities:
  213. for claim in ent["claims"]:
  214. prop_cached = claim.pop("propertyCached", False)
  215. val_cached = claim.pop("valueCached", False)
  216. if not prop_cached and claim["property"] in PROPERTY_CACHE:
  217. claim["property"] = PROPERTY_CACHE[claim["property"]]
  218. if not val_cached and claim["value"] in PROPERTY_CACHE:
  219. claim["value"] = PROPERTY_CACHE[claim["value"]]
  220. claim["type"] = "string"
  221. return {"entities": out_entities}
  222. def _resolve_properties(self, property_list: list[str]) -> None:
  223. # Placeholder for batch property label resolution, kept synchronous for now.
  224. # Use wbgetentities in batches of 50.
  225. for prop_id in property_list:
  226. PROPERTY_CACHE.setdefault(prop_id, prop_id)
  227. async def get_entity_data(self, qid: str) -> dict[str, Any]:
  228. client = self._client or httpx.AsyncClient(
  229. timeout=20,
  230. headers={"Accept": "application/json", "User-Agent": WIKIDATA_USER_AGENT},
  231. )
  232. close_client = self._client is None
  233. try:
  234. resp = await client.get(
  235. f"https://www.wikidata.org/wiki/Special:EntityData/{qid}.json",
  236. params={"flavor": "dump"},
  237. )
  238. resp.raise_for_status()
  239. return resp.json()
  240. finally:
  241. if close_client:
  242. await client.aclose()