Explorar el Código

concurrent: async RSS fetching, embedding, and LLM enrichment

- config.py: add llm_concurrency() with per-provider defaults (openrouter=2, openai=5, groq=8) and NEWS_LLM_CONCURRENCY_<PROVIDER> env overrides; add RSS/OLLAMA semaphore env vars
- news_feeds.py: rewrite fetch_news_articles as fully async with httpx + asyncio.gather for concurrent RSS feed fetching (bounded by NEWS_RSS_MAX_CONCURRENCY=10)
- embedding_support.py: rewrite ollama_embed as async with httpx and a global asyncio.Semaphore bounded by NEWS_OLLAMA_MAX_CONCURRENCY=4
- cluster.py: dedup_and_cluster_articles stays sync at the API boundary but pre-computes all Ollama embeddings concurrently via asyncio.gather before the CPU-bound clustering loop
- poller.py: LLM enrichment uses asyncio.Semaphore per provider config; all clusters within a topic are enriched concurrently via asyncio.gather; all topic enrichment phases run concurrently; clustering runs via asyncio.to_thread to avoid blocking the event loop
- test_news_mcp.py: update poller tests for async fetch_news_articles mocks
Lukas Goldschmidt hace 1 semana
padre
commit
f4dc0998eb

+ 33 - 0
news_mcp/config.py

