eval_extraction.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378
  1. #!/usr/bin/env python3
  2. """
  3. Evaluation harness for entity/keyword extraction prompt.
  4. Usage:
  5. python scripts/eval_extraction.py # Run against golden samples
  6. python scripts/eval_extraction.py --verbose # Show per-sample details
  7. python scripts/eval_extraction.py --model llama-3.1-8b-instant # Test specific model
  8. python scripts/eval_extraction.py --collect N # Collect N new samples from live DB
  9. """
  10. from __future__ import annotations
  11. import argparse
  12. import asyncio
  13. import json
  14. import sys
  15. from pathlib import Path
  16. from typing import Any
  17. # Add project root to path
  18. sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
  19. from news_mcp.llm import build_extraction_prompt, call_llm, load_prompt
  20. from news_mcp.config import (
  21. NEWS_EXTRACT_PROVIDER,
  22. NEWS_EXTRACT_MODEL,
  23. NEWS_ENTITY_BLACKLIST,
  24. )
  25. from news_mcp.entity_normalize import normalize_entities
  26. from news_mcp.enrichment.llm_enrich import _filter_entities
  27. # ---------------------------------------------------------------------------
  28. # Golden samples (curated from real clusters)
  29. # ---------------------------------------------------------------------------
  30. GOLDEN_SAMPLES = [
  31. {
  32. "name": "sec_binance_lawsuit",
  33. "cluster": {
  34. "headline": "SEC sues Binance over unregistered securities",
  35. "summary": "The Securities and Exchange Commission filed a lawsuit against Binance, the world's largest crypto exchange, alleging it operated as an unregistered securities exchange and commingled customer funds.",
  36. },
  37. "expected": {
  38. "entities": ["SEC", "Binance"],
  39. "keywords": ["securities law", "crypto exchange", "enforcement action"],
  40. "topic": "regulation",
  41. },
  42. },
  43. {
  44. "name": "fed_rates_inflation",
  45. "cluster": {
  46. "headline": "Fed holds rates steady as inflation cools",
  47. "summary": "The Federal Reserve kept interest rates unchanged at 5.25-5.50%, citing progress on inflation but signaling caution on future cuts.",
  48. },
  49. "expected": {
  50. "entities": ["Federal Reserve"],
  51. "keywords": ["interest rates", "inflation", "monetary policy"],
  52. "topic": "macro",
  53. },
  54. },
  55. {
  56. "name": "israel_iran_syria_strikes",
  57. "cluster": {
  58. "headline": "Israel strikes Iranian missile sites in Syria",
  59. "summary": "Israeli warplanes targeted Iranian missile depots near Damascus overnight, escalating regional tensions.",
  60. },
  61. "expected": {
  62. "entities": ["Israel", "Iran", "Syria", "Damascus"],
  63. "keywords": ["airstrikes", "missile sites", "regional escalation"],
  64. "topic": "other",
  65. },
  66. },
  67. {
  68. "name": "bitcoin_etf_flows",
  69. "cluster": {
  70. "headline": "Bitcoin ETFs see record inflows as BTC tops $70k",
  71. "summary": "US spot Bitcoin ETFs attracted $2.3 billion in net inflows this week as Bitcoin surged past $70,000, driven by institutional demand.",
  72. },
  73. "expected": {
  74. "entities": ["Bitcoin", "BTC"],
  75. "keywords": ["ETF inflows", "institutional demand", "price surge"],
  76. "topic": "crypto",
  77. },
  78. },
  79. {
  80. "name": "ai_regulation_eu",
  81. "cluster": {
  82. "headline": "EU AI Act enters force with strict rules for high-risk systems",
  83. "summary": "The European Union's landmark AI Act took effect today, imposing strict requirements on high-risk AI systems including transparency, human oversight, and risk management.",
  84. },
  85. "expected": {
  86. "entities": ["European Union", "EU AI Act"],
  87. "keywords": ["AI regulation", "high-risk systems", "compliance requirements"],
  88. "topic": "ai",
  89. },
  90. },
  91. {
  92. "name": "china_economy_stimulus",
  93. "cluster": {
  94. "headline": "China unveils stimulus package to boost slowing economy",
  95. "summary": "Beijing announced a comprehensive stimulus package including infrastructure spending, tax cuts, and monetary easing to counter slowing growth and property sector weakness.",
  96. },
  97. "expected": {
  98. "entities": ["China", "Beijing"],
  99. "keywords": ["stimulus package", "infrastructure spending", "monetary easing", "property sector"],
  100. "topic": "macro",
  101. },
  102. },
  103. {
  104. "name": "oil_prices_opepcuts",
  105. "cluster": {
  106. "headline": "Oil jumps after OPEC+ extends production cuts",
  107. "summary": "Crude oil prices rose 3% after OPEC+ agreed to extend production cuts through year-end, tightening global supply amid demand concerns.",
  108. },
  109. "expected": {
  110. "entities": ["OPEC+"],
  111. "keywords": ["production cuts", "oil prices", "global supply", "demand concerns"],
  112. "topic": "macro",
  113. },
  114. },
  115. {
  116. "name": "nvidia_earnings_ai",
  117. "cluster": {
  118. "headline": "Nvidia beats earnings on AI chip demand",
  119. "summary": "Nvidia reported quarterly revenue of $26 billion, up 262% year-over-year, driven by insatiable demand for its H100 and Blackwell AI chips.",
  120. },
  121. "expected": {
  122. "entities": ["Nvidia", "H100", "Blackwell"],
  123. "keywords": ["AI chips", "earnings beat", "revenue growth", "chip demand"],
  124. "topic": "ai",
  125. },
  126. },
  127. {
  128. "name": "ecb_rate_decision",
  129. "cluster": {
  130. "headline": "ECB cuts rates as eurozone inflation falls to 2.4%",
  131. "summary": "The European Central Bank lowered its deposit rate by 25 basis points to 3.75%, marking its first cut since 2019 as inflation approaches target.",
  132. },
  133. "expected": {
  134. "entities": ["European Central Bank", "ECB"],
  135. "keywords": ["rate cut", "eurozone inflation", "deposit rate", "monetary easing"],
  136. "topic": "macro",
  137. },
  138. },
  139. {
  140. "name": "trump_legal_cases",
  141. "cluster": {
  142. "headline": "Trump convicted in hush money trial, first former president found guilty",
  143. "summary": "A New York jury found Donald Trump guilty on all 34 counts of falsifying business records in the hush money case, making him the first former US president convicted of a crime.",
  144. },
  145. "expected": {
  146. "entities": ["Donald Trump", "New York"],
  147. "keywords": ["criminal conviction", "hush money", "falsifying records", "historic verdict"],
  148. "topic": "other",
  149. },
  150. },
  151. ]
  152. # ---------------------------------------------------------------------------
  153. # Scoring functions
  154. # ---------------------------------------------------------------------------
  155. def normalize_list(items: list[str]) -> set[str]:
  156. """Normalize list of strings for comparison: lowercase, strip, dedupe."""
  157. return {str(x).strip().lower() for x in items if x and str(x).strip()}
  158. def score_extraction(pred: dict[str, Any], gold: dict[str, Any]) -> dict[str, float]:
  159. """Score a single prediction against gold standard."""
  160. pred_entities = normalize_list(pred.get("entities", []))
  161. gold_entities = normalize_list(gold.get("entities", []))
  162. pred_keywords = normalize_list(pred.get("keywords", []))
  163. gold_keywords = normalize_list(gold.get("keywords", []))
  164. # Entity precision/recall/F1
  165. if pred_entities:
  166. ent_p = len(pred_entities & gold_entities) / len(pred_entities)
  167. else:
  168. ent_p = 0.0
  169. if gold_entities:
  170. ent_r = len(pred_entities & gold_entities) / len(gold_entities)
  171. else:
  172. ent_r = 0.0
  173. ent_f1 = 2 * ent_p * ent_r / (ent_p + ent_r) if (ent_p + ent_r) > 0 else 0.0
  174. # Keyword precision/recall/F1 (allow partial overlap since keywords are thematic)
  175. if pred_keywords:
  176. kw_p = len(pred_keywords & gold_keywords) / len(pred_keywords)
  177. else:
  178. kw_p = 0.0
  179. if gold_keywords:
  180. kw_r = len(pred_keywords & gold_keywords) / len(gold_keywords)
  181. else:
  182. kw_r = 0.0
  183. kw_f1 = 2 * kw_p * kw_r / (kw_p + kw_r) if (kw_p + kw_r) > 0 else 0.0
  184. # Leakage: entities appearing in keywords (should be 0)
  185. leakage = len(pred_entities & pred_keywords)
  186. # Topic accuracy
  187. topic_acc = 1.0 if pred.get("topic") == gold.get("topic") else 0.0
  188. return {
  189. "ent_p": round(ent_p, 3),
  190. "ent_r": round(ent_r, 3),
  191. "ent_f1": round(ent_f1, 3),
  192. "kw_p": round(kw_p, 3),
  193. "kw_r": round(kw_r, 3),
  194. "kw_f1": round(kw_f1, 3),
  195. "leakage": leakage,
  196. "topic_acc": topic_acc,
  197. }
  198. def aggregate_scores(scores: list[dict[str, float]]) -> dict[str, float]:
  199. """Compute macro averages across samples."""
  200. if not scores:
  201. return {}
  202. keys = scores[0].keys()
  203. return {k: round(sum(s[k] for s in scores) / len(scores), 3) for k in keys}
  204. def print_sample_result(name: str, pred: dict, gold: dict, scores: dict, verbose: bool):
  205. """Print result for a single sample."""
  206. print(f"\n{'='*60}")
  207. print(f"Sample: {name}")
  208. print(f"{'='*60}")
  209. print(f" Topic: pred={pred.get('topic')} | gold={gold.get('topic')} | {'✓' if scores['topic_acc'] else '✗'}")
  210. print(f" Entities: P={scores['ent_p']:.2f} R={scores['ent_r']:.2f} F1={scores['ent_f1']:.2f}")
  211. print(f" Pred: {pred.get('entities', [])}")
  212. print(f" Gold: {gold.get('entities', [])}")
  213. print(f" Keywords: P={scores['kw_p']:.2f} R={scores['kw_r']:.2f} F1={scores['kw_f1']:.2f}")
  214. print(f" Pred: {pred.get('keywords', [])}")
  215. print(f" Gold: {gold.get('keywords', [])}")
  216. if scores["leakage"] > 0:
  217. leaked = set(e.lower() for e in pred.get("entities", [])) & set(k.lower() for k in pred.get("keywords", []))
  218. print(f" ⚠ LEAKAGE: {leaked} (entities in keywords)")
  219. if verbose:
  220. print(f" Sentiment: {pred.get('sentiment')} ({pred.get('sentimentScore')})")
  221. async def run_extraction(cluster: dict[str, Any], provider: str, model: str) -> dict[str, Any]:
  222. """Run extraction on a single cluster."""
  223. prompt = load_prompt("extract_entities.prompt")
  224. # Build the full user prompt
  225. import json as json_lib
  226. user_prompt = prompt.replace("{cluster_json}", json_lib.dumps(cluster, ensure_ascii=False))
  227. from news_mcp.config import GROQ_API_KEY, OPENAI_API_KEY, OPENROUTER_API_KEY
  228. system_prompt = "You are a news signal extraction engine. Return STRICT JSON only."
  229. content = await call_llm(provider, model, system_prompt, user_prompt)
  230. return json.loads(content)
  231. async def evaluate_samples(samples: list[dict], provider: str, model: str, verbose: bool) -> dict:
  232. """Run evaluation on all samples."""
  233. print(f"\nEvaluating {len(samples)} samples with {provider}/{model}...")
  234. print("-" * 60)
  235. all_scores = []
  236. results = []
  237. for i, sample in enumerate(samples, 1):
  238. name = sample["name"]
  239. cluster = sample["cluster"]
  240. gold = sample["expected"]
  241. print(f"[{i}/{len(samples)}] {name}...", end=" ", flush=True)
  242. try:
  243. pred = await run_extraction(cluster, provider, model)
  244. # Apply same post-processing as production pipeline
  245. from news_mcp.enrichment.llm_enrich import _filter_entities, normalize_entities
  246. from news_mcp.config import DEFAULT_TOPICS, NEWS_ENTITY_BLACKLIST
  247. entities = _filter_entities(normalize_entities(pred.get("entities", [])), blacklist=NEWS_ENTITY_BLACKLIST)
  248. keywords = _filter_entities(normalize_entities(pred.get("keywords", [])), blacklist=NEWS_ENTITY_BLACKLIST)
  249. # Filter topic labels from keywords
  250. topic_labels = {t.lower() for t in DEFAULT_TOPICS}
  251. keywords = [k for k in keywords if k.lower() not in topic_labels]
  252. # Length cap
  253. keywords = [k for k in keywords if len(k.split()) <= 2]
  254. # De-duplicate entities vs keywords
  255. entity_keys = {e.strip().lower() for e in entities}
  256. keywords = [k for k in keywords if k.strip().lower() not in entity_keys]
  257. pred_processed = {
  258. "topic": pred.get("topic", "other"),
  259. "entities": entities,
  260. "keywords": keywords,
  261. "sentiment": pred.get("sentiment", "neutral"),
  262. "sentimentScore": pred.get("sentimentScore", 0.0),
  263. }
  264. scores = score_extraction(pred_processed, gold)
  265. all_scores.append(scores)
  266. results.append({"name": name, "pred": pred_processed, "gold": gold, "scores": scores})
  267. if verbose:
  268. print_sample_result(name, pred_processed, gold, scores, verbose)
  269. else:
  270. print(f"Ent_F1={scores['ent_f1']:.2f} Kw_F1={scores['kw_f1']:.2f} Leak={scores['leakage']} Topic={'✓' if scores['topic_acc'] else '✗'}")
  271. except Exception as e:
  272. print(f"ERROR: {e}")
  273. results.append({"name": name, "error": str(e)})
  274. # Aggregate
  275. agg = aggregate_scores(all_scores)
  276. print(f"\n{'='*60}")
  277. print("AGGREGATE SCORES (macro average)")
  278. print(f"{'='*60}")
  279. print(f" Entity F1: {agg.get('ent_f1', 0):.3f} (P={agg.get('ent_p', 0):.3f} R={agg.get('ent_r', 0):.3f})")
  280. print(f" Keyword F1: {agg.get('kw_f1', 0):.3f} (P={agg.get('kw_p', 0):.3f} R={agg.get('kw_r', 0):.3f})")
  281. print(f" Leakage (avg): {agg.get('leakage', 0):.3f} (entities in keywords)")
  282. print(f" Topic Acc: {agg.get('topic_acc', 0):.3f}")
  283. return {"aggregate": agg, "per_sample": results}
  284. async def collect_samples_from_db(n: int, output_file: str):
  285. """Collect N recent clusters from live DB as new golden samples."""
  286. from news_mcp.storage.sqlite_store import SQLiteClusterStore
  287. from news_mcp.config import DB_PATH
  288. print(f"Collecting {n} samples from {DB_PATH}...")
  289. store = SQLiteClusterStore(DB_PATH)
  290. clusters = store.get_latest_clusters_all_topics(ttl_hours=24*30, limit=n)
  291. samples = []
  292. for c in clusters:
  293. samples.append({
  294. "name": f"live_{c['cluster_id'][:8]}",
  295. "cluster": {"headline": c["headline"], "summary": c["summary"]},
  296. "expected": {"entities": [], "keywords": [], "topic": c.get("topic", "other")},
  297. })
  298. # Save to file for manual annotation
  299. output_path = Path(output_file)
  300. output_path.write_text(json.dumps(samples, indent=2, ensure_ascii=False))
  301. print(f"Saved {len(samples)} samples to {output_path}")
  302. print("Edit the file to fill in expected entities/keywords, then add to GOLDEN_SAMPLES.")
  303. def main():
  304. parser = argparse.ArgumentParser(description="Evaluate extraction prompt")
  305. parser.add_argument("--provider", default=NEWS_EXTRACT_PROVIDER, help="LLM provider")
  306. parser.add_argument("--model", default=NEWS_EXTRACT_MODEL, help="LLM model")
  307. parser.add_argument("--verbose", "-v", action="store_true", help="Verbose output")
  308. parser.add_argument("--collect", type=int, metavar="N", help="Collect N samples from live DB")
  309. parser.add_argument("--output", default="new_samples.json", help="Output file for collected samples")
  310. args = parser.parse_args()
  311. if args.collect:
  312. asyncio.run(collect_samples_from_db(args.collect, args.output))
  313. return
  314. # Run evaluation
  315. result = asyncio.run(evaluate_samples(GOLDEN_SAMPLES, args.provider, args.model, args.verbose))
  316. # Exit code based on quality threshold
  317. agg = result["aggregate"]
  318. if agg.get("ent_f1", 0) < 0.5 or agg.get("kw_f1", 0) < 0.4 or agg.get("leakage", 99) > 0.5:
  319. print("\n⚠ Quality below threshold!")
  320. sys.exit(1)
  321. else:
  322. print("\n✓ Quality acceptable")
  323. sys.exit(0)
  324. if __name__ == "__main__":
  325. main()