Browse Source

concurrent: async RSS fetching, embedding, and LLM enrichment

- config.py: add llm_concurrency() with per-provider defaults (openrouter=2, openai=5, groq=8) and NEWS_LLM_CONCURRENCY_<PROVIDER> env overrides; add RSS/OLLAMA semaphore env vars
- news_feeds.py: rewrite fetch_news_articles as fully async with httpx + asyncio.gather for concurrent RSS feed fetching (bounded by NEWS_RSS_MAX_CONCURRENCY=10)
- embedding_support.py: rewrite ollama_embed as async with httpx and a global asyncio.Semaphore bounded by NEWS_OLLAMA_MAX_CONCURRENCY=4
- cluster.py: dedup_and_cluster_articles stays sync at the API boundary but pre-computes all Ollama embeddings concurrently via asyncio.gather before the CPU-bound clustering loop
- poller.py: LLM enrichment uses asyncio.Semaphore per provider config; all clusters within a topic are enriched concurrently via asyncio.gather; all topic enrichment phases run concurrently; clustering runs via asyncio.to_thread to avoid blocking the event loop
- test_news_mcp.py: update poller tests for async fetch_news_articles mocks
Lukas Goldschmidt 1 week ago
parent
commit
f4dc0998eb

+ 33 - 0
news_mcp/config.py

