compactor.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343
  1. #!/usr/bin/env python3
  2. """
  3. Conversational memory compactor for custom mem0-python-server.
  4. Why:
  5. - keep long-term conversational memory useful
  6. - compact clusters into LLM summaries stored verbatim
  7. - preserve safety with dry-run-first workflow
  8. """
  9. from __future__ import annotations
  10. import argparse
  11. import dataclasses
  12. import datetime as dt
  13. import json
  14. import os
  15. import re
  16. from typing import Any, Dict, List
  17. import requests
  18. def load_env_file(path: str) -> None:
  19. """Load simple KEY=VALUE pairs from a .env file into os.environ.
  20. This keeps cron usage predictable without adding a dependency on python-dotenv.
  21. """
  22. if not os.path.exists(path):
  23. return
  24. with open(path, "r", encoding="utf-8") as handle:
  25. for line in handle:
  26. line = line.strip()
  27. if not line or line.startswith("#") or "=" not in line:
  28. continue
  29. key, value = line.split("=", 1)
  30. key = key.strip()
  31. value = value.strip().strip('"').strip("'")
  32. if key and key not in os.environ:
  33. os.environ[key] = value
  34. EPHEMERAL_HINTS = {
  35. "weather", "forecast", "temperature", "rain", "raining", "expected to stop",
  36. "wind", "humidity", "uv index", "clouds", "sunrise", "sunset",
  37. }
  38. DEFAULT_ENV_PATH = os.path.join(os.path.dirname(__file__), ".env")
  39. @dataclasses.dataclass
  40. class MemoryItem:
  41. id: str
  42. text: str
  43. created_at: str | None
  44. metadata: Dict[str, Any]
  45. @property
  46. def created_dt(self) -> dt.datetime:
  47. if not self.created_at:
  48. return dt.datetime.fromtimestamp(0, tz=dt.timezone.utc)
  49. try:
  50. return dt.datetime.fromisoformat(self.created_at.replace("Z", "+00:00"))
  51. except Exception:
  52. return dt.datetime.fromtimestamp(0, tz=dt.timezone.utc)
  53. class Mem0Client:
  54. def __init__(self, base_url: str, timeout: int = 15):
  55. self.base_url = base_url.rstrip("/")
  56. self.timeout = timeout
  57. def all_memories(self, user_id: str) -> List[MemoryItem]:
  58. r = requests.post(
  59. f"{self.base_url}/memories/all",
  60. json={"user_id": user_id},
  61. timeout=self.timeout,
  62. )
  63. r.raise_for_status()
  64. data = r.json().get("results", [])
  65. out = []
  66. for row in data:
  67. out.append(
  68. MemoryItem(
  69. id=row.get("id", ""),
  70. text=row.get("memory", ""),
  71. created_at=row.get("created_at"),
  72. metadata=row.get("metadata") or {},
  73. )
  74. )
  75. return out
  76. def write_memory(self, user_id: str, text: str, metadata: Dict[str, Any]) -> Dict[str, Any]:
  77. r = requests.post(
  78. f"{self.base_url}/memories/raw",
  79. json={"text": text, "userId": user_id, "metadata": metadata},
  80. timeout=self.timeout,
  81. )
  82. r.raise_for_status()
  83. return r.json()
  84. def delete_memory(self, memory_id: str) -> Dict[str, Any]:
  85. r = requests.delete(
  86. f"{self.base_url}/memory/{memory_id}",
  87. json={"collection": "conversational"},
  88. timeout=self.timeout,
  89. )
  90. r.raise_for_status()
  91. return r.json() if r.content else {"ok": True}
  92. def normalize(text: str) -> str:
  93. text = text.lower().strip()
  94. text = re.sub(r"\s+", " ", text)
  95. return text
  96. def is_ephemeral_cluster(texts: List[str]) -> bool:
  97. joined = normalize("\n".join(texts))
  98. return any(hint in joined for hint in EPHEMERAL_HINTS)
  99. def cluster_by_time(memories: List[MemoryItem], gap_minutes: int) -> List[List[MemoryItem]]:
  100. if not memories:
  101. return []
  102. items = sorted(memories, key=lambda m: m.created_dt)
  103. clusters: List[List[MemoryItem]] = [[items[0]]]
  104. for item in items[1:]:
  105. prev = clusters[-1][-1]
  106. delta = (item.created_dt - prev.created_dt).total_seconds() / 60
  107. if delta <= gap_minutes:
  108. clusters[-1].append(item)
  109. else:
  110. clusters.append([item])
  111. return clusters
  112. def format_segment(cluster: List[MemoryItem]) -> str:
  113. lines = []
  114. for item in cluster:
  115. ts = item.created_at or "unknown"
  116. text = item.text.strip().replace("\n", " ")
  117. if text:
  118. lines.append(f"[{ts}] {text}")
  119. return "\n".join(lines)
  120. def split_cluster(cluster: List[MemoryItem], max_items: int) -> List[List[MemoryItem]]:
  121. if max_items <= 0 or len(cluster) <= max_items:
  122. return [cluster]
  123. chunks = []
  124. for i in range(0, len(cluster), max_items):
  125. chunks.append(cluster[i:i + max_items])
  126. return chunks
  127. def call_groq_extract(segment_text: str, model: str, timeout: int, base_url: str) -> Dict[str, Any]:
  128. api_key = os.getenv("GROQ_API_KEY")
  129. if not api_key:
  130. raise RuntimeError("GROQ_API_KEY is not set in the environment.")
  131. prompt = (
  132. "You extract structured facts and a concise summary from a chat segment. "
  133. "Return ONLY raw JSON (no code fences, no markdown) with keys: "
  134. "facts, summary, segment_kind, resolution. "
  135. "facts must include: people (list of {name, phone, email}), "
  136. "projects (list of {name, url}), urls, paths, phones, emails, names. "
  137. "Only include facts explicitly present in the segment. Do NOT infer or invent. "
  138. "Never include generic 'user' as a person. Use null for unknown phone/email. "
  139. "If no facts exist, return empty lists. "
  140. "summary should be one or two sentences. "
  141. "segment_kind: implementation|debug_arc|planning|deployment|misc. "
  142. "resolution: resolved|open|unknown."
  143. )
  144. payload = {
  145. "model": model,
  146. "messages": [
  147. {"role": "system", "content": prompt},
  148. {"role": "user", "content": segment_text},
  149. ],
  150. "temperature": 0.2,
  151. "max_tokens": 600,
  152. }
  153. url = f"{base_url.rstrip('/')}/chat/completions"
  154. r = requests.post(
  155. url,
  156. headers={"Authorization": f"Bearer {api_key}"},
  157. json=payload,
  158. timeout=timeout,
  159. )
  160. if r.status_code >= 400:
  161. raise RuntimeError(f"Groq API error {r.status_code}: {r.text}")
  162. data = r.json()
  163. content = data["choices"][0]["message"]["content"].strip()
  164. if content.startswith("```"):
  165. content = re.sub(r"^```[a-zA-Z]*\n", "", content)
  166. content = re.sub(r"```$", "", content).strip()
  167. try:
  168. return json.loads(content)
  169. except json.JSONDecodeError:
  170. return {"parse_error": True, "raw": content}
  171. def is_compacted_memory(item: MemoryItem) -> bool:
  172. kind = (item.metadata or {}).get("kind")
  173. return kind in {"segment_summary", "debug_arc_summary"}
  174. def run(args: argparse.Namespace) -> None:
  175. load_env_file(DEFAULT_ENV_PATH)
  176. client = Mem0Client(args.base_url, timeout=args.timeout)
  177. memories = client.all_memories(args.user_id)
  178. # keep very recent entries untouched
  179. cutoff = dt.datetime.now(dt.timezone.utc) - dt.timedelta(days=args.min_age_days)
  180. candidates = [m for m in memories if m.created_dt < cutoff and not is_compacted_memory(m)]
  181. clusters = cluster_by_time(candidates, args.gap_minutes)
  182. report = {
  183. "user_id": args.user_id,
  184. "total_memories": len(memories),
  185. "candidates": len(candidates),
  186. "clusters": len(clusters),
  187. "max_summaries": args.max_summaries,
  188. "actions": [],
  189. }
  190. delete_budget = args.max_deletes
  191. created_count = 0
  192. for cluster in clusters:
  193. texts = [c.text.strip() for c in cluster if c.text.strip()]
  194. if not texts:
  195. continue
  196. if len(texts) < args.segment_min_items:
  197. continue
  198. if args.skip_ephemeral and is_ephemeral_cluster(texts):
  199. continue
  200. for subcluster in split_cluster(cluster, args.segment_max_items):
  201. if args.max_summaries and created_count >= args.max_summaries:
  202. break
  203. segment_text = format_segment(subcluster)
  204. extraction = call_groq_extract(segment_text, args.model, args.timeout, args.groq_base_url)
  205. facts = extraction.get("facts") if isinstance(extraction, dict) else None
  206. summary = extraction.get("summary") if isinstance(extraction, dict) else ""
  207. parse_error = bool(extraction.get("parse_error")) if isinstance(extraction, dict) else True
  208. has_facts = bool(facts) and any(
  209. facts.get(k) for k in ["people", "projects", "urls", "paths", "phones", "emails", "names"]
  210. )
  211. if not args.llm_report_all and not parse_error and not summary and not has_facts:
  212. continue
  213. ids = [m.id for m in subcluster if m.id]
  214. segment_start = subcluster[0].created_at if subcluster else None
  215. segment_end = subcluster[-1].created_at if subcluster else None
  216. action = {
  217. "type": "segment_extract",
  218. "cluster_size": len(subcluster),
  219. "segment_preview": segment_text[:240],
  220. "extraction": extraction,
  221. "source_ids": ids,
  222. "segment_start": segment_start,
  223. "segment_end": segment_end,
  224. }
  225. report["actions"].append(action)
  226. can_create = bool(summary)
  227. if args.apply and not args.dry_run and summary and args.purge_source and len(ids) > delete_budget:
  228. can_create = False
  229. if can_create:
  230. created_count += 1
  231. if args.apply and not args.dry_run and summary:
  232. if args.purge_source and len(ids) > delete_budget:
  233. continue
  234. metadata = {
  235. "compacted_at": dt.datetime.now(dt.timezone.utc).isoformat(),
  236. "compactor_version": "0.4",
  237. "kind": "segment_summary",
  238. "segment_source_ids": ids,
  239. "segment_start": segment_start,
  240. "segment_end": segment_end,
  241. "created_at": segment_end or segment_start,
  242. "extraction": extraction,
  243. "model": args.model,
  244. }
  245. client.write_memory(args.user_id, summary, metadata)
  246. if args.purge_source:
  247. for mid in ids:
  248. client.delete_memory(mid)
  249. delete_budget -= len(ids)
  250. if args.max_summaries and created_count >= args.max_summaries:
  251. break
  252. print(json.dumps(report, indent=2, ensure_ascii=False))
  253. def parse_args() -> argparse.Namespace:
  254. examples = """
  255. Examples:
  256. python3 compactor.py --user-id main
  257. python3 compactor.py --user-id main --apply
  258. python3 compactor.py --user-id main --apply --max-summaries 1
  259. python3 compactor.py --user-id main --segment-max-items 15 --skip-ephemeral
  260. """
  261. class HelpFormatter(argparse.RawDescriptionHelpFormatter, argparse.ArgumentDefaultsHelpFormatter):
  262. pass
  263. p = argparse.ArgumentParser(
  264. description="Compacts conversational memories with temporal clustering.",
  265. formatter_class=HelpFormatter,
  266. epilog=examples.strip(),
  267. )
  268. p.add_argument("--base-url", default="http://192.168.0.200:8420")
  269. p.add_argument("--user-id", required=True)
  270. p.add_argument("--apply", action="store_true", help="Apply changes. Default is dry-run.")
  271. p.add_argument("--dry-run", action="store_true", help="Force dry-run even with --apply.")
  272. p.add_argument("--gap-minutes", type=int, default=45)
  273. p.add_argument("--min-age-days", type=int, default=7)
  274. p.add_argument("--max-deletes", type=int, default=50)
  275. p.add_argument("--max-summaries", type=int, default=0, help="Limit the number of summaries created (0 = no limit).")
  276. p.add_argument("--timeout", type=int, default=20)
  277. p.add_argument("--model", default="meta-llama/llama-4-scout-17b-16e-instruct")
  278. p.add_argument("--segment-min-items", type=int, default=4)
  279. p.add_argument("--segment-max-items", type=int, default=15, help="Split large clusters to reduce topic drift.")
  280. p.add_argument("--skip-ephemeral", action="store_true", help="Skip obvious ephemeral weather-like segments.")
  281. p.add_argument("--llm-report-all", action="store_true", help="Report all LLM extractions even if empty.")
  282. p.add_argument("--purge-source", action="store_true", help="Delete source memories after writing a summary.")
  283. p.add_argument("--groq-base-url", default="https://api.groq.com/openai/v1")
  284. return p.parse_args()
  285. if __name__ == "__main__":
  286. run(parse_args())