Parcourir la source

prompt improvement, eval script

Lukas Goldschmidt il y a 4 jours
Parent
commit
cf0de8f11d
2 fichiers modifiés avec 413 ajouts et 21 suppressions
  1. 35 21
      prompts/extract_entities.prompt
  2. 378 0
      scripts/eval_extraction.py

+ 35 - 21
prompts/extract_entities.prompt

@@ -1,30 +1,44 @@
 Input cluster JSON:
 {cluster_json}
 
-You MUST extract a news signal from the headline AND summary. Do not leave entities empty when the text mentions obvious names.
+You MUST extract a news signal from the headline AND summary. Return STRICT JSON only.
+
 Task:
-1) infer the best top-level topic
-2) extract concise entities from the cluster
-3) assign sentiment from the wording/context
-4) provide short keywords that justify the classification
-
-Entity rules (strict):
-- Use short strings (1-5 words).
-- Include all obvious named entities mentioned in headline or summary: 
-  named people, , named locations, organizations, ministries, presidents, leaders, wars/conflicts if named.
-- Also include finance/crypto entities when present: BTC, ETH, Bitcoin, Ethereum, ETF, SEC, ECB, Fed, euro, inflation, rates.
-- Prefer canonical entity forms over aliases when obvious (for example, use full organization or place names where helpful).
-- Do NOT return empty entities if any such names/places appear.
-
-Keyword rules (strict):
-- Each keyword MUST be 1-2 words. Never 3+.
-- Keywords are thematic search tags, NOT headline restatements or verb phrases.
-- Good keywords: noun phrases or named concepts (e.g. "drone strikes", "energy infrastructure", "nuclear plant", "oil refinery").
-- Bad keywords: full headline fragments, verb-heavy phrases, or anything over 2 words.
-- Keywords should capture the *themes* of the story, not repeat entity names already in the entities list.
+1) infer the best top-level topic (crypto, macro, regulation, ai, other)
+2) extract concise ENTITIES (proper nouns only)
+3) assign sentiment (positive/negative/neutral) + score (-1.0 to 1.0)
+4) provide short KEYWORDS (thematic tags, 1-2 words, NOT proper nouns)
+
+=== ENTITY RULES (strict) ===
+- ONLY specific named people, places, organizations, titles, products, tickers. 1-5 words.
+- Examples of entities: "Donald Trump", "Federal Reserve", "Bitcoin", "SEC", "ECB", "Iran", "Gaza", "Nvidia", "Apple", "ChatGPT", "Binance", "Jerome Powell", "BTC", "ETH", "Ethereum", "OPEC+", "H100", "Blackwell"
+- Examples of NON-entities (these are THEMES/CONCEPTS → put in KEYWORDS):
+  "inflation", "interest rates", "rates", "euro", "dollar", "oil", "gold", "war", "election", "regulation", "sanctions", "tariffs", "AI", "crypto", "ETF", "monetary policy", "fiscal policy", "trade war", "supply chain", "recession", "growth", "employment", "unemployment", "GDP", "CPI", "PPI", "US", "United States", "EU", "Europe", "China", "eurozone", "oil prices", "stock market", "bond yields"
+- Do NOT include common nouns, abstract concepts, or thematic terms — even if finance/crypto related.
+- Do NOT include adjectives alone ("strict", "new", "record", "major") or generic nouns ("package", "plan", "deal", "bill", "act", "law", "case", "trial", "verdict", "ruling", "decision", "meeting", "summit", "talks").
+
+=== KEYWORD RULES (strict) ===
+- Each keyword MUST be 1-2 words. PREFER 2-word phrases. Avoid single words unless they are established compound concepts (e.g. "inflation" is ok alone, "sanctions" is ok alone).
+- Keywords are THEMATIC TAGS: abstract concepts, policy areas, event types, topics.
+- Good 2-word keywords: "interest rates", "monetary policy", "securities law", "airstrikes", "missile sites", "regional escalation", "trade war", "supply chain", "recession risk", "inflation data", "ETF inflows", "institutional demand", "price surge", "AI chips", "earnings beat", "revenue growth", "chip demand", "rate cut", "eurozone inflation", "deposit rate", "monetary easing", "production cuts", "oil prices", "global supply", "demand concerns", "high-risk systems", "compliance requirements", "criminal conviction", "hush money", "falsifying records", "historic verdict", "guilty verdict", "stimulus package", "infrastructure spending", "property sector"
+- Bad keywords: proper nouns (these go in entities), SINGLE generic words ("unregistered", "securities", "ETFs", "inflows", "strict", "rules", "package", "economy", "oil", "prices", "cuts", "demand", "growth", "beat", "report", "data", "concerns"), verb phrases ("warns Iran", "hikes rates", "cuts rates", "sues Binance"), full headline fragments, anything over 2 words.
 - Return 2-4 keywords. Fewer is better than bad ones.
 
