| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316 |
- from __future__ import annotations
- import asyncio
- import os
- import json
- from dataclasses import dataclass, field
- from typing import Any, Callable
- from urllib.parse import urlencode
- import httpx
- PROPERTY_CACHE: dict[str, str] = {}
- WIKIDATA_USER_AGENT = os.getenv(
- "ATLAS_WIKIDATA_USER_AGENT",
- "Atlas/1.0 (contact: lukas.goldschmidt+atlas@googlemail.com)",
- )
- # The Wikidata reconciliation endpoint has moved more than once, so keep it configurable.
- WIKIDATA_QUICK_RESOLVE_BASE_URL = os.getenv(
- "ATLAS_WIKIDATA_QUICK_RESOLVE_URL",
- "https://wikidata.reconci.link/en/api",
- )
- OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL", "http://192.168.0.200:11434")
- OLLAMA_EMBEDDING_MODEL = os.getenv("OLLAMA_EMBEDDING_MODEL", "nomic-embed-text")
- @dataclass
- class WikidataOptions:
- search: str = ""
- language: str = "en"
- strictlanguage: bool = True
- type: str = "item"
- limit: int = 7
- searchAction: str = "wbsearchentities"
- getAction: str = "wbgetentities"
- apiHost: str = "www.wikidata.org"
- apiPath: str = "/w/api.php"
- def _is_null(value: Any) -> bool:
- return value is None
- def _build_url(opts: WikidataOptions, params: dict[str, Any]) -> str:
- query = urlencode(params)
- return f"https://{opts.apiHost}{opts.apiPath}?{query}"
- def _client_kwargs() -> dict[str, Any]:
- return {
- "timeout": 20,
- "headers": {"Accept": "application/json", "User-Agent": WIKIDATA_USER_AGENT},
- "follow_redirects": True,
- }
- def _ollama_client_kwargs() -> dict[str, Any]:
- return {
- "timeout": 20,
- "base_url": OLLAMA_BASE_URL,
- }
- class WikidataSearch:
- def __init__(self, options: dict[str, Any] | None = None, *, client: httpx.AsyncClient | None = None):
- self.defaultOptions = WikidataOptions()
- self.options = WikidataOptions(**{k: v for k, v in (options or {}).items() if hasattr(self.defaultOptions, k)})
- self._client = client
- def set(self, key: str, value: Any) -> None:
- if hasattr(self.options, key):
- setattr(self.options, key, value)
- def validateOptions(self) -> bool:
- if len(self.options.search) == 0:
- return False
- if self.options.limit > 50 or self.options.limit < 1:
- return False
- return True
- def clearPropertyCache(self) -> None:
- PROPERTY_CACHE.clear()
- async def embed_text(self, text: str) -> list[float] | None:
- client = self._client or httpx.AsyncClient(**_ollama_client_kwargs())
- close_client = self._client is None
- try:
- resp = await client.post(
- "/api/embeddings",
- json={"model": OLLAMA_EMBEDDING_MODEL, "prompt": text},
- )
- resp.raise_for_status()
- data = resp.json()
- embedding = data.get("embedding")
- return embedding if isinstance(embedding, list) else None
- finally:
- if close_client:
- await client.aclose()
- async def search(self) -> dict[str, Any]:
- if not self.validateOptions():
- return {"results": [], "error": "Bad options"}
- params = {
- "action": self.options.searchAction,
- "language": self.options.language,
- "search": self.options.search,
- "type": self.options.type,
- "limit": self.options.limit,
- "format": "json",
- }
- url = _build_url(self.options, params)
- client = self._client or httpx.AsyncClient(**_client_kwargs())
- close_client = self._client is None
- try:
- resp = await client.get(url)
- resp.raise_for_status()
- data = resp.json()
- results = []
- for item in data.get("search", []):
- trimmed = {}
- if item.get("url"):
- trimmed["url"] = item["url"]
- if item.get("id"):
- trimmed["id"] = item["id"]
- if item.get("label"):
- trimmed["label"] = item["label"]
- if item.get("description"):
- trimmed["description"] = item["description"]
- if {"url", "id", "label"}.issubset(trimmed):
- results.append(trimmed)
- return {"results": results}
- finally:
- if close_client:
- await client.aclose()
- async def quick_resolve(self, query: str, *, limit: int = 1) -> dict[str, Any]:
- """Use wikidata.reconci.link quick resolve endpoint.
- Returns a payload shaped like:
- {"results": [{"id": "Q..", "label": "..", "description": "..", "type": ".."}, ...]}
- """
- endpoint = WIKIDATA_QUICK_RESOLVE_BASE_URL
- params = {
- "queries": json.dumps({"q0": {"query": query, "limit": limit}}),
- }
- client = self._client or httpx.AsyncClient(**_client_kwargs())
- close_client = self._client is None
- try:
- resp = await client.get(endpoint, params=params)
- resp.raise_for_status()
- data = resp.json()
- finally:
- if close_client:
- await client.aclose()
- results = []
- for row in (data.get("q0", {}) or {}).get("result", []) or []:
- # type is an array of {id,name}; pick the first.
- t0 = (row.get("type") or [])
- type_id = t0[0].get("id") if t0 else None
- results.append(
- {
- "id": row.get("id"),
- "label": row.get("name"),
- "description": row.get("description"),
- "type": type_id,
- }
- )
- return {"results": results}
- async def candidate_embeddings(self, candidates: list[dict[str, Any]]) -> list[dict[str, Any]]:
- out = []
- for cand in candidates:
- text_parts = [cand.get("label") or "", cand.get("description") or "", " ".join(cand.get("aliases") or [])]
- text = " | ".join(part for part in text_parts if part)
- embedding = await self.embed_text(text)
- out.append({**cand, "embedding": embedding})
- return out
- def search_sync(self) -> dict[str, Any]:
- return asyncio.run(self.search())
- async def get_entities(self, entities: list[str], resolve_properties: bool = True) -> dict[str, Any]:
- if not isinstance(entities, list):
- return {"error": "Bad |entities| parameter. Must be an array of strings"}
- if len(entities) == 0:
- return {"entities": []}
- if len(entities) > 50:
- entities = entities[:50]
- params = {
- "action": self.options.getAction,
- "languages": self.options.language,
- "redirects": "yes",
- "props": "claims|descriptions|labels",
- "normalize": "true",
- "ids": "|".join(entities),
- "format": "json",
- }
- url = _build_url(self.options, params)
- client = self._client or httpx.AsyncClient(**_client_kwargs())
- close_client = self._client is None
- try:
- resp = await client.get(url)
- resp.raise_for_status()
- data = resp.json()
- return self._parse_entities(data, resolve_properties)
- finally:
- if close_client:
- await client.aclose()
- def get_entities_sync(self, entities: list[str], resolve_properties: bool = True) -> dict[str, Any]:
- return asyncio.run(self.get_entities(entities, resolve_properties))
- def _parse_entities(self, data: dict[str, Any], resolve_properties: bool) -> dict[str, Any]:
- out_entities = []
- combined_property_list: set[str] = set()
- for entity in data.get("entities", {}).values():
- description = entity.get("descriptions", {}).get(self.options.language, {}).get("value", "")
- label = entity.get("labels", {}).get(self.options.language, {}).get("value", "")
- claims = []
- for claim_group in entity.get("claims", {}).values():
- for claim in claim_group:
- snak = claim.get("mainsnak", {})
- if snak.get("snaktype") != "value":
- continue
- prop = snak.get("property", "")
- prop_type = snak.get("datatype", "")
- val = ""
- dv = snak.get("datavalue", {}).get("value")
- if not dv:
- continue
- if prop_type == "wikibase-item":
- val = f"Q{dv.get('numeric-id')}"
- elif prop_type in {"string", "url", "external-id"}:
- val = dv
- elif prop_type == "time":
- val = dv.get("time", "")
- elif prop_type == "globe-coordinate":
- val = f"{dv.get('longitude')},{dv.get('latitude')}"
- elif prop_type == "quantity":
- val = dv.get("amount", "")
- if dv.get("unit") and dv.get("unit") != "1":
- val = f"{val}{dv.get('unit')}"
- else:
- continue
- if prop and val and prop_type:
- if resolve_properties:
- prop_cached = prop in PROPERTY_CACHE
- if prop_cached:
- prop = PROPERTY_CACHE[prop]
- else:
- combined_property_list.add(prop)
- value_cached = True
- if prop_type == "wikibase-item":
- value_cached = val in PROPERTY_CACHE
- if value_cached:
- val = PROPERTY_CACHE[val]
- else:
- combined_property_list.add(val)
- claims.append({"property": prop, "value": val, "type": prop_type, "propertyCached": prop_cached, "valueCached": value_cached})
- else:
- claims.append({"property": prop, "value": val, "type": prop_type})
- if description and label and claims:
- out_entities.append({"label": label, "description": description, "claims": claims})
- if not resolve_properties:
- return {"entities": out_entities}
- if combined_property_list:
- self._resolve_properties(list(combined_property_list))
- for ent in out_entities:
- for claim in ent["claims"]:
- prop_cached = claim.pop("propertyCached", False)
- val_cached = claim.pop("valueCached", False)
- if not prop_cached and claim["property"] in PROPERTY_CACHE:
- claim["property"] = PROPERTY_CACHE[claim["property"]]
- if not val_cached and claim["value"] in PROPERTY_CACHE:
- claim["value"] = PROPERTY_CACHE[claim["value"]]
- claim["type"] = "string"
- return {"entities": out_entities}
- def _resolve_properties(self, property_list: list[str]) -> None:
- # Placeholder for batch property label resolution, kept synchronous for now.
- # Use wbgetentities in batches of 50.
- for prop_id in property_list:
- PROPERTY_CACHE.setdefault(prop_id, prop_id)
- async def get_entity_data(self, qid: str) -> dict[str, Any]:
- client = self._client or httpx.AsyncClient(**_client_kwargs())
- close_client = self._client is None
- try:
- resp = await client.get(
- f"https://www.wikidata.org/wiki/Special:EntityData/{qid}.json",
- params={"flavor": "dump"},
- )
- resp.raise_for_status()
- return resp.json()
- finally:
- if close_client:
- await client.aclose()
|