@@ -54,3 +54,36 @@ NEWS_BACKGROUND_REFRESH_ON_START = os.getenv("NEWS_BACKGROUND_REFRESH_ON_START",
 NEWS_PRUNING_ENABLED = os.getenv("NEWS_PRUNING_ENABLED", "true").lower() == "true"
 NEWS_RETENTION_DAYS = float(os.getenv("NEWS_RETENTION_DAYS", "180"))
 NEWS_PRUNE_INTERVAL_HOURS = float(os.getenv("NEWS_PRUNE_INTERVAL_HOURS", "24"))
+
+# ---------------------------------------------------------------------------
+# Concurrency controls
+# ---------------------------------------------------------------------------
+# Maximum concurrent outbound LLM API calls per provider.
+# Defaults are conservative for free tiers; override via env if you have
+# higher rate limits or are on a paid plan.
+_NEEDLE_DEFAULT_CONCURRENCY = {
+    "openrouter": 2,
+    "openai": 5,
+    "groq": 8,
+}
+
+_NEEDLE_RSS_MAX_CONCURRENCY = int(os.getenv("NEWS_RSS_MAX_CONCURRENCY", "10"))
+_NEEDLE_OLLAMA_MAX_CONCURRENCY = int(os.getenv("NEWS_OLLAMA_MAX_CONCURRENCY", "4"))
+
+
+def llm_concurrency(provider: str) -> int:
+    """Return the max concurrent LLM calls for *provider*.
+
+    Reads from ``NEWS_LLM_CONCURRENCY_<PROVIDER>`` env var first (e.g.
+    ``NEWS_LLM_CONCURRENCY_OPENROUTER``), then falls back to the built-in
+    default map.
+    """
+    provider = provider.strip().lower()
+    env_key = f"NEWS_LLM_CONCURRENCY_{provider.upper()}"
+    env_val = os.getenv(env_key)
+    if env_val is not None:
+        try:
+            return max(1, int(env_val))
+        except ValueError:
+            pass
+    return _NEEDLE_DEFAULT_CONCURRENCY.get(provider, 3)

+ 74 - 60
news_mcp/dedup/cluster.py

@@ -1,5 +1,6 @@
 from __future__ import annotations
 
+import asyncio
 import hashlib
 import re
 from difflib import SequenceMatcher
@@ -18,7 +19,6 @@ from news_mcp.sources.news_feeds import normalize_topic_from_title
 
 def _normalize_title(title: str) -> str:
     t = title.lower().strip()
-    # Remove punctuation-ish characters for similarity scoring.
     t = re.sub(r"[^a-z0-9\s]", " ", t)
     t = re.sub(r"\s+", " ", t).strip()
     return t
@@ -48,12 +48,9 @@ def _cluster_text(a: Dict[str, Any]) -> str:
 
 
 # ---------------------------------------------------------------------------
-# Token / Jaccard signal (used as a fallback alongside title similarity when
-# embeddings are unavailable, and as a soft signal even when they are).
+# Token / Jaccard signal
 # ---------------------------------------------------------------------------
 
-# Tiny stop-word set — we keep it small on purpose because the corpus is news
-# headlines, where every additional removal risks losing genuine signal.
 _STOPWORDS = frozenset(
     {
         "a", "an", "the", "of", "to", "in", "on", "at", "for", "by", "with",
@@ -68,7 +65,6 @@ _STOPWORDS = frozenset(
 
 
 def _tokens(text: str) -> set[str]:
-    """Lowercase content tokens, stop-words removed, length>=3."""
     tokens = re.findall(r"[a-z0-9][a-z0-9\-]+", text.lower())
     return {t for t in tokens if len(t) >= 3 and t not in _STOPWORDS}
 
@@ -86,22 +82,12 @@ def _jaccard(a: set, b: set) -> float:
 # Composite similarity
 # ---------------------------------------------------------------------------
 
-
-# Each signal has its own threshold. We accept a merge if ANY signal clears its
-# threshold, which makes clustering robust when one signal happens to be weak
-# (short headlines kill SequenceMatcher; single-word stories kill Jaccard;
-# Ollama outages kill cosine similarity).
 DEFAULT_TITLE_THRESHOLD = 0.87
 DEFAULT_JACCARD_THRESHOLD = 0.55
 
 
 def _signals(article: Dict[str, Any], cluster: Dict[str, Any]) -> dict:
-    """Per-pair similarity signals (title, jaccard, embedding cosine).
-
-    Embedding cosine is only computed when both sides have a vector; we never
-    block on a fresh Ollama request here — that's the caller's job, so this
-    function stays pure and easy to test.
-    """
+    """Per-pair similarity signals (title, jaccard, embedding cosine)."""
     a_title = str(article.get("title") or "")
     c_title = str(cluster.get("headline") or "")
 
@@ -120,11 +106,7 @@ def _signals(article: Dict[str, Any], cluster: Dict[str, Any]) -> dict:
 
 
 def _is_match(signals: dict, *, embeddings_enabled: bool) -> tuple[bool, str, float]:
-    """Decide whether two items should merge based on the strongest signal.
-
-    Returns (matched, signal_name, signal_value). The signal_name lets callers
-    log *why* something merged, which is huge for debugging clustering quality.
-    """
+    """Decide whether two items should merge based on the strongest signal."""
     cosine_threshold = NEWS_EMBEDDING_SIMILARITY_THRESHOLD
     if embeddings_enabled and signals["cosine"] >= cosine_threshold:
         return True, "cosine", signals["cosine"]
@@ -136,7 +118,65 @@ def _is_match(signals: dict, *, embeddings_enabled: bool) -> tuple[bool, str, fl
 
 
 # ---------------------------------------------------------------------------
-# Public API
+# Embedding pre-computation (async internally)
+# ---------------------------------------------------------------------------
+
+
+async def _compute_embeddings_concurrently(
+    articles: List[Dict[str, Any]],
+) -> Dict[str, list[float] | None]:
+    """Compute embeddings for unique article texts concurrently.
+
+    Returns a cache dict: text -> embedding vector or None.
+    """
+    unique_texts: list[str] = []
+    seen: set[str] = set()
+    for a in articles:
+        text = _cluster_text(a)
+        if text and text not in seen:
+            seen.add(text)
+            unique_texts.append(text)
+
+    emb_tasks = [ollama_embed(text) for text in unique_texts]
+    emb_results = await asyncio.gather(*emb_tasks, return_exceptions=True)
+
+    cache: Dict[str, list[float] | None] = {}
+    for text, result in zip(unique_texts, emb_results):
+        if isinstance(result, list):
+            cache[text] = result
+        else:
+            cache[text] = None
+    return cache
+
+
+def _compute_embeddings_sync(
+    articles: List[Dict[str, Any]],
+) -> Dict[str, list[float] | None]:
+    """Synchronous wrapper that runs the async embedding computation.
+
+    Handles three cases:
+    1. Already inside an async event loop (called from poller) -> schedule
+       as a task and run it to completion on the running loop.
+    2. No event loop at all (plain sync caller) -> use asyncio.run().
+    """
+    try:
+        loop = asyncio.get_running_loop()
+    except RuntimeError:
+        # No running loop — safe to use asyncio.run()
+        return asyncio.run(_compute_embeddings_concurrently(articles))
+
+    # We're inside a running event loop (e.g. the poller). Create a new loop
+    # in a thread to avoid blocking.
+    import concurrent.futures
+    with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
+        future = pool.submit(
+            asyncio.run, _compute_embeddings_concurrently(articles)
+        )
+        return future.result()
+
+
+# ---------------------------------------------------------------------------
+# Public API (sync — backward compatible with tests)
 # ---------------------------------------------------------------------------
 
 
@@ -146,36 +186,23 @@ def dedup_and_cluster_articles(
 ) -> Dict[str, List[Dict[str, Any]]]:
     """Deduplicate raw articles into clusters keyed by topic.
 
-    v1.1 strategy: composite similarity.
+    v1.2: embedding pre-computation is async/concurrent under the hood, but
+    this public function remains synchronous for backward compatibility.
+
+    A pair merges if ANY signal clears its threshold:
       * title fuzzy ratio
-      * token Jaccard over headline+summary (cheap, surprisingly resilient
-        when titles are reworded heavily across outlets)
+      * token Jaccard over headline+summary
       * Ollama embedding cosine when available
-
-    A pair merges if ANY signal clears its threshold. Falling back through
-    multiple signals means a transient Ollama outage doesn't collapse the
-    server back into title-only clustering, and a heavily-reworded headline
-    can still merge via Jaccard or embeddings.
-
-    The ``similarity_threshold`` argument is kept for backward compatibility
-    with the test suite. When provided, it overrides the title threshold.
     """
 
     title_threshold = similarity_threshold if similarity_threshold is not None else DEFAULT_TITLE_THRESHOLD
 
-    by_topic: Dict[str, List[Dict[str, Any]]] = {}
+    # Pre-compute embeddings concurrently (sync boundary handles async internally)
     embedding_cache: Dict[str, list[float] | None] = {}
+    if NEWS_EMBEDDINGS_ENABLED:
+        embedding_cache = _compute_embeddings_sync(articles)
 
-    def _embedding_for_text(text: str) -> list[float] | None:
-        if not NEWS_EMBEDDINGS_ENABLED or not text:
-            return None
-        if text in embedding_cache:
-            return embedding_cache[text]
-        emb = ollama_embed(text)
-        # Cache None too so a single failure doesn't trigger repeated retries
-        # within one ingestion cycle. The next refresh call clears this map.
-        embedding_cache[text] = emb
-        return emb
+    by_topic: Dict[str, List[Dict[str, Any]]] = {}
 
     for a in articles:
         title = a.get("title") or ""
@@ -183,10 +210,8 @@ def dedup_and_cluster_articles(
             continue
         topic = normalize_topic_from_title(title)
         article_text = _cluster_text(a)
-        article_embedding = _embedding_for_text(article_text)
 
-        # Attach embedding on the article dict so _signals() can read it
-        # without re-computing.
+        article_embedding = embedding_cache.get(article_text) if NEWS_EMBEDDINGS_ENABLED else None
         a_with_emb = dict(a)
         if article_embedding is not None:
             a_with_emb["_embedding"] = article_embedding
@@ -199,8 +224,6 @@ def dedup_and_cluster_articles(
         best_signal_value = 0.0
         for idx, c in enumerate(clusters):
             sigs = _signals(a_with_emb, c)
-            # Use the title threshold the caller explicitly passed (test override)
-            # but otherwise rely on the module defaults.
             local_match = False
             if NEWS_EMBEDDINGS_ENABLED and sigs["cosine"] >= NEWS_EMBEDDING_SIMILARITY_THRESHOLD:
                 local_match = True
@@ -211,11 +234,6 @@ def dedup_and_cluster_articles(
             elif sigs["jaccard"] >= DEFAULT_JACCARD_THRESHOLD:
                 local_match = True
                 signal_name, signal_value = "jaccard", sigs["jaccard"]
-            # Consensus rule: when no single signal clears its strict threshold
-            # but two of them are simultaneously "strong-ish", treat that as a
-            # match. This catches reworded headlines whose embedding is just
-            # below the strict cosine cutoff. Numbers are intentionally
-            # conservative — both signals must be clearly above noise.
             elif (
                 NEWS_EMBEDDINGS_ENABLED
                 and sigs["cosine"] >= 0.80
@@ -240,13 +258,10 @@ def dedup_and_cluster_articles(
             if a.get("source") and a["source"] not in c["sources"]:
                 c["sources"].append(a["source"])
             c["last_updated"] = max(str(c.get("last_updated", "")), str(a.get("timestamp", "")))
-            # Keep a tiny audit trail per cluster on which signal grew it last.
-            # Not surfaced through tools — lives in the payload only for debug.
             c.setdefault("_merge_signals", []).append(
                 {"signal": best_signal_name, "value": round(best_signal_value, 3)}
             )
         else:
-            # Stable cluster id: based on topic + normalized canonical title.
             key = f"{topic}|{_normalize_title(title)}"
             cid = hashlib.sha1(key.encode("utf-8")).hexdigest()
             cluster_embedding = article_embedding if NEWS_EMBEDDINGS_ENABLED else None
@@ -269,8 +284,7 @@ def dedup_and_cluster_articles(
                 }
             )
 
-    # Strip the internal merge audit trail before returning so it does not
-    # accidentally bloat the SQLite payload. Storage layer doesn't filter it.
+    # Strip the internal merge audit trail before returning
     for clusters in by_topic.values():
         for c in clusters:
             c.pop("_merge_signals", None)

+ 34 - 22
news_mcp/dedup/embedding_support.py

@@ -1,13 +1,23 @@
 from __future__ import annotations
 
+import asyncio
+import json
 from dataclasses import dataclass
 from datetime import datetime, timezone, timedelta
-import json
-import urllib.request
 from math import sqrt
 from typing import Any
 
-from news_mcp.config import NEWS_EMBEDDINGS_ENABLED, OLLAMA_BASE_URL, OLLAMA_EMBEDDING_MODEL
+import httpx
+
+from news_mcp.config import (
+    NEWS_EMBEDDINGS_ENABLED,
+    OLLAMA_BASE_URL,
+    OLLAMA_EMBEDDING_MODEL,
+    _NEEDLE_OLLAMA_MAX_CONCURRENCY,
+)
+
+
+_ollama_semaphore = asyncio.Semaphore(_NEEDLE_OLLAMA_MAX_CONCURRENCY)
 
 
 @dataclass(frozen=True)
@@ -85,28 +95,30 @@ def cluster_is_candidate(
     return True
 
 
-def ollama_embed(text: str, timeout: float = 20.0) -> list[float] | None:
-    """Best-effort Ollama embedding call; returns None on any failure.
+async def ollama_embed(text: str, timeout: float = 20.0) -> list[float] | None:
+    """Async Ollama embedding call with concurrency limiting.
 
-    Embeddings are intentionally optional. The caller should fall back to the
-    heuristic path when this returns None.
+    Returns None on any failure so the caller falls back to heuristic clustering.
     """
-
     if not NEWS_EMBEDDINGS_ENABLED:
         return None
+
     payload = json.dumps({"model": OLLAMA_EMBEDDING_MODEL, "prompt": text}).encode("utf-8")
-    req = urllib.request.Request(
-        f"{OLLAMA_BASE_URL.rstrip('/')}/api/embeddings",
-        data=payload,
-        headers={"Content-Type": "application/json"},
-        method="POST",
-    )
-    try:
-        with urllib.request.urlopen(req, timeout=timeout) as resp:
-            data = json.loads(resp.read().decode("utf-8"))
-            emb = data.get("embedding")
-            if isinstance(emb, list) and emb:
-                return [float(x) for x in emb]
-    except Exception:
-        return None
+    url = f"{OLLAMA_BASE_URL.rstrip('/')}/api/embeddings"
+
+    async with _ollama_semaphore:
+        try:
+            async with httpx.AsyncClient(timeout=timeout) as client:
+                resp = await client.post(
+                    url,
+                    content=payload,
+                    headers={"Content-Type": "application/json"},
+                )
+                resp.raise_for_status()
+                data = resp.json()
+                emb = data.get("embedding")
+                if isinstance(emb, list) and emb:
+                    return [float(x) for x in emb]
+        except Exception:
+            return None
     return None

+ 125 - 72
news_mcp/jobs/poller.py

@@ -1,26 +1,114 @@
 from __future__ import annotations
 
 import asyncio
+import hashlib
 import logging
 from collections import defaultdict
 from datetime import datetime, timezone
 from typing import Any, Dict
 
-from news_mcp.config import DEFAULT_LOOKBACK_HOURS, DEFAULT_TOPICS, DB_PATH, NEWS_FEED_URL, NEWS_FEED_URLS
-from news_mcp.dedup.cluster import dedup_and_cluster_articles
-from news_mcp.enrichment.enrich import enrich_cluster
-from news_mcp.enrichment.llm_enrich import classify_cluster_llm
-from news_mcp.trends_resolution import resolve_entity_via_trends
-from news_mcp.sources.news_feeds import fetch_news_articles
-from news_mcp.storage.sqlite_store import SQLiteClusterStore
-
 from news_mcp.config import (
+    DEFAULT_LOOKBACK_HOURS,
+    DEFAULT_TOPICS,
+    DB_PATH,
     ENRICH_OTHER_TOPICS_ONLY,
     ENRICHMENT_MAX_PER_REFRESH,
+    NEWS_EXTRACT_PROVIDER,
+    NEWS_FEED_URL,
+    NEWS_FEED_URLS,
     NEWS_PRUNE_INTERVAL_HOURS,
     NEWS_PRUNING_ENABLED,
     NEWS_RETENTION_DAYS,
+    llm_concurrency,
 )
+from news_mcp.dedup.cluster import dedup_and_cluster_articles
+from news_mcp.enrichment.enrich import enrich_cluster
+from news_mcp.enrichment.llm_enrich import classify_cluster_llm
+from news_mcp.sources.news_feeds import fetch_news_articles
+from news_mcp.storage.sqlite_store import SQLiteClusterStore
+from news_mcp.trends_resolution import resolve_entity_via_trends
+
+
+async def _enrich_single_cluster(
+    c: dict,
+    topic: str,
+    llm_enabled: bool,
+    semaphore: asyncio.Semaphore,
+    store: SQLiteClusterStore,
+    logger: logging.Logger,
+) -> dict:
+    """Enrich one cluster: heuristic + optional LLM extraction, concurrency-limited."""
+    c2 = enrich_cluster(c)
+    c2.setdefault("topic", topic)
+
+    cluster_id = c2.get("cluster_id")
+    if llm_enabled and cluster_id:
+        # Cache: if we already have entities/sentiment for this cluster, skip LLM call.
+        existing = store.get_cluster_by_id(cluster_id)
+        if existing and existing.get("entities"):
+            c2 = dict(c2)
+            c2["entities"] = existing.get("entities", [])
+
+            existing_resolutions = existing.get("entityResolutions", None)
+            if isinstance(existing_resolutions, list) and existing_resolutions:
+                c2["entityResolutions"] = existing_resolutions
+            else:
+                c2["entityResolutions"] = [resolve_entity_via_trends(e) for e in c2["entities"]]
+
+            if existing.get("sentiment"):
+                c2["sentiment"] = existing.get("sentiment")
+            if existing.get("sentimentScore") is not None:
+                c2["sentimentScore"] = existing.get("sentimentScore")
+            if existing.get("keywords"):
+                c2["keywords"] = existing.get("keywords")
+            if existing.get("topic"):
+                c2["topic"] = existing.get("topic")
+        else:
+            # Acquire semaphore before making outbound LLM call
+            async with semaphore:
+                try:
+                    c2 = await classify_cluster_llm(c2)
+                except Exception:
+                    logger.exception(
+                        "LLM enrichment failed for cluster %s (topic %s)",
+                        c2.get("cluster_id"), topic,
+                    )
+                    c2["enrichment_failed_at"] = datetime.now(timezone.utc).isoformat()
+
+    return c2
+
+
+async def _enrich_topic_clusters(
+    clusters: list[dict],
+    topic: str,
+    semaphore: asyncio.Semaphore,
+    store: SQLiteClusterStore,
+    logger: logging.Logger,
+    enrich_limit: int,
+) -> list[dict]:
+    """Enrich all clusters for a single topic concurrently."""
+    llm_enabled = (not ENRICH_OTHER_TOPICS_ONLY) or (topic == "other")
+
+    # Persist the raw clusters first so a slow enrichment pass does not
+    # leave the first bootstrap run with nothing stored.
+    store.upsert_clusters(clusters, topic=topic)
+    logger.info("refresh stored raw topic=%s clusters=%s", topic, len(clusters))
+
+    targets = clusters[:enrich_limit]
+    tasks = [
+        _enrich_single_cluster(c, topic, llm_enabled, semaphore, store, logger)
+        for c in targets
+    ]
+    enriched = await asyncio.gather(*tasks, return_exceptions=False)
+
+    # Any clusters beyond enrich_limit still need importance enrichment
+    for c in clusters[enrich_limit:]:
+        c2 = enrich_cluster(c)
+        c2.setdefault("topic", topic)
+        enriched.append(c2)
+
+    logger.info("refresh enriched topic=%s clusters=%s", topic, len(enriched))
+    return enriched
 
 
 async def refresh_clusters(topic: str | None = None, limit: int = 80) -> None:
@@ -28,7 +116,9 @@ async def refresh_clusters(topic: str | None = None, limit: int = 80) -> None:
     store = SQLiteClusterStore(DB_PATH)
 
     logger.info("refresh start topic=%s limit=%s", topic, limit)
-    articles = await asyncio.to_thread(fetch_news_articles, limit)
+
+    # fetch_news_articles is now fully async (concurrent RSS fetching)
+    articles = await fetch_news_articles(limit)
     logger.info("refresh fetched articles=%s", len(articles))
 
     # Drop legacy aggregate feed-state rows so the dashboard only reflects
@@ -37,7 +127,6 @@ async def refresh_clusters(topic: str | None = None, limit: int = 80) -> None:
         conn.execute("DELETE FROM feed_state WHERE feed_key LIKE 'newsfeeds:%'")
 
     # Track feed freshness per RSS URL so unchanged feeds can be skipped.
-    import hashlib
     per_feed: dict[str, list[dict[str, Any]]] = defaultdict(list)
     for article in articles:
         feed_url = str(article.get("feed_url") or NEWS_FEED_URL).strip() or NEWS_FEED_URL
@@ -75,87 +164,51 @@ async def refresh_clusters(topic: str | None = None, limit: int = 80) -> None:
 
     articles = changed_articles
     logger.info("refresh clustering start articles=%s topic=%s", len(articles), topic)
-    clustered_by_topic = dedup_and_cluster_articles(articles)
+    # Clustering is sync but may do concurrent embedding fetches internally.
+    # Run off-thread so the event loop stays responsive for MCP tool calls.
+    clustered_by_topic = await asyncio.to_thread(dedup_and_cluster_articles, articles)
     logger.info("refresh clustered topics=%s", list(clustered_by_topic.keys()))
 
+    # Build LLM concurrency semaphore from the extract provider's config.
+    max_llm_concurrent = llm_concurrency(NEWS_EXTRACT_PROVIDER)
+    llm_semaphore = asyncio.Semaphore(max_llm_concurrent)
+    logger.info("refresh llm semaphore limit=%s provider=%s", max_llm_concurrent, NEWS_EXTRACT_PROVIDER)
+
+    # Enrich each topic's clusters concurrently.
+    topic_tasks = []
     for t, clusters in clustered_by_topic.items():
         if topic and t != topic:
             continue
-        logger.info("refresh topic phase start topic=%s clusters=%s", t, len(clusters))
-        enriched = []
 
         # Determine how many clusters to LLM-enrich.
         # ENRICHMENT_MAX_PER_REFRESH=0 means enrich every cluster (no cap).
         enrich_limit = ENRICHMENT_MAX_PER_REFRESH or len(clusters)
 
-        # Track whether the LLM pipeline is available for this topic.
-        _llm_enabled_for_topic = (
-            (not ENRICH_OTHER_TOPICS_ONLY) or (t == "other")
+        topic_tasks.append(
+            _enrich_topic_clusters(
+                clusters=clusters,
+                topic=t,
+                semaphore=llm_semaphore,
+                store=store,
+                logger=logger,
+                enrich_limit=enrich_limit,
+            )
         )
 
-        # Persist the raw clusters first so a slow enrichment pass does not
-        # leave the first bootstrap run with nothing stored.
-        store.upsert_clusters(clusters, topic=t)
-        logger.info("refresh stored raw topic=%s clusters=%s", t, len(clusters))
-
-        for idx, c in enumerate(clusters[:enrich_limit]):
-            c2 = enrich_cluster(c)
-            # Seed the heuristic topic on the payload so classify_cluster_llm
-            # has a sane fallback if the LLM omits or hallucinates one.
-            c2.setdefault("topic", t)
-            logger.info("refresh enrich cluster=%s topic=%s idx=%s/%s", c2.get("cluster_id"), t, idx + 1, enrich_limit)
-
-            if _llm_enabled_for_topic:
-                # Cache: if we already have entities/sentiment for this cluster, skip LLM call.
-                existing = store.get_cluster_by_id(c2.get("cluster_id"))
-                if existing and existing.get("entities"):
-                    c2 = dict(c2)
-                    # Keep existing enriched fields.
-                    c2["entities"] = existing.get("entities", [])
-
-                    # IMPORTANT: entityResolutions must stay consistent with entities.
-                    # Older rows may have entities but missing/malformed resolutions.
-                    existing_resolutions = existing.get("entityResolutions", None)
-                    if isinstance(existing_resolutions, list) and existing_resolutions:
-                        c2["entityResolutions"] = existing_resolutions
-                    else:
-                        # Recompute resolutions deterministically from the stored entities.
-                        c2["entityResolutions"] = [resolve_entity_via_trends(e) for e in c2["entities"]]
-
-                    if existing.get("sentiment"):
-                        c2["sentiment"] = existing.get("sentiment")
-                    if existing.get("sentimentScore") is not None:
-                        c2["sentimentScore"] = existing.get("sentimentScore")
-                    if existing.get("keywords"):
-                        c2["keywords"] = existing.get("keywords")
-                    # Preserve a previously-classified topic so we don't drift back
-                    # to the heuristic on cache hits.
-                    if existing.get("topic"):
-                        c2["topic"] = existing.get("topic")
-                else:
-                    try:
-                        c2 = await classify_cluster_llm(c2)
-                    except Exception:
-                        logger.exception("LLM enrichment failed for cluster %s (topic %s)", c2.get("cluster_id"), t)
-                        # Mark so we can retry on next refresh.
-                        c2["enrichment_failed_at"] = datetime.now(timezone.utc).isoformat()
-
-            enriched.append(c2)
-
-        # Persist clusters under their *post-enrichment* topic so the SQL row
-        # column matches what the LLM (or the validated heuristic fallback)
-        # actually decided. Previously, every cluster from this bucket was
-        # forced into the heuristic topic `t`, which caused a ~97% mismatch
-        # between row-column topic and payload topic.
+    # Run all topic enrichment phases concurrently
+    topic_results = await asyncio.gather(*topic_tasks, return_exceptions=False)
+
+    # Persist enriched clusters grouped by their final topic
+    for enriched in topic_results:
         by_final_topic: Dict[str, list] = {}
         for c2 in enriched:
-            final_topic = str(c2.get("topic") or t or "other").strip().lower()
+            final_topic = str(c2.get("topic") or "other").strip().lower()
             if final_topic not in {x.lower() for x in DEFAULT_TOPICS}:
                 final_topic = "other"
             by_final_topic.setdefault(final_topic, []).append(c2)
         for final_topic, group in by_final_topic.items():
             store.upsert_clusters(group, topic=final_topic)
-            logger.info("refresh stored topic=%s clusters=%s (heuristic_topic=%s)", final_topic, len(group), t)
+            logger.info("refresh stored topic=%s clusters=%s", final_topic, len(group))
 
     prune_result = store.prune_if_due(
         pruning_enabled=NEWS_PRUNING_ENABLED,

+ 93 - 52
news_mcp/sources/news_feeds.py

@@ -1,5 +1,6 @@
 from __future__ import annotations
 
+import asyncio
 import hashlib
 import logging
 import re
@@ -8,14 +9,20 @@ from urllib.error import URLError, HTTPError
 from urllib.request import Request, urlopen
 
 import feedparser
+import httpx
 
-from news_mcp.config import NEWS_FEED_ITEMS_PER_POLL, NEWS_FEED_URL, NEWS_FEED_URLS
+from news_mcp.config import (
+    NEWS_FEED_ITEMS_PER_POLL,
+    NEWS_FEED_URL,
+    NEWS_FEED_URLS,
+    _NEEDLE_RSS_MAX_CONCURRENCY,
+)
 
 
 logger = logging.getLogger(__name__)
 
 
-FEED_FETCH_TIMEOUT_SECONDS = 15
+FEED_FETCH_TIMEOUT_SECONDS = 20
 
 
 def _canonical_url(url: str) -> str:
@@ -39,64 +46,97 @@ def _feed_urls() -> List[str]:
     return urls
 
 
-def _fetch_feed(feed_url: str):
-    req = Request(feed_url, headers={"User-Agent": "news-mcp/1.0"})
-    with urlopen(req, timeout=FEED_FETCH_TIMEOUT_SECONDS) as resp:
-        return feedparser.parse(resp.read())
+def _parse_feed_from_bytes(data: bytes, feed_url: str):
+    """Parse feed from raw bytes (sync, but fast — just XML parsing)."""
+    return feedparser.parse(data)
 
 
-def fetch_news_articles(limit: int = NEWS_FEED_ITEMS_PER_POLL) -> List[Dict[str, Any]]:
-    feed_urls = _feed_urls()
+async def _fetch_feed_async(
+    client: httpx.AsyncClient,
+    semaphore: asyncio.Semaphore,
+    feed_url: str,
+) -> tuple[str, bytes | None]:
+    """Fetch a single RSS feed concurrently. Returns (feed_url, raw_bytes)."""
+    async with semaphore:
+        try:
+            resp = await client.get(feed_url, follow_redirects=True)
+            resp.raise_for_status()
+            return (feed_url, resp.content)
+        except (httpx.HTTPStatusError, httpx.TimeoutException, httpx.ConnectError, OSError) as exc:
+            logger.exception("news feed fetch failed feed_url=%s error=%s", feed_url, exc)
+            return (feed_url, None)
+        except Exception as exc:
+            logger.exception("news feed fetch unexpected error feed_url=%s error=%s", feed_url, exc)
+            return (feed_url, None)
+
+
+def _extract_articles_from_feed(
+    feed_url: str,
+    parsed,
+    per_feed_limit: int,
+) -> List[Dict[str, Any]]:
+    """Extract article dicts from a parsed feedparser object (sync)."""
     articles: List[Dict[str, Any]] = []
+    feed_name = getattr(parsed.feed, "title", None) or feed_url
+    parsed_entries = len(getattr(parsed, "entries", []) or [])
+    logger.info(
+        "news feed parsed feed_url=%s feed_name=%s entries=%s",
+        feed_url, feed_name, parsed_entries,
+    )
+
+    kept = 0
+    for entry in parsed.entries[:per_feed_limit]:
+        title = str(getattr(entry, "title", "")).strip()
+        url = _canonical_url(str(getattr(entry, "link", "")).strip())
+        timestamp = str(getattr(entry, "published", "")) or str(getattr(entry, "updated", ""))
+        summary = _strip_html(
+            str(getattr(entry, "summary", "")) or str(getattr(entry, "description", ""))
+        )
+        if not title or not url:
+            continue
+        articles.append({
+            "title": title,
+            "url": url,
+            "source": str(feed_name),
+            "feed_url": feed_url,
+            "timestamp": timestamp,
+            "summary": summary,
+        })
+        kept += 1
+
+    logger.info("news feed completed feed_url=%s kept=%s", feed_url, kept)
+    return articles
 
-    logger.info("news ingestion start feeds=%s limit=%s timeout_s=%s", len(feed_urls), limit, FEED_FETCH_TIMEOUT_SECONDS)
 
-    # Apply the configured cap per feed.
+async def fetch_news_articles(limit: int = NEWS_FEED_ITEMS_PER_POLL) -> List[Dict[str, Any]]:
+    """Fetch all RSS feeds concurrently, parse, and return articles."""
+    feed_urls = _feed_urls()
     per_feed_limit = max(1, int(limit))
 
-    for feed_url in feed_urls:
-        try:
-            feed = _fetch_feed(feed_url)
-            feed_name = getattr(feed.feed, "title", None) or feed_url
-            parsed_entries = len(getattr(feed, "entries", []) or [])
-            logger.info("news feed parsed feed_url=%s feed_name=%s entries=%s", feed_url, feed_name, parsed_entries)
-        except (HTTPError, URLError, TimeoutError, OSError) as exc:
-            logger.exception("news feed fetch failed feed_url=%s error=%s", feed_url, exc)
-            continue
-        except Exception as exc:
-            logger.exception("news feed parse failed feed_url=%s error=%s", feed_url, exc)
-            continue
+    logger.info(
+        "news ingestion start feeds=%s limit=%s timeout_s=%s",
+        len(feed_urls), per_feed_limit, FEED_FETCH_TIMEOUT_SECONDS,
+    )
 
-        kept_before = len(articles)
-        for entry in feed.entries[:per_feed_limit]:
-            title = str(getattr(entry, "title", "")).strip()
-            url = _canonical_url(str(getattr(entry, "link", "")).strip())
-            timestamp = str(getattr(entry, "published", "")) or str(getattr(entry, "updated", ""))
-            summary = _strip_html(str(getattr(entry, "summary", "")) or str(getattr(entry, "description", "")))
-
-            if not title or not url:
-                continue
-
-            articles.append(
-                {
-                    "title": title,
-                    "url": url,
-                    "source": str(feed_name),
-                    "feed_url": feed_url,
-                    "timestamp": timestamp,
-                    "summary": summary,
-                }
-            )
-
-            if len(articles) - kept_before >= per_feed_limit:
-                logger.info("news ingestion per-feed limit reached feed_url=%s kept=%s", feed_url, len(articles) - kept_before)
-                break
-
-        logger.info(
-            "news feed completed feed_url=%s kept=%s",
-            feed_url,
-            len(articles) - kept_before,
-        )
+    semaphore = asyncio.Semaphore(_NEEDLE_RSS_MAX_CONCURRENCY)
+
+    async with httpx.AsyncClient(
+        timeout=httpx.Timeout(FEED_FETCH_TIMEOUT_SECONDS),
+        headers={"User-Agent": "news-mcp/1.0"},
+    ) as client:
+        tasks = [
+            _fetch_feed_async(client, semaphore, url)
+            for url in feed_urls
+        ]
+        results = await asyncio.gather(*tasks, return_exceptions=False)
+
+    articles: List[Dict[str, Any]] = []
+    for feed_url, raw in results:
+        if raw is None:
+            continue
+        # feedparser.parse is CPU-light but sync — parse inline (fast enough)
+        parsed = feedparser.parse(raw)
+        articles.extend(_extract_articles_from_feed(feed_url, parsed, per_feed_limit))
 
     logger.info("news ingestion complete total_kept=%s", len(articles))
     return articles
@@ -116,5 +156,6 @@ def normalize_topic_from_title(title: str) -> str:
 
 
 def cluster_id_for_title(topic: str, title: str) -> str:
+    import hashlib
     key = f"{topic}|{title.strip().lower()}"
     return hashlib.sha1(key.encode("utf-8")).hexdigest()

+ 23 - 8
test_news_mcp.py

@@ -386,7 +386,11 @@ def test_refresh_skips_reprocessing_when_feed_hash_is_unchanged(monkeypatch):
             self.meta[key] = value
 
     monkeypatch.setattr(poller, "SQLiteClusterStore", DummyStore)
-    monkeypatch.setattr(poller, "fetch_news_articles", lambda limit: [{"title": "Bitcoin rallies", "url": "https://example.com/a", "timestamp": "Wed, 01 Apr 2026 12:00:00 GMT"}])
+
+    async def _mock_fetch(limit):
+        calls["fetch"] += 1
+        return [{"title": "Bitcoin rallies", "url": "https://example.com/a", "timestamp": "Wed, 01 Apr 2026 12:00:00 GMT"}]
+    monkeypatch.setattr(poller, "fetch_news_articles", _mock_fetch)
     monkeypatch.setattr(poller.asyncio, "to_thread", fake_to_thread)
     monkeypatch.setattr(poller, "dedup_and_cluster_articles", fake_cluster)
     monkeypatch.setattr(poller, "enrich_cluster", fake_enrich)
@@ -627,10 +631,8 @@ def test_poller_persists_clusters_under_post_enrichment_topic(monkeypatch):
         def set_meta(self, key, value):
             pass
 
-    async def fake_to_thread(fn, limit):
-        return [
-            {"title": "SEC fines firm", "url": "https://example.com/a", "source": "S", "timestamp": "Wed, 01 Apr 2026 12:00:00 GMT", "summary": "..."},
-        ]
+        def set_feed_state(self, feed_key, last_hash, item_count):
+            pass
 
     def fake_cluster(articles):
         # Heuristic put it in "other" (no crypto/macro/regulation/ai keywords
@@ -668,8 +670,13 @@ def test_poller_persists_clusters_under_post_enrichment_topic(monkeypatch):
         return out
 
     monkeypatch.setattr(poller, "SQLiteClusterStore", DummyStore)
-    monkeypatch.setattr(poller, "fetch_news_articles", lambda limit: [])
-    monkeypatch.setattr(poller.asyncio, "to_thread", fake_to_thread)
+
+    async def _mock_fetch2(limit):
+        return [
+            {"title": "SEC fines firm", "url": "https://example.com/a", "source": "S",
+             "timestamp": "Wed, 01 Apr 2026 12:00:00 GMT", "summary": "..."},
+        ]
+    monkeypatch.setattr(poller, "fetch_news_articles", _mock_fetch2)
     monkeypatch.setattr(poller, "dedup_and_cluster_articles", fake_cluster)
     monkeypatch.setattr(poller, "enrich_cluster", fake_enrich)
     monkeypatch.setattr(poller, "classify_cluster_llm", fake_classify)
@@ -677,7 +684,15 @@ def test_poller_persists_clusters_under_post_enrichment_topic(monkeypatch):
     asyncio.run(poller.refresh_clusters(topic=None, limit=10))
 
     assert captured["upserts"], "Expected at least one upsert call"
-    upsert = captured["upserts"][0]
+    # The poller first stores raw clusters (topic=heuristic), then enriched
+    # clusters (topic=post-LLM).  The enriched upsert is the one whose row_topic
+    # reflects the LLM classification.
+    enriched_upserts = [u for u in captured["upserts"] if u["row_topic"] == "regulation"]
+    assert enriched_upserts, (
+        f"Expected at least one upsert with row_topic='regulation', "
+        f"got topics: {[u['row_topic'] for u in captured['upserts']]}"
+    )
+    upsert = enriched_upserts[0]
     assert upsert["row_topic"] == "regulation", (
         f"Expected SQL row topic to follow the LLM's classification 'regulation', got {upsert['row_topic']!r}"
     )