|
@@ -1,58 +1,66 @@
|
|
|
|
|
+import os
|
|
|
|
|
+import httpx
|
|
|
from fastapi import FastAPI, Request
|
|
from fastapi import FastAPI, Request
|
|
|
from fastapi.responses import JSONResponse
|
|
from fastapi.responses import JSONResponse
|
|
|
from mem0 import Memory
|
|
from mem0 import Memory
|
|
|
|
|
|
|
|
-app = FastAPI()
|
|
|
|
|
|
|
+# --- Env validation -----------------------------------------------------------
|
|
|
|
|
+GROQ_API_KEY = os.environ.get("GROQ_API_KEY")
|
|
|
|
|
+if not GROQ_API_KEY:
|
|
|
|
|
+ raise RuntimeError("GROQ_API_KEY environment variable is not set.")
|
|
|
|
|
|
|
|
|
|
+RERANKER_URL = os.environ.get("RERANKER_URL", "http://192.168.0.200:5200/rerank")
|
|
|
|
|
+
|
|
|
|
|
+# --- mem0 config --------------------------------------------------------------
|
|
|
config = {
|
|
config = {
|
|
|
"llm": {
|
|
"llm": {
|
|
|
"provider": "groq",
|
|
"provider": "groq",
|
|
|
"config": {
|
|
"config": {
|
|
|
- "model": "llama-3.1-8b-instant",
|
|
|
|
|
- "temperature": 0.1,
|
|
|
|
|
- "max_tokens": 1500
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ "model": "meta-llama/llama-4-scout-17b-16e-instruct",
|
|
|
|
|
+ "temperature": 0.025,
|
|
|
|
|
+ "max_tokens": 1500,
|
|
|
|
|
+ },
|
|
|
},
|
|
},
|
|
|
"vector_store": {
|
|
"vector_store": {
|
|
|
"provider": "chroma",
|
|
"provider": "chroma",
|
|
|
"config": {
|
|
"config": {
|
|
|
"host": "192.168.0.200",
|
|
"host": "192.168.0.200",
|
|
|
"port": 8001,
|
|
"port": 8001,
|
|
|
- "collection_name": "openclaw_mem"
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ "collection_name": "openclaw_mem",
|
|
|
|
|
+ },
|
|
|
},
|
|
},
|
|
|
"embedder": {
|
|
"embedder": {
|
|
|
"provider": "ollama",
|
|
"provider": "ollama",
|
|
|
"config": {
|
|
"config": {
|
|
|
"model": "nomic-embed-text",
|
|
"model": "nomic-embed-text",
|
|
|
- "ollama_base_url": "http://192.168.0.200:11434"
|
|
|
|
|
- }
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ "ollama_base_url": "http://192.168.0.200:11434",
|
|
|
|
|
+ },
|
|
|
|
|
+ },
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
memory = Memory.from_config(config)
|
|
memory = Memory.from_config(config)
|
|
|
|
|
|
|
|
-# Patch Chroma empty-filter crash (mem0 sometimes calls search with {} filters)
|
|
|
|
|
|
|
+# --- Patch: Chroma empty-filter crash -----------------------------------------
|
|
|
orig_search = memory.vector_store.search
|
|
orig_search = memory.vector_store.search
|
|
|
|
|
|
|
|
|
|
+NOOP_WHERE = {"$and": [
|
|
|
|
|
+ {"user_id": {"$ne": ""}},
|
|
|
|
|
+ {"user_id": {"$ne": ""}},
|
|
|
|
|
+]}
|
|
|
|
|
+
|
|
|
def is_effectively_empty(filters):
|
|
def is_effectively_empty(filters):
|
|
|
if not filters:
|
|
if not filters:
|
|
|
return True
|
|
return True
|
|
|
- if filters == {"AND": []} or filters == {"OR": []}:
|
|
|
|
|
|
|
+ if filters in ({"AND": []}, {"OR": []}):
|
|
|
return True
|
|
return True
|
|
|
return False
|
|
return False
|
|
|
|
|
|
|
|
-NOOP_WHERE = {"$and": [
|
|
|
|
|
- {"user_id": {"$ne": ""}},
|
|
|
|
|
- {"user_id": {"$ne": ""}}
|
|
|
|
|
-]}
|
|
|
|
|
-
|
|
|
|
|
def safe_search(query, vectors, limit=10, filters=None):
|
|
def safe_search(query, vectors, limit=10, filters=None):
|
|
|
if is_effectively_empty(filters):
|
|
if is_effectively_empty(filters):
|
|
|
return memory.vector_store.collection.query(
|
|
return memory.vector_store.collection.query(
|
|
|
query_embeddings=vectors,
|
|
query_embeddings=vectors,
|
|
|
n_results=limit,
|
|
n_results=limit,
|
|
|
- where=NOOP_WHERE
|
|
|
|
|
|
|
+ where=NOOP_WHERE,
|
|
|
)
|
|
)
|
|
|
try:
|
|
try:
|
|
|
return orig_search(query=query, vectors=vectors, limit=limit, filters=filters)
|
|
return orig_search(query=query, vectors=vectors, limit=limit, filters=filters)
|
|
@@ -61,13 +69,51 @@ def safe_search(query, vectors, limit=10, filters=None):
|
|
|
return memory.vector_store.collection.query(
|
|
return memory.vector_store.collection.query(
|
|
|
query_embeddings=vectors,
|
|
query_embeddings=vectors,
|
|
|
n_results=limit,
|
|
n_results=limit,
|
|
|
- where=NOOP_WHERE
|
|
|
|
|
|
|
+ where=NOOP_WHERE,
|
|
|
)
|
|
)
|
|
|
raise
|
|
raise
|
|
|
|
|
|
|
|
-
|
|
|
|
|
memory.vector_store.search = safe_search
|
|
memory.vector_store.search = safe_search
|
|
|
|
|
|
|
|
|
|
+# --- Reranker -----------------------------------------------------------------
|
|
|
|
|
+def rerank_results(query: str, items: list, top_k: int) -> list:
|
|
|
|
|
+ """
|
|
|
|
|
+ Call the local reranker server and re-order mem0 results by score.
|
|
|
|
|
+ Falls back to the original list if the reranker is unavailable.
|
|
|
|
|
+ """
|
|
|
|
|
+ if not items:
|
|
|
|
|
+ return items
|
|
|
|
|
+
|
|
|
|
|
+ documents = [r.get("memory", "") for r in items]
|
|
|
|
|
+ try:
|
|
|
|
|
+ resp = httpx.post(
|
|
|
|
|
+ RERANKER_URL,
|
|
|
|
|
+ json={"query": query, "documents": documents, "top_k": top_k},
|
|
|
|
|
+ timeout=5.0,
|
|
|
|
|
+ )
|
|
|
|
|
+ resp.raise_for_status()
|
|
|
|
|
+ reranked = resp.json()["results"]
|
|
|
|
|
+ except Exception as exc:
|
|
|
|
|
+ print(f"[reranker] unavailable, skipping rerank: {exc}")
|
|
|
|
|
+ return items[:top_k]
|
|
|
|
|
+
|
|
|
|
|
+ # Re-attach original mem0 metadata by matching text
|
|
|
|
|
+ text_to_meta = {r.get("memory", ""): r for r in items}
|
|
|
|
|
+ merged = []
|
|
|
|
|
+ for r in reranked:
|
|
|
|
|
+ meta = text_to_meta.get(r["text"])
|
|
|
|
|
+ if meta:
|
|
|
|
|
+ merged.append({**meta, "rerank_score": r["score"]})
|
|
|
|
|
+ return merged
|
|
|
|
|
+
|
|
|
|
|
+# --- App ----------------------------------------------------------------------
|
|
|
|
|
+app = FastAPI(title="mem0 server")
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+@app.get("/health")
|
|
|
|
|
+async def health():
|
|
|
|
|
+ return {"status": "ok", "reranker_url": RERANKER_URL}
|
|
|
|
|
+
|
|
|
|
|
|
|
|
@app.post("/memories")
|
|
@app.post("/memories")
|
|
|
async def add_memory(req: Request):
|
|
async def add_memory(req: Request):
|
|
@@ -76,62 +122,61 @@ async def add_memory(req: Request):
|
|
|
user_id = data.get("userId") or data.get("user_id") or "default"
|
|
user_id = data.get("userId") or data.get("user_id") or "default"
|
|
|
if not text:
|
|
if not text:
|
|
|
return JSONResponse({"error": "Empty 'text' field"}, status_code=400)
|
|
return JSONResponse({"error": "Empty 'text' field"}, status_code=400)
|
|
|
-
|
|
|
|
|
result = memory.add(text, user_id=user_id)
|
|
result = memory.add(text, user_id=user_id)
|
|
|
print("add_memory:", {"user_id": user_id, "text": text[:80], "result": result})
|
|
print("add_memory:", {"user_id": user_id, "text": text[:80], "result": result})
|
|
|
return result
|
|
return result
|
|
|
|
|
|
|
|
|
|
+
|
|
|
@app.post("/memories/search")
|
|
@app.post("/memories/search")
|
|
|
async def search(req: Request):
|
|
async def search(req: Request):
|
|
|
data = await req.json()
|
|
data = await req.json()
|
|
|
query = (data.get("query") or "").strip()
|
|
query = (data.get("query") or "").strip()
|
|
|
user_id = data.get("userId") or data.get("user_id") or "default"
|
|
user_id = data.get("userId") or data.get("user_id") or "default"
|
|
|
|
|
+ limit = int(data.get("limit", 5))
|
|
|
|
|
|
|
|
if not query:
|
|
if not query:
|
|
|
return {"results": []}
|
|
return {"results": []}
|
|
|
|
|
|
|
|
|
|
+ # 1. Retrieve candidates from mem0 (fetch more than limit for reranker)
|
|
|
|
|
+ fetch_k = max(limit * 3, 15)
|
|
|
try:
|
|
try:
|
|
|
- result = memory.search(query, user_id=user_id)
|
|
|
|
|
|
|
+ result = memory.search(query, user_id=user_id, limit=fetch_k)
|
|
|
except Exception:
|
|
except Exception:
|
|
|
- # fallback: get_all + simple text filter
|
|
|
|
|
|
|
+ # Fallback: get_all + simple text filter
|
|
|
all_res = memory.get_all(user_id=user_id)
|
|
all_res = memory.get_all(user_id=user_id)
|
|
|
- if isinstance(all_res, dict):
|
|
|
|
|
- items = all_res.get("results", [])
|
|
|
|
|
- elif isinstance(all_res, list):
|
|
|
|
|
- items = all_res
|
|
|
|
|
- else:
|
|
|
|
|
- items = []
|
|
|
|
|
-
|
|
|
|
|
|
|
+ items = all_res.get("results", []) if isinstance(all_res, dict) else (all_res if isinstance(all_res, list) else [])
|
|
|
q = query.lower()
|
|
q = query.lower()
|
|
|
- items = [r for r in items if q in (r.get("memory", "").lower())]
|
|
|
|
|
|
|
+ items = [r for r in items if q in r.get("memory", "").lower()]
|
|
|
result = {"results": items}
|
|
result = {"results": items}
|
|
|
|
|
|
|
|
- print("search:", {"user_id": user_id, "query": query, "count": len(result.get("results", []))})
|
|
|
|
|
- limit = int(data.get("limit", 5))
|
|
|
|
|
items = result.get("results", [])
|
|
items = result.get("results", [])
|
|
|
- items = sorted(items, key=lambda r: r.get("score", float("inf")))[:limit]
|
|
|
|
|
- result = {"results": items}
|
|
|
|
|
- print("search:", {"user_id": user_id, "query": query, "count": len(result.get("results", []))})
|
|
|
|
|
|
|
|
|
|
|
|
+ # 2. Rerank
|
|
|
|
|
+ items = rerank_results(query, items, top_k=limit)
|
|
|
|
|
+
|
|
|
|
|
+ result = {"results": items}
|
|
|
|
|
+ print("search:", {"user_id": user_id, "query": query, "count": len(items)})
|
|
|
return result
|
|
return result
|
|
|
|
|
|
|
|
|
|
+
|
|
|
@app.delete("/memories")
|
|
@app.delete("/memories")
|
|
|
async def delete(req: Request):
|
|
async def delete(req: Request):
|
|
|
data = await req.json()
|
|
data = await req.json()
|
|
|
return memory.delete(data.get("filter", {}))
|
|
return memory.delete(data.get("filter", {}))
|
|
|
|
|
|
|
|
|
|
+
|
|
|
@app.post("/memories/recent")
|
|
@app.post("/memories/recent")
|
|
|
async def recent(req: Request):
|
|
async def recent(req: Request):
|
|
|
data = await req.json()
|
|
data = await req.json()
|
|
|
user_id = data.get("userId") or data.get("user_id") or "default"
|
|
user_id = data.get("userId") or data.get("user_id") or "default"
|
|
|
if not user_id:
|
|
if not user_id:
|
|
|
- return JSONResponse({"error":"Missing userId"}, status_code=400)
|
|
|
|
|
- print("recent payload:", data, "user_id:", user_id)
|
|
|
|
|
|
|
+ return JSONResponse({"error": "Missing userId"}, status_code=400)
|
|
|
limit = int(data.get("limit", 5))
|
|
limit = int(data.get("limit", 5))
|
|
|
|
|
+ print("recent payload:", data, "user_id:", user_id)
|
|
|
try:
|
|
try:
|
|
|
results = memory.get_all(user_id=user_id)
|
|
results = memory.get_all(user_id=user_id)
|
|
|
except Exception:
|
|
except Exception:
|
|
|
results = memory.search(query="*", user_id=user_id)
|
|
results = memory.search(query="*", user_id=user_id)
|
|
|
items = results.get("results", [])
|
|
items = results.get("results", [])
|
|
|
items = sorted(items, key=lambda r: r.get("created_at", ""), reverse=True)
|
|
items = sorted(items, key=lambda r: r.get("created_at", ""), reverse=True)
|
|
|
- return {"results": items[:limit]}
|
|
|
|
|
|
|
+ return {"results": items[:limit]}
|