wikidata.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239
  1. from __future__ import annotations
  2. import asyncio
  3. import os
  4. from dataclasses import dataclass, field
  5. from typing import Any, Callable
  6. from urllib.parse import urlencode
  7. import httpx
  8. PROPERTY_CACHE: dict[str, str] = {}
  9. WIKIDATA_USER_AGENT = os.getenv(
  10. "ATLAS_WIKIDATA_USER_AGENT",
  11. "Atlas/1.0 (contact: lukas.goldschmidt+atlas@googlemail.com)",
  12. )
  13. @dataclass
  14. class WikidataOptions:
  15. search: str = ""
  16. language: str = "en"
  17. strictlanguage: bool = True
  18. type: str = "item"
  19. limit: int = 7
  20. searchAction: str = "wbsearchentities"
  21. getAction: str = "wbgetentities"
  22. apiHost: str = "www.wikidata.org"
  23. apiPath: str = "/w/api.php"
  24. def _is_null(value: Any) -> bool:
  25. return value is None
  26. def _build_url(opts: WikidataOptions, params: dict[str, Any]) -> str:
  27. query = urlencode(params)
  28. return f"https://{opts.apiHost}{opts.apiPath}?{query}"
  29. class WikidataSearch:
  30. def __init__(self, options: dict[str, Any] | None = None, *, client: httpx.AsyncClient | None = None):
  31. self.defaultOptions = WikidataOptions()
  32. self.options = WikidataOptions(**{k: v for k, v in (options or {}).items() if hasattr(self.defaultOptions, k)})
  33. self._client = client
  34. def set(self, key: str, value: Any) -> None:
  35. if hasattr(self.options, key):
  36. setattr(self.options, key, value)
  37. def validateOptions(self) -> bool:
  38. if len(self.options.search) == 0:
  39. return False
  40. if self.options.limit > 50 or self.options.limit < 1:
  41. return False
  42. return True
  43. def clearPropertyCache(self) -> None:
  44. PROPERTY_CACHE.clear()
  45. async def search(self) -> dict[str, Any]:
  46. if not self.validateOptions():
  47. return {"results": [], "error": "Bad options"}
  48. params = {
  49. "action": self.options.searchAction,
  50. "language": self.options.language,
  51. "search": self.options.search,
  52. "type": self.options.type,
  53. "limit": self.options.limit,
  54. "format": "json",
  55. }
  56. url = _build_url(self.options, params)
  57. client = self._client or httpx.AsyncClient(
  58. timeout=20,
  59. headers={"Accept": "application/json", "User-Agent": WIKIDATA_USER_AGENT},
  60. )
  61. close_client = self._client is None
  62. try:
  63. resp = await client.get(url)
  64. resp.raise_for_status()
  65. data = resp.json()
  66. results = []
  67. for item in data.get("search", []):
  68. trimmed = {}
  69. if item.get("url"):
  70. trimmed["url"] = item["url"]
  71. if item.get("id"):
  72. trimmed["id"] = item["id"]
  73. if item.get("label"):
  74. trimmed["label"] = item["label"]
  75. if item.get("description"):
  76. trimmed["description"] = item["description"]
  77. if {"url", "id", "label"}.issubset(trimmed):
  78. results.append(trimmed)
  79. return {"results": results}
  80. finally:
  81. if close_client:
  82. await client.aclose()
  83. def search_sync(self) -> dict[str, Any]:
  84. return asyncio.run(self.search())
  85. async def get_entities(self, entities: list[str], resolve_properties: bool = True) -> dict[str, Any]:
  86. if not isinstance(entities, list):
  87. return {"error": "Bad |entities| parameter. Must be an array of strings"}
  88. if len(entities) == 0:
  89. return {"entities": []}
  90. if len(entities) > 50:
  91. entities = entities[:50]
  92. params = {
  93. "action": self.options.getAction,
  94. "languages": self.options.language,
  95. "redirects": "yes",
  96. "props": "claims|descriptions|labels",
  97. "normalize": "true",
  98. "ids": "|".join(entities),
  99. "format": "json",
  100. }
  101. url = _build_url(self.options, params)
  102. client = self._client or httpx.AsyncClient(
  103. timeout=20,
  104. headers={"Accept": "application/json", "User-Agent": WIKIDATA_USER_AGENT},
  105. )
  106. close_client = self._client is None
  107. try:
  108. resp = await client.get(url)
  109. resp.raise_for_status()
  110. data = resp.json()
  111. return self._parse_entities(data, resolve_properties)
  112. finally:
  113. if close_client:
  114. await client.aclose()
  115. def get_entities_sync(self, entities: list[str], resolve_properties: bool = True) -> dict[str, Any]:
  116. return asyncio.run(self.get_entities(entities, resolve_properties))
  117. def _parse_entities(self, data: dict[str, Any], resolve_properties: bool) -> dict[str, Any]:
  118. out_entities = []
  119. combined_property_list: set[str] = set()
  120. for entity in data.get("entities", {}).values():
  121. description = entity.get("descriptions", {}).get(self.options.language, {}).get("value", "")
  122. label = entity.get("labels", {}).get(self.options.language, {}).get("value", "")
  123. claims = []
  124. for claim_group in entity.get("claims", {}).values():
  125. for claim in claim_group:
  126. snak = claim.get("mainsnak", {})
  127. if snak.get("snaktype") != "value":
  128. continue
  129. prop = snak.get("property", "")
  130. prop_type = snak.get("datatype", "")
  131. val = ""
  132. dv = snak.get("datavalue", {}).get("value")
  133. if not dv:
  134. continue
  135. if prop_type == "wikibase-item":
  136. val = f"Q{dv.get('numeric-id')}"
  137. elif prop_type in {"string", "url", "external-id"}:
  138. val = dv
  139. elif prop_type == "time":
  140. val = dv.get("time", "")
  141. elif prop_type == "globe-coordinate":
  142. val = f"{dv.get('longitude')},{dv.get('latitude')}"
  143. elif prop_type == "quantity":
  144. val = dv.get("amount", "")
  145. if dv.get("unit") and dv.get("unit") != "1":
  146. val = f"{val}{dv.get('unit')}"
  147. else:
  148. continue
  149. if prop and val and prop_type:
  150. if resolve_properties:
  151. prop_cached = prop in PROPERTY_CACHE
  152. if prop_cached:
  153. prop = PROPERTY_CACHE[prop]
  154. else:
  155. combined_property_list.add(prop)
  156. value_cached = True
  157. if prop_type == "wikibase-item":
  158. value_cached = val in PROPERTY_CACHE
  159. if value_cached:
  160. val = PROPERTY_CACHE[val]
  161. else:
  162. combined_property_list.add(val)
  163. claims.append({"property": prop, "value": val, "type": prop_type, "propertyCached": prop_cached, "valueCached": value_cached})
  164. else:
  165. claims.append({"property": prop, "value": val, "type": prop_type})
  166. if description and label and claims:
  167. out_entities.append({"label": label, "description": description, "claims": claims})
  168. if not resolve_properties:
  169. return {"entities": out_entities}
  170. if combined_property_list:
  171. self._resolve_properties(list(combined_property_list))
  172. for ent in out_entities:
  173. for claim in ent["claims"]:
  174. prop_cached = claim.pop("propertyCached", False)
  175. val_cached = claim.pop("valueCached", False)
  176. if not prop_cached and claim["property"] in PROPERTY_CACHE:
  177. claim["property"] = PROPERTY_CACHE[claim["property"]]
  178. if not val_cached and claim["value"] in PROPERTY_CACHE:
  179. claim["value"] = PROPERTY_CACHE[claim["value"]]
  180. claim["type"] = "string"
  181. return {"entities": out_entities}
  182. def _resolve_properties(self, property_list: list[str]) -> None:
  183. # Placeholder for batch property label resolution, kept synchronous for now.
  184. # Use wbgetentities in batches of 50.
  185. for prop_id in property_list:
  186. PROPERTY_CACHE.setdefault(prop_id, prop_id)
  187. async def get_entity_data(self, qid: str) -> dict[str, Any]:
  188. client = self._client or httpx.AsyncClient(
  189. timeout=20,
  190. headers={"Accept": "application/json", "User-Agent": WIKIDATA_USER_AGENT},
  191. )
  192. close_client = self._client is None
  193. try:
  194. resp = await client.get(
  195. f"https://www.wikidata.org/wiki/Special:EntityData/{qid}.json",
  196. params={"flavor": "dump"},
  197. )
  198. resp.raise_for_status()
  199. return resp.json()
  200. finally:
  201. if close_client:
  202. await client.aclose()