-Sentiment rules:
+=== DECISION PROCEDURE ===
+For each candidate term in the text:
+1. Is it a specific named person/place/org/product/ticker? → ENTITY
+2. Is it a theme, topic, policy area, or event type? → KEYWORD
+3. Can you form a meaningful 2-word phrase? → KEYWORD (use the phrase)
+4. Unclear? Default to KEYWORD (safer to miss an entity than pollute entities with themes)
+
+=== TOPIC CLASSIFICATION ===
+- crypto: Bitcoin, Ethereum, crypto exchanges, DeFi, tokens, mining, ETFs
+- macro: central banks (Fed, ECB, BoE, BoJ), interest rates, inflation, GDP, employment, fiscal/monetary policy, oil, commodities, China economy
+- regulation: SEC, CFTC, lawsuits, enforcement, legislation, compliance, legal rulings, EU AI Act, financial regulation
+- ai: AI models, chips (Nvidia, AMD), LLMs, generative AI, AI companies, AI regulation (but prefer 'regulation' if legal focus)
+- other: geopolitics, war, politics, elections, corporate earnings (non-AI), general business
+
+=== SENTIMENT RULES ===
 - positive: clearly encouraging, improving, or supportive tone
 - negative: clearly alarming, worsening, severe, conflict, loss, risk, warning tone
 - neutral: factual, balanced, or mixed

+ 378 - 0
scripts/eval_extraction.py

