|
|
@@ -0,0 +1,378 @@
|
|
|
+#!/usr/bin/env python3
|
|
|
+"""
|
|
|
+Evaluation harness for entity/keyword extraction prompt.
|
|
|
+
|
|
|
+Usage:
|
|
|
+ python scripts/eval_extraction.py # Run against golden samples
|
|
|
+ python scripts/eval_extraction.py --verbose # Show per-sample details
|
|
|
+ python scripts/eval_extraction.py --model llama-3.1-8b-instant # Test specific model
|
|
|
+ python scripts/eval_extraction.py --collect N # Collect N new samples from live DB
|
|
|
+"""
|
|
|
+
|
|
|
+from __future__ import annotations
|
|
|
+import argparse
|
|
|
+import asyncio
|
|
|
+import json
|
|
|
+import sys
|
|
|
+from pathlib import Path
|
|
|
+from typing import Any
|
|
|
+
|
|
|
+# Add project root to path
|
|
|
+sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
|
|
+
|
|
|
+from news_mcp.llm import build_extraction_prompt, call_llm, load_prompt
|
|
|
+from news_mcp.config import (
|
|
|
+ NEWS_EXTRACT_PROVIDER,
|
|
|
+ NEWS_EXTRACT_MODEL,
|
|
|
+ NEWS_ENTITY_BLACKLIST,
|
|
|
+)
|
|
|
+from news_mcp.entity_normalize import normalize_entities
|
|
|
+from news_mcp.enrichment.llm_enrich import _filter_entities
|
|
|
+
|
|
|
+
|
|
|
+# ---------------------------------------------------------------------------
|
|
|
+# Golden samples (curated from real clusters)
|
|
|
+# ---------------------------------------------------------------------------
|
|
|
+GOLDEN_SAMPLES = [
|
|
|
+ {
|
|
|
+ "name": "sec_binance_lawsuit",
|
|
|
+ "cluster": {
|
|
|
+ "headline": "SEC sues Binance over unregistered securities",
|
|
|
+ "summary": "The Securities and Exchange Commission filed a lawsuit against Binance, the world's largest crypto exchange, alleging it operated as an unregistered securities exchange and commingled customer funds.",
|
|
|
+ },
|
|
|
+ "expected": {
|
|
|
+ "entities": ["SEC", "Binance"],
|
|
|
+ "keywords": ["securities law", "crypto exchange", "enforcement action"],
|
|
|
+ "topic": "regulation",
|
|
|
+ },
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "name": "fed_rates_inflation",
|
|
|
+ "cluster": {
|
|
|
+ "headline": "Fed holds rates steady as inflation cools",
|
|
|
+ "summary": "The Federal Reserve kept interest rates unchanged at 5.25-5.50%, citing progress on inflation but signaling caution on future cuts.",
|
|
|
+ },
|
|
|
+ "expected": {
|
|
|
+ "entities": ["Federal Reserve"],
|
|
|
+ "keywords": ["interest rates", "inflation", "monetary policy"],
|
|
|
+ "topic": "macro",
|
|
|
+ },
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "name": "israel_iran_syria_strikes",
|
|
|
+ "cluster": {
|
|
|
+ "headline": "Israel strikes Iranian missile sites in Syria",
|
|
|
+ "summary": "Israeli warplanes targeted Iranian missile depots near Damascus overnight, escalating regional tensions.",
|
|
|
+ },
|
|
|
+ "expected": {
|
|
|
+ "entities": ["Israel", "Iran", "Syria", "Damascus"],
|
|
|
+ "keywords": ["airstrikes", "missile sites", "regional escalation"],
|
|
|
+ "topic": "other",
|
|
|
+ },
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "name": "bitcoin_etf_flows",
|
|
|
+ "cluster": {
|
|
|
+ "headline": "Bitcoin ETFs see record inflows as BTC tops $70k",
|
|
|
+ "summary": "US spot Bitcoin ETFs attracted $2.3 billion in net inflows this week as Bitcoin surged past $70,000, driven by institutional demand.",
|
|
|
+ },
|
|
|
+ "expected": {
|
|
|
+ "entities": ["Bitcoin", "BTC"],
|
|
|
+ "keywords": ["ETF inflows", "institutional demand", "price surge"],
|
|
|
+ "topic": "crypto",
|
|
|
+ },
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "name": "ai_regulation_eu",
|
|
|
+ "cluster": {
|
|
|
+ "headline": "EU AI Act enters force with strict rules for high-risk systems",
|
|
|
+ "summary": "The European Union's landmark AI Act took effect today, imposing strict requirements on high-risk AI systems including transparency, human oversight, and risk management.",
|
|
|
+ },
|
|
|
+ "expected": {
|
|
|
+ "entities": ["European Union", "EU AI Act"],
|
|
|
+ "keywords": ["AI regulation", "high-risk systems", "compliance requirements"],
|
|
|
+ "topic": "ai",
|
|
|
+ },
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "name": "china_economy_stimulus",
|
|
|
+ "cluster": {
|
|
|
+ "headline": "China unveils stimulus package to boost slowing economy",
|
|
|
+ "summary": "Beijing announced a comprehensive stimulus package including infrastructure spending, tax cuts, and monetary easing to counter slowing growth and property sector weakness.",
|
|
|
+ },
|
|
|
+ "expected": {
|
|
|
+ "entities": ["China", "Beijing"],
|
|
|
+ "keywords": ["stimulus package", "infrastructure spending", "monetary easing", "property sector"],
|
|
|
+ "topic": "macro",
|
|
|
+ },
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "name": "oil_prices_opepcuts",
|
|
|
+ "cluster": {
|
|
|
+ "headline": "Oil jumps after OPEC+ extends production cuts",
|
|
|
+ "summary": "Crude oil prices rose 3% after OPEC+ agreed to extend production cuts through year-end, tightening global supply amid demand concerns.",
|
|
|
+ },
|
|
|
+ "expected": {
|
|
|
+ "entities": ["OPEC+"],
|
|
|
+ "keywords": ["production cuts", "oil prices", "global supply", "demand concerns"],
|
|
|
+ "topic": "macro",
|
|
|
+ },
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "name": "nvidia_earnings_ai",
|
|
|
+ "cluster": {
|
|
|
+ "headline": "Nvidia beats earnings on AI chip demand",
|
|
|
+ "summary": "Nvidia reported quarterly revenue of $26 billion, up 262% year-over-year, driven by insatiable demand for its H100 and Blackwell AI chips.",
|
|
|
+ },
|
|
|
+ "expected": {
|
|
|
+ "entities": ["Nvidia", "H100", "Blackwell"],
|
|
|
+ "keywords": ["AI chips", "earnings beat", "revenue growth", "chip demand"],
|
|
|
+ "topic": "ai",
|
|
|
+ },
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "name": "ecb_rate_decision",
|
|
|
+ "cluster": {
|
|
|
+ "headline": "ECB cuts rates as eurozone inflation falls to 2.4%",
|
|
|
+ "summary": "The European Central Bank lowered its deposit rate by 25 basis points to 3.75%, marking its first cut since 2019 as inflation approaches target.",
|
|
|
+ },
|
|
|
+ "expected": {
|
|
|
+ "entities": ["European Central Bank", "ECB"],
|
|
|
+ "keywords": ["rate cut", "eurozone inflation", "deposit rate", "monetary easing"],
|
|
|
+ "topic": "macro",
|
|
|
+ },
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "name": "trump_legal_cases",
|
|
|
+ "cluster": {
|
|
|
+ "headline": "Trump convicted in hush money trial, first former president found guilty",
|
|
|
+ "summary": "A New York jury found Donald Trump guilty on all 34 counts of falsifying business records in the hush money case, making him the first former US president convicted of a crime.",
|
|
|
+ },
|
|
|
+ "expected": {
|
|
|
+ "entities": ["Donald Trump", "New York"],
|
|
|
+ "keywords": ["criminal conviction", "hush money", "falsifying records", "historic verdict"],
|
|
|
+ "topic": "other",
|
|
|
+ },
|
|
|
+ },
|
|
|
+]
|
|
|
+
|
|
|
+
|
|
|
+# ---------------------------------------------------------------------------
|
|
|
+# Scoring functions
|
|
|
+# ---------------------------------------------------------------------------
|
|
|
+def normalize_list(items: list[str]) -> set[str]:
|
|
|
+ """Normalize list of strings for comparison: lowercase, strip, dedupe."""
|
|
|
+ return {str(x).strip().lower() for x in items if x and str(x).strip()}
|
|
|
+
|
|
|
+
|
|
|
+def score_extraction(pred: dict[str, Any], gold: dict[str, Any]) -> dict[str, float]:
|
|
|
+ """Score a single prediction against gold standard."""
|
|
|
+ pred_entities = normalize_list(pred.get("entities", []))
|
|
|
+ gold_entities = normalize_list(gold.get("entities", []))
|
|
|
+ pred_keywords = normalize_list(pred.get("keywords", []))
|
|
|
+ gold_keywords = normalize_list(gold.get("keywords", []))
|
|
|
+
|
|
|
+ # Entity precision/recall/F1
|
|
|
+ if pred_entities:
|
|
|
+ ent_p = len(pred_entities & gold_entities) / len(pred_entities)
|
|
|
+ else:
|
|
|
+ ent_p = 0.0
|
|
|
+ if gold_entities:
|
|
|
+ ent_r = len(pred_entities & gold_entities) / len(gold_entities)
|
|
|
+ else:
|
|
|
+ ent_r = 0.0
|
|
|
+ ent_f1 = 2 * ent_p * ent_r / (ent_p + ent_r) if (ent_p + ent_r) > 0 else 0.0
|
|
|
+
|
|
|
+ # Keyword precision/recall/F1 (allow partial overlap since keywords are thematic)
|
|
|
+ if pred_keywords:
|
|
|
+ kw_p = len(pred_keywords & gold_keywords) / len(pred_keywords)
|
|
|
+ else:
|
|
|
+ kw_p = 0.0
|
|
|
+ if gold_keywords:
|
|
|
+ kw_r = len(pred_keywords & gold_keywords) / len(gold_keywords)
|
|
|
+ else:
|
|
|
+ kw_r = 0.0
|
|
|
+ kw_f1 = 2 * kw_p * kw_r / (kw_p + kw_r) if (kw_p + kw_r) > 0 else 0.0
|
|
|
+
|
|
|
+ # Leakage: entities appearing in keywords (should be 0)
|
|
|
+ leakage = len(pred_entities & pred_keywords)
|
|
|
+
|
|
|
+ # Topic accuracy
|
|
|
+ topic_acc = 1.0 if pred.get("topic") == gold.get("topic") else 0.0
|
|
|
+
|
|
|
+ return {
|
|
|
+ "ent_p": round(ent_p, 3),
|
|
|
+ "ent_r": round(ent_r, 3),
|
|
|
+ "ent_f1": round(ent_f1, 3),
|
|
|
+ "kw_p": round(kw_p, 3),
|
|
|
+ "kw_r": round(kw_r, 3),
|
|
|
+ "kw_f1": round(kw_f1, 3),
|
|
|
+ "leakage": leakage,
|
|
|
+ "topic_acc": topic_acc,
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+def aggregate_scores(scores: list[dict[str, float]]) -> dict[str, float]:
|
|
|
+ """Compute macro averages across samples."""
|
|
|
+ if not scores:
|
|
|
+ return {}
|
|
|
+ keys = scores[0].keys()
|
|
|
+ return {k: round(sum(s[k] for s in scores) / len(scores), 3) for k in keys}
|
|
|
+
|
|
|
+
|
|
|
+def print_sample_result(name: str, pred: dict, gold: dict, scores: dict, verbose: bool):
|
|
|
+ """Print result for a single sample."""
|
|
|
+ print(f"\n{'='*60}")
|
|
|
+ print(f"Sample: {name}")
|
|
|
+ print(f"{'='*60}")
|
|
|
+ print(f" Topic: pred={pred.get('topic')} | gold={gold.get('topic')} | {'✓' if scores['topic_acc'] else '✗'}")
|
|
|
+ print(f" Entities: P={scores['ent_p']:.2f} R={scores['ent_r']:.2f} F1={scores['ent_f1']:.2f}")
|
|
|
+ print(f" Pred: {pred.get('entities', [])}")
|
|
|
+ print(f" Gold: {gold.get('entities', [])}")
|
|
|
+ print(f" Keywords: P={scores['kw_p']:.2f} R={scores['kw_r']:.2f} F1={scores['kw_f1']:.2f}")
|
|
|
+ print(f" Pred: {pred.get('keywords', [])}")
|
|
|
+ print(f" Gold: {gold.get('keywords', [])}")
|
|
|
+ if scores["leakage"] > 0:
|
|
|
+ leaked = set(e.lower() for e in pred.get("entities", [])) & set(k.lower() for k in pred.get("keywords", []))
|
|
|
+ print(f" ⚠ LEAKAGE: {leaked} (entities in keywords)")
|
|
|
+ if verbose:
|
|
|
+ print(f" Sentiment: {pred.get('sentiment')} ({pred.get('sentimentScore')})")
|
|
|
+
|
|
|
+
|
|
|
+async def run_extraction(cluster: dict[str, Any], provider: str, model: str) -> dict[str, Any]:
|
|
|
+ """Run extraction on a single cluster."""
|
|
|
+ prompt = load_prompt("extract_entities.prompt")
|
|
|
+ # Build the full user prompt
|
|
|
+ import json as json_lib
|
|
|
+ user_prompt = prompt.replace("{cluster_json}", json_lib.dumps(cluster, ensure_ascii=False))
|
|
|
+
|
|
|
+ from news_mcp.config import GROQ_API_KEY, OPENAI_API_KEY, OPENROUTER_API_KEY
|
|
|
+
|
|
|
+ system_prompt = "You are a news signal extraction engine. Return STRICT JSON only."
|
|
|
+
|
|
|
+ content = await call_llm(provider, model, system_prompt, user_prompt)
|
|
|
+ return json.loads(content)
|
|
|
+
|
|
|
+
|
|
|
+async def evaluate_samples(samples: list[dict], provider: str, model: str, verbose: bool) -> dict:
|
|
|
+ """Run evaluation on all samples."""
|
|
|
+ print(f"\nEvaluating {len(samples)} samples with {provider}/{model}...")
|
|
|
+ print("-" * 60)
|
|
|
+
|
|
|
+ all_scores = []
|
|
|
+ results = []
|
|
|
+
|
|
|
+ for i, sample in enumerate(samples, 1):
|
|
|
+ name = sample["name"]
|
|
|
+ cluster = sample["cluster"]
|
|
|
+ gold = sample["expected"]
|
|
|
+
|
|
|
+ print(f"[{i}/{len(samples)}] {name}...", end=" ", flush=True)
|
|
|
+
|
|
|
+ try:
|
|
|
+ pred = await run_extraction(cluster, provider, model)
|
|
|
+
|
|
|
+ # Apply same post-processing as production pipeline
|
|
|
+ from news_mcp.enrichment.llm_enrich import _filter_entities, normalize_entities
|
|
|
+ from news_mcp.config import DEFAULT_TOPICS, NEWS_ENTITY_BLACKLIST
|
|
|
+
|
|
|
+ entities = _filter_entities(normalize_entities(pred.get("entities", [])), blacklist=NEWS_ENTITY_BLACKLIST)
|
|
|
+ keywords = _filter_entities(normalize_entities(pred.get("keywords", [])), blacklist=NEWS_ENTITY_BLACKLIST)
|
|
|
+
|
|
|
+ # Filter topic labels from keywords
|
|
|
+ topic_labels = {t.lower() for t in DEFAULT_TOPICS}
|
|
|
+ keywords = [k for k in keywords if k.lower() not in topic_labels]
|
|
|
+
|
|
|
+ # Length cap
|
|
|
+ keywords = [k for k in keywords if len(k.split()) <= 2]
|
|
|
+
|
|
|
+ # De-duplicate entities vs keywords
|
|
|
+ entity_keys = {e.strip().lower() for e in entities}
|
|
|
+ keywords = [k for k in keywords if k.strip().lower() not in entity_keys]
|
|
|
+
|
|
|
+ pred_processed = {
|
|
|
+ "topic": pred.get("topic", "other"),
|
|
|
+ "entities": entities,
|
|
|
+ "keywords": keywords,
|
|
|
+ "sentiment": pred.get("sentiment", "neutral"),
|
|
|
+ "sentimentScore": pred.get("sentimentScore", 0.0),
|
|
|
+ }
|
|
|
+
|
|
|
+ scores = score_extraction(pred_processed, gold)
|
|
|
+ all_scores.append(scores)
|
|
|
+ results.append({"name": name, "pred": pred_processed, "gold": gold, "scores": scores})
|
|
|
+
|
|
|
+ if verbose:
|
|
|
+ print_sample_result(name, pred_processed, gold, scores, verbose)
|
|
|
+ else:
|
|
|
+ print(f"Ent_F1={scores['ent_f1']:.2f} Kw_F1={scores['kw_f1']:.2f} Leak={scores['leakage']} Topic={'✓' if scores['topic_acc'] else '✗'}")
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ print(f"ERROR: {e}")
|
|
|
+ results.append({"name": name, "error": str(e)})
|
|
|
+
|
|
|
+ # Aggregate
|
|
|
+ agg = aggregate_scores(all_scores)
|
|
|
+ print(f"\n{'='*60}")
|
|
|
+ print("AGGREGATE SCORES (macro average)")
|
|
|
+ print(f"{'='*60}")
|
|
|
+ print(f" Entity F1: {agg.get('ent_f1', 0):.3f} (P={agg.get('ent_p', 0):.3f} R={agg.get('ent_r', 0):.3f})")
|
|
|
+ print(f" Keyword F1: {agg.get('kw_f1', 0):.3f} (P={agg.get('kw_p', 0):.3f} R={agg.get('kw_r', 0):.3f})")
|
|
|
+ print(f" Leakage (avg): {agg.get('leakage', 0):.3f} (entities in keywords)")
|
|
|
+ print(f" Topic Acc: {agg.get('topic_acc', 0):.3f}")
|
|
|
+
|
|
|
+ return {"aggregate": agg, "per_sample": results}
|
|
|
+
|
|
|
+
|
|
|
+async def collect_samples_from_db(n: int, output_file: str):
|
|
|
+ """Collect N recent clusters from live DB as new golden samples."""
|
|
|
+ from news_mcp.storage.sqlite_store import SQLiteClusterStore
|
|
|
+ from news_mcp.config import DB_PATH
|
|
|
+
|
|
|
+ print(f"Collecting {n} samples from {DB_PATH}...")
|
|
|
+ store = SQLiteClusterStore(DB_PATH)
|
|
|
+ clusters = store.get_latest_clusters_all_topics(ttl_hours=24*30, limit=n)
|
|
|
+
|
|
|
+ samples = []
|
|
|
+ for c in clusters:
|
|
|
+ samples.append({
|
|
|
+ "name": f"live_{c['cluster_id'][:8]}",
|
|
|
+ "cluster": {"headline": c["headline"], "summary": c["summary"]},
|
|
|
+ "expected": {"entities": [], "keywords": [], "topic": c.get("topic", "other")},
|
|
|
+ })
|
|
|
+
|
|
|
+ # Save to file for manual annotation
|
|
|
+ output_path = Path(output_file)
|
|
|
+ output_path.write_text(json.dumps(samples, indent=2, ensure_ascii=False))
|
|
|
+ print(f"Saved {len(samples)} samples to {output_path}")
|
|
|
+ print("Edit the file to fill in expected entities/keywords, then add to GOLDEN_SAMPLES.")
|
|
|
+
|
|
|
+
|
|
|
+def main():
|
|
|
+ parser = argparse.ArgumentParser(description="Evaluate extraction prompt")
|
|
|
+ parser.add_argument("--provider", default=NEWS_EXTRACT_PROVIDER, help="LLM provider")
|
|
|
+ parser.add_argument("--model", default=NEWS_EXTRACT_MODEL, help="LLM model")
|
|
|
+ parser.add_argument("--verbose", "-v", action="store_true", help="Verbose output")
|
|
|
+ parser.add_argument("--collect", type=int, metavar="N", help="Collect N samples from live DB")
|
|
|
+ parser.add_argument("--output", default="new_samples.json", help="Output file for collected samples")
|
|
|
+ args = parser.parse_args()
|
|
|
+
|
|
|
+ if args.collect:
|
|
|
+ asyncio.run(collect_samples_from_db(args.collect, args.output))
|
|
|
+ return
|
|
|
+
|
|
|
+ # Run evaluation
|
|
|
+ result = asyncio.run(evaluate_samples(GOLDEN_SAMPLES, args.provider, args.model, args.verbose))
|
|
|
+
|
|
|
+ # Exit code based on quality threshold
|
|
|
+ agg = result["aggregate"]
|
|
|
+ if agg.get("ent_f1", 0) < 0.5 or agg.get("kw_f1", 0) < 0.4 or agg.get("leakage", 99) > 0.5:
|
|
|
+ print("\n⚠ Quality below threshold!")
|
|
|
+ sys.exit(1)
|
|
|
+ else:
|
|
|
+ print("\n✓ Quality acceptable")
|
|
|
+ sys.exit(0)
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ main()
|