llm_enrich.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. from __future__ import annotations
  2. from fnmatch import fnmatchcase
  3. from typing import Any, Dict
  4. from news_mcp.config import NEWS_ENTITY_BLACKLIST, DEFAULT_TOPICS
  5. from news_mcp.entity_normalize import normalize_entities
  6. from news_mcp.llm import call_extraction, call_summary
  7. from news_mcp.trends_resolution import resolve_entity_via_trends
  8. def _matches_blacklist(value: str, blacklist=None) -> bool:
  9. patterns = [x.strip().lower() for x in (blacklist if blacklist is not None else NEWS_ENTITY_BLACKLIST) if x and x.strip()]
  10. key = str(value).strip().lower()
  11. if not key:
  12. return True
  13. return any(fnmatchcase(key, pattern) for pattern in patterns)
  14. def _filter_entities(entities, blacklist=None):
  15. out = []
  16. for ent in entities or []:
  17. if _matches_blacklist(ent, blacklist=blacklist):
  18. continue
  19. out.append(ent)
  20. return out
  21. async def classify_cluster_llm(cluster: Dict[str, Any]) -> Dict[str, Any]:
  22. parsed = await call_extraction(cluster)
  23. out = dict(cluster)
  24. # Topic: prefer the LLM's classification, fall back to the heuristic topic
  25. # already on the input cluster. Validate against the allowed set so we never
  26. # promote a free-form string into the SQL row column.
  27. raw_topic = parsed.get("topic", cluster.get("topic"))
  28. topic = str(raw_topic).strip().lower() if raw_topic else None
  29. if topic and _matches_blacklist(topic):
  30. topic = "other"
  31. if topic not in {t.lower() for t in DEFAULT_TOPICS}:
  32. # Unknown / hallucinated label -> fall back to whatever the heuristic
  33. # classifier on the headline gave us, else "other".
  34. fallback = str(cluster.get("topic") or "").strip().lower()
  35. topic = fallback if fallback in {t.lower() for t in DEFAULT_TOPICS} else "other"
  36. # IMPORTANT: normalize aliases BEFORE applying the blacklist, otherwise
  37. # blacklisting "bitcoin" misses entries the LLM returned as "btc".
  38. entities = _filter_entities(normalize_entities(parsed.get("entities", [])))
  39. keywords = _filter_entities(normalize_entities(parsed.get("keywords", [])))
  40. # Filter out topic labels from keywords. The LLM often returns the
  41. # topic (e.g. "crypto", "macro", "regulation", "ai") as a keyword
  42. # since the prompt asks for "keywords that justify the classification".
  43. # These are already captured by the cluster topic field and should not
  44. # pollute keyword search/scoring/frequencies.
  45. _topic_labels = {t.lower() for t in DEFAULT_TOPICS}
  46. keywords = [k for k in keywords if k.lower() not in _topic_labels]
  47. # Enforce per-keyword length cap (max 2 words) as a hard guard.
  48. # The prompt requests this but the LLM occasionally ignores it.
  49. keywords = [k for k in keywords if len(k.split()) <= 2]
  50. # De-duplicate entities vs keywords — entities list is the
  51. # authoritative source for proper nouns; keywords should be the
  52. # thematic complement, not a repeat.
  53. _entity_keys = {e.strip().lower() for e in entities}
  54. keywords = [k for k in keywords if k.strip().lower() not in _entity_keys]
  55. out.update({
  56. "topic": topic,
  57. "entities": entities,
  58. "entityResolutions": [resolve_entity_via_trends(e) for e in entities],
  59. "sentiment": parsed.get("sentiment", "neutral"),
  60. "sentimentScore": parsed.get("sentimentScore"),
  61. "keywords": keywords,
  62. })
  63. return out
  64. async def summarize_cluster_llm(cluster: Dict[str, Any]) -> Dict[str, Any]:
  65. parsed = await call_summary(cluster)
  66. return parsed
  67. # Backward-compatible aliases during the transition away from provider-specific naming.
  68. classify_cluster_groq = classify_cluster_llm
  69. summarize_cluster_groq = summarize_cluster_llm