mem0server.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. from fastapi import FastAPI, Request
  2. from fastapi.responses import JSONResponse
  3. from mem0 import Memory
  4. app = FastAPI()
  5. config = {
  6. "llm": {
  7. "provider": "groq",
  8. "config": {
  9. "model": "llama-3.1-8b-instant",
  10. "temperature": 0.1,
  11. "max_tokens": 1500
  12. }
  13. },
  14. "vector_store": {
  15. "provider": "chroma",
  16. "config": {
  17. "host": "192.168.0.200",
  18. "port": 8001,
  19. "collection_name": "openclaw_mem"
  20. }
  21. },
  22. "embedder": {
  23. "provider": "ollama",
  24. "config": {
  25. "model": "nomic-embed-text",
  26. "ollama_base_url": "http://192.168.0.200:11434"
  27. }
  28. }
  29. }
  30. memory = Memory.from_config(config)
  31. # Patch Chroma empty-filter crash (mem0 sometimes calls search with {} filters)
  32. orig_search = memory.vector_store.search
  33. def is_effectively_empty(filters):
  34. if not filters:
  35. return True
  36. if filters == {"AND": []} or filters == {"OR": []}:
  37. return True
  38. return False
  39. NOOP_WHERE = {"$and": [
  40. {"user_id": {"$ne": ""}},
  41. {"user_id": {"$ne": ""}}
  42. ]}
  43. def safe_search(query, vectors, limit=10, filters=None):
  44. if is_effectively_empty(filters):
  45. return memory.vector_store.collection.query(
  46. query_embeddings=vectors,
  47. n_results=limit,
  48. where=NOOP_WHERE
  49. )
  50. try:
  51. return orig_search(query=query, vectors=vectors, limit=limit, filters=filters)
  52. except Exception as e:
  53. if "Expected where" in str(e):
  54. return memory.vector_store.collection.query(
  55. query_embeddings=vectors,
  56. n_results=limit,
  57. where=NOOP_WHERE
  58. )
  59. raise
  60. memory.vector_store.search = safe_search
  61. @app.post("/memories")
  62. async def add_memory(req: Request):
  63. data = await req.json()
  64. text = data.get("text")
  65. user_id = data.get("userId") or data.get("user_id") or "default"
  66. if not text:
  67. return JSONResponse({"error": "Empty 'text' field"}, status_code=400)
  68. result = memory.add(text, user_id=user_id)
  69. print("add_memory:", {"user_id": user_id, "text": text[:80], "result": result})
  70. return result
  71. @app.post("/memories/search")
  72. async def search(req: Request):
  73. data = await req.json()
  74. query = (data.get("query") or "").strip()
  75. user_id = data.get("userId") or data.get("user_id") or "default"
  76. if not query:
  77. return {"results": []}
  78. try:
  79. result = memory.search(query, user_id=user_id)
  80. except Exception:
  81. # fallback: get_all + simple text filter
  82. all_res = memory.get_all(user_id=user_id)
  83. if isinstance(all_res, dict):
  84. items = all_res.get("results", [])
  85. elif isinstance(all_res, list):
  86. items = all_res
  87. else:
  88. items = []
  89. q = query.lower()
  90. items = [r for r in items if q in (r.get("memory", "").lower())]
  91. result = {"results": items}
  92. print("search:", {"user_id": user_id, "query": query, "count": len(result.get("results", []))})
  93. limit = int(data.get("limit", 5))
  94. items = result.get("results", [])
  95. items = sorted(items, key=lambda r: r.get("score", float("inf")))[:limit]
  96. result = {"results": items}
  97. print("search:", {"user_id": user_id, "query": query, "count": len(result.get("results", []))})
  98. return result
  99. @app.delete("/memories")
  100. async def delete(req: Request):
  101. data = await req.json()
  102. return memory.delete(data.get("filter", {}))
  103. @app.post("/memories/recent")
  104. async def recent(req: Request):
  105. data = await req.json()
  106. user_id = data.get("userId") or data.get("user_id") or "default"
  107. if not user_id:
  108. return JSONResponse({"error":"Missing userId"}, status_code=400)
  109. print("recent payload:", data, "user_id:", user_id)
  110. limit = int(data.get("limit", 5))
  111. try:
  112. results = memory.get_all(user_id=user_id)
  113. except Exception:
  114. results = memory.search(query="*", user_id=user_id)
  115. items = results.get("results", [])
  116. items = sorted(items, key=lambda r: r.get("created_at", ""), reverse=True)
  117. return {"results": items[:limit]}