from fastapi import Request from mem0 import Memory from .config import DEFAULT_USER_ID from .reranker import rerank_results from .responses import SafeJSONResponse def sanitize_metadata(meta: dict) -> dict: """Ensure metadata values are Chroma-compatible primitive types.""" clean = {} for key, value in meta.items(): if value is None: continue if isinstance(value, (str, int, float, bool)): clean[key] = value else: clean[key] = str(value) return clean def extract_ids(result: dict | list | None) -> list[str]: """Best-effort extraction of memory IDs from a mem0 add() response.""" if not result: return [] if isinstance(result, dict): if isinstance(result.get("id"), str): return [result["id"]] if isinstance(result.get("memory_id"), str): return [result["memory_id"]] results = result.get("results") or result.get("memories") or [] if isinstance(results, list): return [r.get("id") for r in results if isinstance(r, dict) and r.get("id")] if isinstance(result, list): return [r.get("id") for r in result if isinstance(r, dict) and r.get("id")] return [] def override_created_at(mem: Memory, memory_ids: list[str], created_at: str) -> None: """Force a created_at override in Chroma metadata for the given memories.""" if not memory_ids or not created_at: return collection = mem.vector_store.collection existing = collection.get(ids=memory_ids, include=["metadatas"]) metadatas = existing.get("metadatas", []) updated = [] for meta in metadatas: meta = meta or {} meta["created_at"] = created_at updated.append(meta) if updated: collection.update(ids=memory_ids, metadatas=updated) def extract_user_id(data: dict) -> str: """Read user ID from either snake_case or camelCase payload fields.""" return data.get("userId") or data.get("user_id") or DEFAULT_USER_ID async def handle_add( req: Request, mem: Memory, verbatim_allowed: bool = False, allow_created_at_override: bool = False, ): """Handle add for /memories and /knowledge with optional verbatim mode.""" data = await req.json() user_id = extract_user_id(data) raw_meta = data.get("metadata") metadata = sanitize_metadata(raw_meta) if raw_meta else None desired_created_at = metadata.get("created_at") if metadata else None messages = data.get("messages") text = data.get("text") if not messages and not text: return SafeJSONResponse(content={"error": "Provide 'text' or 'messages'"}, status_code=400) if verbatim_allowed: content = text or " ".join(m["content"] for m in messages if m.get("role") == "user") result = mem.add(content, user_id=user_id, metadata=metadata, infer=False) if allow_created_at_override and desired_created_at: memory_ids = extract_ids(result) override_created_at(mem, memory_ids, desired_created_at) if isinstance(result, dict) and metadata: result["metadata"] = {**(result.get("metadata") or {}), "created_at": desired_created_at} print(f"[add verbatim] user={user_id} chars={len(content)} meta={metadata}") return SafeJSONResponse(content=result) kwargs = {"user_id": user_id} if metadata: kwargs["metadata"] = metadata result = mem.add(messages or text, **kwargs) print(f"[add conversational] user={user_id} meta={metadata}") return SafeJSONResponse(content=result) async def handle_search(req: Request, mem: Memory): """Run semantic search, then rerank the candidate list.""" data = await req.json() query = (data.get("query") or "").strip() user_id = extract_user_id(data) limit = int(data.get("limit", 5)) if not query: return SafeJSONResponse(content={"results": []}) fetch_k = max(limit * 3, 15) try: result = mem.search(query, user_id=user_id, limit=fetch_k) except Exception: all_res = mem.get_all(user_id=user_id) items = all_res.get("results", []) if isinstance(all_res, dict) else (all_res if isinstance(all_res, list) else []) q = query.lower() items = [r for r in items if q in r.get("memory", "").lower()] result = {"results": items} items = rerank_results(query, result.get("results", []), top_k=limit) print(f"[search] user={user_id} query={query!r} hits={len(items)}") return SafeJSONResponse(content={"results": items}) async def handle_recent(req: Request, mem: Memory): """Return most recently created memories for a user.""" data = await req.json() user_id = extract_user_id(data) if not user_id: return SafeJSONResponse(content={"error": "Missing userId"}, status_code=400) limit = int(data.get("limit", 5)) try: results = mem.get_all(user_id=user_id) except Exception: results = mem.search(query="recent", user_id=user_id) items = results.get("results", []) items = sorted(items, key=lambda r: r.get("created_at", ""), reverse=True) return SafeJSONResponse(content={"results": items[:limit]})