#!/usr/bin/env python3 """ Evaluation harness for entity/keyword extraction prompt. Usage: python scripts/eval_extraction.py # Run against golden samples python scripts/eval_extraction.py --verbose # Show per-sample details python scripts/eval_extraction.py --model llama-3.1-8b-instant # Test specific model python scripts/eval_extraction.py --collect N # Collect N new samples from live DB python scripts/eval_extraction.py --prompt-file prompts/extract_entities_fewshot.prompt # Test alternate prompt """ from __future__ import annotations import argparse import asyncio import json import sys from pathlib import Path from typing import Any # Add project root to path sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) from news_mcp.llm import call_llm, load_prompt from news_mcp.config import ( NEWS_EXTRACT_PROVIDER, NEWS_EXTRACT_MODEL, NEWS_ENTITY_BLACKLIST, DEFAULT_TOPICS, ) from news_mcp.entity_normalize import normalize_entities from news_mcp.enrichment.llm_enrich import _filter_entities # --------------------------------------------------------------------------- # Load golden samples from JSON file # --------------------------------------------------------------------------- def load_golden_samples(filepath: str = "data/annotated_samples.json") -> list[dict]: """Load annotated samples from JSON file.""" path = Path(__file__).resolve().parent.parent / filepath if not path.exists(): # Fallback to built-in samples return GOLDEN_SAMPLES return json.loads(path.read_text(encoding="utf-8")) # Fallback built-in samples (original 10) GOLDEN_SAMPLES = [ { "name": "sec_binance_lawsuit", "cluster": { "headline": "SEC sues Binance over unregistered securities", "summary": "The Securities and Exchange Commission filed a lawsuit against Binance, the world's largest crypto exchange, alleging it operated as an unregistered securities exchange and commingled customer funds.", }, "expected": { "entities": ["SEC", "Binance"], "keywords": ["securities law", "crypto exchange", "enforcement action"], "topic": "regulation", }, }, { "name": "fed_rates_inflation", "cluster": { "headline": "Fed holds rates steady as inflation cools", "summary": "The Federal Reserve kept interest rates unchanged at 5.25-5.50%, citing progress on inflation but signaling caution on future cuts.", }, "expected": { "entities": ["Federal Reserve"], "keywords": ["interest rates", "inflation", "monetary policy"], "topic": "macro", }, }, { "name": "israel_iran_syria_strikes", "cluster": { "headline": "Israel strikes Iranian missile sites in Syria", "summary": "Israeli warplanes targeted Iranian missile depots near Damascus overnight, escalating regional tensions.", }, "expected": { "entities": ["Israel", "Iran", "Syria", "Damascus"], "keywords": ["airstrikes", "missile sites", "regional escalation"], "topic": "other", }, }, { "name": "bitcoin_etf_flows", "cluster": { "headline": "Bitcoin ETFs see record inflows as BTC tops $70k", "summary": "US spot Bitcoin ETFs attracted $2.3 billion in net inflows this week as Bitcoin surged past $70,000, driven by institutional demand.", }, "expected": { "entities": ["Bitcoin", "BTC"], "keywords": ["ETF inflows", "institutional demand", "price surge"], "topic": "crypto", }, }, { "name": "ai_regulation_eu", "cluster": { "headline": "EU AI Act enters force with strict rules for high-risk systems", "summary": "The European Union's landmark AI Act took effect today, imposing strict requirements on high-risk AI systems including transparency, human oversight, and risk management.", }, "expected": { "entities": ["European Union", "EU AI Act"], "keywords": ["AI regulation", "high-risk systems", "compliance requirements"], "topic": "ai", }, }, { "name": "china_economy_stimulus", "cluster": { "headline": "China unveils stimulus package to boost slowing economy", "summary": "Beijing announced a comprehensive stimulus package including infrastructure spending, tax cuts, and monetary easing to counter slowing growth and property sector weakness.", }, "expected": { "entities": ["China", "Beijing"], "keywords": ["stimulus package", "infrastructure spending", "monetary easing", "property sector"], "topic": "macro", }, }, { "name": "oil_prices_opepcuts", "cluster": { "headline": "Oil jumps after OPEC+ extends production cuts", "summary": "Crude oil prices rose 3% after OPEC+ agreed to extend production cuts through year-end, tightening global supply amid demand concerns.", }, "expected": { "entities": ["OPEC+"], "keywords": ["production cuts", "oil prices", "global supply", "demand concerns"], "topic": "macro", }, }, { "name": "nvidia_earnings_ai", "cluster": { "headline": "Nvidia beats earnings on AI chip demand", "summary": "Nvidia reported quarterly revenue of $26 billion, up 262% year-over-year, driven by insatiable demand for its H100 and Blackwell AI chips.", }, "expected": { "entities": ["Nvidia", "H100", "Blackwell"], "keywords": ["AI chips", "earnings beat", "revenue growth", "chip demand"], "topic": "ai", }, }, { "name": "ecb_rate_decision", "cluster": { "headline": "ECB cuts rates as eurozone inflation falls to 2.4%", "summary": "The European Central Bank lowered its deposit rate by 25 basis points to 3.75%, marking its first cut since 2019 as inflation approaches target.", }, "expected": { "entities": ["European Central Bank", "ECB"], "keywords": ["rate cut", "eurozone inflation", "deposit rate", "monetary easing"], "topic": "macro", }, }, { "name": "trump_legal_cases", "cluster": { "headline": "Trump convicted in hush money trial, first former president found guilty", "summary": "A New York jury found Donald Trump guilty on all 34 counts of falsifying business records in the hush money case, making him the first former US president convicted of a crime.", }, "expected": { "entities": ["Donald Trump", "New York"], "keywords": ["criminal conviction", "hush money", "falsifying records", "historic verdict"], "topic": "other", }, }, ] # --------------------------------------------------------------------------- # Scoring functions # --------------------------------------------------------------------------- def normalize_list(items: list[str]) -> set[str]: """Normalize list of strings for comparison: lowercase, strip, dedupe.""" return {str(x).strip().lower() for x in items if x and str(x).strip()} def score_extraction(pred: dict[str, Any], gold: dict[str, Any]) -> dict[str, float]: """Score a single prediction against gold standard.""" pred_entities = normalize_list(pred.get("entities", [])) gold_entities = normalize_list(gold.get("entities", [])) pred_keywords = normalize_list(pred.get("keywords", [])) gold_keywords = normalize_list(gold.get("keywords", [])) # Entity precision/recall/F1 if pred_entities: ent_p = len(pred_entities & gold_entities) / len(pred_entities) else: ent_p = 0.0 if gold_entities: ent_r = len(pred_entities & gold_entities) / len(gold_entities) else: ent_r = 0.0 ent_f1 = 2 * ent_p * ent_r / (ent_p + ent_r) if (ent_p + ent_r) > 0 else 0.0 # Keyword precision/recall/F1 (allow partial overlap since keywords are thematic) if pred_keywords: kw_p = len(pred_keywords & gold_keywords) / len(pred_keywords) else: kw_p = 0.0 if gold_keywords: kw_r = len(pred_keywords & gold_keywords) / len(gold_keywords) else: kw_r = 0.0 kw_f1 = 2 * kw_p * kw_r / (kw_p + kw_r) if (kw_p + kw_r) > 0 else 0.0 # Leakage: entities appearing in keywords (should be 0) leakage = len(pred_entities & pred_keywords) # Topic accuracy topic_acc = 1.0 if pred.get("topic") == gold.get("topic") else 0.0 return { "ent_p": round(ent_p, 3), "ent_r": round(ent_r, 3), "ent_f1": round(ent_f1, 3), "kw_p": round(kw_p, 3), "kw_r": round(kw_r, 3), "kw_f1": round(kw_f1, 3), "leakage": leakage, "topic_acc": topic_acc, } def aggregate_scores(scores: list[dict[str, float]]) -> dict[str, float]: """Compute macro averages across samples.""" if not scores: return {} keys = scores[0].keys() return {k: round(sum(s[k] for s in scores) / len(scores), 3) for k in keys} def print_sample_result(name: str, pred: dict, gold: dict, scores: dict, verbose: bool): """Print result for a single sample.""" print(f"\n{'='*60}") print(f"Sample: {name}") print(f"{'='*60}") print(f" Topic: pred={pred.get('topic')} | gold={gold.get('topic')} | {'✓' if scores['topic_acc'] else '✗'}") print(f" Entities: P={scores['ent_p']:.2f} R={scores['ent_r']:.2f} F1={scores['ent_f1']:.2f}") print(f" Pred: {pred.get('entities', [])}") print(f" Gold: {gold.get('entities', [])}") print(f" Keywords: P={scores['kw_p']:.2f} R={scores['kw_r']:.2f} F1={scores['kw_f1']:.2f}") print(f" Pred: {pred.get('keywords', [])}") print(f" Gold: {gold.get('keywords', [])}") if scores["leakage"] > 0: leaked = set(e.lower() for e in pred.get("entities", [])) & set(k.lower() for k in pred.get("keywords", [])) print(f" ⚠ LEAKAGE: {leaked} (entities in keywords)") if verbose: print(f" Sentiment: {pred.get('sentiment')} ({pred.get('sentimentScore')})") async def run_extraction(cluster: dict[str, Any], provider: str, model: str, prompt_text: str) -> dict[str, Any]: """Run extraction on a single cluster.""" import json as json_lib user_prompt = prompt_text.replace("{cluster_json}", json_lib.dumps(cluster, ensure_ascii=False)) system_prompt = "You are a news signal extraction engine. Return STRICT JSON only." content = await call_llm(provider, model, system_prompt, user_prompt) return json.loads(content) async def evaluate_samples(samples: list[dict], provider: str, model: str, prompt_text: str, verbose: bool) -> dict: """Run evaluation on all samples.""" print(f"\nEvaluating {len(samples)} samples with {provider}/{model}...") print("-" * 60) all_scores = [] results = [] for i, sample in enumerate(samples, 1): name = sample["name"] cluster = sample["cluster"] gold = sample["expected"] print(f"[{i}/{len(samples)}] {name}...", end=" ", flush=True) try: pred = await run_extraction(cluster, provider, model, prompt_text) # Apply same post-processing as production pipeline entities = _filter_entities(normalize_entities(pred.get("entities", [])), blacklist=NEWS_ENTITY_BLACKLIST) keywords = _filter_entities(normalize_entities(pred.get("keywords", [])), blacklist=NEWS_ENTITY_BLACKLIST) # Filter topic labels from keywords topic_labels = {t.lower() for t in DEFAULT_TOPICS} keywords = [k for k in keywords if k.lower() not in topic_labels] # Length cap keywords = [k for k in keywords if len(k.split()) <= 2] # De-duplicate entities vs keywords entity_keys = {e.strip().lower() for e in entities} keywords = [k for k in keywords if k.strip().lower() not in entity_keys] pred_processed = { "topic": pred.get("topic", "other"), "entities": entities, "keywords": keywords, "sentiment": pred.get("sentiment", "neutral"), "sentimentScore": pred.get("sentimentScore", 0.0), } scores = score_extraction(pred_processed, gold) all_scores.append(scores) results.append({"name": name, "pred": pred_processed, "gold": gold, "scores": scores}) if verbose: print_sample_result(name, pred_processed, gold, scores, verbose) else: print(f"Ent_F1={scores['ent_f1']:.2f} Kw_F1={scores['kw_f1']:.2f} Leak={scores['leakage']} Topic={'✓' if scores['topic_acc'] else '✗'}") except Exception as e: print(f"ERROR: {e}") results.append({"name": name, "error": str(e)}) # Aggregate agg = aggregate_scores(all_scores) print(f"\n{'='*60}") print("AGGREGATE SCORES (macro average)") print(f"{'='*60}") print(f" Entity F1: {agg.get('ent_f1', 0):.3f} (P={agg.get('ent_p', 0):.3f} R={agg.get('ent_r', 0):.3f})") print(f" Keyword F1: {agg.get('kw_f1', 0):.3f} (P={agg.get('kw_p', 0):.3f} R={agg.get('kw_r', 0):.3f})") print(f" Leakage (avg): {agg.get('leakage', 0):.3f} (entities in keywords)") print(f" Topic Acc: {agg.get('topic_acc', 0):.3f}") # Per-topic breakdown print("\n Per-topic breakdown:") by_topic = {} for r in results: if "scores" in r: t = r["gold"]["topic"] if t not in by_topic: by_topic[t] = [] by_topic[t].append(r["scores"]) for topic, scores_list in sorted(by_topic.items()): topic_agg = aggregate_scores(scores_list) print(f" {topic:12s}: Ent_F1={topic_agg.get('ent_f1',0):.2f} Kw_F1={topic_agg.get('kw_f1',0):.2f} Topic={topic_agg.get('topic_acc',0):.2f} (n={len(scores_list)})") return {"aggregate": agg, "per_sample": results} async def collect_samples_from_db(n: int, output_file: str): """Collect N recent clusters from live DB as new golden samples.""" from news_mcp.storage.sqlite_store import SQLiteClusterStore from news_mcp.config import DB_PATH print(f"Collecting {n} samples from {DB_PATH}...") store = SQLiteClusterStore(DB_PATH) clusters = store.get_latest_clusters_all_topics(ttl_hours=24*30, limit=n) samples = [] for c in clusters: samples.append({ "name": f"live_{c['cluster_id'][:8]}", "cluster": {"headline": c["headline"], "summary": c["summary"]}, "expected": {"entities": [], "keywords": [], "topic": c.get("topic", "other")}, }) # Save to file for manual annotation output_path = Path(output_file) output_path.write_text(json.dumps(samples, indent=2, ensure_ascii=False)) print(f"Saved {len(samples)} samples to {output_path}") print("Edit the file to fill in expected entities/keywords, then add to GOLDEN_SAMPLES.") def main(): parser = argparse.ArgumentParser(description="Evaluate extraction prompt") parser.add_argument("--provider", default=NEWS_EXTRACT_PROVIDER, help="LLM provider") parser.add_argument("--model", default=NEWS_EXTRACT_MODEL, help="LLM model") parser.add_argument("--verbose", "-v", action="store_true", help="Verbose output") parser.add_argument("--collect", type=int, metavar="N", help="Collect N samples from live DB") parser.add_argument("--output", default="new_samples.json", help="Output file for collected samples") parser.add_argument("--prompt-file", default="prompts/extract_entities.prompt", help="Prompt file to test") parser.add_argument("--samples-file", default="data/annotated_samples.json", help="Annotated samples JSON") args = parser.parse_args() if args.collect: asyncio.run(collect_samples_from_db(args.collect, args.output)) return # Load prompt prompt_name = Path(args.prompt_file).name prompt_text = load_prompt(prompt_name) print(f"Using prompt: {args.prompt_file}") # Load samples samples = load_golden_samples(args.samples_file) print(f"Loaded {len(samples)} samples from {args.samples_file}") # Run evaluation result = asyncio.run(evaluate_samples(samples, args.provider, args.model, prompt_text, args.verbose)) # Exit code based on quality threshold agg = result["aggregate"] if agg.get("ent_f1", 0) < 0.5 or agg.get("kw_f1", 0) < 0.4 or agg.get("leakage", 99) > 0.5: print("\n⚠ Quality below threshold!") sys.exit(1) else: print("\n✓ Quality acceptable") sys.exit(0) if __name__ == "__main__": main()