@@ -54,3 +54,36 @@ NEWS_BACKGROUND_REFRESH_ON_START = os.getenv("NEWS_BACKGROUND_REFRESH_ON_START",
 NEWS_PRUNING_ENABLED = os.getenv("NEWS_PRUNING_ENABLED", "true").lower() == "true"
 NEWS_PRUNING_ENABLED = os.getenv("NEWS_PRUNING_ENABLED", "true").lower() == "true"
 NEWS_RETENTION_DAYS = float(os.getenv("NEWS_RETENTION_DAYS", "180"))
 NEWS_RETENTION_DAYS = float(os.getenv("NEWS_RETENTION_DAYS", "180"))
 NEWS_PRUNE_INTERVAL_HOURS = float(os.getenv("NEWS_PRUNE_INTERVAL_HOURS", "24"))
 NEWS_PRUNE_INTERVAL_HOURS = float(os.getenv("NEWS_PRUNE_INTERVAL_HOURS", "24"))
+
+# ---------------------------------------------------------------------------
+# Concurrency controls
+# ---------------------------------------------------------------------------
+# Maximum concurrent outbound LLM API calls per provider.
+# Defaults are conservative for free tiers; override via env if you have
+# higher rate limits or are on a paid plan.
+_NEEDLE_DEFAULT_CONCURRENCY = {
+    "openrouter": 2,
+    "openai": 5,
+    "groq": 8,
+}
+
+_NEEDLE_RSS_MAX_CONCURRENCY = int(os.getenv("NEWS_RSS_MAX_CONCURRENCY", "10"))
+_NEEDLE_OLLAMA_MAX_CONCURRENCY = int(os.getenv("NEWS_OLLAMA_MAX_CONCURRENCY", "4"))
+
+
+def llm_concurrency(provider: str) -> int:
+    """Return the max concurrent LLM calls for *provider*.
+
+    Reads from ``NEWS_LLM_CONCURRENCY_<PROVIDER>`` env var first (e.g.
+    ``NEWS_LLM_CONCURRENCY_OPENROUTER``), then falls back to the built-in
+    default map.
+    """
+    provider = provider.strip().lower()
+    env_key = f"NEWS_LLM_CONCURRENCY_{provider.upper()}"
+    env_val = os.getenv(env_key)
+    if env_val is not None:
+        try:
+            return max(1, int(env_val))
+        except ValueError:
+            pass
+    return _NEEDLE_DEFAULT_CONCURRENCY.get(provider, 3)

+ 74 - 60
news_mcp/dedup/cluster.py

@@ -1,5 +1,6 @@
 from __future__ import annotations
 from __future__ import annotations
 
 
+import asyncio
 import hashlib
 import hashlib
 import re
 import re
 from difflib import SequenceMatcher
 from difflib import SequenceMatcher
@@ -18,7 +19,6 @@ from news_mcp.sources.news_feeds import normalize_topic_from_title
 
 
 def _normalize_title(title: str) -> str:
 def _normalize_title(title: str) -> str:
     t = title.lower().strip()
     t = title.lower().strip()
-    # Remove punctuation-ish characters for similarity scoring.
     t = re.sub(r"[^a-z0-9\s]", " ", t)
     t = re.sub(r"[^a-z0-9\s]", " ", t)
     t = re.sub(r"\s+", " ", t).strip()
     t = re.sub(r"\s+", " ", t).strip()
     return t
     return t
@@ -48,12 +48,9 @@ def _cluster_text(a: Dict[str, Any]) -> str:
 
 
 
 
 # ---------------------------------------------------------------------------
 # ---------------------------------------------------------------------------
-# Token / Jaccard signal (used as a fallback alongside title similarity when
-# embeddings are unavailable, and as a soft signal even when they are).
+# Token / Jaccard signal
 # ---------------------------------------------------------------------------
 # ---------------------------------------------------------------------------
 
 
-# Tiny stop-word set — we keep it small on purpose because the corpus is news
-# headlines, where every additional removal risks losing genuine signal.
 _STOPWORDS = frozenset(
 _STOPWORDS = frozenset(
     {
     {
         "a", "an", "the", "of", "to", "in", "on", "at", "for", "by", "with",
         "a", "an", "the", "of", "to", "in", "on", "at", "for", "by", "with",
@@ -68,7 +65,6 @@ _STOPWORDS = frozenset(
 
 
 
 
 def _tokens(text: str) -> set[str]:
 def _tokens(text: str) -> set[str]:
-    """Lowercase content tokens, stop-words removed, length>=3."""
     tokens = re.findall(r"[a-z0-9][a-z0-9\-]+", text.lower())
     tokens = re.findall(r"[a-z0-9][a-z0-9\-]+", text.lower())
     return {t for t in tokens if len(t) >= 3 and t not in _STOPWORDS}
     return {t for t in tokens if len(t) >= 3 and t not in _STOPWORDS}
 
 
@@ -86,22 +82,12 @@ def _jaccard(a: set, b: set) -> float:
 # Composite similarity
 # Composite similarity
 # ---------------------------------------------------------------------------
 # ---------------------------------------------------------------------------
 
 
-
-# Each signal has its own threshold. We accept a merge if ANY signal clears its
-# threshold, which makes clustering robust when one signal happens to be weak
-# (short headlines kill SequenceMatcher; single-word stories kill Jaccard;
-# Ollama outages kill cosine similarity).
 DEFAULT_TITLE_THRESHOLD = 0.87
 DEFAULT_TITLE_THRESHOLD = 0.87
 DEFAULT_JACCARD_THRESHOLD = 0.55
 DEFAULT_JACCARD_THRESHOLD = 0.55
 
 
 
 
 def _signals(article: Dict[str, Any], cluster: Dict[str, Any]) -> dict:
 def _signals(article: Dict[str, Any], cluster: Dict[str, Any]) -> dict:
-    """Per-pair similarity signals (title, jaccard, embedding cosine).
-
-    Embedding cosine is only computed when both sides have a vector; we never
-    block on a fresh Ollama request here — that's the caller's job, so this
-    function stays pure and easy to test.
-    """
+    """Per-pair similarity signals (title, jaccard, embedding cosine)."""
     a_title = str(article.get("title") or "")
     a_title = str(article.get("title") or "")
     c_title = str(cluster.get("headline") or "")
     c_title = str(cluster.get("headline") or "")
 
 
@@ -120,11 +106,7 @@ def _signals(article: Dict[str, Any], cluster: Dict[str, Any]) -> dict:
 
 
 
 
 def _is_match(signals: dict, *, embeddings_enabled: bool) -> tuple[bool, str, float]:
 def _is_match(signals: dict, *, embeddings_enabled: bool) -> tuple[bool, str, float]:
-    """Decide whether two items should merge based on the strongest signal.
-
-    Returns (matched, signal_name, signal_value). The signal_name lets callers
-    log *why* something merged, which is huge for debugging clustering quality.
-    """
+    """Decide whether two items should merge based on the strongest signal."""
     cosine_threshold = NEWS_EMBEDDING_SIMILARITY_THRESHOLD
     cosine_threshold = NEWS_EMBEDDING_SIMILARITY_THRESHOLD
     if embeddings_enabled and signals["cosine"] >= cosine_threshold:
     if embeddings_enabled and signals["cosine"] >= cosine_threshold:
         return True, "cosine", signals["cosine"]
         return True, "cosine", signals["cosine"]
@@ -136,7 +118,65 @@ def _is_match(signals: dict, *, embeddings_enabled: bool) -> tuple[bool, str, fl
 
 
 
 
 # ---------------------------------------------------------------------------
 # ---------------------------------------------------------------------------
-# Public API
+# Embedding pre-computation (async internally)
+# ---------------------------------------------------------------------------
+
+
+async def _compute_embeddings_concurrently(
+    articles: List[Dict[str, Any]],
+) -> Dict[str, list[float] | None]:
+    """Compute embeddings for unique article texts concurrently.
+
+    Returns a cache dict: text -> embedding vector or None.
+    """
+    unique_texts: list[str] = []
+    seen: set[str] = set()
+    for a in articles:
+        text = _cluster_text(a)
+        if text and text not in seen:
+            seen.add(text)
+            unique_texts.append(text)
+
+    emb_tasks = [ollama_embed(text) for text in unique_texts]
+    emb_results = await asyncio.gather(*emb_tasks, return_exceptions=True)
+
+    cache: Dict[str, list[float] | None] = {}
+    for text, result in zip(unique_texts, emb_results):
+        if isinstance(result, list):
+            cache[text] = result
+        else:
+            cache[text] = None
+    return cache
+
+
+def _compute_embeddings_sync(
+    articles: List[Dict[str, Any]],
+) -> Dict[str, list[float] | None]:
+    """Synchronous wrapper that runs the async embedding computation.
+
+    Handles three cases:
+    1. Already inside an async event loop (called from poller) -> schedule
+       as a task and run it to completion on the running loop.
+    2. No event loop at all (plain sync caller) -> use asyncio.run().
+    """
+    try:
+        loop = asyncio.get_running_loop()
+    except RuntimeError:
+        # No running loop — safe to use asyncio.run()
+        return asyncio.run(_compute_embeddings_concurrently(articles))
+
+    # We're inside a running event loop (e.g. the poller). Create a new loop
+    # in a thread to avoid blocking.
+    import concurrent.futures
+    with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
+        future = pool.submit(
+            asyncio.run, _compute_embeddings_concurrently(articles)
+        )
+        return future.result()
+
+
+# ---------------------------------------------------------------------------
+# Public API (sync — backward compatible with tests)
 # ---------------------------------------------------------------------------
 # ---------------------------------------------------------------------------
 
 
 
 
@@ -146,36 +186,23 @@ def dedup_and_cluster_articles(
 ) -> Dict[str, List[Dict[str, Any]]]:
 ) -> Dict[str, List[Dict[str, Any]]]:
     """Deduplicate raw articles into clusters keyed by topic.
     """Deduplicate raw articles into clusters keyed by topic.
 
 
-    v1.1 strategy: composite similarity.
+    v1.2: embedding pre-computation is async/concurrent under the hood, but
+    this public function remains synchronous for backward compatibility.
+
+    A pair merges if ANY signal clears its threshold:
       * title fuzzy ratio
       * title fuzzy ratio
-      * token Jaccard over headline+summary (cheap, surprisingly resilient
-        when titles are reworded heavily across outlets)
+      * token Jaccard over headline+summary
       * Ollama embedding cosine when available
       * Ollama embedding cosine when available
-
-    A pair merges if ANY signal clears its threshold. Falling back through
-    multiple signals means a transient Ollama outage doesn't collapse the
-    server back into title-only clustering, and a heavily-reworded headline
-    can still merge via Jaccard or embeddings.
-
-    The ``similarity_threshold`` argument is kept for backward compatibility
-    with the test suite. When provided, it overrides the title threshold.
     """
     """
 
 
     title_threshold = similarity_threshold if similarity_threshold is not None else DEFAULT_TITLE_THRESHOLD
     title_threshold = similarity_threshold if similarity_threshold is not None else DEFAULT_TITLE_THRESHOLD
 
 
-    by_topic: Dict[str, List[Dict[str, Any]]] = {}
+    # Pre-compute embeddings concurrently (sync boundary handles async internally)
     embedding_cache: Dict[str, list[float] | None] = {}
     embedding_cache: Dict[str, list[float] | None] = {}
+    if NEWS_EMBEDDINGS_ENABLED:
+        embedding_cache = _compute_embeddings_sync(articles)
 
 
-    def _embedding_for_text(text: str) -> list[float] | None:
-        if not NEWS_EMBEDDINGS_ENABLED or not text:
-            return None
-        if text in embedding_cache:
-            return embedding_cache[text]
-        emb = ollama_embed(text)
-        # Cache None too so a single failure doesn't trigger repeated retries
-        # within one ingestion cycle. The next refresh call clears this map.
-        embedding_cache[text] = emb
-        return emb
+    by_topic: Dict[str, List[Dict[str, Any]]] = {}
 
 
     for a in articles:
     for a in articles:
         title = a.get("title") or ""
         title = a.get("title") or ""
@@ -183,10 +210,8 @@ def dedup_and_cluster_articles(
             continue
             continue
         topic = normalize_topic_from_title(title)
         topic = normalize_topic_from_title(title)
         article_text = _cluster_text(a)
         article_text = _cluster_text(a)
-        article_embedding = _embedding_for_text(article_text)
 
 
-        # Attach embedding on the article dict so _signals() can read it
-        # without re-computing.
+        article_embedding = embedding_cache.get(article_text) if NEWS_EMBEDDINGS_ENABLED else None
         a_with_emb = dict(a)
         a_with_emb = dict(a)
         if article_embedding is not None:
         if article_embedding is not None:
             a_with_emb["_embedding"] = article_embedding
             a_with_emb["_embedding"] = article_embedding
@@ -199,8 +224,6 @@ def dedup_and_cluster_articles(
         best_signal_value = 0.0
         best_signal_value = 0.0
         for idx, c in enumerate(clusters):
         for idx, c in enumerate(clusters):
             sigs = _signals(a_with_emb, c)
             sigs = _signals(a_with_emb, c)
-            # Use the title threshold the caller explicitly passed (test override)
-            # but otherwise rely on the module defaults.
             local_match = False
             local_match = False
             if NEWS_EMBEDDINGS_ENABLED and sigs["cosine"] >= NEWS_EMBEDDING_SIMILARITY_THRESHOLD:
             if NEWS_EMBEDDINGS_ENABLED and sigs["cosine"] >= NEWS_EMBEDDING_SIMILARITY_THRESHOLD:
                 local_match = True
                 local_match = True
@@ -211,11 +234,6 @@ def dedup_and_cluster_articles(
             elif sigs["jaccard"] >= DEFAULT_JACCARD_THRESHOLD:
             elif sigs["jaccard"] >= DEFAULT_JACCARD_THRESHOLD:
                 local_match = True
                 local_match = True
                 signal_name, signal_value = "jaccard", sigs["jaccard"]
                 signal_name, signal_value = "jaccard", sigs["jaccard"]
-            # Consensus rule: when no single signal clears its strict threshold
-            # but two of them are simultaneously "strong-ish", treat that as a
-            # match. This catches reworded headlines whose embedding is just
-            # below the strict cosine cutoff. Numbers are intentionally
-            # conservative — both signals must be clearly above noise.
             elif (
             elif (
                 NEWS_EMBEDDINGS_ENABLED
                 NEWS_EMBEDDINGS_ENABLED
                 and sigs["cosine"] >= 0.80
                 and sigs["cosine"] >= 0.80
@@ -240,13 +258,10 @@ def dedup_and_cluster_articles(
             if a.get("source") and a["source"] not in c["sources"]:
             if a.get("source") and a["source"] not in c["sources"]:
                 c["sources"].append(a["source"])
                 c["sources"].append(a["source"])
             c["last_updated"] = max(str(c.get("last_updated", "")), str(a.get("timestamp", "")))
             c["last_updated"] = max(str(c.get("last_updated", "")), str(a.get("timestamp", "")))
-            # Keep a tiny audit trail per cluster on which signal grew it last.
-            # Not surfaced through tools — lives in the payload only for debug.
             c.setdefault("_merge_signals", []).append(
             c.setdefault("_merge_signals", []).append(
                 {"signal": best_signal_name, "value": round(best_signal_value, 3)}
                 {"signal": best_signal_name, "value": round(best_signal_value, 3)}
             )
             )
         else:
         else:
-            # Stable cluster id: based on topic + normalized canonical title.
             key = f"{topic}|{_normalize_title(title)}"
             key = f"{topic}|{_normalize_title(title)}"
             cid = hashlib.sha1(key.encode("utf-8")).hexdigest()
             cid = hashlib.sha1(key.encode("utf-8")).hexdigest()
             cluster_embedding = article_embedding if NEWS_EMBEDDINGS_ENABLED else None
             cluster_embedding = article_embedding if NEWS_EMBEDDINGS_ENABLED else None
@@ -269,8 +284,7 @@ def dedup_and_cluster_articles(
                 }
                 }
             )
             )
 
 
-    # Strip the internal merge audit trail before returning so it does not
-    # accidentally bloat the SQLite payload. Storage layer doesn't filter it.
+    # Strip the internal merge audit trail before returning
     for clusters in by_topic.values():
     for clusters in by_topic.values():
         for c in clusters:
         for c in clusters:
             c.pop("_merge_signals", None)
             c.pop("_merge_signals", None)

+ 34 - 22
news_mcp/dedup/embedding_support.py

@@ -1,13 +1,23 @@
 from __future__ import annotations
 from __future__ import annotations
 
 
+import asyncio
+import json
 from dataclasses import dataclass
 from dataclasses import dataclass
 from datetime import datetime, timezone, timedelta
 from datetime import datetime, timezone, timedelta
-import json
-import urllib.request
 from math import sqrt
 from math import sqrt
 from typing import Any
 from typing import Any
 
 
-from news_mcp.config import NEWS_EMBEDDINGS_ENABLED, OLLAMA_BASE_URL, OLLAMA_EMBEDDING_MODEL
+import httpx
+
+from news_mcp.config import (
+    NEWS_EMBEDDINGS_ENABLED,
+    OLLAMA_BASE_URL,
+    OLLAMA_EMBEDDING_MODEL,
+    _NEEDLE_OLLAMA_MAX_CONCURRENCY,
+)
+
+
+_ollama_semaphore = asyncio.Semaphore(_NEEDLE_OLLAMA_MAX_CONCURRENCY)
 
 
 
 
 @dataclass(frozen=True)
 @dataclass(frozen=True)
@@ -85,28 +95,30 @@ def cluster_is_candidate(
     return True
     return True
 
 
 
 
-def ollama_embed(text: str, timeout: float = 20.0) -> list[float] | None:
-    """Best-effort Ollama embedding call; returns None on any failure.
+async def ollama_embed(text: str, timeout: float = 20.0) -> list[float] | None:
+    """Async Ollama embedding call with concurrency limiting.
 
 
-    Embeddings are intentionally optional. The caller should fall back to the
-    heuristic path when this returns None.
+    Returns None on any failure so the caller falls back to heuristic clustering.
     """
     """
-
     if not NEWS_EMBEDDINGS_ENABLED:
     if not NEWS_EMBEDDINGS_ENABLED:
         return None
         return None
+
     payload = json.dumps({"model": OLLAMA_EMBEDDING_MODEL, "prompt": text}).encode("utf-8")
     payload = json.dumps({"model": OLLAMA_EMBEDDING_MODEL, "prompt": text}).encode("utf-8")
-    req = urllib.request.Request(
-        f"{OLLAMA_BASE_URL.rstrip('/')}/api/embeddings",
-        data=payload,
-        headers={"Content-Type": "application/json"},
-        method="POST",
-    )
-    try:
-        with urllib.request.urlopen(req, timeout=timeout) as resp:
-            data = json.loads(resp.read().decode("utf-8"))
-            emb = data.get("embedding")
-            if isinstance(emb, list) and emb:
-                return [float(x) for x in emb]
-    except Exception:
-        return None
+    url = f"{OLLAMA_BASE_URL.rstrip('/')}/api/embeddings"
+
+    async with _ollama_semaphore:
+        try:
+            async with httpx.AsyncClient(timeout=timeout) as client:
+                resp = await client.post(
+                    url,
+                    content=payload,
+                    headers={"Content-Type": "application/json"},
+                )
+                resp.raise_for_status()
+                data = resp.json()
+                emb = data.get("embedding")
+                if isinstance(emb, list) and emb:
+                    return [float(x) for x in emb]
+        except Exception:
+            return None
     return None
     return None

+ 125 - 72
news_mcp/jobs/poller.py

@@ -1,26 +1,114 @@
 from __future__ import annotations
 from __future__ import annotations
 
 
 import asyncio
 import asyncio
+import hashlib
 import logging
 import logging
 from collections import defaultdict
 from collections import defaultdict
 from datetime import datetime, timezone
 from datetime import datetime, timezone
 from typing import Any, Dict
 from typing import Any, Dict
 
 
-from news_mcp.config import DEFAULT_LOOKBACK_HOURS, DEFAULT_TOPICS, DB_PATH, NEWS_FEED_URL, NEWS_FEED_URLS
-from news_mcp.dedup.cluster import dedup_and_cluster_articles
-from news_mcp.enrichment.enrich import enrich_cluster
-from news_mcp.enrichment.llm_enrich import classify_cluster_llm
-from news_mcp.trends_resolution import resolve_entity_via_trends
-from news_mcp.sources.news_feeds import fetch_news_articles
-from news_mcp.storage.sqlite_store import SQLiteClusterStore
-
 from news_mcp.config import (
 from news_mcp.config import (
+    DEFAULT_LOOKBACK_HOURS,
+    DEFAULT_TOPICS,
+    DB_PATH,
     ENRICH_OTHER_TOPICS_ONLY,
     ENRICH_OTHER_TOPICS_ONLY,
     ENRICHMENT_MAX_PER_REFRESH,
     ENRICHMENT_MAX_PER_REFRESH,
+    NEWS_EXTRACT_PROVIDER,
+    NEWS_FEED_URL,
+    NEWS_FEED_URLS,
     NEWS_PRUNE_INTERVAL_HOURS,
     NEWS_PRUNE_INTERVAL_HOURS,
     NEWS_PRUNING_ENABLED,
     NEWS_PRUNING_ENABLED,
     NEWS_RETENTION_DAYS,
     NEWS_RETENTION_DAYS,
+    llm_concurrency,
 )
 )
+from news_mcp.dedup.cluster import dedup_and_cluster_articles
+from news_mcp.enrichment.enrich import enrich_cluster
+from news_mcp.enrichment.llm_enrich import classify_cluster_llm
+from news_mcp.sources.news_feeds import fetch_news_articles
+from news_mcp.storage.sqlite_store import SQLiteClusterStore
+from news_mcp.trends_resolution import resolve_entity_via_trends
+
+
+async def _enrich_single_cluster(
+    c: dict,
+    topic: str,
+    llm_enabled: bool,
+    semaphore: asyncio.Semaphore,
+    store: SQLiteClusterStore,
+    logger: logging.Logger,
+) -> dict:
+    """Enrich one cluster: heuristic + optional LLM extraction, concurrency-limited."""
+    c2 = enrich_cluster(c)
+    c2.setdefault("topic", topic)
+
+    cluster_id = c2.get("cluster_id")
+    if llm_enabled and cluster_id:
+        # Cache: if we already have entities/sentiment for this cluster, skip LLM call.
+        existing = store.get_cluster_by_id(cluster_id)
+        if existing and existing.get("entities"):
+            c2 = dict(c2)
+            c2["entities"] = existing.get("entities", [])
+
+            existing_resolutions = existing.get("entityResolutions", None)
+            if isinstance(existing_resolutions, list) and existing_resolutions:
+                c2["entityResolutions"] = existing_resolutions
+            else:
+                c2["entityResolutions"] = [resolve_entity_via_trends(e) for e in c2["entities"]]
+
+            if existing.get("sentiment"):
+                c2["sentiment"] = existing.get("sentiment")
+            if existing.get("sentimentScore") is not None:
+                c2["sentimentScore"] = existing.get("sentimentScore")
+            if existing.get("keywords"):
+                c2["keywords"] = existing.get("keywords")
+            if existing.get("topic"):
+                c2["topic"] = existing.get("topic")
+        else:
+            # Acquire semaphore before making outbound LLM call
+            async with semaphore:
+                try:
+                    c2 = await classify_cluster_llm(c2)
+                except Exception:
+                    logger.exception(
+                        "LLM enrichment failed for cluster %s (topic %s)",
+                        c2.get("cluster_id"), topic,
+                    )
+                    c2["enrichment_failed_at"] = datetime.now(timezone.utc).isoformat()
+
+    return c2
+
+
+async def _enrich_topic_clusters(
+    clusters: list[dict],
+    topic: str,
+    semaphore: asyncio.Semaphore,
+    store: SQLiteClusterStore,
+    logger: logging.Logger,
+    enrich_limit: int,
+) -> list[dict]:
+    """Enrich all clusters for a single topic concurrently."""
+    llm_enabled = (not ENRICH_OTHER_TOPICS_ONLY) or (topic == "other")
+
+    # Persist the raw clusters first so a slow enrichment pass does not
+    # leave the first bootstrap run with nothing stored.
+    store.upsert_clusters(clusters, topic=topic)
+    logger.info("refresh stored raw topic=%s clusters=%s", topic, len(clusters))
+
+    targets = clusters[:enrich_limit]
+    tasks = [
+        _enrich_single_cluster(c, topic, llm_enabled, semaphore, store, logger)
+        for c in targets
+    ]
+    enriched = await asyncio.gather(*tasks, return_exceptions=False)
+
+    # Any clusters beyond enrich_limit still need importance enrichment
+    for c in clusters[enrich_limit:]:
+        c2 = enrich_cluster(c)
+        c2.setdefault("topic", topic)
+        enriched.append(c2)
+
+    logger.info("refresh enriched topic=%s clusters=%s", topic, len(enriched))
+    return enriched
 
 
 
 
 async def refresh_clusters(topic: str | None = None, limit: int = 80) -> None:
 async def refresh_clusters(topic: str | None = None, limit: int = 80) -> None:
@@ -28,7 +116,9 @@ async def refresh_clusters(topic: str | None = None, limit: int = 80) -> None:
     store = SQLiteClusterStore(DB_PATH)
     store = SQLiteClusterStore(DB_PATH)
 
 
     logger.info("refresh start topic=%s limit=%s", topic, limit)
     logger.info("refresh start topic=%s limit=%s", topic, limit)
-    articles = await asyncio.to_thread(fetch_news_articles, limit)
+
+    # fetch_news_articles is now fully async (concurrent RSS fetching)
+    articles = await fetch_news_articles(limit)
     logger.info("refresh fetched articles=%s", len(articles))
     logger.info("refresh fetched articles=%s", len(articles))
 
 
     # Drop legacy aggregate feed-state rows so the dashboard only reflects
     # Drop legacy aggregate feed-state rows so the dashboard only reflects
@@ -37,7 +127,6 @@ async def refresh_clusters(topic: str | None = None, limit: int = 80) -> None:
         conn.execute("DELETE FROM feed_state WHERE feed_key LIKE 'newsfeeds:%'")
         conn.execute("DELETE FROM feed_state WHERE feed_key LIKE 'newsfeeds:%'")
 
 
     # Track feed freshness per RSS URL so unchanged feeds can be skipped.
     # Track feed freshness per RSS URL so unchanged feeds can be skipped.
-    import hashlib
     per_feed: dict[str, list[dict[str, Any]]] = defaultdict(list)
     per_feed: dict[str, list[dict[str, Any]]] = defaultdict(list)
     for article in articles:
     for article in articles:
         feed_url = str(article.get("feed_url") or NEWS_FEED_URL).strip() or NEWS_FEED_URL
         feed_url = str(article.get("feed_url") or NEWS_FEED_URL).strip() or NEWS_FEED_URL
@@ -75,87 +164,51 @@ async def refresh_clusters(topic: str | None = None, limit: int = 80) -> None:
 
 
     articles = changed_articles
     articles = changed_articles
     logger.info("refresh clustering start articles=%s topic=%s", len(articles), topic)
     logger.info("refresh clustering start articles=%s topic=%s", len(articles), topic)
-    clustered_by_topic = dedup_and_cluster_articles(articles)
+    # Clustering is sync but may do concurrent embedding fetches internally.
+    # Run off-thread so the event loop stays responsive for MCP tool calls.
+    clustered_by_topic = await asyncio.to_thread(dedup_and_cluster_articles, articles)
     logger.info("refresh clustered topics=%s", list(clustered_by_topic.keys()))
     logger.info("refresh clustered topics=%s", list(clustered_by_topic.keys()))
 
 
+    # Build LLM concurrency semaphore from the extract provider's config.
+    max_llm_concurrent = llm_concurrency(NEWS_EXTRACT_PROVIDER)
+    llm_semaphore = asyncio.Semaphore(max_llm_concurrent)
+    logger.info("refresh llm semaphore limit=%s provider=%s", max_llm_concurrent, NEWS_EXTRACT_PROVIDER)
+
+    # Enrich each topic's clusters concurrently.
+    topic_tasks = []
     for t, clusters in clustered_by_topic.items():
     for t, clusters in clustered_by_topic.items():
         if topic and t != topic:
         if topic and t != topic:
             continue
             continue
-        logger.info("refresh topic phase start topic=%s clusters=%s", t, len(clusters))
-        enriched = []
 
 
         # Determine how many clusters to LLM-enrich.
         # Determine how many clusters to LLM-enrich.
         # ENRICHMENT_MAX_PER_REFRESH=0 means enrich every cluster (no cap).
         # ENRICHMENT_MAX_PER_REFRESH=0 means enrich every cluster (no cap).
         enrich_limit = ENRICHMENT_MAX_PER_REFRESH or len(clusters)
         enrich_limit = ENRICHMENT_MAX_PER_REFRESH or len(clusters)
 
 
-        # Track whether the LLM pipeline is available for this topic.
-        _llm_enabled_for_topic = (
-            (not ENRICH_OTHER_TOPICS_ONLY) or (t == "other")
+        topic_tasks.append(
+            _enrich_topic_clusters(
+                clusters=clusters,
+                topic=t,
+                semaphore=llm_semaphore,
+                store=store,
+                logger=logger,
+                enrich_limit=enrich_limit,
+            )
         )
         )
 
 
-        # Persist the raw clusters first so a slow enrichment pass does not
-        # leave the first bootstrap run with nothing stored.
-        store.upsert_clusters(clusters, topic=t)
-        logger.info("refresh stored raw topic=%s clusters=%s", t, len(clusters))
-
-        for idx, c in enumerate(clusters[:enrich_limit]):
-            c2 = enrich_cluster(c)
-            # Seed the heuristic topic on the payload so classify_cluster_llm
-            # has a sane fallback if the LLM omits or hallucinates one.
-            c2.setdefault("topic", t)
-            logger.info("refresh enrich cluster=%s topic=%s idx=%s/%s", c2.get("cluster_id"), t, idx + 1, enrich_limit)
-
-            if _llm_enabled_for_topic:
-                # Cache: if we already have entities/sentiment for this cluster, skip LLM call.
-                existing = store.get_cluster_by_id(c2.get("cluster_id"))
-                if existing and existing.get("entities"):
-                    c2 = dict(c2)
-                    # Keep existing enriched fields.
-                    c2["entities"] = existing.get("entities", [])
-
-                    # IMPORTANT: entityResolutions must stay consistent with entities.
-                    # Older rows may have entities but missing/malformed resolutions.
-                    existing_resolutions = existing.get("entityResolutions", None)
-                    if isinstance(existing_resolutions, list) and existing_resolutions:
-                        c2["entityResolutions"] = existing_resolutions
-                    else:
-                        # Recompute resolutions deterministically from the stored entities.
-                        c2["entityResolutions"] = [resolve_entity_via_trends(e) for e in c2["entities"]]
-
-                    if existing.get("sentiment"):
-                        c2["sentiment"] = existing.get("sentiment")
-                    if existing.get("sentimentScore") is not None:
-                        c2["sentimentScore"] = existing.get("sentimentScore")
-                    if existing.get("keywords"):
-                        c2["keywords"] = existing.get("keywords")
-                    # Preserve a previously-classified topic so we don't drift back
-                    # to the heuristic on cache hits.
-                    if existing.get("topic"):
-                        c2["topic"] = existing.get("topic")
-                else:
-                    try:
-                        c2 = await classify_cluster_llm(c2)
-                    except Exception:
-                        logger.exception("LLM enrichment failed for cluster %s (topic %s)", c2.get("cluster_id"), t)
-                        # Mark so we can retry on next refresh.
-                        c2["enrichment_failed_at"] = datetime.now(timezone.utc).isoformat()
-
-            enriched.append(c2)
-
-        # Persist clusters under their *post-enrichment* topic so the SQL row
-        # column matches what the LLM (or the validated heuristic fallback)
-        # actually decided. Previously, every cluster from this bucket was
-        # forced into the heuristic topic `t`, which caused a ~97% mismatch
-        # between row-column topic and payload topic.
+    # Run all topic enrichment phases concurrently
+    topic_results = await asyncio.gather(*topic_tasks, return_exceptions=False)
+
+    # Persist enriched clusters grouped by their final topic
+    for enriched in topic_results:
         by_final_topic: Dict[str, list] = {}
         by_final_topic: Dict[str, list] = {}
         for c2 in enriched:
         for c2 in enriched:
-            final_topic = str(c2.get("topic") or t or "other").strip().lower()
+            final_topic = str(c2.get("topic") or "other").strip().lower()
             if final_topic not in {x.lower() for x in DEFAULT_TOPICS}:
             if final_topic not in {x.lower() for x in DEFAULT_TOPICS}:
                 final_topic = "other"
                 final_topic = "other"
             by_final_topic.setdefault(final_topic, []).append(c2)
             by_final_topic.setdefault(final_topic, []).append(c2)
         for final_topic, group in by_final_topic.items():
         for final_topic, group in by_final_topic.items():
             store.upsert_clusters(group, topic=final_topic)
             store.upsert_clusters(group, topic=final_topic)
-            logger.info("refresh stored topic=%s clusters=%s (heuristic_topic=%s)", final_topic, len(group), t)
+            logger.info("refresh stored topic=%s clusters=%s", final_topic, len(group))
 
 
     prune_result = store.prune_if_due(
     prune_result = store.prune_if_due(
         pruning_enabled=NEWS_PRUNING_ENABLED,
         pruning_enabled=NEWS_PRUNING_ENABLED,

+ 93 - 52
news_mcp/sources/news_feeds.py

@@ -1,5 +1,6 @@
 from __future__ import annotations
 from __future__ import annotations
 
 
+import asyncio
 import hashlib
 import hashlib
 import logging
 import logging
 import re
 import re
@@ -8,14 +9,20 @@ from urllib.error import URLError, HTTPError
 from urllib.request import Request, urlopen
 from urllib.request import Request, urlopen
 
 
 import feedparser
 import feedparser
+import httpx
 
 
-from news_mcp.config import NEWS_FEED_ITEMS_PER_POLL, NEWS_FEED_URL, NEWS_FEED_URLS
+from news_mcp.config import (
+    NEWS_FEED_ITEMS_PER_POLL,
+    NEWS_FEED_URL,
+    NEWS_FEED_URLS,
+    _NEEDLE_RSS_MAX_CONCURRENCY,
+)
 
 
 
 
 logger = logging.getLogger(__name__)
 logger = logging.getLogger(__name__)
 
 
 
 
-FEED_FETCH_TIMEOUT_SECONDS = 15
+FEED_FETCH_TIMEOUT_SECONDS = 20
 
 
 
 
 def _canonical_url(url: str) -> str:
 def _canonical_url(url: str) -> str:
@@ -39,64 +46,97 @@ def _feed_urls() -> List[str]:
     return urls
     return urls
 
 
 
 
-def _fetch_feed(feed_url: str):
-    req = Request(feed_url, headers={"User-Agent": "news-mcp/1.0"})
-    with urlopen(req, timeout=FEED_FETCH_TIMEOUT_SECONDS) as resp:
-        return feedparser.parse(resp.read())
+def _parse_feed_from_bytes(data: bytes, feed_url: str):
+    """Parse feed from raw bytes (sync, but fast — just XML parsing)."""
+    return feedparser.parse(data)
 
 
 
 
-def fetch_news_articles(limit: int = NEWS_FEED_ITEMS_PER_POLL) -> List[Dict[str, Any]]:
-    feed_urls = _feed_urls()
+async def _fetch_feed_async(
+    client: httpx.AsyncClient,
+    semaphore: asyncio.Semaphore,
+    feed_url: str,
+) -> tuple[str, bytes | None]:
+    """Fetch a single RSS feed concurrently. Returns (feed_url, raw_bytes)."""
+    async with semaphore:
+        try:
+            resp = await client.get(feed_url, follow_redirects=True)
+            resp.raise_for_status()
+            return (feed_url, resp.content)
+        except (httpx.HTTPStatusError, httpx.TimeoutException, httpx.ConnectError, OSError) as exc:
+            logger.exception("news feed fetch failed feed_url=%s error=%s", feed_url, exc)
+            return (feed_url, None)
+        except Exception as exc:
+            logger.exception("news feed fetch unexpected error feed_url=%s error=%s", feed_url, exc)
+            return (feed_url, None)
+
+
+def _extract_articles_from_feed(
+    feed_url: str,
+    parsed,
+    per_feed_limit: int,
+) -> List[Dict[str, Any]]:
+    """Extract article dicts from a parsed feedparser object (sync)."""
     articles: List[Dict[str, Any]] = []
     articles: List[Dict[str, Any]] = []
+    feed_name = getattr(parsed.feed, "title", None) or feed_url
+    parsed_entries = len(getattr(parsed, "entries", []) or [])
+    logger.info(
+        "news feed parsed feed_url=%s feed_name=%s entries=%s",
+        feed_url, feed_name, parsed_entries,
+    )
+
+    kept = 0
+    for entry in parsed.entries[:per_feed_limit]:
+        title = str(getattr(entry, "title", "")).strip()
+        url = _canonical_url(str(getattr(entry, "link", "")).strip())
+        timestamp = str(getattr(entry, "published", "")) or str(getattr(entry, "updated", ""))
+        summary = _strip_html(
+            str(getattr(entry, "summary", "")) or str(getattr(entry, "description", ""))
+        )
+        if not title or not url:
+            continue
+        articles.append({
+            "title": title,
+            "url": url,
+            "source": str(feed_name),
+            "feed_url": feed_url,
+            "timestamp": timestamp,
+            "summary": summary,
+        })
+        kept += 1
+
+    logger.info("news feed completed feed_url=%s kept=%s", feed_url, kept)
+    return articles
 
 
-    logger.info("news ingestion start feeds=%s limit=%s timeout_s=%s", len(feed_urls), limit, FEED_FETCH_TIMEOUT_SECONDS)
 
 
-    # Apply the configured cap per feed.
+async def fetch_news_articles(limit: int = NEWS_FEED_ITEMS_PER_POLL) -> List[Dict[str, Any]]:
+    """Fetch all RSS feeds concurrently, parse, and return articles."""
+    feed_urls = _feed_urls()
     per_feed_limit = max(1, int(limit))
     per_feed_limit = max(1, int(limit))
 
 
-    for feed_url in feed_urls:
-        try:
-            feed = _fetch_feed(feed_url)
-            feed_name = getattr(feed.feed, "title", None) or feed_url
-            parsed_entries = len(getattr(feed, "entries", []) or [])
-            logger.info("news feed parsed feed_url=%s feed_name=%s entries=%s", feed_url, feed_name, parsed_entries)
-        except (HTTPError, URLError, TimeoutError, OSError) as exc:
-            logger.exception("news feed fetch failed feed_url=%s error=%s", feed_url, exc)
-            continue
-        except Exception as exc:
-            logger.exception("news feed parse failed feed_url=%s error=%s", feed_url, exc)
-            continue
+    logger.info(
+        "news ingestion start feeds=%s limit=%s timeout_s=%s",
+        len(feed_urls), per_feed_limit, FEED_FETCH_TIMEOUT_SECONDS,
+    )
 
 
-        kept_before = len(articles)
-        for entry in feed.entries[:per_feed_limit]:
-            title = str(getattr(entry, "title", "")).strip()
-            url = _canonical_url(str(getattr(entry, "link", "")).strip())
-            timestamp = str(getattr(entry, "published", "")) or str(getattr(entry, "updated", ""))
-            summary = _strip_html(str(getattr(entry, "summary", "")) or str(getattr(entry, "description", "")))
-
-            if not title or not url:
-                continue
-
-            articles.append(
-                {
-                    "title": title,
-                    "url": url,
-                    "source": str(feed_name),
-                    "feed_url": feed_url,
-                    "timestamp": timestamp,
-                    "summary": summary,
-                }
-            )
-
-            if len(articles) - kept_before >= per_feed_limit:
-                logger.info("news ingestion per-feed limit reached feed_url=%s kept=%s", feed_url, len(articles) - kept_before)
-                break
-
-        logger.info(
-            "news feed completed feed_url=%s kept=%s",
-            feed_url,
-            len(articles) - kept_before,
-        )
+    semaphore = asyncio.Semaphore(_NEEDLE_RSS_MAX_CONCURRENCY)
+
+    async with httpx.AsyncClient(
+        timeout=httpx.Timeout(FEED_FETCH_TIMEOUT_SECONDS),
+        headers={"User-Agent": "news-mcp/1.0"},
+    ) as client:
+        tasks = [
+            _fetch_feed_async(client, semaphore, url)
+            for url in feed_urls
+        ]
+        results = await asyncio.gather(*tasks, return_exceptions=False)
+
+    articles: List[Dict[str, Any]] = []
+    for feed_url, raw in results:
+        if raw is None:
+            continue
+        # feedparser.parse is CPU-light but sync — parse inline (fast enough)
+        parsed = feedparser.parse(raw)
+        articles.extend(_extract_articles_from_feed(feed_url, parsed, per_feed_limit))
 
 
     logger.info("news ingestion complete total_kept=%s", len(articles))
     logger.info("news ingestion complete total_kept=%s", len(articles))
     return articles
     return articles
@@ -116,5 +156,6 @@ def normalize_topic_from_title(title: str) -> str:
 
 
 
 
 def cluster_id_for_title(topic: str, title: str) -> str:
 def cluster_id_for_title(topic: str, title: str) -> str:
+    import hashlib
     key = f"{topic}|{title.strip().lower()}"
     key = f"{topic}|{title.strip().lower()}"
     return hashlib.sha1(key.encode("utf-8")).hexdigest()
     return hashlib.sha1(key.encode("utf-8")).hexdigest()

+ 23 - 8
test_news_mcp.py

@@ -386,7 +386,11 @@ def test_refresh_skips_reprocessing_when_feed_hash_is_unchanged(monkeypatch):
             self.meta[key] = value
             self.meta[key] = value
 
 
     monkeypatch.setattr(poller, "SQLiteClusterStore", DummyStore)
     monkeypatch.setattr(poller, "SQLiteClusterStore", DummyStore)
-    monkeypatch.setattr(poller, "fetch_news_articles", lambda limit: [{"title": "Bitcoin rallies", "url": "https://example.com/a", "timestamp": "Wed, 01 Apr 2026 12:00:00 GMT"}])
+
+    async def _mock_fetch(limit):
+        calls["fetch"] += 1
+        return [{"title": "Bitcoin rallies", "url": "https://example.com/a", "timestamp": "Wed, 01 Apr 2026 12:00:00 GMT"}]
+    monkeypatch.setattr(poller, "fetch_news_articles", _mock_fetch)
     monkeypatch.setattr(poller.asyncio, "to_thread", fake_to_thread)
     monkeypatch.setattr(poller.asyncio, "to_thread", fake_to_thread)
     monkeypatch.setattr(poller, "dedup_and_cluster_articles", fake_cluster)
     monkeypatch.setattr(poller, "dedup_and_cluster_articles", fake_cluster)
     monkeypatch.setattr(poller, "enrich_cluster", fake_enrich)
     monkeypatch.setattr(poller, "enrich_cluster", fake_enrich)
@@ -627,10 +631,8 @@ def test_poller_persists_clusters_under_post_enrichment_topic(monkeypatch):
         def set_meta(self, key, value):
         def set_meta(self, key, value):
             pass
             pass
 
 
-    async def fake_to_thread(fn, limit):
-        return [
-            {"title": "SEC fines firm", "url": "https://example.com/a", "source": "S", "timestamp": "Wed, 01 Apr 2026 12:00:00 GMT", "summary": "..."},
-        ]
+        def set_feed_state(self, feed_key, last_hash, item_count):
+            pass
 
 
     def fake_cluster(articles):
     def fake_cluster(articles):
         # Heuristic put it in "other" (no crypto/macro/regulation/ai keywords
         # Heuristic put it in "other" (no crypto/macro/regulation/ai keywords
@@ -668,8 +670,13 @@ def test_poller_persists_clusters_under_post_enrichment_topic(monkeypatch):
         return out
         return out
 
 
     monkeypatch.setattr(poller, "SQLiteClusterStore", DummyStore)
     monkeypatch.setattr(poller, "SQLiteClusterStore", DummyStore)
-    monkeypatch.setattr(poller, "fetch_news_articles", lambda limit: [])
-    monkeypatch.setattr(poller.asyncio, "to_thread", fake_to_thread)
+
+    async def _mock_fetch2(limit):
+        return [
+            {"title": "SEC fines firm", "url": "https://example.com/a", "source": "S",
+             "timestamp": "Wed, 01 Apr 2026 12:00:00 GMT", "summary": "..."},
+        ]
+    monkeypatch.setattr(poller, "fetch_news_articles", _mock_fetch2)
     monkeypatch.setattr(poller, "dedup_and_cluster_articles", fake_cluster)
     monkeypatch.setattr(poller, "dedup_and_cluster_articles", fake_cluster)
     monkeypatch.setattr(poller, "enrich_cluster", fake_enrich)
     monkeypatch.setattr(poller, "enrich_cluster", fake_enrich)
     monkeypatch.setattr(poller, "classify_cluster_llm", fake_classify)
     monkeypatch.setattr(poller, "classify_cluster_llm", fake_classify)
@@ -677,7 +684,15 @@ def test_poller_persists_clusters_under_post_enrichment_topic(monkeypatch):
     asyncio.run(poller.refresh_clusters(topic=None, limit=10))
     asyncio.run(poller.refresh_clusters(topic=None, limit=10))
 
 
     assert captured["upserts"], "Expected at least one upsert call"
     assert captured["upserts"], "Expected at least one upsert call"
-    upsert = captured["upserts"][0]
+    # The poller first stores raw clusters (topic=heuristic), then enriched
+    # clusters (topic=post-LLM).  The enriched upsert is the one whose row_topic
+    # reflects the LLM classification.
+    enriched_upserts = [u for u in captured["upserts"] if u["row_topic"] == "regulation"]
+    assert enriched_upserts, (
+        f"Expected at least one upsert with row_topic='regulation', "
+        f"got topics: {[u['row_topic'] for u in captured['upserts']]}"
+    )
+    upsert = enriched_upserts[0]
     assert upsert["row_topic"] == "regulation", (
     assert upsert["row_topic"] == "regulation", (
         f"Expected SQL row topic to follow the LLM's classification 'regulation', got {upsert['row_topic']!r}"
         f"Expected SQL row topic to follow the LLM's classification 'regulation', got {upsert['row_topic']!r}"
     )
     )