wikidata.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316
  1. from __future__ import annotations
  2. import asyncio
  3. import os
  4. import json
  5. from dataclasses import dataclass, field
  6. from typing import Any, Callable
  7. from urllib.parse import urlencode
  8. import httpx
  9. PROPERTY_CACHE: dict[str, str] = {}
  10. WIKIDATA_USER_AGENT = os.getenv(
  11. "ATLAS_WIKIDATA_USER_AGENT",
  12. "Atlas/1.0 (contact: lukas.goldschmidt+atlas@googlemail.com)",
  13. )
  14. # The Wikidata reconciliation endpoint has moved more than once, so keep it configurable.
  15. WIKIDATA_QUICK_RESOLVE_BASE_URL = os.getenv(
  16. "ATLAS_WIKIDATA_QUICK_RESOLVE_URL",
  17. "https://wikidata.reconci.link/en/api",
  18. )
  19. OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL", "http://192.168.0.200:11434")
  20. OLLAMA_EMBEDDING_MODEL = os.getenv("OLLAMA_EMBEDDING_MODEL", "nomic-embed-text")
  21. @dataclass
  22. class WikidataOptions:
  23. search: str = ""
  24. language: str = "en"
  25. strictlanguage: bool = True
  26. type: str = "item"
  27. limit: int = 7
  28. searchAction: str = "wbsearchentities"
  29. getAction: str = "wbgetentities"
  30. apiHost: str = "www.wikidata.org"
  31. apiPath: str = "/w/api.php"
  32. def _is_null(value: Any) -> bool:
  33. return value is None
  34. def _build_url(opts: WikidataOptions, params: dict[str, Any]) -> str:
  35. query = urlencode(params)
  36. return f"https://{opts.apiHost}{opts.apiPath}?{query}"
  37. def _client_kwargs() -> dict[str, Any]:
  38. return {
  39. "timeout": 20,
  40. "headers": {"Accept": "application/json", "User-Agent": WIKIDATA_USER_AGENT},
  41. "follow_redirects": True,
  42. }
  43. def _ollama_client_kwargs() -> dict[str, Any]:
  44. return {
  45. "timeout": 20,
  46. "base_url": OLLAMA_BASE_URL,
  47. }
  48. class WikidataSearch:
  49. def __init__(self, options: dict[str, Any] | None = None, *, client: httpx.AsyncClient | None = None):
  50. self.defaultOptions = WikidataOptions()
  51. self.options = WikidataOptions(**{k: v for k, v in (options or {}).items() if hasattr(self.defaultOptions, k)})
  52. self._client = client
  53. def set(self, key: str, value: Any) -> None:
  54. if hasattr(self.options, key):
  55. setattr(self.options, key, value)
  56. def validateOptions(self) -> bool:
  57. if len(self.options.search) == 0:
  58. return False
  59. if self.options.limit > 50 or self.options.limit < 1:
  60. return False
  61. return True
  62. def clearPropertyCache(self) -> None:
  63. PROPERTY_CACHE.clear()
  64. async def embed_text(self, text: str) -> list[float] | None:
  65. client = self._client or httpx.AsyncClient(**_ollama_client_kwargs())
  66. close_client = self._client is None
  67. try:
  68. resp = await client.post(
  69. "/api/embeddings",
  70. json={"model": OLLAMA_EMBEDDING_MODEL, "prompt": text},
  71. )
  72. resp.raise_for_status()
  73. data = resp.json()
  74. embedding = data.get("embedding")
  75. return embedding if isinstance(embedding, list) else None
  76. finally:
  77. if close_client:
  78. await client.aclose()
  79. async def search(self) -> dict[str, Any]:
  80. if not self.validateOptions():
  81. return {"results": [], "error": "Bad options"}
  82. params = {
  83. "action": self.options.searchAction,
  84. "language": self.options.language,
  85. "search": self.options.search,
  86. "type": self.options.type,
  87. "limit": self.options.limit,
  88. "format": "json",
  89. }
  90. url = _build_url(self.options, params)
  91. client = self._client or httpx.AsyncClient(**_client_kwargs())
  92. close_client = self._client is None
  93. try:
  94. resp = await client.get(url)
  95. resp.raise_for_status()
  96. data = resp.json()
  97. results = []
  98. for item in data.get("search", []):
  99. trimmed = {}
  100. if item.get("url"):
  101. trimmed["url"] = item["url"]
  102. if item.get("id"):
  103. trimmed["id"] = item["id"]
  104. if item.get("label"):
  105. trimmed["label"] = item["label"]
  106. if item.get("description"):
  107. trimmed["description"] = item["description"]
  108. if {"url", "id", "label"}.issubset(trimmed):
  109. results.append(trimmed)
  110. return {"results": results}
  111. finally:
  112. if close_client:
  113. await client.aclose()
  114. async def quick_resolve(self, query: str, *, limit: int = 1) -> dict[str, Any]:
  115. """Use wikidata.reconci.link quick resolve endpoint.
  116. Returns a payload shaped like:
  117. {"results": [{"id": "Q..", "label": "..", "description": "..", "type": ".."}, ...]}
  118. """
  119. endpoint = WIKIDATA_QUICK_RESOLVE_BASE_URL
  120. params = {
  121. "queries": json.dumps({"q0": {"query": query, "limit": limit}}),
  122. }
  123. client = self._client or httpx.AsyncClient(**_client_kwargs())
  124. close_client = self._client is None
  125. try:
  126. resp = await client.get(endpoint, params=params)
  127. resp.raise_for_status()
  128. data = resp.json()
  129. finally:
  130. if close_client:
  131. await client.aclose()
  132. results = []
  133. for row in (data.get("q0", {}) or {}).get("result", []) or []:
  134. # type is an array of {id,name}; pick the first.
  135. t0 = (row.get("type") or [])
  136. type_id = t0[0].get("id") if t0 else None
  137. results.append(
  138. {
  139. "id": row.get("id"),
  140. "label": row.get("name"),
  141. "description": row.get("description"),
  142. "type": type_id,
  143. }
  144. )
  145. return {"results": results}
  146. async def candidate_embeddings(self, candidates: list[dict[str, Any]]) -> list[dict[str, Any]]:
  147. out = []
  148. for cand in candidates:
  149. text_parts = [cand.get("label") or "", cand.get("description") or "", " ".join(cand.get("aliases") or [])]
  150. text = " | ".join(part for part in text_parts if part)
  151. embedding = await self.embed_text(text)
  152. out.append({**cand, "embedding": embedding})
  153. return out
  154. def search_sync(self) -> dict[str, Any]:
  155. return asyncio.run(self.search())
  156. async def get_entities(self, entities: list[str], resolve_properties: bool = True) -> dict[str, Any]:
  157. if not isinstance(entities, list):
  158. return {"error": "Bad |entities| parameter. Must be an array of strings"}
  159. if len(entities) == 0:
  160. return {"entities": []}
  161. if len(entities) > 50:
  162. entities = entities[:50]
  163. params = {
  164. "action": self.options.getAction,
  165. "languages": self.options.language,
  166. "redirects": "yes",
  167. "props": "claims|descriptions|labels",
  168. "normalize": "true",
  169. "ids": "|".join(entities),
  170. "format": "json",
  171. }
  172. url = _build_url(self.options, params)
  173. client = self._client or httpx.AsyncClient(**_client_kwargs())
  174. close_client = self._client is None
  175. try:
  176. resp = await client.get(url)
  177. resp.raise_for_status()
  178. data = resp.json()
  179. return self._parse_entities(data, resolve_properties)
  180. finally:
  181. if close_client:
  182. await client.aclose()
  183. def get_entities_sync(self, entities: list[str], resolve_properties: bool = True) -> dict[str, Any]:
  184. return asyncio.run(self.get_entities(entities, resolve_properties))
  185. def _parse_entities(self, data: dict[str, Any], resolve_properties: bool) -> dict[str, Any]:
  186. out_entities = []
  187. combined_property_list: set[str] = set()
  188. for entity in data.get("entities", {}).values():
  189. description = entity.get("descriptions", {}).get(self.options.language, {}).get("value", "")
  190. label = entity.get("labels", {}).get(self.options.language, {}).get("value", "")
  191. claims = []
  192. for claim_group in entity.get("claims", {}).values():
  193. for claim in claim_group:
  194. snak = claim.get("mainsnak", {})
  195. if snak.get("snaktype") != "value":
  196. continue
  197. prop = snak.get("property", "")
  198. prop_type = snak.get("datatype", "")
  199. val = ""
  200. dv = snak.get("datavalue", {}).get("value")
  201. if not dv:
  202. continue
  203. if prop_type == "wikibase-item":
  204. val = f"Q{dv.get('numeric-id')}"
  205. elif prop_type in {"string", "url", "external-id"}:
  206. val = dv
  207. elif prop_type == "time":
  208. val = dv.get("time", "")
  209. elif prop_type == "globe-coordinate":
  210. val = f"{dv.get('longitude')},{dv.get('latitude')}"
  211. elif prop_type == "quantity":
  212. val = dv.get("amount", "")
  213. if dv.get("unit") and dv.get("unit") != "1":
  214. val = f"{val}{dv.get('unit')}"
  215. else:
  216. continue
  217. if prop and val and prop_type:
  218. if resolve_properties:
  219. prop_cached = prop in PROPERTY_CACHE
  220. if prop_cached:
  221. prop = PROPERTY_CACHE[prop]
  222. else:
  223. combined_property_list.add(prop)
  224. value_cached = True
  225. if prop_type == "wikibase-item":
  226. value_cached = val in PROPERTY_CACHE
  227. if value_cached:
  228. val = PROPERTY_CACHE[val]
  229. else:
  230. combined_property_list.add(val)
  231. claims.append({"property": prop, "value": val, "type": prop_type, "propertyCached": prop_cached, "valueCached": value_cached})
  232. else:
  233. claims.append({"property": prop, "value": val, "type": prop_type})
  234. if description and label and claims:
  235. out_entities.append({"label": label, "description": description, "claims": claims})
  236. if not resolve_properties:
  237. return {"entities": out_entities}
  238. if combined_property_list:
  239. self._resolve_properties(list(combined_property_list))
  240. for ent in out_entities:
  241. for claim in ent["claims"]:
  242. prop_cached = claim.pop("propertyCached", False)
  243. val_cached = claim.pop("valueCached", False)
  244. if not prop_cached and claim["property"] in PROPERTY_CACHE:
  245. claim["property"] = PROPERTY_CACHE[claim["property"]]
  246. if not val_cached and claim["value"] in PROPERTY_CACHE:
  247. claim["value"] = PROPERTY_CACHE[claim["value"]]
  248. claim["type"] = "string"
  249. return {"entities": out_entities}
  250. def _resolve_properties(self, property_list: list[str]) -> None:
  251. # Placeholder for batch property label resolution, kept synchronous for now.
  252. # Use wbgetentities in batches of 50.
  253. for prop_id in property_list:
  254. PROPERTY_CACHE.setdefault(prop_id, prop_id)
  255. async def get_entity_data(self, qid: str) -> dict[str, Any]:
  256. client = self._client or httpx.AsyncClient(**_client_kwargs())
  257. close_client = self._client is None
  258. try:
  259. resp = await client.get(
  260. f"https://www.wikidata.org/wiki/Special:EntityData/{qid}.json",
  261. params={"flavor": "dump"},
  262. )
  263. resp.raise_for_status()
  264. return resp.json()
  265. finally:
  266. if close_client:
  267. await client.aclose()