eval_extraction.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407
  1. #!/usr/bin/env python3
  2. """
  3. Evaluation harness for entity/keyword extraction prompt.
  4. Usage:
  5. python scripts/eval_extraction.py # Run against golden samples
  6. python scripts/eval_extraction.py --verbose # Show per-sample details
  7. python scripts/eval_extraction.py --model llama-3.1-8b-instant # Test specific model
  8. python scripts/eval_extraction.py --collect N # Collect N new samples from live DB
  9. python scripts/eval_extraction.py --prompt-file prompts/extract_entities_fewshot.prompt # Test alternate prompt
  10. """
  11. from __future__ import annotations
  12. import argparse
  13. import asyncio
  14. import json
  15. import sys
  16. from pathlib import Path
  17. from typing import Any
  18. # Add project root to path
  19. sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
  20. from news_mcp.llm import call_llm, load_prompt
  21. from news_mcp.config import (
  22. NEWS_EXTRACT_PROVIDER,
  23. NEWS_EXTRACT_MODEL,
  24. NEWS_ENTITY_BLACKLIST,
  25. DEFAULT_TOPICS,
  26. )
  27. from news_mcp.entity_normalize import normalize_entities
  28. from news_mcp.enrichment.llm_enrich import _filter_entities
  29. # ---------------------------------------------------------------------------
  30. # Load golden samples from JSON file
  31. # ---------------------------------------------------------------------------
  32. def load_golden_samples(filepath: str = "data/annotated_samples.json") -> list[dict]:
  33. """Load annotated samples from JSON file."""
  34. path = Path(__file__).resolve().parent.parent / filepath
  35. if not path.exists():
  36. # Fallback to built-in samples
  37. return GOLDEN_SAMPLES
  38. return json.loads(path.read_text(encoding="utf-8"))
  39. # Fallback built-in samples (original 10)
  40. GOLDEN_SAMPLES = [
  41. {
  42. "name": "sec_binance_lawsuit",
  43. "cluster": {
  44. "headline": "SEC sues Binance over unregistered securities",
  45. "summary": "The Securities and Exchange Commission filed a lawsuit against Binance, the world's largest crypto exchange, alleging it operated as an unregistered securities exchange and commingled customer funds.",
  46. },
  47. "expected": {
  48. "entities": ["SEC", "Binance"],
  49. "keywords": ["securities law", "crypto exchange", "enforcement action"],
  50. "topic": "regulation",
  51. },
  52. },
  53. {
  54. "name": "fed_rates_inflation",
  55. "cluster": {
  56. "headline": "Fed holds rates steady as inflation cools",
  57. "summary": "The Federal Reserve kept interest rates unchanged at 5.25-5.50%, citing progress on inflation but signaling caution on future cuts.",
  58. },
  59. "expected": {
  60. "entities": ["Federal Reserve"],
  61. "keywords": ["interest rates", "inflation", "monetary policy"],
  62. "topic": "macro",
  63. },
  64. },
  65. {
  66. "name": "israel_iran_syria_strikes",
  67. "cluster": {
  68. "headline": "Israel strikes Iranian missile sites in Syria",
  69. "summary": "Israeli warplanes targeted Iranian missile depots near Damascus overnight, escalating regional tensions.",
  70. },
  71. "expected": {
  72. "entities": ["Israel", "Iran", "Syria", "Damascus"],
  73. "keywords": ["airstrikes", "missile sites", "regional escalation"],
  74. "topic": "other",
  75. },
  76. },
  77. {
  78. "name": "bitcoin_etf_flows",
  79. "cluster": {
  80. "headline": "Bitcoin ETFs see record inflows as BTC tops $70k",
  81. "summary": "US spot Bitcoin ETFs attracted $2.3 billion in net inflows this week as Bitcoin surged past $70,000, driven by institutional demand.",
  82. },
  83. "expected": {
  84. "entities": ["Bitcoin", "BTC"],
  85. "keywords": ["ETF inflows", "institutional demand", "price surge"],
  86. "topic": "crypto",
  87. },
  88. },
  89. {
  90. "name": "ai_regulation_eu",
  91. "cluster": {
  92. "headline": "EU AI Act enters force with strict rules for high-risk systems",
  93. "summary": "The European Union's landmark AI Act took effect today, imposing strict requirements on high-risk AI systems including transparency, human oversight, and risk management.",
  94. },
  95. "expected": {
  96. "entities": ["European Union", "EU AI Act"],
  97. "keywords": ["AI regulation", "high-risk systems", "compliance requirements"],
  98. "topic": "ai",
  99. },
  100. },
  101. {
  102. "name": "china_economy_stimulus",
  103. "cluster": {
  104. "headline": "China unveils stimulus package to boost slowing economy",
  105. "summary": "Beijing announced a comprehensive stimulus package including infrastructure spending, tax cuts, and monetary easing to counter slowing growth and property sector weakness.",
  106. },
  107. "expected": {
  108. "entities": ["China", "Beijing"],
  109. "keywords": ["stimulus package", "infrastructure spending", "monetary easing", "property sector"],
  110. "topic": "macro",
  111. },
  112. },
  113. {
  114. "name": "oil_prices_opepcuts",
  115. "cluster": {
  116. "headline": "Oil jumps after OPEC+ extends production cuts",
  117. "summary": "Crude oil prices rose 3% after OPEC+ agreed to extend production cuts through year-end, tightening global supply amid demand concerns.",
  118. },
  119. "expected": {
  120. "entities": ["OPEC+"],
  121. "keywords": ["production cuts", "oil prices", "global supply", "demand concerns"],
  122. "topic": "macro",
  123. },
  124. },
  125. {
  126. "name": "nvidia_earnings_ai",
  127. "cluster": {
  128. "headline": "Nvidia beats earnings on AI chip demand",
  129. "summary": "Nvidia reported quarterly revenue of $26 billion, up 262% year-over-year, driven by insatiable demand for its H100 and Blackwell AI chips.",
  130. },
  131. "expected": {
  132. "entities": ["Nvidia", "H100", "Blackwell"],
  133. "keywords": ["AI chips", "earnings beat", "revenue growth", "chip demand"],
  134. "topic": "ai",
  135. },
  136. },
  137. {
  138. "name": "ecb_rate_decision",
  139. "cluster": {
  140. "headline": "ECB cuts rates as eurozone inflation falls to 2.4%",
  141. "summary": "The European Central Bank lowered its deposit rate by 25 basis points to 3.75%, marking its first cut since 2019 as inflation approaches target.",
  142. },
  143. "expected": {
  144. "entities": ["European Central Bank", "ECB"],
  145. "keywords": ["rate cut", "eurozone inflation", "deposit rate", "monetary easing"],
  146. "topic": "macro",
  147. },
  148. },
  149. {
  150. "name": "trump_legal_cases",
  151. "cluster": {
  152. "headline": "Trump convicted in hush money trial, first former president found guilty",
  153. "summary": "A New York jury found Donald Trump guilty on all 34 counts of falsifying business records in the hush money case, making him the first former US president convicted of a crime.",
  154. },
  155. "expected": {
  156. "entities": ["Donald Trump", "New York"],
  157. "keywords": ["criminal conviction", "hush money", "falsifying records", "historic verdict"],
  158. "topic": "other",
  159. },
  160. },
  161. ]
  162. # ---------------------------------------------------------------------------
  163. # Scoring functions
  164. # ---------------------------------------------------------------------------
  165. def normalize_list(items: list[str]) -> set[str]:
  166. """Normalize list of strings for comparison: lowercase, strip, dedupe."""
  167. return {str(x).strip().lower() for x in items if x and str(x).strip()}
  168. def score_extraction(pred: dict[str, Any], gold: dict[str, Any]) -> dict[str, float]:
  169. """Score a single prediction against gold standard."""
  170. pred_entities = normalize_list(pred.get("entities", []))
  171. gold_entities = normalize_list(gold.get("entities", []))
  172. pred_keywords = normalize_list(pred.get("keywords", []))
  173. gold_keywords = normalize_list(gold.get("keywords", []))
  174. # Entity precision/recall/F1
  175. if pred_entities:
  176. ent_p = len(pred_entities & gold_entities) / len(pred_entities)
  177. else:
  178. ent_p = 0.0
  179. if gold_entities:
  180. ent_r = len(pred_entities & gold_entities) / len(gold_entities)
  181. else:
  182. ent_r = 0.0
  183. ent_f1 = 2 * ent_p * ent_r / (ent_p + ent_r) if (ent_p + ent_r) > 0 else 0.0
  184. # Keyword precision/recall/F1 (allow partial overlap since keywords are thematic)
  185. if pred_keywords:
  186. kw_p = len(pred_keywords & gold_keywords) / len(pred_keywords)
  187. else:
  188. kw_p = 0.0
  189. if gold_keywords:
  190. kw_r = len(pred_keywords & gold_keywords) / len(gold_keywords)
  191. else:
  192. kw_r = 0.0
  193. kw_f1 = 2 * kw_p * kw_r / (kw_p + kw_r) if (kw_p + kw_r) > 0 else 0.0
  194. # Leakage: entities appearing in keywords (should be 0)
  195. leakage = len(pred_entities & pred_keywords)
  196. # Topic accuracy
  197. topic_acc = 1.0 if pred.get("topic") == gold.get("topic") else 0.0
  198. return {
  199. "ent_p": round(ent_p, 3),
  200. "ent_r": round(ent_r, 3),
  201. "ent_f1": round(ent_f1, 3),
  202. "kw_p": round(kw_p, 3),
  203. "kw_r": round(kw_r, 3),
  204. "kw_f1": round(kw_f1, 3),
  205. "leakage": leakage,
  206. "topic_acc": topic_acc,
  207. }
  208. def aggregate_scores(scores: list[dict[str, float]]) -> dict[str, float]:
  209. """Compute macro averages across samples."""
  210. if not scores:
  211. return {}
  212. keys = scores[0].keys()
  213. return {k: round(sum(s[k] for s in scores) / len(scores), 3) for k in keys}
  214. def print_sample_result(name: str, pred: dict, gold: dict, scores: dict, verbose: bool):
  215. """Print result for a single sample."""
  216. print(f"\n{'='*60}")
  217. print(f"Sample: {name}")
  218. print(f"{'='*60}")
  219. print(f" Topic: pred={pred.get('topic')} | gold={gold.get('topic')} | {'✓' if scores['topic_acc'] else '✗'}")
  220. print(f" Entities: P={scores['ent_p']:.2f} R={scores['ent_r']:.2f} F1={scores['ent_f1']:.2f}")
  221. print(f" Pred: {pred.get('entities', [])}")
  222. print(f" Gold: {gold.get('entities', [])}")
  223. print(f" Keywords: P={scores['kw_p']:.2f} R={scores['kw_r']:.2f} F1={scores['kw_f1']:.2f}")
  224. print(f" Pred: {pred.get('keywords', [])}")
  225. print(f" Gold: {gold.get('keywords', [])}")
  226. if scores["leakage"] > 0:
  227. leaked = set(e.lower() for e in pred.get("entities", [])) & set(k.lower() for k in pred.get("keywords", []))
  228. print(f" ⚠ LEAKAGE: {leaked} (entities in keywords)")
  229. if verbose:
  230. print(f" Sentiment: {pred.get('sentiment')} ({pred.get('sentimentScore')})")
  231. async def run_extraction(cluster: dict[str, Any], provider: str, model: str, prompt_text: str) -> dict[str, Any]:
  232. """Run extraction on a single cluster."""
  233. import json as json_lib
  234. user_prompt = prompt_text.replace("{cluster_json}", json_lib.dumps(cluster, ensure_ascii=False))
  235. system_prompt = "You are a news signal extraction engine. Return STRICT JSON only."
  236. content = await call_llm(provider, model, system_prompt, user_prompt)
  237. return json.loads(content)
  238. async def evaluate_samples(samples: list[dict], provider: str, model: str, prompt_text: str, verbose: bool) -> dict:
  239. """Run evaluation on all samples."""
  240. print(f"\nEvaluating {len(samples)} samples with {provider}/{model}...")
  241. print("-" * 60)
  242. all_scores = []
  243. results = []
  244. for i, sample in enumerate(samples, 1):
  245. name = sample["name"]
  246. cluster = sample["cluster"]
  247. gold = sample["expected"]
  248. print(f"[{i}/{len(samples)}] {name}...", end=" ", flush=True)
  249. try:
  250. pred = await run_extraction(cluster, provider, model, prompt_text)
  251. # Apply same post-processing as production pipeline
  252. entities = _filter_entities(normalize_entities(pred.get("entities", [])), blacklist=NEWS_ENTITY_BLACKLIST)
  253. keywords = _filter_entities(normalize_entities(pred.get("keywords", [])), blacklist=NEWS_ENTITY_BLACKLIST)
  254. # Filter topic labels from keywords
  255. topic_labels = {t.lower() for t in DEFAULT_TOPICS}
  256. keywords = [k for k in keywords if k.lower() not in topic_labels]
  257. # Length cap
  258. keywords = [k for k in keywords if len(k.split()) <= 2]
  259. # De-duplicate entities vs keywords
  260. entity_keys = {e.strip().lower() for e in entities}
  261. keywords = [k for k in keywords if k.strip().lower() not in entity_keys]
  262. pred_processed = {
  263. "topic": pred.get("topic", "other"),
  264. "entities": entities,
  265. "keywords": keywords,
  266. "sentiment": pred.get("sentiment", "neutral"),
  267. "sentimentScore": pred.get("sentimentScore", 0.0),
  268. }
  269. scores = score_extraction(pred_processed, gold)
  270. all_scores.append(scores)
  271. results.append({"name": name, "pred": pred_processed, "gold": gold, "scores": scores})
  272. if verbose:
  273. print_sample_result(name, pred_processed, gold, scores, verbose)
  274. else:
  275. print(f"Ent_F1={scores['ent_f1']:.2f} Kw_F1={scores['kw_f1']:.2f} Leak={scores['leakage']} Topic={'✓' if scores['topic_acc'] else '✗'}")
  276. except Exception as e:
  277. print(f"ERROR: {e}")
  278. results.append({"name": name, "error": str(e)})
  279. # Aggregate
  280. agg = aggregate_scores(all_scores)
  281. print(f"\n{'='*60}")
  282. print("AGGREGATE SCORES (macro average)")
  283. print(f"{'='*60}")
  284. print(f" Entity F1: {agg.get('ent_f1', 0):.3f} (P={agg.get('ent_p', 0):.3f} R={agg.get('ent_r', 0):.3f})")
  285. print(f" Keyword F1: {agg.get('kw_f1', 0):.3f} (P={agg.get('kw_p', 0):.3f} R={agg.get('kw_r', 0):.3f})")
  286. print(f" Leakage (avg): {agg.get('leakage', 0):.3f} (entities in keywords)")
  287. print(f" Topic Acc: {agg.get('topic_acc', 0):.3f}")
  288. # Per-topic breakdown
  289. print("\n Per-topic breakdown:")
  290. by_topic = {}
  291. for r in results:
  292. if "scores" in r:
  293. t = r["gold"]["topic"]
  294. if t not in by_topic:
  295. by_topic[t] = []
  296. by_topic[t].append(r["scores"])
  297. for topic, scores_list in sorted(by_topic.items()):
  298. topic_agg = aggregate_scores(scores_list)
  299. print(f" {topic:12s}: Ent_F1={topic_agg.get('ent_f1',0):.2f} Kw_F1={topic_agg.get('kw_f1',0):.2f} Topic={topic_agg.get('topic_acc',0):.2f} (n={len(scores_list)})")
  300. return {"aggregate": agg, "per_sample": results}
  301. async def collect_samples_from_db(n: int, output_file: str):
  302. """Collect N recent clusters from live DB as new golden samples."""
  303. from news_mcp.storage.sqlite_store import SQLiteClusterStore
  304. from news_mcp.config import DB_PATH
  305. print(f"Collecting {n} samples from {DB_PATH}...")
  306. store = SQLiteClusterStore(DB_PATH)
  307. clusters = store.get_latest_clusters_all_topics(ttl_hours=24*30, limit=n)
  308. samples = []
  309. for c in clusters:
  310. samples.append({
  311. "name": f"live_{c['cluster_id'][:8]}",
  312. "cluster": {"headline": c["headline"], "summary": c["summary"]},
  313. "expected": {"entities": [], "keywords": [], "topic": c.get("topic", "other")},
  314. })
  315. # Save to file for manual annotation
  316. output_path = Path(output_file)
  317. output_path.write_text(json.dumps(samples, indent=2, ensure_ascii=False))
  318. print(f"Saved {len(samples)} samples to {output_path}")
  319. print("Edit the file to fill in expected entities/keywords, then add to GOLDEN_SAMPLES.")
  320. def main():
  321. parser = argparse.ArgumentParser(description="Evaluate extraction prompt")
  322. parser.add_argument("--provider", default=NEWS_EXTRACT_PROVIDER, help="LLM provider")
  323. parser.add_argument("--model", default=NEWS_EXTRACT_MODEL, help="LLM model")
  324. parser.add_argument("--verbose", "-v", action="store_true", help="Verbose output")
  325. parser.add_argument("--collect", type=int, metavar="N", help="Collect N samples from live DB")
  326. parser.add_argument("--output", default="new_samples.json", help="Output file for collected samples")
  327. parser.add_argument("--prompt-file", default="prompts/extract_entities.prompt", help="Prompt file to test")
  328. parser.add_argument("--samples-file", default="data/annotated_samples.json", help="Annotated samples JSON")
  329. args = parser.parse_args()
  330. if args.collect:
  331. asyncio.run(collect_samples_from_db(args.collect, args.output))
  332. return
  333. # Load prompt
  334. prompt_name = Path(args.prompt_file).name
  335. prompt_text = load_prompt(prompt_name)
  336. print(f"Using prompt: {args.prompt_file}")
  337. # Load samples
  338. samples = load_golden_samples(args.samples_file)
  339. print(f"Loaded {len(samples)} samples from {args.samples_file}")
  340. # Run evaluation
  341. result = asyncio.run(evaluate_samples(samples, args.provider, args.model, prompt_text, args.verbose))
  342. # Exit code based on quality threshold
  343. agg = result["aggregate"]
  344. if agg.get("ent_f1", 0) < 0.5 or agg.get("kw_f1", 0) < 0.4 or agg.get("leakage", 99) > 0.5:
  345. print("\n⚠ Quality below threshold!")
  346. sys.exit(1)
  347. else:
  348. print("\n✓ Quality acceptable")
  349. sys.exit(0)
  350. if __name__ == "__main__":
  351. main()