Forráskód Böngészése

Improve entity lookup fallback and docs

Lukas Goldschmidt 1 hónapja
szülő
commit
57bb07fdd6

+ 7 - 0
README.md

@@ -31,6 +31,13 @@ Health:
 
 2) `get_events_for_entity(entity, limit)`
 - substring, case-insensitive match over extracted `entities`
+- uses a shallow recent scan first, then falls back to a wider historical scan if needed
+
+### Entity aliasing
+
+The server keeps a conservative alias map in `config/entity_aliases.json` for obvious shorthands
+like `btc -> Bitcoin`, `eth -> Ethereum`, and `ether -> Ethereum`. Keep this map tight; it is meant
+to reduce false misses, not to rewrite every possible name variant.
 
 3) `get_event_summary(event_id)`
 - Groq-written compressed narrative for a given `cluster_id`

+ 1 - 0
config/entity_aliases.json

@@ -2,6 +2,7 @@
   "btc": "Bitcoin",
   "bitcoin": "Bitcoin",
   "eth": "Ethereum",
+  "ether": "Ethereum",
   "ethereum": "Ethereum",
   "fed": "Federal Reserve",
   "federal reserve": "Federal Reserve",

+ 16 - 7
news_mcp/mcp_server_fastmcp.py

@@ -114,15 +114,24 @@ async def get_events_for_entity(entity: str, limit: int = 10):
 
     # Cache-first: search recent clusters across all topics.
     store = SQLiteClusterStore(DB_PATH)
+
+    def _match_clusters(clusters: list[dict]) -> list[dict]:
+        hits: list[dict] = []
+        for c in clusters:
+            haystack = _cluster_entity_haystack(c)
+            if any(any(term in item for item in haystack) for term in query_terms):
+                hits.append(c)
+            if len(hits) >= limit:
+                break
+        return hits
+
     clusters = store.get_latest_clusters_all_topics(ttl_hours=CLUSTERS_TTL_HOURS, limit=limit * 5)
+    hits = _match_clusters(clusters)
 
-    hits = []
-    for c in clusters:
-        haystack = _cluster_entity_haystack(c)
-        if any(any(term in item for item in haystack) for term in query_terms):
-            hits.append(c)
-        if len(hits) >= limit:
-            break
+    # If the recent slice misses, broaden the search window before giving up.
+    if not hits:
+        clusters = store.get_latest_clusters_all_topics(ttl_hours=24 * 7, limit=500)
+        hits = _match_clusters(clusters)
 
     # Compress to tool response shape.
     out = []

+ 138 - 0
scripts/backfill_news_entities.py

@@ -0,0 +1,138 @@
+from __future__ import annotations
+
+"""One-off backfill for news-mcp cluster entity metadata.
+
+This reprocesses every stored cluster so older rows pick up the current
+entity normalization and trends resolution fields used by lookup tools.
+"""
+
+import argparse
+import asyncio
+import json
+import sys
+import logging
+from pathlib import Path
+from typing import Any
+
+
+ROOT = Path(__file__).resolve().parents[1]
+if str(ROOT) not in sys.path:
+    sys.path.insert(0, str(ROOT))
+
+from news_mcp.config import DB_PATH
+from news_mcp.entity_normalize import normalize_entities
+from news_mcp.storage.sqlite_store import SQLiteClusterStore
+from news_mcp.trends_resolution import resolve_entity_via_trends
+
+
+def _compute_entity_resolutions(entities: list[str]) -> list[dict[str, Any]]:
+    return [resolve_entity_via_trends(ent) for ent in entities]
+
+
+def _needs_backfill(cluster: dict) -> bool:
+    raw_entities = cluster.get("entities", []) or []
+    normalized_entities = normalize_entities(raw_entities)
+    resolutions = cluster.get("entityResolutions", []) or []
+
+    # Clearly missing or stale metadata.
+    if not resolutions:
+        return bool(normalized_entities or raw_entities)
+    if normalized_entities != raw_entities:
+        return True
+    if len(resolutions) != len(normalized_entities):
+        return True
+    for res in resolutions:
+        if not isinstance(res, dict):
+            return True
+        if not res.get("normalized") or not res.get("canonical_label"):
+            return True
+    return False
+
+
+async def backfill(
+    db_path: Path,
+    limit: int | None = None,
+    dry_run: bool = False,
+    scan_only: bool = False,
+    progress_every: int = 20,
+) -> dict[str, int]:
+    store = SQLiteClusterStore(db_path)
+
+    with store._conn() as conn:  # noqa: SLF001 - intentional one-off maintenance script
+        cur = conn.execute("SELECT cluster_id, topic, payload FROM clusters ORDER BY updated_at ASC")
+        rows = cur.fetchall()
+
+    total = 0
+    updated = 0
+    skipped = 0
+    logger = logging.getLogger("news_mcp.backfill")
+
+    print(f"starting backfill: total_rows={len(rows)} limit={limit or 'all'} dry_run={dry_run} scan_only={scan_only}", flush=True)
+
+    for cluster_id, topic, payload_json in rows:
+        total += 1
+        try:
+            cluster = json.loads(payload_json)
+        except Exception:
+            skipped += 1
+            continue
+
+        if not _needs_backfill(cluster):
+            continue
+
+        raw_entities = cluster.get("entities", []) or []
+        normalized_entities = normalize_entities(raw_entities)
+        entity_resolutions = _compute_entity_resolutions(normalized_entities) if not scan_only else []
+
+        cluster = dict(cluster)
+        cluster["entities"] = normalized_entities
+        if not scan_only:
+            cluster["entityResolutions"] = entity_resolutions
+
+        if not dry_run:
+            store.upsert_clusters([cluster], topic=topic or cluster.get("topic", "other"))
+        updated += 1
+
+        if limit is not None and updated >= limit:
+            break
+
+        if progress_every and total % progress_every == 0:
+            logger.info("backfill progress total=%s updated=%s skipped=%s", total, updated, skipped)
+
+        print(".", end="", flush=True)
+        if progress_every and total % progress_every == 0:
+            print(flush=True)
+
+    if total % max(1, progress_every) != 0:
+        print(flush=True)
+
+    return {"total": total, "updated": updated, "skipped": skipped}
+
+
+def main() -> None:
+    parser = argparse.ArgumentParser(description="Backfill news-mcp entity normalization and resolutions")
+    parser.add_argument("--db", type=Path, default=DB_PATH, help="Path to the news sqlite DB")
+    parser.add_argument("--limit", type=int, default=None, help="Optional cap on clusters processed")
+    parser.add_argument("--dry-run", action="store_true", help="Compute changes without writing them")
+    parser.add_argument("--scan-only", action="store_true", help="Only detect stale rows; skip trends lookups")
+    parser.add_argument("--progress-every", type=int, default=20, help="Emit a progress line every N processed rows")
+    args = parser.parse_args()
+
+    logging.basicConfig(level=logging.INFO, format="%(message)s")
+    result = asyncio.run(
+        backfill(
+            args.db,
+            limit=args.limit,
+            dry_run=args.dry_run,
+            scan_only=args.scan_only,
+            progress_every=max(0, int(args.progress_every)),
+        )
+    )
+    mode = "DRY RUN" if args.dry_run else "DONE"
+    if args.scan_only:
+        mode = f"{mode} SCAN-ONLY"
+    print(f"{mode} {result}")
+
+
+if __name__ == "__main__":
+    main()