|
|
@@ -19,17 +19,56 @@ def sanitize_metadata(meta: dict) -> dict:
|
|
|
return clean
|
|
|
|
|
|
|
|
|
+def extract_ids(result: dict | list | None) -> list[str]:
|
|
|
+ """Best-effort extraction of memory IDs from a mem0 add() response."""
|
|
|
+ if not result:
|
|
|
+ return []
|
|
|
+ if isinstance(result, dict):
|
|
|
+ if isinstance(result.get("id"), str):
|
|
|
+ return [result["id"]]
|
|
|
+ if isinstance(result.get("memory_id"), str):
|
|
|
+ return [result["memory_id"]]
|
|
|
+ results = result.get("results") or result.get("memories") or []
|
|
|
+ if isinstance(results, list):
|
|
|
+ return [r.get("id") for r in results if isinstance(r, dict) and r.get("id")]
|
|
|
+ if isinstance(result, list):
|
|
|
+ return [r.get("id") for r in result if isinstance(r, dict) and r.get("id")]
|
|
|
+ return []
|
|
|
+
|
|
|
+
|
|
|
+def override_created_at(mem: Memory, memory_ids: list[str], created_at: str) -> None:
|
|
|
+ """Force a created_at override in Chroma metadata for the given memories."""
|
|
|
+ if not memory_ids or not created_at:
|
|
|
+ return
|
|
|
+ collection = mem.vector_store.collection
|
|
|
+ existing = collection.get(ids=memory_ids, include=["metadatas"])
|
|
|
+ metadatas = existing.get("metadatas", [])
|
|
|
+ updated = []
|
|
|
+ for meta in metadatas:
|
|
|
+ meta = meta or {}
|
|
|
+ meta["created_at"] = created_at
|
|
|
+ updated.append(meta)
|
|
|
+ if updated:
|
|
|
+ collection.update(ids=memory_ids, metadatas=updated)
|
|
|
+
|
|
|
+
|
|
|
def extract_user_id(data: dict) -> str:
|
|
|
"""Read user ID from either snake_case or camelCase payload fields."""
|
|
|
return data.get("userId") or data.get("user_id") or DEFAULT_USER_ID
|
|
|
|
|
|
|
|
|
-async def handle_add(req: Request, mem: Memory, verbatim_allowed: bool = False):
|
|
|
+async def handle_add(
|
|
|
+ req: Request,
|
|
|
+ mem: Memory,
|
|
|
+ verbatim_allowed: bool = False,
|
|
|
+ allow_created_at_override: bool = False,
|
|
|
+):
|
|
|
"""Handle add for /memories and /knowledge with optional verbatim mode."""
|
|
|
data = await req.json()
|
|
|
user_id = extract_user_id(data)
|
|
|
raw_meta = data.get("metadata")
|
|
|
metadata = sanitize_metadata(raw_meta) if raw_meta else None
|
|
|
+ desired_created_at = metadata.get("created_at") if metadata else None
|
|
|
messages = data.get("messages")
|
|
|
text = data.get("text")
|
|
|
|
|
|
@@ -39,6 +78,11 @@ async def handle_add(req: Request, mem: Memory, verbatim_allowed: bool = False):
|
|
|
if verbatim_allowed:
|
|
|
content = text or " ".join(m["content"] for m in messages if m.get("role") == "user")
|
|
|
result = mem.add(content, user_id=user_id, metadata=metadata, infer=False)
|
|
|
+ if allow_created_at_override and desired_created_at:
|
|
|
+ memory_ids = extract_ids(result)
|
|
|
+ override_created_at(mem, memory_ids, desired_created_at)
|
|
|
+ if isinstance(result, dict) and metadata:
|
|
|
+ result["metadata"] = {**(result.get("metadata") or {}), "created_at": desired_created_at}
|
|
|
print(f"[add verbatim] user={user_id} chars={len(content)} meta={metadata}")
|
|
|
return SafeJSONResponse(content=result)
|
|
|
|