compactor.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461
  1. #!/usr/bin/env python3
  2. """
  3. Conversational memory compactor for custom mem0-python-server.
  4. Why:
  5. - keep long-term conversational memory useful
  6. - compact clusters into LLM summaries stored verbatim
  7. - preserve safety with dry-run-first workflow
  8. """
  9. from __future__ import annotations
  10. import argparse
  11. import dataclasses
  12. import datetime as dt
  13. import json
  14. import os
  15. import re
  16. from typing import Any, Dict, List
  17. import requests
  18. def load_env_file(path: str) -> None:
  19. """Load simple KEY=VALUE pairs from a .env file into os.environ.
  20. This keeps cron usage predictable without adding a dependency on python-dotenv.
  21. """
  22. if not os.path.exists(path):
  23. return
  24. with open(path, "r", encoding="utf-8") as handle:
  25. for line in handle:
  26. line = line.strip()
  27. if not line or line.startswith("#") or "=" not in line:
  28. continue
  29. key, value = line.split("=", 1)
  30. key = key.strip()
  31. value = value.strip().strip('"').strip("'")
  32. if key and key not in os.environ:
  33. os.environ[key] = value
  34. EPHEMERAL_HINTS = {
  35. "weather", "forecast", "temperature", "rain", "raining", "expected to stop",
  36. "wind", "humidity", "uv index", "clouds", "sunrise", "sunset",
  37. }
  38. DEFAULT_ENV_PATH = os.path.join(os.path.dirname(__file__), ".env")
  39. @dataclasses.dataclass
  40. class MemoryItem:
  41. id: str
  42. text: str
  43. created_at: str | None
  44. metadata: Dict[str, Any]
  45. @property
  46. def created_dt(self) -> dt.datetime:
  47. if not self.created_at:
  48. return dt.datetime.fromtimestamp(0, tz=dt.timezone.utc)
  49. try:
  50. return dt.datetime.fromisoformat(self.created_at.replace("Z", "+00:00"))
  51. except Exception:
  52. return dt.datetime.fromtimestamp(0, tz=dt.timezone.utc)
  53. class Mem0Client:
  54. def __init__(self, base_url: str, timeout: int = 15):
  55. self.base_url = base_url.rstrip("/")
  56. self.timeout = timeout
  57. def all_memories(self, user_id: str) -> List[MemoryItem]:
  58. r = requests.post(
  59. f"{self.base_url}/memories/all",
  60. json={"user_id": user_id},
  61. timeout=self.timeout,
  62. )
  63. r.raise_for_status()
  64. data = r.json().get("results", [])
  65. out = []
  66. for row in data:
  67. out.append(
  68. MemoryItem(
  69. id=row.get("id", ""),
  70. text=row.get("memory", ""),
  71. created_at=row.get("created_at"),
  72. metadata=row.get("metadata") or {},
  73. )
  74. )
  75. return out
  76. def write_memory(self, user_id: str, text: str, metadata: Dict[str, Any]) -> Dict[str, Any]:
  77. r = requests.post(
  78. f"{self.base_url}/memories/raw",
  79. json={"text": text, "userId": user_id, "metadata": metadata},
  80. timeout=self.timeout,
  81. )
  82. r.raise_for_status()
  83. return r.json()
  84. def delete_memory(self, memory_id: str) -> Dict[str, Any]:
  85. r = requests.delete(
  86. f"{self.base_url}/memory/{memory_id}",
  87. json={"collection": "conversational"},
  88. timeout=self.timeout,
  89. )
  90. r.raise_for_status()
  91. return r.json() if r.content else {"ok": True}
  92. def normalize(text: str) -> str:
  93. text = text.lower().strip()
  94. text = re.sub(r"\s+", " ", text)
  95. return text
  96. def is_ephemeral_cluster(texts: List[str]) -> bool:
  97. joined = normalize("\n".join(texts))
  98. return any(hint in joined for hint in EPHEMERAL_HINTS)
  99. def cluster_by_time(memories: List[MemoryItem], gap_minutes: int) -> List[List[MemoryItem]]:
  100. if not memories:
  101. return []
  102. items = sorted(memories, key=lambda m: m.created_dt)
  103. clusters: List[List[MemoryItem]] = [[items[0]]]
  104. for item in items[1:]:
  105. prev = clusters[-1][-1]
  106. delta = (item.created_dt - prev.created_dt).total_seconds() / 60
  107. if delta <= gap_minutes:
  108. clusters[-1].append(item)
  109. else:
  110. clusters.append([item])
  111. return clusters
  112. def format_segment(cluster: List[MemoryItem]) -> str:
  113. lines = []
  114. for item in cluster:
  115. ts = item.created_at or "unknown"
  116. text = item.text.strip().replace("\n", " ")
  117. if text:
  118. lines.append(f"[{ts}] {text}")
  119. return "\n".join(lines)
  120. def split_cluster(cluster: List[MemoryItem], max_items: int) -> List[List[MemoryItem]]:
  121. if max_items <= 0 or len(cluster) <= max_items:
  122. return [cluster]
  123. chunks = []
  124. for i in range(0, len(cluster), max_items):
  125. chunks.append(cluster[i:i + max_items])
  126. return chunks
  127. def call_groq_extract(segment_text: str, model: str, timeout: int, base_url: str) -> Dict[str, Any]:
  128. api_key = os.getenv("GROQ_API_KEY")
  129. if not api_key:
  130. raise RuntimeError("GROQ_API_KEY is not set in the environment.")
  131. prompt = (
  132. "You extract structured facts and a detailed summary from a chat segment. "
  133. "Return ONLY raw JSON (no code fences, no markdown) with keys: "
  134. "facts, summary, segment_kind, resolution. "
  135. "facts must include: "
  136. "people (list of {name, phone, email}), projects (list of {name, url}), "
  137. "urls, paths, commands, packages, services, env_vars, ips, ports, hosts, "
  138. "phones, emails, names. "
  139. "Only include facts explicitly present in the segment. Do NOT infer or invent. "
  140. "Never include generic 'user' as a person. Use null for unknown phone/email. "
  141. "If no facts exist, return empty lists. "
  142. "summary should be concise (1–3 sentences max) and include all concrete details (IP addresses, ports, commands, files, URLs). "
  143. "Use 1 sentence for short segments, 2 for medium, 3 for long. Avoid redundancy. "
  144. "segment_kind: implementation|debug_arc|planning|deployment|misc. "
  145. "resolution: resolved|open|unknown."
  146. )
  147. payload = {
  148. "model": model,
  149. "messages": [
  150. {"role": "system", "content": prompt},
  151. {"role": "user", "content": segment_text},
  152. ],
  153. "temperature": 0.2,
  154. "max_tokens": 600,
  155. }
  156. url = f"{base_url.rstrip('/')}/chat/completions"
  157. r = requests.post(
  158. url,
  159. headers={"Authorization": f"Bearer {api_key}"},
  160. json=payload,
  161. timeout=timeout,
  162. )
  163. if r.status_code >= 400:
  164. raise RuntimeError(f"Groq API error {r.status_code}: {r.text}")
  165. data = r.json()
  166. content = data["choices"][0]["message"]["content"].strip()
  167. if content.startswith("```"):
  168. content = re.sub(r"^```[a-zA-Z]*\n", "", content)
  169. content = re.sub(r"```$", "", content).strip()
  170. try:
  171. return json.loads(content)
  172. except json.JSONDecodeError:
  173. return {"parse_error": True, "raw": content}
  174. def is_compacted_memory(item: MemoryItem) -> bool:
  175. kind = (item.metadata or {}).get("kind")
  176. return kind in {"segment_summary", "debug_arc_summary"}
  177. def _limit_items(items: List[Any], max_items: int = 5) -> List[Any]:
  178. if len(items) <= max_items:
  179. return items
  180. return items[:max_items]
  181. def _format_people(people: List[Dict[str, Any]]) -> List[str]:
  182. out = []
  183. for person in people:
  184. name = (person or {}).get("name")
  185. phone = (person or {}).get("phone")
  186. email = (person or {}).get("email")
  187. bits = [b for b in [name, phone, email] if b]
  188. if bits:
  189. out.append("/".join(bits))
  190. return out
  191. def _format_projects(projects: List[Dict[str, Any]]) -> List[str]:
  192. out = []
  193. for proj in projects:
  194. name = (proj or {}).get("name")
  195. url = (proj or {}).get("url")
  196. bits = [b for b in [name, url] if b]
  197. if bits:
  198. out.append("/".join(bits))
  199. return out
  200. def format_facts_inline(facts: Dict[str, Any]) -> str:
  201. if not isinstance(facts, dict):
  202. return ""
  203. parts = []
  204. people = _format_people(facts.get("people") or [])
  205. projects = _format_projects(facts.get("projects") or [])
  206. fields = [
  207. ("people", people),
  208. ("projects", projects),
  209. ("urls", facts.get("urls") or []),
  210. ("paths", facts.get("paths") or []),
  211. ("commands", facts.get("commands") or []),
  212. ("packages", facts.get("packages") or []),
  213. ("services", facts.get("services") or []),
  214. ("env_vars", facts.get("env_vars") or []),
  215. ("ips", facts.get("ips") or []),
  216. ("ports", facts.get("ports") or []),
  217. ("hosts", facts.get("hosts") or []),
  218. ("phones", facts.get("phones") or []),
  219. ("emails", facts.get("emails") or []),
  220. ("names", facts.get("names") or []),
  221. ]
  222. for key, value in fields:
  223. if not value:
  224. continue
  225. trimmed = _limit_items(value)
  226. parts.append(f"{key}={trimmed}")
  227. if not parts:
  228. return ""
  229. return "Facts: " + "; ".join(parts)
  230. def build_summary_metadata(
  231. *,
  232. segment_ids: List[str],
  233. segment_start: str | None,
  234. segment_end: str | None,
  235. extraction: Dict[str, Any],
  236. model: str,
  237. ) -> Dict[str, Any]:
  238. # Keep summaries sortable without embedding timestamps in the text itself.
  239. created_at = segment_end or segment_start
  240. return {
  241. "compacted_at": dt.datetime.now(dt.timezone.utc).isoformat(),
  242. "compactor_version": "0.6",
  243. "kind": "segment_summary",
  244. "segment_source_ids": segment_ids,
  245. "segment_start": segment_start,
  246. "segment_end": segment_end,
  247. "created_at": created_at,
  248. "extraction": extraction,
  249. "model": model,
  250. "source": "memory-compactor",
  251. "scope": "compacted",
  252. }
  253. def run(args: argparse.Namespace) -> None:
  254. load_env_file(DEFAULT_ENV_PATH)
  255. client = Mem0Client(args.base_url, timeout=args.timeout)
  256. memories = client.all_memories(args.user_id)
  257. # keep very recent entries untouched
  258. cutoff = dt.datetime.now(dt.timezone.utc) - dt.timedelta(days=args.min_age_days)
  259. candidates = [m for m in memories if m.created_dt < cutoff and not is_compacted_memory(m)]
  260. clusters = cluster_by_time(candidates, args.gap_minutes)
  261. report = {
  262. "user_id": args.user_id,
  263. "total_memories": len(memories),
  264. "candidates": len(candidates),
  265. "clusters": len(clusters),
  266. "max_summaries": args.max_summaries,
  267. "actions": [],
  268. }
  269. delete_budget = args.max_deletes
  270. created_count = 0
  271. for cluster in clusters:
  272. texts = [c.text.strip() for c in cluster if c.text.strip()]
  273. if not texts:
  274. continue
  275. if len(texts) < args.segment_min_items:
  276. continue
  277. if args.skip_ephemeral and is_ephemeral_cluster(texts):
  278. continue
  279. for subcluster in split_cluster(cluster, args.segment_max_items):
  280. if args.max_summaries and created_count >= args.max_summaries:
  281. break
  282. segment_text = format_segment(subcluster)
  283. extraction = call_groq_extract(segment_text, args.model, args.timeout, args.groq_base_url)
  284. facts = extraction.get("facts") if isinstance(extraction, dict) else {}
  285. summary_raw = extraction.get("summary") if isinstance(extraction, dict) else ""
  286. parse_error = bool(extraction.get("parse_error")) if isinstance(extraction, dict) else True
  287. has_facts = bool(facts) and any(
  288. facts.get(k)
  289. for k in [
  290. "people",
  291. "projects",
  292. "urls",
  293. "paths",
  294. "commands",
  295. "packages",
  296. "services",
  297. "env_vars",
  298. "ips",
  299. "ports",
  300. "hosts",
  301. "phones",
  302. "emails",
  303. "names",
  304. ]
  305. )
  306. if not args.llm_report_all and not parse_error and not summary_raw and not has_facts:
  307. continue
  308. facts_inline = format_facts_inline(facts)
  309. summary_text = summary_raw
  310. if facts_inline:
  311. summary_text = f"{summary_raw} {facts_inline}".strip() if summary_raw else facts_inline
  312. ids = [m.id for m in subcluster if m.id]
  313. segment_start = subcluster[0].created_at if subcluster else None
  314. segment_end = subcluster[-1].created_at if subcluster else None
  315. action = {
  316. "type": "segment_extract",
  317. "cluster_size": len(subcluster),
  318. "segment_preview": segment_text[:240],
  319. "extraction": extraction,
  320. "source_ids": ids,
  321. "segment_start": segment_start,
  322. "segment_end": segment_end,
  323. }
  324. metadata = None
  325. if summary_text:
  326. metadata = build_summary_metadata(
  327. segment_ids=ids,
  328. segment_start=segment_start,
  329. segment_end=segment_end,
  330. extraction=extraction,
  331. model=args.model,
  332. )
  333. action["summary_raw"] = summary_raw
  334. action["summary_text"] = summary_text
  335. action["summary_metadata"] = metadata
  336. action["write_payload"] = {"text": summary_text, "metadata": metadata}
  337. report["actions"].append(action)
  338. can_create = bool(summary_text)
  339. if args.apply and not args.dry_run and summary_text and args.purge_source and len(ids) > delete_budget:
  340. can_create = False
  341. if can_create:
  342. created_count += 1
  343. if args.apply and not args.dry_run and summary_text:
  344. if args.purge_source and len(ids) > delete_budget:
  345. continue
  346. client.write_memory(args.user_id, summary_text, metadata or {})
  347. if args.purge_source:
  348. for mid in ids:
  349. client.delete_memory(mid)
  350. delete_budget -= len(ids)
  351. if args.max_summaries and created_count >= args.max_summaries:
  352. break
  353. print(json.dumps(report, indent=2, ensure_ascii=False))
  354. def parse_args() -> argparse.Namespace:
  355. examples = """
  356. Examples:
  357. python3 compactor.py --user-id main
  358. python3 compactor.py --user-id main --apply
  359. python3 compactor.py --user-id main --apply --max-summaries 1
  360. python3 compactor.py --user-id main --segment-max-items 15 --skip-ephemeral
  361. """
  362. class HelpFormatter(argparse.RawDescriptionHelpFormatter, argparse.ArgumentDefaultsHelpFormatter):
  363. pass
  364. p = argparse.ArgumentParser(
  365. description="Compacts conversational memories with temporal clustering.",
  366. formatter_class=HelpFormatter,
  367. epilog=examples.strip(),
  368. )
  369. p.add_argument("--base-url", default="http://192.168.0.200:8420")
  370. p.add_argument("--user-id", required=True)
  371. p.add_argument("--apply", action="store_true", help="Apply changes. Default is dry-run.")
  372. p.add_argument("--dry-run", action="store_true", help="Force dry-run even with --apply.")
  373. p.add_argument("--gap-minutes", type=int, default=45)
  374. p.add_argument("--min-age-days", type=int, default=7)
  375. p.add_argument("--max-deletes", type=int, default=50)
  376. p.add_argument("--max-summaries", type=int, default=0, help="Limit the number of summaries created (0 = no limit).")
  377. p.add_argument("--timeout", type=int, default=20)
  378. p.add_argument("--model", default="meta-llama/llama-4-scout-17b-16e-instruct")
  379. p.add_argument("--segment-min-items", type=int, default=4)
  380. p.add_argument("--segment-max-items", type=int, default=15, help="Split large clusters to reduce topic drift.")
  381. p.add_argument("--skip-ephemeral", action="store_true", help="Skip obvious ephemeral weather-like segments.")
  382. p.add_argument("--llm-report-all", action="store_true", help="Report all LLM extractions even if empty.")
  383. p.add_argument("--purge-source", action="store_true", help="Delete source memories after writing a summary.")
  384. p.add_argument("--groq-base-url", default="https://api.groq.com/openai/v1")
  385. return p.parse_args()
  386. if __name__ == "__main__":
  387. run(parse_args())