|
|
@@ -0,0 +1,113 @@
|
|
|
+from __future__ import annotations
|
|
|
+
|
|
|
+import json
|
|
|
+from pathlib import Path
|
|
|
+from typing import Any, Dict, Iterable, List
|
|
|
+
|
|
|
+import httpx
|
|
|
+
|
|
|
+from news_mcp.config import (
|
|
|
+ NEWS_EXTRACT_PROVIDER,
|
|
|
+ NEWS_EXTRACT_MODEL,
|
|
|
+ NEWS_SUMMARY_PROVIDER,
|
|
|
+ NEWS_SUMMARY_MODEL,
|
|
|
+ PROMPTS_DIR,
|
|
|
+)
|
|
|
+
|
|
|
+
|
|
|
+SYSTEM_PROMPT = "You are a news signal extraction engine. Return STRICT JSON only."
|
|
|
+
|
|
|
+
|
|
|
+class LLMError(RuntimeError):
|
|
|
+ pass
|
|
|
+
|
|
|
+
|
|
|
+def load_prompt(name: str) -> str:
|
|
|
+ path = PROMPTS_DIR / name
|
|
|
+ return path.read_text(encoding="utf-8")
|
|
|
+
|
|
|
+
|
|
|
+def _render_prompt(template: str, **kwargs: Any) -> str:
|
|
|
+ rendered = template
|
|
|
+ for key, value in kwargs.items():
|
|
|
+ rendered = rendered.replace("{" + key + "}", str(value))
|
|
|
+ return rendered
|
|
|
+
|
|
|
+
|
|
|
+def active_llm_config() -> dict[str, str]:
|
|
|
+ return {
|
|
|
+ "extract_provider": NEWS_EXTRACT_PROVIDER,
|
|
|
+ "extract_model": NEWS_EXTRACT_MODEL,
|
|
|
+ "summary_provider": NEWS_SUMMARY_PROVIDER,
|
|
|
+ "summary_model": NEWS_SUMMARY_MODEL,
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+async def _call_groq(model: str, messages: List[Dict[str, str]], response_json: bool = True) -> str:
|
|
|
+ from news_mcp.config import GROQ_API_KEY
|
|
|
+
|
|
|
+ if not GROQ_API_KEY:
|
|
|
+ raise LLMError("GROQ_API_KEY is not configured")
|
|
|
+ req = {"model": model, "messages": messages, "temperature": 0.2}
|
|
|
+ if response_json:
|
|
|
+ req["response_format"] = {"type": "json_object"}
|
|
|
+ async with httpx.AsyncClient(timeout=45.0) as client:
|
|
|
+ resp = await client.post(
|
|
|
+ "https://api.groq.com/openai/v1/chat/completions",
|
|
|
+ headers={"Authorization": f"Bearer {GROQ_API_KEY}"},
|
|
|
+ json=req,
|
|
|
+ )
|
|
|
+ resp.raise_for_status()
|
|
|
+ data = resp.json()
|
|
|
+ return data["choices"][0]["message"]["content"]
|
|
|
+
|
|
|
+
|
|
|
+async def _call_openai(model: str, messages: List[Dict[str, str]], response_json: bool = True) -> str:
|
|
|
+ # OpenAI-compatible chat endpoint; uses NEWS_OPENAI_API_KEY.
|
|
|
+ from news_mcp.config import OPENAI_API_KEY
|
|
|
+
|
|
|
+ if not OPENAI_API_KEY:
|
|
|
+ raise LLMError("OPENAI_API_KEY is not configured")
|
|
|
+ req = {"model": model, "messages": messages}
|
|
|
+ if response_json:
|
|
|
+ req["response_format"] = {"type": "json_object"}
|
|
|
+ async with httpx.AsyncClient(timeout=45.0) as client:
|
|
|
+ resp = await client.post(
|
|
|
+ "https://api.openai.com/v1/chat/completions",
|
|
|
+ headers={"Authorization": f"Bearer {OPENAI_API_KEY}"},
|
|
|
+ json=req,
|
|
|
+ )
|
|
|
+ resp.raise_for_status()
|
|
|
+ data = resp.json()
|
|
|
+ return data["choices"][0]["message"]["content"]
|
|
|
+
|
|
|
+
|
|
|
+async def call_llm(provider: str, model: str, system_prompt: str, user_prompt: str) -> str:
|
|
|
+ messages = [
|
|
|
+ {"role": "system", "content": system_prompt},
|
|
|
+ {"role": "user", "content": user_prompt},
|
|
|
+ ]
|
|
|
+ provider = provider.lower().strip()
|
|
|
+ if provider == "groq":
|
|
|
+ return await _call_groq(model, messages)
|
|
|
+ if provider == "openai":
|
|
|
+ return await _call_openai(model, messages)
|
|
|
+ raise LLMError(f"Unsupported provider: {provider}")
|
|
|
+
|
|
|
+
|
|
|
+def build_extraction_prompt(cluster: Dict[str, Any]) -> str:
|
|
|
+ prompt = load_prompt("extract_entities.prompt")
|
|
|
+ return _render_prompt(prompt, cluster_json=json.dumps(cluster, ensure_ascii=False))
|
|
|
+
|
|
|
+
|
|
|
+async def call_extraction(cluster: Dict[str, Any]) -> Dict[str, Any]:
|
|
|
+ user_prompt = build_extraction_prompt(cluster)
|
|
|
+ content = await call_llm(NEWS_EXTRACT_PROVIDER, NEWS_EXTRACT_MODEL, SYSTEM_PROMPT, user_prompt)
|
|
|
+ return json.loads(content)
|
|
|
+
|
|
|
+
|
|
|
+async def call_summary(cluster: Dict[str, Any]) -> Dict[str, Any]:
|
|
|
+ prompt = load_prompt("summarize_cluster.prompt")
|
|
|
+ user_prompt = _render_prompt(prompt, cluster_json=json.dumps(cluster, ensure_ascii=False))
|
|
|
+ content = await call_llm(NEWS_SUMMARY_PROVIDER, NEWS_SUMMARY_MODEL, "You are a summarization engine for news clusters. Return strict JSON only.", user_prompt)
|
|
|
+ return json.loads(content)
|