llm.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. from __future__ import annotations
  2. import asyncio
  3. import json
  4. from pathlib import Path
  5. from typing import Any, Dict, Iterable, List
  6. import httpx
  7. from news_mcp.config import (
  8. GROQ_API_KEY,
  9. NEWS_EXTRACT_PROVIDER,
  10. NEWS_EXTRACT_MODEL,
  11. NEWS_SUMMARY_PROVIDER,
  12. NEWS_SUMMARY_MODEL,
  13. OPENAI_API_KEY,
  14. OPENROUTER_API_KEY,
  15. PROMPTS_DIR,
  16. )
  17. SYSTEM_PROMPT = "You are a news signal extraction engine. Return STRICT JSON only."
  18. class LLMError(RuntimeError):
  19. pass
  20. def load_prompt(name: str) -> str:
  21. path = PROMPTS_DIR / name
  22. return path.read_text(encoding="utf-8")
  23. def _render_prompt(template: str, **kwargs: Any) -> str:
  24. rendered = template
  25. for key, value in kwargs.items():
  26. rendered = rendered.replace("{" + key + "}", str(value))
  27. return rendered
  28. def active_llm_config() -> dict[str, str]:
  29. return {
  30. "extract_provider": NEWS_EXTRACT_PROVIDER,
  31. "extract_model": NEWS_EXTRACT_MODEL,
  32. "summary_provider": NEWS_SUMMARY_PROVIDER,
  33. "summary_model": NEWS_SUMMARY_MODEL,
  34. "openrouter_key_set": bool(OPENROUTER_API_KEY),
  35. }
  36. async def _call_groq(model: str, messages: List[Dict[str, str]], response_json: bool = True) -> str:
  37. if not GROQ_API_KEY:
  38. raise LLMError("GROQ_API_KEY is not configured")
  39. req = {"model": model, "messages": messages, "temperature": 0.2}
  40. if response_json:
  41. req["response_format"] = {"type": "json_object"}
  42. async with httpx.AsyncClient(timeout=45.0) as client:
  43. resp = await client.post(
  44. "https://api.groq.com/openai/v1/chat/completions",
  45. headers={"Authorization": f"Bearer {GROQ_API_KEY}"},
  46. json=req,
  47. )
  48. resp.raise_for_status()
  49. data = resp.json()
  50. return data["choices"][0]["message"]["content"]
  51. async def _call_openai(model: str, messages: List[Dict[str, str]], response_json: bool = True) -> str:
  52. # OpenAI-compatible chat endpoint; uses NEWS_OPENAI_API_KEY.
  53. if not OPENAI_API_KEY:
  54. raise LLMError("OPENAI_API_KEY is not configured")
  55. req = {"model": model, "messages": messages}
  56. if response_json:
  57. req["response_format"] = {"type": "json_object"}
  58. async with httpx.AsyncClient(timeout=45.0) as client:
  59. resp = await client.post(
  60. "https://api.openai.com/v1/chat/completions",
  61. headers={"Authorization": f"Bearer {OPENAI_API_KEY}"},
  62. json=req,
  63. )
  64. resp.raise_for_status()
  65. data = resp.json()
  66. return data["choices"][0]["message"]["content"]
  67. OR_OPENROUTER_URL = "https://openrouter.ai/api/v1/chat/completions"
  68. async def _call_openrouter(model: str, messages: List[Dict[str, str]], response_json: bool = True, retries: int = 2) -> str:
  69. if not OPENROUTER_API_KEY:
  70. raise LLMError("OPENROUTER_API_KEY is not configured")
  71. req = {"model": model, "messages": messages, "temperature": 0.2}
  72. if response_json:
  73. req["response_format"] = {"type": "json_object"}
  74. headers = {
  75. "Authorization": f"Bearer {OPENROUTER_API_KEY}",
  76. "HTTP-Referer": "https://github.com/gr1m0/bolt.new-rss",
  77. "X-Title": "news-mcp",
  78. }
  79. last_err = ""
  80. async with httpx.AsyncClient(timeout=45.0) as client:
  81. for attempt in range(1 + retries):
  82. resp = await client.post(
  83. OR_OPENROUTER_URL,
  84. headers=headers,
  85. json=req,
  86. )
  87. if resp.status_code != 200:
  88. last_err = f"HTTP {resp.status_code}: {resp.text[:300]}"
  89. if resp.status_code in (429, 500, 502, 503):
  90. await asyncio.sleep(2 ** attempt)
  91. continue
  92. resp.raise_for_status()
  93. data = resp.json()
  94. if "error" in data:
  95. last_err = f"API error: {data['error']}"
  96. break
  97. choices = data.get("choices", [])
  98. if not choices:
  99. last_err = f"No choices in response: {str(data)[:300]}"
  100. if attempt < retries:
  101. await asyncio.sleep(2 ** attempt)
  102. continue
  103. break
  104. msg = choices[0].get("message", {})
  105. content = msg.get("content")
  106. if content:
  107. return content
  108. last_err = f"Empty content in choice: {str(msg)[:200]}"
  109. break
  110. raise LLMError(f"OpenRouter failed after {1+retries} attempts: {last_err}")
  111. async def call_llm(provider: str, model: str, system_prompt: str, user_prompt: str) -> str:
  112. messages = [
  113. {"role": "system", "content": system_prompt},
  114. {"role": "user", "content": user_prompt},
  115. ]
  116. provider = provider.lower().strip()
  117. if provider == "groq":
  118. return await _call_groq(model, messages)
  119. if provider == "openai":
  120. return await _call_openai(model, messages)
  121. if provider == "openrouter":
  122. return await _call_openrouter(model, messages)
  123. raise LLMError(f"Unsupported provider: {provider}. Valid: groq, openai, openrouter")
  124. def build_extraction_prompt(cluster: Dict[str, Any]) -> str:
  125. prompt = load_prompt("extract_entities.prompt")
  126. return _render_prompt(prompt, cluster_json=json.dumps(cluster, ensure_ascii=False))
  127. async def call_extraction(cluster: Dict[str, Any]) -> Dict[str, Any]:
  128. user_prompt = build_extraction_prompt(cluster)
  129. content = await call_llm(NEWS_EXTRACT_PROVIDER, NEWS_EXTRACT_MODEL, SYSTEM_PROMPT, user_prompt)
  130. return json.loads(content)
  131. async def call_summary(cluster: Dict[str, Any]) -> Dict[str, Any]:
  132. prompt = load_prompt("summarize_cluster.prompt")
  133. user_prompt = _render_prompt(prompt, cluster_json=json.dumps(cluster, ensure_ascii=False))
  134. content = await call_llm(NEWS_SUMMARY_PROVIDER, NEWS_SUMMARY_MODEL, "You are a summarization engine for news clusters. Return strict JSON only.", user_prompt)
  135. return json.loads(content)