llm.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256
  1. from __future__ import annotations
  2. import asyncio
  3. import json
  4. from pathlib import Path
  5. from typing import Any, Dict, Iterable, List
  6. import httpx
  7. from news_mcp.config import (
  8. GROQ_API_KEY,
  9. NEWS_EXTRACT_PROVIDER,
  10. NEWS_EXTRACT_MODEL,
  11. NEWS_SUMMARY_PROVIDER,
  12. NEWS_SUMMARY_MODEL,
  13. OPENAI_API_KEY,
  14. OPENROUTER_API_KEY,
  15. PROMPTS_DIR,
  16. llm_rate_limit,
  17. )
  18. SYSTEM_PROMPT = "You are a news signal extraction engine. Return STRICT JSON only."
  19. class LLMError(RuntimeError):
  20. pass
  21. # ---------------------------------------------------------------------------
  22. # Per-provider rate limiter (token bucket).
  23. # ---------------------------------------------------------------------------
  24. _rate_limiters: dict[str, _RateLimiter] = {}
  25. def _get_rate_limiter(provider: str) -> "_RateLimiter | None":
  26. """Return (or lazily create) the rate limiter for *provider*.
  27. Returns None when rate limiting is disabled for this provider.
  28. """
  29. provider = provider.strip().lower()
  30. rl = llm_rate_limit(provider)
  31. if rl <= 0.0:
  32. return None
  33. if provider not in _rate_limiters:
  34. _rate_limiters[provider] = _RateLimiter(rl)
  35. return _rate_limiters[provider]
  36. class _RateLimiter:
  37. """Simple async token-bucket rate limiter shared across all calls for one provider."""
  38. def __init__(self, rate: float):
  39. self._interval = 1.0 / rate # seconds between tokens
  40. self._last_used = 0.0 # monotonic timestamp of last acquire
  41. self._lock = asyncio.Lock()
  42. async def acquire(self):
  43. async with self._lock:
  44. now = asyncio.get_event_loop().time()
  45. wait = self._last_used + self._interval - now
  46. if wait > 0:
  47. await asyncio.sleep(wait)
  48. self._last_used = asyncio.get_event_loop().time()
  49. def load_prompt(name: str) -> str:
  50. path = PROMPTS_DIR / name
  51. return path.read_text(encoding="utf-8")
  52. def _render_prompt(template: str, **kwargs: Any) -> str:
  53. rendered = template
  54. for key, value in kwargs.items():
  55. rendered = rendered.replace("{" + key + "}", str(value))
  56. return rendered
  57. def active_llm_config() -> dict[str, str]:
  58. return {
  59. "extract_provider": NEWS_EXTRACT_PROVIDER,
  60. "extract_model": NEWS_EXTRACT_MODEL,
  61. "summary_provider": NEWS_SUMMARY_PROVIDER,
  62. "summary_model": NEWS_SUMMARY_MODEL,
  63. "openrouter_key_set": bool(OPENROUTER_API_KEY),
  64. }
  65. async def _call_groq(model: str, messages: List[Dict[str, str]], response_json: bool = True, retries: int = 2) -> str:
  66. if not GROQ_API_KEY:
  67. raise LLMError("GROQ_API_KEY is not configured")
  68. req = {"model": model, "messages": messages, "temperature": 0.2}
  69. if response_json:
  70. req["response_format"] = {"type": "json_object"}
  71. last_err = ""
  72. async with httpx.AsyncClient(timeout=45.0) as client:
  73. for attempt in range(1 + retries):
  74. resp = await client.post(
  75. "https://api.groq.com/openai/v1/chat/completions",
  76. headers={"Authorization": f"Bearer {GROQ_API_KEY}"},
  77. json=req,
  78. )
  79. if resp.status_code != 200:
  80. last_err = f"HTTP {resp.status_code}: {resp.text[:300]}"
  81. if resp.status_code in (429, 500, 502, 503):
  82. await asyncio.sleep(2 ** attempt)
  83. continue
  84. resp.raise_for_status()
  85. data = resp.json()
  86. if "error" in data:
  87. last_err = f"API error: {data['error']}"
  88. break
  89. choices = data.get("choices", [])
  90. if not choices:
  91. last_err = f"No choices in response: {str(data)[:300]}"
  92. if attempt < retries:
  93. await asyncio.sleep(2 ** attempt)
  94. continue
  95. break
  96. content = choices[0].get("message", {}).get("content")
  97. if content:
  98. return content
  99. last_err = f"Empty content in choice: {str(choices[0])[:200]}"
  100. if attempt < retries:
  101. await asyncio.sleep(2 ** attempt)
  102. continue
  103. break
  104. raise LLMError(f"Groq failed after {1+retries} attempts: {last_err}")
  105. async def _call_openai(model: str, messages: List[Dict[str, str]], response_json: bool = True, retries: int = 2) -> str:
  106. if not OPENAI_API_KEY:
  107. raise LLMError("OPENAI_API_KEY is not configured")
  108. req = {"model": model, "messages": messages}
  109. if response_json:
  110. req["response_format"] = {"type": "json_object"}
  111. last_err = ""
  112. async with httpx.AsyncClient(timeout=45.0) as client:
  113. for attempt in range(1 + retries):
  114. resp = await client.post(
  115. "https://api.openai.com/v1/chat/completions",
  116. headers={"Authorization": f"Bearer {OPENAI_API_KEY}"},
  117. json=req,
  118. )
  119. if resp.status_code != 200:
  120. last_err = f"HTTP {resp.status_code}: {resp.text[:300]}"
  121. if resp.status_code in (429, 500, 502, 503):
  122. await asyncio.sleep(2 ** attempt)
  123. continue
  124. resp.raise_for_status()
  125. data = resp.json()
  126. if "error" in data:
  127. last_err = f"API error: {data['error']}"
  128. break
  129. choices = data.get("choices", [])
  130. if not choices:
  131. last_err = f"No choices in response: {str(data)[:300]}"
  132. if attempt < retries:
  133. await asyncio.sleep(2 ** attempt)
  134. continue
  135. break
  136. content = choices[0].get("message", {}).get("content")
  137. if content:
  138. return content
  139. last_err = f"Empty content in choice: {str(choices[0])[:200]}"
  140. if attempt < retries:
  141. await asyncio.sleep(2 ** attempt)
  142. continue
  143. break
  144. raise LLMError(f"OpenAI failed after {1+retries} attempts: {last_err}")
  145. OR_OPENROUTER_URL = "https://openrouter.ai/api/v1/chat/completions"
  146. async def _call_openrouter(model: str, messages: List[Dict[str, str]], response_json: bool = True, retries: int = 2) -> str:
  147. if not OPENROUTER_API_KEY:
  148. raise LLMError("OPENROUTER_API_KEY is not configured")
  149. req = {"model": model, "messages": messages, "temperature": 0.2}
  150. if response_json:
  151. req["response_format"] = {"type": "json_object"}
  152. headers = {
  153. "Authorization": f"Bearer {OPENROUTER_API_KEY}",
  154. "HTTP-Referer": "https://github.com/gr1m0/bolt.new-rss",
  155. "X-Title": "news-mcp",
  156. }
  157. last_err = ""
  158. async with httpx.AsyncClient(timeout=45.0) as client:
  159. for attempt in range(1 + retries):
  160. resp = await client.post(
  161. OR_OPENROUTER_URL,
  162. headers=headers,
  163. json=req,
  164. )
  165. if resp.status_code != 200:
  166. last_err = f"HTTP {resp.status_code}: {resp.text[:300]}"
  167. if resp.status_code in (429, 500, 502, 503):
  168. await asyncio.sleep(2 ** attempt)
  169. continue
  170. resp.raise_for_status()
  171. data = resp.json()
  172. if "error" in data:
  173. last_err = f"API error: {data['error']}"
  174. break
  175. choices = data.get("choices", [])
  176. if not choices:
  177. last_err = f"No choices in response: {str(data)[:300]}"
  178. if attempt < retries:
  179. await asyncio.sleep(2 ** attempt)
  180. continue
  181. break
  182. msg = choices[0].get("message", {})
  183. content = msg.get("content")
  184. if content:
  185. return content
  186. last_err = f"Empty content in choice: {str(msg)[:200]}"
  187. break
  188. raise LLMError(f"OpenRouter failed after {1+retries} attempts: {last_err}")
  189. async def call_llm(provider: str, model: str, system_prompt: str, user_prompt: str) -> str:
  190. messages = [
  191. {"role": "system", "content": system_prompt},
  192. {"role": "user", "content": user_prompt},
  193. ]
  194. provider = provider.lower().strip()
  195. # Rate-limit before dispatching to the provider-specific caller.
  196. rl = _get_rate_limiter(provider)
  197. if rl is not None:
  198. await rl.acquire()
  199. if provider == "groq":
  200. return await _call_groq(model, messages)
  201. if provider == "openai":
  202. return await _call_openai(model, messages)
  203. if provider == "openrouter":
  204. return await _call_openrouter(model, messages)
  205. raise LLMError(f"Unsupported provider: {provider}. Valid: groq, openai, openrouter")
  206. def build_extraction_prompt(cluster: Dict[str, Any]) -> str:
  207. prompt = load_prompt("extract_entities.prompt")
  208. return _render_prompt(prompt, cluster_json=json.dumps(cluster, ensure_ascii=False))
  209. async def call_extraction(cluster: Dict[str, Any]) -> Dict[str, Any]:
  210. user_prompt = build_extraction_prompt(cluster)
  211. content = await call_llm(NEWS_EXTRACT_PROVIDER, NEWS_EXTRACT_MODEL, SYSTEM_PROMPT, user_prompt)
  212. return json.loads(content)
  213. async def call_summary(cluster: Dict[str, Any]) -> Dict[str, Any]:
  214. prompt = load_prompt("summarize_cluster.prompt")
  215. user_prompt = _render_prompt(prompt, cluster_json=json.dumps(cluster, ensure_ascii=False))
  216. content = await call_llm(NEWS_SUMMARY_PROVIDER, NEWS_SUMMARY_MODEL, "You are a summarization engine for news clusters. Return strict JSON only.", user_prompt)
  217. return json.loads(content)