mem0server.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. import os
  2. import httpx
  3. from fastapi import FastAPI, Request
  4. from fastapi.responses import JSONResponse
  5. from mem0 import Memory
  6. # --- Env validation -----------------------------------------------------------
  7. GROQ_API_KEY = os.environ.get("GROQ_API_KEY")
  8. if not GROQ_API_KEY:
  9. raise RuntimeError("GROQ_API_KEY environment variable is not set.")
  10. RERANKER_URL = os.environ.get("RERANKER_URL", "http://192.168.0.200:5200/rerank")
  11. # --- mem0 config --------------------------------------------------------------
  12. config = {
  13. "llm": {
  14. "provider": "groq",
  15. "config": {
  16. "model": "meta-llama/llama-4-scout-17b-16e-instruct",
  17. "temperature": 0.025,
  18. "max_tokens": 1500,
  19. },
  20. },
  21. "vector_store": {
  22. "provider": "chroma",
  23. "config": {
  24. "host": "192.168.0.200",
  25. "port": 8001,
  26. "collection_name": "openclaw_mem",
  27. },
  28. },
  29. "embedder": {
  30. "provider": "ollama",
  31. "config": {
  32. "model": "nomic-embed-text",
  33. "ollama_base_url": "http://192.168.0.200:11434",
  34. },
  35. },
  36. }
  37. memory = Memory.from_config(config)
  38. # --- Patch: Chroma empty-filter crash -----------------------------------------
  39. orig_search = memory.vector_store.search
  40. NOOP_WHERE = {"$and": [
  41. {"user_id": {"$ne": ""}},
  42. {"user_id": {"$ne": ""}},
  43. ]}
  44. def is_effectively_empty(filters):
  45. if not filters:
  46. return True
  47. if filters in ({"AND": []}, {"OR": []}):
  48. return True
  49. return False
  50. def safe_search(query, vectors, limit=10, filters=None):
  51. if is_effectively_empty(filters):
  52. return memory.vector_store.collection.query(
  53. query_embeddings=vectors,
  54. n_results=limit,
  55. where=NOOP_WHERE,
  56. )
  57. try:
  58. return orig_search(query=query, vectors=vectors, limit=limit, filters=filters)
  59. except Exception as e:
  60. if "Expected where" in str(e):
  61. return memory.vector_store.collection.query(
  62. query_embeddings=vectors,
  63. n_results=limit,
  64. where=NOOP_WHERE,
  65. )
  66. raise
  67. memory.vector_store.search = safe_search
  68. # --- Reranker -----------------------------------------------------------------
  69. def rerank_results(query: str, items: list, top_k: int) -> list:
  70. """
  71. Call the local reranker server and re-order mem0 results by score.
  72. Falls back to the original list if the reranker is unavailable.
  73. """
  74. if not items:
  75. return items
  76. documents = [r.get("memory", "") for r in items]
  77. try:
  78. resp = httpx.post(
  79. RERANKER_URL,
  80. json={"query": query, "documents": documents, "top_k": top_k},
  81. timeout=5.0,
  82. )
  83. resp.raise_for_status()
  84. reranked = resp.json()["results"]
  85. except Exception as exc:
  86. print(f"[reranker] unavailable, skipping rerank: {exc}")
  87. return items[:top_k]
  88. # Re-attach original mem0 metadata by matching text
  89. text_to_meta = {r.get("memory", ""): r for r in items}
  90. merged = []
  91. for r in reranked:
  92. meta = text_to_meta.get(r["text"])
  93. if meta:
  94. merged.append({**meta, "rerank_score": r["score"]})
  95. return merged
  96. # --- App ----------------------------------------------------------------------
  97. app = FastAPI(title="mem0 server")
  98. @app.get("/health")
  99. async def health():
  100. return {"status": "ok", "reranker_url": RERANKER_URL}
  101. @app.post("/memories")
  102. async def add_memory(req: Request):
  103. data = await req.json()
  104. text = data.get("text")
  105. user_id = data.get("userId") or data.get("user_id") or "default"
  106. if not text:
  107. return JSONResponse({"error": "Empty 'text' field"}, status_code=400)
  108. result = memory.add(text, user_id=user_id)
  109. print("add_memory:", {"user_id": user_id, "text": text[:80], "result": result})
  110. return result
  111. @app.post("/memories/search")
  112. async def search(req: Request):
  113. data = await req.json()
  114. query = (data.get("query") or "").strip()
  115. user_id = data.get("userId") or data.get("user_id") or "default"
  116. limit = int(data.get("limit", 5))
  117. if not query:
  118. return {"results": []}
  119. # 1. Retrieve candidates from mem0 (fetch more than limit for reranker)
  120. fetch_k = max(limit * 3, 15)
  121. try:
  122. result = memory.search(query, user_id=user_id, limit=fetch_k)
  123. except Exception:
  124. # Fallback: get_all + simple text filter
  125. all_res = memory.get_all(user_id=user_id)
  126. items = all_res.get("results", []) if isinstance(all_res, dict) else (all_res if isinstance(all_res, list) else [])
  127. q = query.lower()
  128. items = [r for r in items if q in r.get("memory", "").lower()]
  129. result = {"results": items}
  130. items = result.get("results", [])
  131. # 2. Rerank
  132. items = rerank_results(query, items, top_k=limit)
  133. result = {"results": items}
  134. print("search:", {"user_id": user_id, "query": query, "count": len(items)})
  135. return result
  136. @app.delete("/memories")
  137. async def delete(req: Request):
  138. data = await req.json()
  139. return memory.delete(data.get("filter", {}))
  140. @app.post("/memories/recent")
  141. async def recent(req: Request):
  142. data = await req.json()
  143. user_id = data.get("userId") or data.get("user_id") or "default"
  144. if not user_id:
  145. return JSONResponse({"error": "Missing userId"}, status_code=400)
  146. limit = int(data.get("limit", 5))
  147. print("recent payload:", data, "user_id:", user_id)
  148. try:
  149. results = memory.get_all(user_id=user_id)
  150. except Exception:
  151. results = memory.search(query="*", user_id=user_id)
  152. items = results.get("results", [])
  153. items = sorted(items, key=lambda r: r.get("created_at", ""), reverse=True)
  154. return {"results": items[:limit]}