llm_enrich.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. from __future__ import annotations
  2. from fnmatch import fnmatchcase
  3. from typing import Any, Dict
  4. from news_mcp.config import NEWS_ENTITY_BLACKLIST, DEFAULT_TOPICS
  5. from news_mcp.entity_normalize import normalize_entities
  6. from news_mcp.llm import call_extraction, call_summary
  7. from news_mcp.trends_resolution import resolve_entity_via_trends
  8. def _matches_blacklist(value: str, blacklist=None) -> bool:
  9. patterns = [x.strip().lower() for x in (blacklist if blacklist is not None else NEWS_ENTITY_BLACKLIST) if x and x.strip()]
  10. key = str(value).strip().lower()
  11. if not key:
  12. return True
  13. return any(fnmatchcase(key, pattern) for pattern in patterns)
  14. def _filter_entities(entities, blacklist=None):
  15. out = []
  16. for ent in entities or []:
  17. if _matches_blacklist(ent, blacklist=blacklist):
  18. continue
  19. out.append(ent)
  20. return out
  21. async def classify_cluster_llm(cluster: Dict[str, Any]) -> Dict[str, Any]:
  22. parsed = await call_extraction(cluster)
  23. out = dict(cluster)
  24. # Topic: prefer the LLM's classification, fall back to the heuristic topic
  25. # already on the input cluster. Validate against the allowed set so we never
  26. # promote a free-form string into the SQL row column.
  27. raw_topic = parsed.get("topic", cluster.get("topic"))
  28. topic = str(raw_topic).strip().lower() if raw_topic else None
  29. if topic and _matches_blacklist(topic):
  30. topic = "other"
  31. if topic not in {t.lower() for t in DEFAULT_TOPICS}:
  32. # Unknown / hallucinated label -> fall back to whatever the heuristic
  33. # classifier on the headline gave us, else "other".
  34. fallback = str(cluster.get("topic") or "").strip().lower()
  35. topic = fallback if fallback in {t.lower() for t in DEFAULT_TOPICS} else "other"
  36. # IMPORTANT: normalize aliases BEFORE applying the blacklist, otherwise
  37. # blacklisting "bitcoin" misses entries the LLM returned as "btc".
  38. entities = _filter_entities(normalize_entities(parsed.get("entities", [])))
  39. keywords = _filter_entities(normalize_entities(parsed.get("keywords", [])))
  40. out.update({
  41. "topic": topic,
  42. "entities": entities,
  43. "entityResolutions": [resolve_entity_via_trends(e) for e in entities],
  44. "sentiment": parsed.get("sentiment", "neutral"),
  45. "sentimentScore": parsed.get("sentimentScore"),
  46. "keywords": keywords,
  47. })
  48. return out
  49. async def summarize_cluster_llm(cluster: Dict[str, Any]) -> Dict[str, Any]:
  50. parsed = await call_summary(cluster)
  51. return parsed
  52. # Backward-compatible aliases during the transition away from provider-specific naming.
  53. classify_cluster_groq = classify_cluster_llm
  54. summarize_cluster_groq = summarize_cluster_llm