@@ -0,0 +1,378 @@
+#!/usr/bin/env python3
+"""
+Evaluation harness for entity/keyword extraction prompt.
+
+Usage:
+  python scripts/eval_extraction.py                    # Run against golden samples
+  python scripts/eval_extraction.py --verbose          # Show per-sample details
+  python scripts/eval_extraction.py --model llama-3.1-8b-instant  # Test specific model
+  python scripts/eval_extraction.py --collect N        # Collect N new samples from live DB
+"""
+
+from __future__ import annotations
+import argparse
+import asyncio
+import json
+import sys
+from pathlib import Path
+from typing import Any
+
+# Add project root to path
+sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
+
+from news_mcp.llm import build_extraction_prompt, call_llm, load_prompt
+from news_mcp.config import (
+    NEWS_EXTRACT_PROVIDER,
+    NEWS_EXTRACT_MODEL,
+    NEWS_ENTITY_BLACKLIST,
+)
+from news_mcp.entity_normalize import normalize_entities
+from news_mcp.enrichment.llm_enrich import _filter_entities
+
+
+# ---------------------------------------------------------------------------
+# Golden samples (curated from real clusters)
+# ---------------------------------------------------------------------------
+GOLDEN_SAMPLES = [
+    {
+        "name": "sec_binance_lawsuit",
+        "cluster": {
+            "headline": "SEC sues Binance over unregistered securities",
+            "summary": "The Securities and Exchange Commission filed a lawsuit against Binance, the world's largest crypto exchange, alleging it operated as an unregistered securities exchange and commingled customer funds.",
+        },
+        "expected": {
+            "entities": ["SEC", "Binance"],
+            "keywords": ["securities law", "crypto exchange", "enforcement action"],
+            "topic": "regulation",
+        },
+    },
+    {
+        "name": "fed_rates_inflation",
+        "cluster": {
+            "headline": "Fed holds rates steady as inflation cools",
+            "summary": "The Federal Reserve kept interest rates unchanged at 5.25-5.50%, citing progress on inflation but signaling caution on future cuts.",
+        },
+        "expected": {
+            "entities": ["Federal Reserve"],
+            "keywords": ["interest rates", "inflation", "monetary policy"],
+            "topic": "macro",
+        },
+    },
+    {
+        "name": "israel_iran_syria_strikes",
+        "cluster": {
+            "headline": "Israel strikes Iranian missile sites in Syria",
+            "summary": "Israeli warplanes targeted Iranian missile depots near Damascus overnight, escalating regional tensions.",
+        },
+        "expected": {
+            "entities": ["Israel", "Iran", "Syria", "Damascus"],
+            "keywords": ["airstrikes", "missile sites", "regional escalation"],
+            "topic": "other",
+        },
+    },
+    {
+        "name": "bitcoin_etf_flows",
+        "cluster": {
+            "headline": "Bitcoin ETFs see record inflows as BTC tops $70k",
+            "summary": "US spot Bitcoin ETFs attracted $2.3 billion in net inflows this week as Bitcoin surged past $70,000, driven by institutional demand.",
+        },
+        "expected": {
+            "entities": ["Bitcoin", "BTC"],
+            "keywords": ["ETF inflows", "institutional demand", "price surge"],
+            "topic": "crypto",
+        },
+    },
+    {
+        "name": "ai_regulation_eu",
+        "cluster": {
+            "headline": "EU AI Act enters force with strict rules for high-risk systems",
+            "summary": "The European Union's landmark AI Act took effect today, imposing strict requirements on high-risk AI systems including transparency, human oversight, and risk management.",
+        },
+        "expected": {
+            "entities": ["European Union", "EU AI Act"],
+            "keywords": ["AI regulation", "high-risk systems", "compliance requirements"],
+            "topic": "ai",
+        },
+    },
+    {
+        "name": "china_economy_stimulus",
+        "cluster": {
+            "headline": "China unveils stimulus package to boost slowing economy",
+            "summary": "Beijing announced a comprehensive stimulus package including infrastructure spending, tax cuts, and monetary easing to counter slowing growth and property sector weakness.",
+        },
+        "expected": {
+            "entities": ["China", "Beijing"],
+            "keywords": ["stimulus package", "infrastructure spending", "monetary easing", "property sector"],
+            "topic": "macro",
+        },
+    },
+    {
+        "name": "oil_prices_opepcuts",
+        "cluster": {
+            "headline": "Oil jumps after OPEC+ extends production cuts",
+            "summary": "Crude oil prices rose 3% after OPEC+ agreed to extend production cuts through year-end, tightening global supply amid demand concerns.",
+        },
+        "expected": {
+            "entities": ["OPEC+"],
+            "keywords": ["production cuts", "oil prices", "global supply", "demand concerns"],
+            "topic": "macro",
+        },
+    },
+    {
+        "name": "nvidia_earnings_ai",
+        "cluster": {
+            "headline": "Nvidia beats earnings on AI chip demand",
+            "summary": "Nvidia reported quarterly revenue of $26 billion, up 262% year-over-year, driven by insatiable demand for its H100 and Blackwell AI chips.",
+        },
+        "expected": {
+            "entities": ["Nvidia", "H100", "Blackwell"],
+            "keywords": ["AI chips", "earnings beat", "revenue growth", "chip demand"],
+            "topic": "ai",
+        },
+    },
+    {
+        "name": "ecb_rate_decision",
+        "cluster": {
+            "headline": "ECB cuts rates as eurozone inflation falls to 2.4%",
+            "summary": "The European Central Bank lowered its deposit rate by 25 basis points to 3.75%, marking its first cut since 2019 as inflation approaches target.",
+        },
+        "expected": {
+            "entities": ["European Central Bank", "ECB"],
+            "keywords": ["rate cut", "eurozone inflation", "deposit rate", "monetary easing"],
+            "topic": "macro",
+        },
+    },
+    {
+        "name": "trump_legal_cases",
+        "cluster": {
+            "headline": "Trump convicted in hush money trial, first former president found guilty",
+            "summary": "A New York jury found Donald Trump guilty on all 34 counts of falsifying business records in the hush money case, making him the first former US president convicted of a crime.",
+        },
+        "expected": {
+            "entities": ["Donald Trump", "New York"],
+            "keywords": ["criminal conviction", "hush money", "falsifying records", "historic verdict"],
+            "topic": "other",
+        },
+    },
+]
+
+
+# ---------------------------------------------------------------------------
+# Scoring functions
+# ---------------------------------------------------------------------------
+def normalize_list(items: list[str]) -> set[str]:
+    """Normalize list of strings for comparison: lowercase, strip, dedupe."""
+    return {str(x).strip().lower() for x in items if x and str(x).strip()}
+
+
+def score_extraction(pred: dict[str, Any], gold: dict[str, Any]) -> dict[str, float]:
+    """Score a single prediction against gold standard."""
+    pred_entities = normalize_list(pred.get("entities", []))
+    gold_entities = normalize_list(gold.get("entities", []))
+    pred_keywords = normalize_list(pred.get("keywords", []))
+    gold_keywords = normalize_list(gold.get("keywords", []))
+
+    # Entity precision/recall/F1
+    if pred_entities:
+        ent_p = len(pred_entities & gold_entities) / len(pred_entities)
+    else:
+        ent_p = 0.0
+    if gold_entities:
+        ent_r = len(pred_entities & gold_entities) / len(gold_entities)
+    else:
+        ent_r = 0.0
+    ent_f1 = 2 * ent_p * ent_r / (ent_p + ent_r) if (ent_p + ent_r) > 0 else 0.0
+
+    # Keyword precision/recall/F1 (allow partial overlap since keywords are thematic)
+    if pred_keywords:
+        kw_p = len(pred_keywords & gold_keywords) / len(pred_keywords)
+    else:
+        kw_p = 0.0
+    if gold_keywords:
+        kw_r = len(pred_keywords & gold_keywords) / len(gold_keywords)
+    else:
+        kw_r = 0.0
+    kw_f1 = 2 * kw_p * kw_r / (kw_p + kw_r) if (kw_p + kw_r) > 0 else 0.0
+
+    # Leakage: entities appearing in keywords (should be 0)
+    leakage = len(pred_entities & pred_keywords)
+
+    # Topic accuracy
+    topic_acc = 1.0 if pred.get("topic") == gold.get("topic") else 0.0
+
+    return {
+        "ent_p": round(ent_p, 3),
+        "ent_r": round(ent_r, 3),
+        "ent_f1": round(ent_f1, 3),
+        "kw_p": round(kw_p, 3),
+        "kw_r": round(kw_r, 3),
+        "kw_f1": round(kw_f1, 3),
+        "leakage": leakage,
+        "topic_acc": topic_acc,
+    }
+
+
+def aggregate_scores(scores: list[dict[str, float]]) -> dict[str, float]:
+    """Compute macro averages across samples."""
+    if not scores:
+        return {}
+    keys = scores[0].keys()
+    return {k: round(sum(s[k] for s in scores) / len(scores), 3) for k in keys}
+
+
+def print_sample_result(name: str, pred: dict, gold: dict, scores: dict, verbose: bool):
+    """Print result for a single sample."""
+    print(f"\n{'='*60}")
+    print(f"Sample: {name}")
+    print(f"{'='*60}")
+    print(f"  Topic: pred={pred.get('topic')} | gold={gold.get('topic')} | {'✓' if scores['topic_acc'] else '✗'}")
+    print(f"  Entities:  P={scores['ent_p']:.2f} R={scores['ent_r']:.2f} F1={scores['ent_f1']:.2f}")
+    print(f"    Pred: {pred.get('entities', [])}")
+    print(f"    Gold: {gold.get('entities', [])}")
+    print(f"  Keywords:  P={scores['kw_p']:.2f} R={scores['kw_r']:.2f} F1={scores['kw_f1']:.2f}")
+    print(f"    Pred: {pred.get('keywords', [])}")
+    print(f"    Gold: {gold.get('keywords', [])}")
+    if scores["leakage"] > 0:
+        leaked = set(e.lower() for e in pred.get("entities", [])) & set(k.lower() for k in pred.get("keywords", []))
+        print(f"  ⚠ LEAKAGE: {leaked} (entities in keywords)")
+    if verbose:
+        print(f"  Sentiment: {pred.get('sentiment')} ({pred.get('sentimentScore')})")
+
+
+async def run_extraction(cluster: dict[str, Any], provider: str, model: str) -> dict[str, Any]:
+    """Run extraction on a single cluster."""
+    prompt = load_prompt("extract_entities.prompt")
+    # Build the full user prompt
+    import json as json_lib
+    user_prompt = prompt.replace("{cluster_json}", json_lib.dumps(cluster, ensure_ascii=False))
+
+    from news_mcp.config import GROQ_API_KEY, OPENAI_API_KEY, OPENROUTER_API_KEY
+
+    system_prompt = "You are a news signal extraction engine. Return STRICT JSON only."
+
+    content = await call_llm(provider, model, system_prompt, user_prompt)
+    return json.loads(content)
+
+
+async def evaluate_samples(samples: list[dict], provider: str, model: str, verbose: bool) -> dict:
+    """Run evaluation on all samples."""
+    print(f"\nEvaluating {len(samples)} samples with {provider}/{model}...")
+    print("-" * 60)
+
+    all_scores = []
+    results = []
+
+    for i, sample in enumerate(samples, 1):
+        name = sample["name"]
+        cluster = sample["cluster"]
+        gold = sample["expected"]
+
+        print(f"[{i}/{len(samples)}] {name}...", end=" ", flush=True)
+
+        try:
+            pred = await run_extraction(cluster, provider, model)
+
+            # Apply same post-processing as production pipeline
+            from news_mcp.enrichment.llm_enrich import _filter_entities, normalize_entities
+            from news_mcp.config import DEFAULT_TOPICS, NEWS_ENTITY_BLACKLIST
+
+            entities = _filter_entities(normalize_entities(pred.get("entities", [])), blacklist=NEWS_ENTITY_BLACKLIST)
+            keywords = _filter_entities(normalize_entities(pred.get("keywords", [])), blacklist=NEWS_ENTITY_BLACKLIST)
+
+            # Filter topic labels from keywords
+            topic_labels = {t.lower() for t in DEFAULT_TOPICS}
+            keywords = [k for k in keywords if k.lower() not in topic_labels]
+
+            # Length cap
+            keywords = [k for k in keywords if len(k.split()) <= 2]
+
+            # De-duplicate entities vs keywords
+            entity_keys = {e.strip().lower() for e in entities}
+            keywords = [k for k in keywords if k.strip().lower() not in entity_keys]
+
+            pred_processed = {
+                "topic": pred.get("topic", "other"),
+                "entities": entities,
+                "keywords": keywords,
+                "sentiment": pred.get("sentiment", "neutral"),
+                "sentimentScore": pred.get("sentimentScore", 0.0),
+            }
+
+            scores = score_extraction(pred_processed, gold)
+            all_scores.append(scores)
+            results.append({"name": name, "pred": pred_processed, "gold": gold, "scores": scores})
+
+            if verbose:
+                print_sample_result(name, pred_processed, gold, scores, verbose)
+            else:
+                print(f"Ent_F1={scores['ent_f1']:.2f} Kw_F1={scores['kw_f1']:.2f} Leak={scores['leakage']} Topic={'✓' if scores['topic_acc'] else '✗'}")
+
+        except Exception as e:
+            print(f"ERROR: {e}")
+            results.append({"name": name, "error": str(e)})
+
+    # Aggregate
+    agg = aggregate_scores(all_scores)
+    print(f"\n{'='*60}")
+    print("AGGREGATE SCORES (macro average)")
+    print(f"{'='*60}")
+    print(f"  Entity F1:     {agg.get('ent_f1', 0):.3f}  (P={agg.get('ent_p', 0):.3f} R={agg.get('ent_r', 0):.3f})")
+    print(f"  Keyword F1:    {agg.get('kw_f1', 0):.3f}  (P={agg.get('kw_p', 0):.3f} R={agg.get('kw_r', 0):.3f})")
+    print(f"  Leakage (avg): {agg.get('leakage', 0):.3f}  (entities in keywords)")
+    print(f"  Topic Acc:     {agg.get('topic_acc', 0):.3f}")
+
+    return {"aggregate": agg, "per_sample": results}
+
+
+async def collect_samples_from_db(n: int, output_file: str):
+    """Collect N recent clusters from live DB as new golden samples."""
+    from news_mcp.storage.sqlite_store import SQLiteClusterStore
+    from news_mcp.config import DB_PATH
+
+    print(f"Collecting {n} samples from {DB_PATH}...")
+    store = SQLiteClusterStore(DB_PATH)
+    clusters = store.get_latest_clusters_all_topics(ttl_hours=24*30, limit=n)
+
+    samples = []
+    for c in clusters:
+        samples.append({
+            "name": f"live_{c['cluster_id'][:8]}",
+            "cluster": {"headline": c["headline"], "summary": c["summary"]},
+            "expected": {"entities": [], "keywords": [], "topic": c.get("topic", "other")},
+        })
+
+    # Save to file for manual annotation
+    output_path = Path(output_file)
+    output_path.write_text(json.dumps(samples, indent=2, ensure_ascii=False))
+    print(f"Saved {len(samples)} samples to {output_path}")
+    print("Edit the file to fill in expected entities/keywords, then add to GOLDEN_SAMPLES.")
+
+
+def main():
+    parser = argparse.ArgumentParser(description="Evaluate extraction prompt")
+    parser.add_argument("--provider", default=NEWS_EXTRACT_PROVIDER, help="LLM provider")
+    parser.add_argument("--model", default=NEWS_EXTRACT_MODEL, help="LLM model")
+    parser.add_argument("--verbose", "-v", action="store_true", help="Verbose output")
+    parser.add_argument("--collect", type=int, metavar="N", help="Collect N samples from live DB")
+    parser.add_argument("--output", default="new_samples.json", help="Output file for collected samples")
+    args = parser.parse_args()
+
+    if args.collect:
+        asyncio.run(collect_samples_from_db(args.collect, args.output))
+        return
+
+    # Run evaluation
+    result = asyncio.run(evaluate_samples(GOLDEN_SAMPLES, args.provider, args.model, args.verbose))
+
+    # Exit code based on quality threshold
+    agg = result["aggregate"]
+    if agg.get("ent_f1", 0) < 0.5 or agg.get("kw_f1", 0) < 0.4 or agg.get("leakage", 99) > 0.5:
+        print("\n⚠ Quality below threshold!")
+        sys.exit(1)
+    else:
+        print("\n✓ Quality acceptable")
+        sys.exit(0)
+
+
+if __name__ == "__main__":
+    main()