llm.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. from __future__ import annotations
  2. import asyncio
  3. import json
  4. from pathlib import Path
  5. from typing import Any, Dict, Iterable, List
  6. import httpx
  7. from news_mcp.config import (
  8. GROQ_API_KEY,
  9. NEWS_EXTRACT_PROVIDER,
  10. NEWS_EXTRACT_MODEL,
  11. NEWS_SUMMARY_PROVIDER,
  12. NEWS_SUMMARY_MODEL,
  13. OPENAI_API_KEY,
  14. OPENROUTER_API_KEY,
  15. PROMPTS_DIR,
  16. )
  17. SYSTEM_PROMPT = "You are a news signal extraction engine. Return STRICT JSON only."
  18. class LLMError(RuntimeError):
  19. pass
  20. def load_prompt(name: str) -> str:
  21. path = PROMPTS_DIR / name
  22. return path.read_text(encoding="utf-8")
  23. def _render_prompt(template: str, **kwargs: Any) -> str:
  24. rendered = template
  25. for key, value in kwargs.items():
  26. rendered = rendered.replace("{" + key + "}", str(value))
  27. return rendered
  28. def active_llm_config() -> dict[str, str]:
  29. return {
  30. "extract_provider": NEWS_EXTRACT_PROVIDER,
  31. "extract_model": NEWS_EXTRACT_MODEL,
  32. "summary_provider": NEWS_SUMMARY_PROVIDER,
  33. "summary_model": NEWS_SUMMARY_MODEL,
  34. "openrouter_key_set": bool(OPENROUTER_API_KEY),
  35. }
  36. async def _call_groq(model: str, messages: List[Dict[str, str]], response_json: bool = True, retries: int = 2) -> str:
  37. if not GROQ_API_KEY:
  38. raise LLMError("GROQ_API_KEY is not configured")
  39. req = {"model": model, "messages": messages, "temperature": 0.2}
  40. if response_json:
  41. req["response_format"] = {"type": "json_object"}
  42. last_err = ""
  43. async with httpx.AsyncClient(timeout=45.0) as client:
  44. for attempt in range(1 + retries):
  45. resp = await client.post(
  46. "https://api.groq.com/openai/v1/chat/completions",
  47. headers={"Authorization": f"Bearer {GROQ_API_KEY}"},
  48. json=req,
  49. )
  50. if resp.status_code != 200:
  51. last_err = f"HTTP {resp.status_code}: {resp.text[:300]}"
  52. if resp.status_code in (429, 500, 502, 503):
  53. await asyncio.sleep(2 ** attempt)
  54. continue
  55. resp.raise_for_status()
  56. data = resp.json()
  57. if "error" in data:
  58. last_err = f"API error: {data['error']}"
  59. break
  60. choices = data.get("choices", [])
  61. if not choices:
  62. last_err = f"No choices in response: {str(data)[:300]}"
  63. if attempt < retries:
  64. await asyncio.sleep(2 ** attempt)
  65. continue
  66. break
  67. content = choices[0].get("message", {}).get("content")
  68. if content:
  69. return content
  70. last_err = f"Empty content in choice: {str(choices[0])[:200]}"
  71. if attempt < retries:
  72. await asyncio.sleep(2 ** attempt)
  73. continue
  74. break
  75. raise LLMError(f"Groq failed after {1+retries} attempts: {last_err}")
  76. async def _call_openai(model: str, messages: List[Dict[str, str]], response_json: bool = True, retries: int = 2) -> str:
  77. if not OPENAI_API_KEY:
  78. raise LLMError("OPENAI_API_KEY is not configured")
  79. req = {"model": model, "messages": messages}
  80. if response_json:
  81. req["response_format"] = {"type": "json_object"}
  82. last_err = ""
  83. async with httpx.AsyncClient(timeout=45.0) as client:
  84. for attempt in range(1 + retries):
  85. resp = await client.post(
  86. "https://api.openai.com/v1/chat/completions",
  87. headers={"Authorization": f"Bearer {OPENAI_API_KEY}"},
  88. json=req,
  89. )
  90. if resp.status_code != 200:
  91. last_err = f"HTTP {resp.status_code}: {resp.text[:300]}"
  92. if resp.status_code in (429, 500, 502, 503):
  93. await asyncio.sleep(2 ** attempt)
  94. continue
  95. resp.raise_for_status()
  96. data = resp.json()
  97. if "error" in data:
  98. last_err = f"API error: {data['error']}"
  99. break
  100. choices = data.get("choices", [])
  101. if not choices:
  102. last_err = f"No choices in response: {str(data)[:300]}"
  103. if attempt < retries:
  104. await asyncio.sleep(2 ** attempt)
  105. continue
  106. break
  107. content = choices[0].get("message", {}).get("content")
  108. if content:
  109. return content
  110. last_err = f"Empty content in choice: {str(choices[0])[:200]}"
  111. if attempt < retries:
  112. await asyncio.sleep(2 ** attempt)
  113. continue
  114. break
  115. raise LLMError(f"OpenAI failed after {1+retries} attempts: {last_err}")
  116. OR_OPENROUTER_URL = "https://openrouter.ai/api/v1/chat/completions"
  117. async def _call_openrouter(model: str, messages: List[Dict[str, str]], response_json: bool = True, retries: int = 2) -> str:
  118. if not OPENROUTER_API_KEY:
  119. raise LLMError("OPENROUTER_API_KEY is not configured")
  120. req = {"model": model, "messages": messages, "temperature": 0.2}
  121. if response_json:
  122. req["response_format"] = {"type": "json_object"}
  123. headers = {
  124. "Authorization": f"Bearer {OPENROUTER_API_KEY}",
  125. "HTTP-Referer": "https://github.com/gr1m0/bolt.new-rss",
  126. "X-Title": "news-mcp",
  127. }
  128. last_err = ""
  129. async with httpx.AsyncClient(timeout=45.0) as client:
  130. for attempt in range(1 + retries):
  131. resp = await client.post(
  132. OR_OPENROUTER_URL,
  133. headers=headers,
  134. json=req,
  135. )
  136. if resp.status_code != 200:
  137. last_err = f"HTTP {resp.status_code}: {resp.text[:300]}"
  138. if resp.status_code in (429, 500, 502, 503):
  139. await asyncio.sleep(2 ** attempt)
  140. continue
  141. resp.raise_for_status()
  142. data = resp.json()
  143. if "error" in data:
  144. last_err = f"API error: {data['error']}"
  145. break
  146. choices = data.get("choices", [])
  147. if not choices:
  148. last_err = f"No choices in response: {str(data)[:300]}"
  149. if attempt < retries:
  150. await asyncio.sleep(2 ** attempt)
  151. continue
  152. break
  153. msg = choices[0].get("message", {})
  154. content = msg.get("content")
  155. if content:
  156. return content
  157. last_err = f"Empty content in choice: {str(msg)[:200]}"
  158. break
  159. raise LLMError(f"OpenRouter failed after {1+retries} attempts: {last_err}")
  160. async def call_llm(provider: str, model: str, system_prompt: str, user_prompt: str) -> str:
  161. messages = [
  162. {"role": "system", "content": system_prompt},
  163. {"role": "user", "content": user_prompt},
  164. ]
  165. provider = provider.lower().strip()
  166. if provider == "groq":
  167. return await _call_groq(model, messages)
  168. if provider == "openai":
  169. return await _call_openai(model, messages)
  170. if provider == "openrouter":
  171. return await _call_openrouter(model, messages)
  172. raise LLMError(f"Unsupported provider: {provider}. Valid: groq, openai, openrouter")
  173. def build_extraction_prompt(cluster: Dict[str, Any]) -> str:
  174. prompt = load_prompt("extract_entities.prompt")
  175. return _render_prompt(prompt, cluster_json=json.dumps(cluster, ensure_ascii=False))
  176. async def call_extraction(cluster: Dict[str, Any]) -> Dict[str, Any]:
  177. user_prompt = build_extraction_prompt(cluster)
  178. content = await call_llm(NEWS_EXTRACT_PROVIDER, NEWS_EXTRACT_MODEL, SYSTEM_PROMPT, user_prompt)
  179. return json.loads(content)
  180. async def call_summary(cluster: Dict[str, Any]) -> Dict[str, Any]:
  181. prompt = load_prompt("summarize_cluster.prompt")
  182. user_prompt = _render_prompt(prompt, cluster_json=json.dumps(cluster, ensure_ascii=False))
  183. content = await call_llm(NEWS_SUMMARY_PROVIDER, NEWS_SUMMARY_MODEL, "You are a summarization engine for news clusters. Return strict JSON only.", user_prompt)
  184. return json.loads(content)