handlers.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. from fastapi import Request
  2. from mem0 import Memory
  3. from .config import DEFAULT_USER_ID
  4. from .reranker import rerank_results
  5. from .responses import SafeJSONResponse
  6. def sanitize_metadata(meta: dict) -> dict:
  7. """Ensure metadata values are Chroma-compatible primitive types."""
  8. clean = {}
  9. for key, value in meta.items():
  10. if value is None:
  11. continue
  12. if isinstance(value, (str, int, float, bool)):
  13. clean[key] = value
  14. else:
  15. clean[key] = str(value)
  16. return clean
  17. def extract_ids(result: dict | list | None) -> list[str]:
  18. """Best-effort extraction of memory IDs from a mem0 add() response."""
  19. if not result:
  20. return []
  21. if isinstance(result, dict):
  22. if isinstance(result.get("id"), str):
  23. return [result["id"]]
  24. if isinstance(result.get("memory_id"), str):
  25. return [result["memory_id"]]
  26. results = result.get("results") or result.get("memories") or []
  27. if isinstance(results, list):
  28. return [r.get("id") for r in results if isinstance(r, dict) and r.get("id")]
  29. if isinstance(result, list):
  30. return [r.get("id") for r in result if isinstance(r, dict) and r.get("id")]
  31. return []
  32. def override_created_at(mem: Memory, memory_ids: list[str], created_at: str) -> None:
  33. """Force a created_at override in Chroma metadata for the given memories."""
  34. if not memory_ids or not created_at:
  35. return
  36. collection = mem.vector_store.collection
  37. existing = collection.get(ids=memory_ids, include=["metadatas"])
  38. metadatas = existing.get("metadatas", [])
  39. updated = []
  40. for meta in metadatas:
  41. meta = meta or {}
  42. meta["created_at"] = created_at
  43. updated.append(meta)
  44. if updated:
  45. collection.update(ids=memory_ids, metadatas=updated)
  46. def extract_user_id(data: dict) -> str:
  47. """Read user ID from either snake_case or camelCase payload fields."""
  48. return data.get("userId") or data.get("user_id") or DEFAULT_USER_ID
  49. async def handle_add(
  50. req: Request,
  51. mem: Memory,
  52. verbatim_allowed: bool = False,
  53. allow_created_at_override: bool = False,
  54. ):
  55. """Handle add for /memories and /knowledge with optional verbatim mode."""
  56. data = await req.json()
  57. user_id = extract_user_id(data)
  58. raw_meta = data.get("metadata")
  59. metadata = sanitize_metadata(raw_meta) if raw_meta else None
  60. desired_created_at = metadata.get("created_at") if metadata else None
  61. messages = data.get("messages")
  62. text = data.get("text")
  63. if not messages and not text:
  64. return SafeJSONResponse(content={"error": "Provide 'text' or 'messages'"}, status_code=400)
  65. if verbatim_allowed:
  66. content = text or " ".join(m["content"] for m in messages if m.get("role") == "user")
  67. result = mem.add(content, user_id=user_id, metadata=metadata, infer=False)
  68. if allow_created_at_override and desired_created_at:
  69. memory_ids = extract_ids(result)
  70. override_created_at(mem, memory_ids, desired_created_at)
  71. if isinstance(result, dict) and metadata:
  72. result["metadata"] = {**(result.get("metadata") or {}), "created_at": desired_created_at}
  73. print(f"[add verbatim] user={user_id} chars={len(content)} meta={metadata}")
  74. return SafeJSONResponse(content=result)
  75. kwargs = {"user_id": user_id}
  76. if metadata:
  77. kwargs["metadata"] = metadata
  78. result = mem.add(messages or text, **kwargs)
  79. print(f"[add conversational] user={user_id} meta={metadata}")
  80. return SafeJSONResponse(content=result)
  81. async def handle_search(req: Request, mem: Memory):
  82. """Run semantic search, then rerank the candidate list."""
  83. data = await req.json()
  84. query = (data.get("query") or "").strip()
  85. user_id = extract_user_id(data)
  86. limit = int(data.get("limit", 5))
  87. if not query:
  88. return SafeJSONResponse(content={"results": []})
  89. fetch_k = max(limit * 3, 15)
  90. try:
  91. result = mem.search(query, user_id=user_id, limit=fetch_k)
  92. except Exception:
  93. all_res = mem.get_all(user_id=user_id)
  94. items = all_res.get("results", []) if isinstance(all_res, dict) else (all_res if isinstance(all_res, list) else [])
  95. q = query.lower()
  96. items = [r for r in items if q in r.get("memory", "").lower()]
  97. result = {"results": items}
  98. items = rerank_results(query, result.get("results", []), top_k=limit)
  99. print(f"[search] user={user_id} query={query!r} hits={len(items)}")
  100. return SafeJSONResponse(content={"results": items})
  101. async def handle_recent(req: Request, mem: Memory):
  102. """Return most recently created memories for a user."""
  103. data = await req.json()
  104. user_id = extract_user_id(data)
  105. if not user_id:
  106. return SafeJSONResponse(content={"error": "Missing userId"}, status_code=400)
  107. limit = int(data.get("limit", 5))
  108. try:
  109. results = mem.get_all(user_id=user_id)
  110. except Exception:
  111. results = mem.search(query="recent", user_id=user_id)
  112. items = results.get("results", [])
  113. items = sorted(items, key=lambda r: r.get("created_at", ""), reverse=True)
  114. return SafeJSONResponse(content={"results": items[:limit]})