llm_enrich.py 3.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. from __future__ import annotations
  2. from fnmatch import fnmatchcase
  3. from typing import Any, Dict
  4. from news_mcp.config import NEWS_ENTITY_BLACKLIST, DEFAULT_TOPICS
  5. from news_mcp.entity_normalize import normalize_entities
  6. from news_mcp.llm import call_extraction, call_summary
  7. from news_mcp.trends_resolution import resolve_entity_via_trends
  8. def _matches_blacklist(value: str, blacklist=None) -> bool:
  9. patterns = [x.strip().lower() for x in (blacklist if blacklist is not None else NEWS_ENTITY_BLACKLIST) if x and x.strip()]
  10. key = str(value).strip().lower()
  11. if not key:
  12. return True
  13. return any(fnmatchcase(key, pattern) for pattern in patterns)
  14. def _filter_entities(entities, blacklist=None):
  15. out = []
  16. for ent in entities or []:
  17. if _matches_blacklist(ent, blacklist=blacklist):
  18. continue
  19. out.append(ent)
  20. return out
  21. async def classify_cluster_llm(cluster: Dict[str, Any]) -> Dict[str, Any]:
  22. parsed = await call_extraction(cluster)
  23. out = dict(cluster)
  24. # Topic: prefer the LLM's classification, fall back to the heuristic topic
  25. # already on the input cluster. Validate against the allowed set so we never
  26. # promote a free-form string into the SQL row column.
  27. raw_topic = parsed.get("topic", cluster.get("topic"))
  28. topic = str(raw_topic).strip().lower() if raw_topic else None
  29. if topic and _matches_blacklist(topic):
  30. topic = "other"
  31. if topic not in {t.lower() for t in DEFAULT_TOPICS}:
  32. # Unknown / hallucinated label -> fall back to whatever the heuristic
  33. # classifier on the headline gave us, else "other".
  34. fallback = str(cluster.get("topic") or "").strip().lower()
  35. topic = fallback if fallback in {t.lower() for t in DEFAULT_TOPICS} else "other"
  36. # IMPORTANT: normalize aliases BEFORE applying the blacklist, otherwise
  37. # blacklisting "bitcoin" misses entries the LLM returned as "btc".
  38. entities = _filter_entities(normalize_entities(parsed.get("entities", [])))
  39. keywords = _filter_entities(normalize_entities(parsed.get("keywords", [])))
  40. # Filter out topic labels from keywords. The LLM often returns the
  41. # topic (e.g. "crypto", "macro", "regulation", "ai") as a keyword
  42. # since the prompt asks for "keywords that justify the classification".
  43. # These are already captured by the cluster topic field and should not
  44. # pollute keyword search/scoring/frequencies.
  45. _topic_labels = {t.lower() for t in DEFAULT_TOPICS}
  46. keywords = [k for k in keywords if k.lower() not in _topic_labels]
  47. out.update({
  48. "topic": topic,
  49. "entities": entities,
  50. "entityResolutions": [resolve_entity_via_trends(e) for e in entities],
  51. "sentiment": parsed.get("sentiment", "neutral"),
  52. "sentimentScore": parsed.get("sentimentScore"),
  53. "keywords": keywords,
  54. })
  55. return out
  56. async def summarize_cluster_llm(cluster: Dict[str, Any]) -> Dict[str, Any]:
  57. parsed = await call_summary(cluster)
  58. return parsed
  59. # Backward-compatible aliases during the transition away from provider-specific naming.
  60. classify_cluster_groq = classify_cluster_llm
  61. summarize_cluster_groq = summarize_cluster_llm