| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293 |
- from fastapi import Request
- from mem0 import Memory
- from .config import DEFAULT_USER_ID
- from .reranker import rerank_results
- from .responses import SafeJSONResponse
- def sanitize_metadata(meta: dict) -> dict:
- """Ensure metadata values are Chroma-compatible primitive types."""
- clean = {}
- for key, value in meta.items():
- if value is None:
- continue
- if isinstance(value, (str, int, float, bool)):
- clean[key] = value
- else:
- clean[key] = str(value)
- return clean
- def extract_user_id(data: dict) -> str:
- """Read user ID from either snake_case or camelCase payload fields."""
- return data.get("userId") or data.get("user_id") or DEFAULT_USER_ID
- async def handle_add(req: Request, mem: Memory, verbatim_allowed: bool = False):
- """Handle add for /memories and /knowledge with optional verbatim mode."""
- data = await req.json()
- user_id = extract_user_id(data)
- raw_meta = data.get("metadata")
- metadata = sanitize_metadata(raw_meta) if raw_meta else None
- messages = data.get("messages")
- text = data.get("text")
- if not messages and not text:
- return SafeJSONResponse(content={"error": "Provide 'text' or 'messages'"}, status_code=400)
- if verbatim_allowed:
- content = text or " ".join(m["content"] for m in messages if m.get("role") == "user")
- result = mem.add(content, user_id=user_id, metadata=metadata, infer=False)
- print(f"[add verbatim] user={user_id} chars={len(content)} meta={metadata}")
- return SafeJSONResponse(content=result)
- kwargs = {"user_id": user_id}
- if metadata:
- kwargs["metadata"] = metadata
- result = mem.add(messages or text, **kwargs)
- print(f"[add conversational] user={user_id} meta={metadata}")
- return SafeJSONResponse(content=result)
- async def handle_search(req: Request, mem: Memory):
- """Run semantic search, then rerank the candidate list."""
- data = await req.json()
- query = (data.get("query") or "").strip()
- user_id = extract_user_id(data)
- limit = int(data.get("limit", 5))
- if not query:
- return SafeJSONResponse(content={"results": []})
- fetch_k = max(limit * 3, 15)
- try:
- result = mem.search(query, user_id=user_id, limit=fetch_k)
- except Exception:
- all_res = mem.get_all(user_id=user_id)
- items = all_res.get("results", []) if isinstance(all_res, dict) else (all_res if isinstance(all_res, list) else [])
- q = query.lower()
- items = [r for r in items if q in r.get("memory", "").lower()]
- result = {"results": items}
- items = rerank_results(query, result.get("results", []), top_k=limit)
- print(f"[search] user={user_id} query={query!r} hits={len(items)}")
- return SafeJSONResponse(content={"results": items})
- async def handle_recent(req: Request, mem: Memory):
- """Return most recently created memories for a user."""
- data = await req.json()
- user_id = extract_user_id(data)
- if not user_id:
- return SafeJSONResponse(content={"error": "Missing userId"}, status_code=400)
- limit = int(data.get("limit", 5))
- try:
- results = mem.get_all(user_id=user_id)
- except Exception:
- results = mem.search(query="recent", user_id=user_id)
- items = results.get("results", [])
- items = sorted(items, key=lambda r: r.get("created_at", ""), reverse=True)
- return SafeJSONResponse(content={"results": items[:limit]})
|