llm.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. from __future__ import annotations
  2. import json
  3. from pathlib import Path
  4. from typing import Any, Dict, Iterable, List
  5. import httpx
  6. from news_mcp.config import (
  7. GROQ_API_KEY,
  8. NEWS_EXTRACT_PROVIDER,
  9. NEWS_EXTRACT_MODEL,
  10. NEWS_SUMMARY_PROVIDER,
  11. NEWS_SUMMARY_MODEL,
  12. OPENAI_API_KEY,
  13. OPENROUTER_API_KEY,
  14. PROMPTS_DIR,
  15. )
  16. SYSTEM_PROMPT = "You are a news signal extraction engine. Return STRICT JSON only."
  17. class LLMError(RuntimeError):
  18. pass
  19. def load_prompt(name: str) -> str:
  20. path = PROMPTS_DIR / name
  21. return path.read_text(encoding="utf-8")
  22. def _render_prompt(template: str, **kwargs: Any) -> str:
  23. rendered = template
  24. for key, value in kwargs.items():
  25. rendered = rendered.replace("{" + key + "}", str(value))
  26. return rendered
  27. def active_llm_config() -> dict[str, str]:
  28. return {
  29. "extract_provider": NEWS_EXTRACT_PROVIDER,
  30. "extract_model": NEWS_EXTRACT_MODEL,
  31. "summary_provider": NEWS_SUMMARY_PROVIDER,
  32. "summary_model": NEWS_SUMMARY_MODEL,
  33. "openrouter_key_set": bool(OPENROUTER_API_KEY),
  34. }
  35. async def _call_groq(model: str, messages: List[Dict[str, str]], response_json: bool = True) -> str:
  36. if not GROQ_API_KEY:
  37. raise LLMError("GROQ_API_KEY is not configured")
  38. req = {"model": model, "messages": messages, "temperature": 0.2}
  39. if response_json:
  40. req["response_format"] = {"type": "json_object"}
  41. async with httpx.AsyncClient(timeout=45.0) as client:
  42. resp = await client.post(
  43. "https://api.groq.com/openai/v1/chat/completions",
  44. headers={"Authorization": f"Bearer {GROQ_API_KEY}"},
  45. json=req,
  46. )
  47. resp.raise_for_status()
  48. data = resp.json()
  49. return data["choices"][0]["message"]["content"]
  50. async def _call_openai(model: str, messages: List[Dict[str, str]], response_json: bool = True) -> str:
  51. # OpenAI-compatible chat endpoint; uses NEWS_OPENAI_API_KEY.
  52. if not OPENAI_API_KEY:
  53. raise LLMError("OPENAI_API_KEY is not configured")
  54. req = {"model": model, "messages": messages}
  55. if response_json:
  56. req["response_format"] = {"type": "json_object"}
  57. async with httpx.AsyncClient(timeout=45.0) as client:
  58. resp = await client.post(
  59. "https://api.openai.com/v1/chat/completions",
  60. headers={"Authorization": f"Bearer {OPENAI_API_KEY}"},
  61. json=req,
  62. )
  63. resp.raise_for_status()
  64. data = resp.json()
  65. return data["choices"][0]["message"]["content"]
  66. OR_OPENROUTER_URL = "https://openrouter.ai/api/v1/chat/completions"
  67. async def _call_openrouter(model: str, messages: List[Dict[str, str]], response_json: bool = True) -> str:
  68. if not OPENROUTER_API_KEY:
  69. raise LLMError("OPENROUTER_API_KEY is not configured")
  70. req = {"model": model, "messages": messages, "temperature": 0.2}
  71. if response_json:
  72. req["response_format"] = {"type": "json_object"}
  73. headers = {
  74. "Authorization": f"Bearer {OPENROUTER_API_KEY}",
  75. "HTTP-Referer": "https://github.com/gr1m0/bolt.new-rss",
  76. "X-Title": "news-mcp",
  77. }
  78. async with httpx.AsyncClient(timeout=45.0) as client:
  79. resp = await client.post(
  80. OR_OPENROUTER_URL,
  81. headers=headers,
  82. json=req,
  83. )
  84. resp.raise_for_status()
  85. data = resp.json()
  86. return data["choices"][0]["message"]["content"]
  87. async def call_llm(provider: str, model: str, system_prompt: str, user_prompt: str) -> str:
  88. messages = [
  89. {"role": "system", "content": system_prompt},
  90. {"role": "user", "content": user_prompt},
  91. ]
  92. provider = provider.lower().strip()
  93. if provider == "groq":
  94. return await _call_groq(model, messages)
  95. if provider == "openai":
  96. return await _call_openai(model, messages)
  97. if provider == "openrouter":
  98. return await _call_openrouter(model, messages)
  99. raise LLMError(f"Unsupported provider: {provider}. Valid: groq, openai, openrouter")
  100. def build_extraction_prompt(cluster: Dict[str, Any]) -> str:
  101. prompt = load_prompt("extract_entities.prompt")
  102. return _render_prompt(prompt, cluster_json=json.dumps(cluster, ensure_ascii=False))
  103. async def call_extraction(cluster: Dict[str, Any]) -> Dict[str, Any]:
  104. user_prompt = build_extraction_prompt(cluster)
  105. content = await call_llm(NEWS_EXTRACT_PROVIDER, NEWS_EXTRACT_MODEL, SYSTEM_PROMPT, user_prompt)
  106. return json.loads(content)
  107. async def call_summary(cluster: Dict[str, Any]) -> Dict[str, Any]:
  108. prompt = load_prompt("summarize_cluster.prompt")
  109. user_prompt = _render_prompt(prompt, cluster_json=json.dumps(cluster, ensure_ascii=False))
  110. content = await call_llm(NEWS_SUMMARY_PROVIDER, NEWS_SUMMARY_MODEL, "You are a summarization engine for news clusters. Return strict JSON only.", user_prompt)
  111. return json.loads(content)