from fastapi import Request from mem0 import Memory from .config import DEFAULT_USER_ID from .reranker import rerank_results from .responses import SafeJSONResponse def sanitize_metadata(meta: dict) -> dict: """Ensure metadata values are Chroma-compatible primitive types.""" clean = {} for key, value in meta.items(): if value is None: continue if isinstance(value, (str, int, float, bool)): clean[key] = value else: clean[key] = str(value) return clean def extract_user_id(data: dict) -> str: """Read user ID from either snake_case or camelCase payload fields.""" return data.get("userId") or data.get("user_id") or DEFAULT_USER_ID async def handle_add(req: Request, mem: Memory, verbatim_allowed: bool = False): """Handle add for /memories and /knowledge with optional verbatim mode.""" data = await req.json() user_id = extract_user_id(data) raw_meta = data.get("metadata") metadata = sanitize_metadata(raw_meta) if raw_meta else None messages = data.get("messages") text = data.get("text") if not messages and not text: return SafeJSONResponse(content={"error": "Provide 'text' or 'messages'"}, status_code=400) if verbatim_allowed: content = text or " ".join(m["content"] for m in messages if m.get("role") == "user") result = mem.add(content, user_id=user_id, metadata=metadata, infer=False) print(f"[add verbatim] user={user_id} chars={len(content)} meta={metadata}") return SafeJSONResponse(content=result) kwargs = {"user_id": user_id} if metadata: kwargs["metadata"] = metadata result = mem.add(messages or text, **kwargs) print(f"[add conversational] user={user_id} meta={metadata}") return SafeJSONResponse(content=result) async def handle_search(req: Request, mem: Memory): """Run semantic search, then rerank the candidate list.""" data = await req.json() query = (data.get("query") or "").strip() user_id = extract_user_id(data) limit = int(data.get("limit", 5)) if not query: return SafeJSONResponse(content={"results": []}) fetch_k = max(limit * 3, 15) try: result = mem.search(query, user_id=user_id, limit=fetch_k) except Exception: all_res = mem.get_all(user_id=user_id) items = all_res.get("results", []) if isinstance(all_res, dict) else (all_res if isinstance(all_res, list) else []) q = query.lower() items = [r for r in items if q in r.get("memory", "").lower()] result = {"results": items} items = rerank_results(query, result.get("results", []), top_k=limit) print(f"[search] user={user_id} query={query!r} hits={len(items)}") return SafeJSONResponse(content={"results": items}) async def handle_recent(req: Request, mem: Memory): """Return most recently created memories for a user.""" data = await req.json() user_id = extract_user_id(data) if not user_id: return SafeJSONResponse(content={"error": "Missing userId"}, status_code=400) limit = int(data.get("limit", 5)) try: results = mem.get_all(user_id=user_id) except Exception: results = mem.search(query="recent", user_id=user_id) items = results.get("results", []) items = sorted(items, key=lambda r: r.get("created_at", ""), reverse=True) return SafeJSONResponse(content={"results": items[:limit]})