llm.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. from __future__ import annotations
  2. import json
  3. from pathlib import Path
  4. from typing import Any, Dict, Iterable, List
  5. import httpx
  6. from news_mcp.config import (
  7. NEWS_EXTRACT_PROVIDER,
  8. NEWS_EXTRACT_MODEL,
  9. NEWS_SUMMARY_PROVIDER,
  10. NEWS_SUMMARY_MODEL,
  11. PROMPTS_DIR,
  12. )
  13. SYSTEM_PROMPT = "You are a news signal extraction engine. Return STRICT JSON only."
  14. class LLMError(RuntimeError):
  15. pass
  16. def load_prompt(name: str) -> str:
  17. path = PROMPTS_DIR / name
  18. return path.read_text(encoding="utf-8")
  19. def _render_prompt(template: str, **kwargs: Any) -> str:
  20. rendered = template
  21. for key, value in kwargs.items():
  22. rendered = rendered.replace("{" + key + "}", str(value))
  23. return rendered
  24. def active_llm_config() -> dict[str, str]:
  25. return {
  26. "extract_provider": NEWS_EXTRACT_PROVIDER,
  27. "extract_model": NEWS_EXTRACT_MODEL,
  28. "summary_provider": NEWS_SUMMARY_PROVIDER,
  29. "summary_model": NEWS_SUMMARY_MODEL,
  30. }
  31. async def _call_groq(model: str, messages: List[Dict[str, str]], response_json: bool = True) -> str:
  32. from news_mcp.config import GROQ_API_KEY
  33. if not GROQ_API_KEY:
  34. raise LLMError("GROQ_API_KEY is not configured")
  35. req = {"model": model, "messages": messages, "temperature": 0.2}
  36. if response_json:
  37. req["response_format"] = {"type": "json_object"}
  38. async with httpx.AsyncClient(timeout=45.0) as client:
  39. resp = await client.post(
  40. "https://api.groq.com/openai/v1/chat/completions",
  41. headers={"Authorization": f"Bearer {GROQ_API_KEY}"},
  42. json=req,
  43. )
  44. resp.raise_for_status()
  45. data = resp.json()
  46. return data["choices"][0]["message"]["content"]
  47. async def _call_openai(model: str, messages: List[Dict[str, str]], response_json: bool = True) -> str:
  48. # OpenAI-compatible chat endpoint; uses NEWS_OPENAI_API_KEY.
  49. from news_mcp.config import OPENAI_API_KEY
  50. if not OPENAI_API_KEY:
  51. raise LLMError("OPENAI_API_KEY is not configured")
  52. req = {"model": model, "messages": messages}
  53. if response_json:
  54. req["response_format"] = {"type": "json_object"}
  55. async with httpx.AsyncClient(timeout=45.0) as client:
  56. resp = await client.post(
  57. "https://api.openai.com/v1/chat/completions",
  58. headers={"Authorization": f"Bearer {OPENAI_API_KEY}"},
  59. json=req,
  60. )
  61. resp.raise_for_status()
  62. data = resp.json()
  63. return data["choices"][0]["message"]["content"]
  64. async def call_llm(provider: str, model: str, system_prompt: str, user_prompt: str) -> str:
  65. messages = [
  66. {"role": "system", "content": system_prompt},
  67. {"role": "user", "content": user_prompt},
  68. ]
  69. provider = provider.lower().strip()
  70. if provider == "groq":
  71. return await _call_groq(model, messages)
  72. if provider == "openai":
  73. return await _call_openai(model, messages)
  74. raise LLMError(f"Unsupported provider: {provider}")
  75. def build_extraction_prompt(cluster: Dict[str, Any]) -> str:
  76. prompt = load_prompt("extract_entities.prompt")
  77. return _render_prompt(prompt, cluster_json=json.dumps(cluster, ensure_ascii=False))
  78. async def call_extraction(cluster: Dict[str, Any]) -> Dict[str, Any]:
  79. user_prompt = build_extraction_prompt(cluster)
  80. content = await call_llm(NEWS_EXTRACT_PROVIDER, NEWS_EXTRACT_MODEL, SYSTEM_PROMPT, user_prompt)
  81. return json.loads(content)
  82. async def call_summary(cluster: Dict[str, Any]) -> Dict[str, Any]:
  83. prompt = load_prompt("summarize_cluster.prompt")
  84. user_prompt = _render_prompt(prompt, cluster_json=json.dumps(cluster, ensure_ascii=False))
  85. content = await call_llm(NEWS_SUMMARY_PROVIDER, NEWS_SUMMARY_MODEL, "You are a summarization engine for news clusters. Return strict JSON only.", user_prompt)
  86. return json.loads(content)