| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407 |
- #!/usr/bin/env python3
- """
- Evaluation harness for entity/keyword extraction prompt.
- Usage:
- python scripts/eval_extraction.py # Run against golden samples
- python scripts/eval_extraction.py --verbose # Show per-sample details
- python scripts/eval_extraction.py --model llama-3.1-8b-instant # Test specific model
- python scripts/eval_extraction.py --collect N # Collect N new samples from live DB
- python scripts/eval_extraction.py --prompt-file prompts/extract_entities_fewshot.prompt # Test alternate prompt
- """
- from __future__ import annotations
- import argparse
- import asyncio
- import json
- import sys
- from pathlib import Path
- from typing import Any
- # Add project root to path
- sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
- from news_mcp.llm import call_llm, load_prompt
- from news_mcp.config import (
- NEWS_EXTRACT_PROVIDER,
- NEWS_EXTRACT_MODEL,
- NEWS_ENTITY_BLACKLIST,
- DEFAULT_TOPICS,
- )
- from news_mcp.entity_normalize import normalize_entities
- from news_mcp.enrichment.llm_enrich import _filter_entities
- # ---------------------------------------------------------------------------
- # Load golden samples from JSON file
- # ---------------------------------------------------------------------------
- def load_golden_samples(filepath: str = "data/annotated_samples.json") -> list[dict]:
- """Load annotated samples from JSON file."""
- path = Path(__file__).resolve().parent.parent / filepath
- if not path.exists():
- # Fallback to built-in samples
- return GOLDEN_SAMPLES
- return json.loads(path.read_text(encoding="utf-8"))
- # Fallback built-in samples (original 10)
- GOLDEN_SAMPLES = [
- {
- "name": "sec_binance_lawsuit",
- "cluster": {
- "headline": "SEC sues Binance over unregistered securities",
- "summary": "The Securities and Exchange Commission filed a lawsuit against Binance, the world's largest crypto exchange, alleging it operated as an unregistered securities exchange and commingled customer funds.",
- },
- "expected": {
- "entities": ["SEC", "Binance"],
- "keywords": ["securities law", "crypto exchange", "enforcement action"],
- "topic": "regulation",
- },
- },
- {
- "name": "fed_rates_inflation",
- "cluster": {
- "headline": "Fed holds rates steady as inflation cools",
- "summary": "The Federal Reserve kept interest rates unchanged at 5.25-5.50%, citing progress on inflation but signaling caution on future cuts.",
- },
- "expected": {
- "entities": ["Federal Reserve"],
- "keywords": ["interest rates", "inflation", "monetary policy"],
- "topic": "macro",
- },
- },
- {
- "name": "israel_iran_syria_strikes",
- "cluster": {
- "headline": "Israel strikes Iranian missile sites in Syria",
- "summary": "Israeli warplanes targeted Iranian missile depots near Damascus overnight, escalating regional tensions.",
- },
- "expected": {
- "entities": ["Israel", "Iran", "Syria", "Damascus"],
- "keywords": ["airstrikes", "missile sites", "regional escalation"],
- "topic": "other",
- },
- },
- {
- "name": "bitcoin_etf_flows",
- "cluster": {
- "headline": "Bitcoin ETFs see record inflows as BTC tops $70k",
- "summary": "US spot Bitcoin ETFs attracted $2.3 billion in net inflows this week as Bitcoin surged past $70,000, driven by institutional demand.",
- },
- "expected": {
- "entities": ["Bitcoin", "BTC"],
- "keywords": ["ETF inflows", "institutional demand", "price surge"],
- "topic": "crypto",
- },
- },
- {
- "name": "ai_regulation_eu",
- "cluster": {
- "headline": "EU AI Act enters force with strict rules for high-risk systems",
- "summary": "The European Union's landmark AI Act took effect today, imposing strict requirements on high-risk AI systems including transparency, human oversight, and risk management.",
- },
- "expected": {
- "entities": ["European Union", "EU AI Act"],
- "keywords": ["AI regulation", "high-risk systems", "compliance requirements"],
- "topic": "ai",
- },
- },
- {
- "name": "china_economy_stimulus",
- "cluster": {
- "headline": "China unveils stimulus package to boost slowing economy",
- "summary": "Beijing announced a comprehensive stimulus package including infrastructure spending, tax cuts, and monetary easing to counter slowing growth and property sector weakness.",
- },
- "expected": {
- "entities": ["China", "Beijing"],
- "keywords": ["stimulus package", "infrastructure spending", "monetary easing", "property sector"],
- "topic": "macro",
- },
- },
- {
- "name": "oil_prices_opepcuts",
- "cluster": {
- "headline": "Oil jumps after OPEC+ extends production cuts",
- "summary": "Crude oil prices rose 3% after OPEC+ agreed to extend production cuts through year-end, tightening global supply amid demand concerns.",
- },
- "expected": {
- "entities": ["OPEC+"],
- "keywords": ["production cuts", "oil prices", "global supply", "demand concerns"],
- "topic": "macro",
- },
- },
- {
- "name": "nvidia_earnings_ai",
- "cluster": {
- "headline": "Nvidia beats earnings on AI chip demand",
- "summary": "Nvidia reported quarterly revenue of $26 billion, up 262% year-over-year, driven by insatiable demand for its H100 and Blackwell AI chips.",
- },
- "expected": {
- "entities": ["Nvidia", "H100", "Blackwell"],
- "keywords": ["AI chips", "earnings beat", "revenue growth", "chip demand"],
- "topic": "ai",
- },
- },
- {
- "name": "ecb_rate_decision",
- "cluster": {
- "headline": "ECB cuts rates as eurozone inflation falls to 2.4%",
- "summary": "The European Central Bank lowered its deposit rate by 25 basis points to 3.75%, marking its first cut since 2019 as inflation approaches target.",
- },
- "expected": {
- "entities": ["European Central Bank", "ECB"],
- "keywords": ["rate cut", "eurozone inflation", "deposit rate", "monetary easing"],
- "topic": "macro",
- },
- },
- {
- "name": "trump_legal_cases",
- "cluster": {
- "headline": "Trump convicted in hush money trial, first former president found guilty",
- "summary": "A New York jury found Donald Trump guilty on all 34 counts of falsifying business records in the hush money case, making him the first former US president convicted of a crime.",
- },
- "expected": {
- "entities": ["Donald Trump", "New York"],
- "keywords": ["criminal conviction", "hush money", "falsifying records", "historic verdict"],
- "topic": "other",
- },
- },
- ]
- # ---------------------------------------------------------------------------
- # Scoring functions
- # ---------------------------------------------------------------------------
- def normalize_list(items: list[str]) -> set[str]:
- """Normalize list of strings for comparison: lowercase, strip, dedupe."""
- return {str(x).strip().lower() for x in items if x and str(x).strip()}
- def score_extraction(pred: dict[str, Any], gold: dict[str, Any]) -> dict[str, float]:
- """Score a single prediction against gold standard."""
- pred_entities = normalize_list(pred.get("entities", []))
- gold_entities = normalize_list(gold.get("entities", []))
- pred_keywords = normalize_list(pred.get("keywords", []))
- gold_keywords = normalize_list(gold.get("keywords", []))
- # Entity precision/recall/F1
- if pred_entities:
- ent_p = len(pred_entities & gold_entities) / len(pred_entities)
- else:
- ent_p = 0.0
- if gold_entities:
- ent_r = len(pred_entities & gold_entities) / len(gold_entities)
- else:
- ent_r = 0.0
- ent_f1 = 2 * ent_p * ent_r / (ent_p + ent_r) if (ent_p + ent_r) > 0 else 0.0
- # Keyword precision/recall/F1 (allow partial overlap since keywords are thematic)
- if pred_keywords:
- kw_p = len(pred_keywords & gold_keywords) / len(pred_keywords)
- else:
- kw_p = 0.0
- if gold_keywords:
- kw_r = len(pred_keywords & gold_keywords) / len(gold_keywords)
- else:
- kw_r = 0.0
- kw_f1 = 2 * kw_p * kw_r / (kw_p + kw_r) if (kw_p + kw_r) > 0 else 0.0
- # Leakage: entities appearing in keywords (should be 0)
- leakage = len(pred_entities & pred_keywords)
- # Topic accuracy
- topic_acc = 1.0 if pred.get("topic") == gold.get("topic") else 0.0
- return {
- "ent_p": round(ent_p, 3),
- "ent_r": round(ent_r, 3),
- "ent_f1": round(ent_f1, 3),
- "kw_p": round(kw_p, 3),
- "kw_r": round(kw_r, 3),
- "kw_f1": round(kw_f1, 3),
- "leakage": leakage,
- "topic_acc": topic_acc,
- }
- def aggregate_scores(scores: list[dict[str, float]]) -> dict[str, float]:
- """Compute macro averages across samples."""
- if not scores:
- return {}
- keys = scores[0].keys()
- return {k: round(sum(s[k] for s in scores) / len(scores), 3) for k in keys}
- def print_sample_result(name: str, pred: dict, gold: dict, scores: dict, verbose: bool):
- """Print result for a single sample."""
- print(f"\n{'='*60}")
- print(f"Sample: {name}")
- print(f"{'='*60}")
- print(f" Topic: pred={pred.get('topic')} | gold={gold.get('topic')} | {'✓' if scores['topic_acc'] else '✗'}")
- print(f" Entities: P={scores['ent_p']:.2f} R={scores['ent_r']:.2f} F1={scores['ent_f1']:.2f}")
- print(f" Pred: {pred.get('entities', [])}")
- print(f" Gold: {gold.get('entities', [])}")
- print(f" Keywords: P={scores['kw_p']:.2f} R={scores['kw_r']:.2f} F1={scores['kw_f1']:.2f}")
- print(f" Pred: {pred.get('keywords', [])}")
- print(f" Gold: {gold.get('keywords', [])}")
- if scores["leakage"] > 0:
- leaked = set(e.lower() for e in pred.get("entities", [])) & set(k.lower() for k in pred.get("keywords", []))
- print(f" ⚠ LEAKAGE: {leaked} (entities in keywords)")
- if verbose:
- print(f" Sentiment: {pred.get('sentiment')} ({pred.get('sentimentScore')})")
- async def run_extraction(cluster: dict[str, Any], provider: str, model: str, prompt_text: str) -> dict[str, Any]:
- """Run extraction on a single cluster."""
- import json as json_lib
- user_prompt = prompt_text.replace("{cluster_json}", json_lib.dumps(cluster, ensure_ascii=False))
- system_prompt = "You are a news signal extraction engine. Return STRICT JSON only."
- content = await call_llm(provider, model, system_prompt, user_prompt)
- return json.loads(content)
- async def evaluate_samples(samples: list[dict], provider: str, model: str, prompt_text: str, verbose: bool) -> dict:
- """Run evaluation on all samples."""
- print(f"\nEvaluating {len(samples)} samples with {provider}/{model}...")
- print("-" * 60)
- all_scores = []
- results = []
- for i, sample in enumerate(samples, 1):
- name = sample["name"]
- cluster = sample["cluster"]
- gold = sample["expected"]
- print(f"[{i}/{len(samples)}] {name}...", end=" ", flush=True)
- try:
- pred = await run_extraction(cluster, provider, model, prompt_text)
- # Apply same post-processing as production pipeline
- entities = _filter_entities(normalize_entities(pred.get("entities", [])), blacklist=NEWS_ENTITY_BLACKLIST)
- keywords = _filter_entities(normalize_entities(pred.get("keywords", [])), blacklist=NEWS_ENTITY_BLACKLIST)
- # Filter topic labels from keywords
- topic_labels = {t.lower() for t in DEFAULT_TOPICS}
- keywords = [k for k in keywords if k.lower() not in topic_labels]
- # Length cap
- keywords = [k for k in keywords if len(k.split()) <= 2]
- # De-duplicate entities vs keywords
- entity_keys = {e.strip().lower() for e in entities}
- keywords = [k for k in keywords if k.strip().lower() not in entity_keys]
- pred_processed = {
- "topic": pred.get("topic", "other"),
- "entities": entities,
- "keywords": keywords,
- "sentiment": pred.get("sentiment", "neutral"),
- "sentimentScore": pred.get("sentimentScore", 0.0),
- }
- scores = score_extraction(pred_processed, gold)
- all_scores.append(scores)
- results.append({"name": name, "pred": pred_processed, "gold": gold, "scores": scores})
- if verbose:
- print_sample_result(name, pred_processed, gold, scores, verbose)
- else:
- print(f"Ent_F1={scores['ent_f1']:.2f} Kw_F1={scores['kw_f1']:.2f} Leak={scores['leakage']} Topic={'✓' if scores['topic_acc'] else '✗'}")
- except Exception as e:
- print(f"ERROR: {e}")
- results.append({"name": name, "error": str(e)})
- # Aggregate
- agg = aggregate_scores(all_scores)
- print(f"\n{'='*60}")
- print("AGGREGATE SCORES (macro average)")
- print(f"{'='*60}")
- print(f" Entity F1: {agg.get('ent_f1', 0):.3f} (P={agg.get('ent_p', 0):.3f} R={agg.get('ent_r', 0):.3f})")
- print(f" Keyword F1: {agg.get('kw_f1', 0):.3f} (P={agg.get('kw_p', 0):.3f} R={agg.get('kw_r', 0):.3f})")
- print(f" Leakage (avg): {agg.get('leakage', 0):.3f} (entities in keywords)")
- print(f" Topic Acc: {agg.get('topic_acc', 0):.3f}")
- # Per-topic breakdown
- print("\n Per-topic breakdown:")
- by_topic = {}
- for r in results:
- if "scores" in r:
- t = r["gold"]["topic"]
- if t not in by_topic:
- by_topic[t] = []
- by_topic[t].append(r["scores"])
- for topic, scores_list in sorted(by_topic.items()):
- topic_agg = aggregate_scores(scores_list)
- print(f" {topic:12s}: Ent_F1={topic_agg.get('ent_f1',0):.2f} Kw_F1={topic_agg.get('kw_f1',0):.2f} Topic={topic_agg.get('topic_acc',0):.2f} (n={len(scores_list)})")
- return {"aggregate": agg, "per_sample": results}
- async def collect_samples_from_db(n: int, output_file: str):
- """Collect N recent clusters from live DB as new golden samples."""
- from news_mcp.storage.sqlite_store import SQLiteClusterStore
- from news_mcp.config import DB_PATH
- print(f"Collecting {n} samples from {DB_PATH}...")
- store = SQLiteClusterStore(DB_PATH)
- clusters = store.get_latest_clusters_all_topics(ttl_hours=24*30, limit=n)
- samples = []
- for c in clusters:
- samples.append({
- "name": f"live_{c['cluster_id'][:8]}",
- "cluster": {"headline": c["headline"], "summary": c["summary"]},
- "expected": {"entities": [], "keywords": [], "topic": c.get("topic", "other")},
- })
- # Save to file for manual annotation
- output_path = Path(output_file)
- output_path.write_text(json.dumps(samples, indent=2, ensure_ascii=False))
- print(f"Saved {len(samples)} samples to {output_path}")
- print("Edit the file to fill in expected entities/keywords, then add to GOLDEN_SAMPLES.")
- def main():
- parser = argparse.ArgumentParser(description="Evaluate extraction prompt")
- parser.add_argument("--provider", default=NEWS_EXTRACT_PROVIDER, help="LLM provider")
- parser.add_argument("--model", default=NEWS_EXTRACT_MODEL, help="LLM model")
- parser.add_argument("--verbose", "-v", action="store_true", help="Verbose output")
- parser.add_argument("--collect", type=int, metavar="N", help="Collect N samples from live DB")
- parser.add_argument("--output", default="new_samples.json", help="Output file for collected samples")
- parser.add_argument("--prompt-file", default="prompts/extract_entities.prompt", help="Prompt file to test")
- parser.add_argument("--samples-file", default="data/annotated_samples.json", help="Annotated samples JSON")
- args = parser.parse_args()
- if args.collect:
- asyncio.run(collect_samples_from_db(args.collect, args.output))
- return
- # Load prompt
- prompt_name = Path(args.prompt_file).name
- prompt_text = load_prompt(prompt_name)
- print(f"Using prompt: {args.prompt_file}")
- # Load samples
- samples = load_golden_samples(args.samples_file)
- print(f"Loaded {len(samples)} samples from {args.samples_file}")
- # Run evaluation
- result = asyncio.run(evaluate_samples(samples, args.provider, args.model, prompt_text, args.verbose))
- # Exit code based on quality threshold
- agg = result["aggregate"]
- if agg.get("ent_f1", 0) < 0.5 or agg.get("kw_f1", 0) < 0.4 or agg.get("leakage", 99) > 0.5:
- print("\n⚠ Quality below threshold!")
- sys.exit(1)
- else:
- print("\n✓ Quality acceptable")
- sys.exit(0)
- if __name__ == "__main__":
- main()
|