handlers.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. from fastapi import Request
  2. from mem0 import Memory
  3. from .config import DEFAULT_USER_ID
  4. from .reranker import rerank_results
  5. from .responses import SafeJSONResponse
  6. def sanitize_metadata(meta: dict) -> dict:
  7. """Ensure metadata values are Chroma-compatible primitive types."""
  8. clean = {}
  9. for key, value in meta.items():
  10. if value is None:
  11. continue
  12. if isinstance(value, (str, int, float, bool)):
  13. clean[key] = value
  14. else:
  15. clean[key] = str(value)
  16. return clean
  17. def extract_user_id(data: dict) -> str:
  18. """Read user ID from either snake_case or camelCase payload fields."""
  19. return data.get("userId") or data.get("user_id") or DEFAULT_USER_ID
  20. async def handle_add(req: Request, mem: Memory, verbatim_allowed: bool = False):
  21. """Handle add for /memories and /knowledge with optional verbatim mode."""
  22. data = await req.json()
  23. user_id = extract_user_id(data)
  24. raw_meta = data.get("metadata")
  25. metadata = sanitize_metadata(raw_meta) if raw_meta else None
  26. messages = data.get("messages")
  27. text = data.get("text")
  28. if not messages and not text:
  29. return SafeJSONResponse(content={"error": "Provide 'text' or 'messages'"}, status_code=400)
  30. if verbatim_allowed:
  31. content = text or " ".join(m["content"] for m in messages if m.get("role") == "user")
  32. result = mem.add(content, user_id=user_id, metadata=metadata, infer=False)
  33. print(f"[add verbatim] user={user_id} chars={len(content)} meta={metadata}")
  34. return SafeJSONResponse(content=result)
  35. kwargs = {"user_id": user_id}
  36. if metadata:
  37. kwargs["metadata"] = metadata
  38. result = mem.add(messages or text, **kwargs)
  39. print(f"[add conversational] user={user_id} meta={metadata}")
  40. return SafeJSONResponse(content=result)
  41. async def handle_search(req: Request, mem: Memory):
  42. """Run semantic search, then rerank the candidate list."""
  43. data = await req.json()
  44. query = (data.get("query") or "").strip()
  45. user_id = extract_user_id(data)
  46. limit = int(data.get("limit", 5))
  47. if not query:
  48. return SafeJSONResponse(content={"results": []})
  49. fetch_k = max(limit * 3, 15)
  50. try:
  51. result = mem.search(query, user_id=user_id, limit=fetch_k)
  52. except Exception:
  53. all_res = mem.get_all(user_id=user_id)
  54. items = all_res.get("results", []) if isinstance(all_res, dict) else (all_res if isinstance(all_res, list) else [])
  55. q = query.lower()
  56. items = [r for r in items if q in r.get("memory", "").lower()]
  57. result = {"results": items}
  58. items = rerank_results(query, result.get("results", []), top_k=limit)
  59. print(f"[search] user={user_id} query={query!r} hits={len(items)}")
  60. return SafeJSONResponse(content={"results": items})
  61. async def handle_recent(req: Request, mem: Memory):
  62. """Return most recently created memories for a user."""
  63. data = await req.json()
  64. user_id = extract_user_id(data)
  65. if not user_id:
  66. return SafeJSONResponse(content={"error": "Missing userId"}, status_code=400)
  67. limit = int(data.get("limit", 5))
  68. try:
  69. results = mem.get_all(user_id=user_id)
  70. except Exception:
  71. results = mem.search(query="recent", user_id=user_id)
  72. items = results.get("results", [])
  73. items = sorted(items, key=lambda r: r.get("created_at", ""), reverse=True)
  74. return SafeJSONResponse(content={"results": items[:limit]})