test_news_mcp.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250
  1. from __future__ import annotations
  2. import tempfile
  3. from pathlib import Path
  4. from news_mcp.dedup.cluster import dedup_and_cluster_articles
  5. from news_mcp.storage.sqlite_store import SQLiteClusterStore
  6. from news_mcp.enrichment.importance import compute_importance
  7. from news_mcp.enrichment.llm_enrich import _filter_entities, _matches_blacklist
  8. from news_mcp.entity_normalize import normalize_query, normalize_entities
  9. from news_mcp.llm import build_extraction_prompt, call_llm, load_prompt
  10. from news_mcp.trends_resolution import resolve_entity_via_trends
  11. from news_mcp.mcp_server_fastmcp import _sort_clusters_by_recency
  12. def _article(title: str, url: str = "https://example.com/x", source: str = "Src", ts: str = "Mon, 30 Mar 2026 12:00:00 GMT"):
  13. return {
  14. "title": title,
  15. "url": url,
  16. "source": source,
  17. "timestamp": ts,
  18. "summary": "summary text",
  19. }
  20. def test_dedup_merges_similar_titles():
  21. articles = [
  22. _article("Trump warns Iran war could spread"),
  23. _article("Trump warns Iran conflict could spread"),
  24. _article("Unrelated sports result"),
  25. ]
  26. clustered = dedup_and_cluster_articles(articles, similarity_threshold=0.75)
  27. # We expect the Trump/Iran items to be merged into one cluster in the same topic bucket.
  28. total_clusters = sum(len(v) for v in clustered.values())
  29. assert total_clusters == 2
  30. def test_sqlite_feed_hash_roundtrip():
  31. with tempfile.TemporaryDirectory() as td:
  32. db = Path(td) / "news.sqlite"
  33. store = SQLiteClusterStore(db)
  34. assert store.get_feed_hash("breakingthenews") is None
  35. store.set_feed_hash("breakingthenews", "abc123")
  36. assert store.get_feed_hash("breakingthenews") == "abc123"
  37. def test_sqlite_summary_cache_roundtrip():
  38. with tempfile.TemporaryDirectory() as td:
  39. db = Path(td) / "news.sqlite"
  40. store = SQLiteClusterStore(db)
  41. # Upsert a base cluster first.
  42. store.upsert_clusters([
  43. {
  44. "cluster_id": "cid1",
  45. "headline": "Headline",
  46. "summary": "Summary",
  47. "entities": ["Iran"],
  48. "sentiment": "negative",
  49. "importance": 0.5,
  50. "sources": ["BreakingTheNews"],
  51. "timestamp": "Mon, 30 Mar 2026 12:00:00 GMT",
  52. "articles": [],
  53. "first_seen": "Mon, 30 Mar 2026 12:00:00 GMT",
  54. "last_updated": "Mon, 30 Mar 2026 12:00:00 GMT",
  55. }
  56. ], topic="other")
  57. store.upsert_cluster_summary(
  58. "cid1",
  59. {
  60. "headline": "Headline",
  61. "mergedSummary": "Merged summary",
  62. "keyFacts": ["Fact 1"],
  63. "sources": ["BreakingTheNews"],
  64. },
  65. )
  66. cached = store.get_cluster_summary("cid1", ttl_hours=24)
  67. assert cached is not None
  68. assert cached["mergedSummary"] == "Merged summary"
  69. assert cached["keyFacts"] == ["Fact 1"]
  70. def test_prune_clusters_deletes_rows_older_than_retention():
  71. with tempfile.TemporaryDirectory() as td:
  72. db = Path(td) / "news.sqlite"
  73. store = SQLiteClusterStore(db)
  74. store.upsert_clusters([
  75. {
  76. "cluster_id": "fresh",
  77. "headline": "Fresh",
  78. "summary": "Fresh summary",
  79. "entities": ["Bitcoin"],
  80. "timestamp": "Wed, 01 Apr 2026 12:00:00 GMT",
  81. "articles": [],
  82. },
  83. {
  84. "cluster_id": "stale",
  85. "headline": "Stale",
  86. "summary": "Stale summary",
  87. "entities": ["Iran"],
  88. "timestamp": "Wed, 01 Apr 2026 11:00:00 GMT",
  89. "articles": [],
  90. },
  91. ], topic="other")
  92. with store._conn() as conn:
  93. conn.execute(
  94. "UPDATE clusters SET updated_at=? WHERE cluster_id=?",
  95. ("2025-01-01T00:00:00+00:00", "stale"),
  96. )
  97. deleted = store.prune_clusters(retention_days=30)
  98. assert deleted == 1
  99. assert store.get_cluster_by_id("stale") is None
  100. assert store.get_cluster_by_id("fresh") is not None
  101. assert store.get_prune_state(pruning_enabled=True, retention_days=30, interval_hours=24)["last_prune_at"] is not None
  102. def test_prune_if_due_skips_deletes_when_pruning_disabled():
  103. with tempfile.TemporaryDirectory() as td:
  104. db = Path(td) / "news.sqlite"
  105. store = SQLiteClusterStore(db)
  106. store.upsert_clusters([
  107. {
  108. "cluster_id": "stale",
  109. "headline": "Stale",
  110. "summary": "Stale summary",
  111. "entities": ["Iran"],
  112. "timestamp": "Wed, 01 Apr 2026 11:00:00 GMT",
  113. "articles": [],
  114. }
  115. ], topic="other")
  116. with store._conn() as conn:
  117. conn.execute(
  118. "UPDATE clusters SET updated_at=? WHERE cluster_id=?",
  119. ("2025-01-01T00:00:00+00:00", "stale"),
  120. )
  121. result = store.prune_if_due(pruning_enabled=False, retention_days=30, interval_hours=24)
  122. assert result["enabled"] is False
  123. assert result["deleted"] == 0
  124. assert store.get_cluster_by_id("stale") is not None
  125. def test_blacklist_filters_entities_case_insensitively():
  126. entities = ["Bloomberg", "Reuters", "bloomberg", "CoinDesk"]
  127. filtered = _filter_entities(entities, blacklist=["bloomberg"])
  128. assert filtered == ["Reuters", "CoinDesk"]
  129. def test_blacklist_supports_wildcards():
  130. assert _matches_blacklist("Bloomberg Economics", blacklist=["bloomberg*"])
  131. assert _matches_blacklist("bloomberg", blacklist=["*berg"])
  132. assert not _matches_blacklist("Reuters", blacklist=["bloomberg*"])
  133. def test_query_normalization_keeps_common_shorthand_working():
  134. assert normalize_query("btc") == "Bitcoin"
  135. assert normalize_query("Trump") == "Donald Trump"
  136. assert normalize_query("nvidia") == "nvidia"
  137. def test_entity_normalization_deduplicates_aliases():
  138. assert normalize_entities(["btc", "Bitcoin", "BTC", "Ethereum"]) == ["Bitcoin", "Ethereum"]
  139. def test_load_prompt_reads_prompt_files():
  140. text = load_prompt("extract_entities.prompt")
  141. assert "Return STRICT JSON" in text
  142. def test_resolve_entity_falls_back_cleanly_when_provider_unavailable(monkeypatch):
  143. import news_mcp.trends_resolution as trends_resolution
  144. trends_resolution.resolve_entity_via_trends.cache_clear()
  145. trends_resolution._provider.cache_clear()
  146. monkeypatch.setattr(trends_resolution, "_provider", lambda: None)
  147. resolved = resolve_entity_via_trends("btc")
  148. assert resolved["normalized"] == "Bitcoin"
  149. assert resolved["canonical_label"] == "Bitcoin"
  150. assert resolved["mid"] is None
  151. assert resolved["candidates"] == []
  152. assert resolved["source"] == "fallback"
  153. trends_resolution.resolve_entity_via_trends.cache_clear()
  154. def test_sort_clusters_by_recency_prefers_newer_timestamp_over_importance():
  155. clusters = [
  156. {"headline": "older", "timestamp": "Wed, 01 Apr 2026 10:00:00 GMT", "importance": 0.9},
  157. {"headline": "newer", "timestamp": "Wed, 01 Apr 2026 11:00:00 GMT", "importance": 0.1},
  158. ]
  159. sorted_clusters = _sort_clusters_by_recency(clusters)
  160. assert [c["headline"] for c in sorted_clusters] == ["newer", "older"]
  161. def test_build_extraction_prompt_is_stable_without_blacklist():
  162. cluster = {
  163. "headline": "Bloomberg reports Bitcoin rallies after US rate comments",
  164. "summary": "A report from Bloomberg says Bitcoin moved higher after comments from the Fed.",
  165. "articles": [],
  166. }
  167. prompt = build_extraction_prompt(cluster)
  168. assert "Bloomberg reports Bitcoin rallies" in prompt
  169. assert "Do NOT return empty entities" in prompt
  170. assert "Bloomberg" in prompt # present in the input, not filtered here
  171. def test_call_llm_dispatches_to_selected_provider(monkeypatch):
  172. async def fake_groq(model, messages, response_json=True):
  173. return '{"ok": true, "provider": "groq"}'
  174. async def fake_openai(model, messages, response_json=True):
  175. return '{"ok": true, "provider": "openai"}'
  176. monkeypatch.setattr("news_mcp.llm._call_groq", fake_groq)
  177. monkeypatch.setattr("news_mcp.llm._call_openai", fake_openai)
  178. import asyncio
  179. groq = asyncio.run(call_llm("groq", "x", "sys", "user"))
  180. openai = asyncio.run(call_llm("openai", "x", "sys", "user"))
  181. assert '"provider": "groq"' in groq
  182. assert '"provider": "openai"' in openai
  183. def test_importance_prefers_llm_signal():
  184. # Two clusters with same coverage but different sentiment magnitude.
  185. base = {
  186. "sources": ["A", "B"],
  187. "articles": [{}, {}],
  188. "sentiment": "neutral",
  189. "sentimentScore": 0.0,
  190. }
  191. pos = dict(base, sentimentScore=0.9)
  192. neg = dict(base, sentimentScore=-0.8)
  193. imp_base = compute_importance(base)
  194. imp_pos = compute_importance(pos)
  195. imp_neg = compute_importance(neg)
  196. assert imp_pos >= imp_base
  197. assert imp_neg >= imp_base