Sfoglia il codice sorgente

Cache Bitstamp clients per account

Lukas Goldschmidt 1 mese fa
parent
commit
273606c115
3 ha cambiato i file con 106 aggiunte e 12 eliminazioni
  1. 25 2
      src/exec_mcp/services_bitstamp.py
  2. 20 10
      src/exec_mcp/services_orders.py
  3. 61 0
      test_stress.sh

+ 25 - 2
src/exec_mcp/services_bitstamp.py

@@ -8,6 +8,7 @@ except ModuleNotFoundError:  # allows test runs without the optional dependency
     bitstamp = None  # type: ignore
 
 from . import repo
+from .bitstamp import BitstampClient
 from .bitstamp_fx import load_eur_usd
 
 BALANCE_CACHE_TTL_SECONDS = 20
@@ -15,6 +16,8 @@ ACCOUNT_INFO_CACHE_TTL_SECONDS = 30
 STALE_CACHE_TTL_SECONDS = 10 * 60
 _CACHE_LOCKS: dict[str, threading.Lock] = {}
 _CACHE_LOCKS_GUARD = threading.Lock()
+_BITSTAMP_CLIENTS: dict[str, BitstampClient] = {}
+_BITSTAMP_CLIENTS_GUARD = threading.Lock()
 
 
 def _ttl_from_env(name: str, default_seconds: int) -> int:
@@ -62,6 +65,26 @@ def _build_trading_client(account_id: str):
     )
 
 
+def get_bitstamp_client(account_id: str):
+    with _BITSTAMP_CLIENTS_GUARD:
+        client = _BITSTAMP_CLIENTS.get(account_id)
+        if client is None:
+            account = repo.get_account(account_id)
+            secrets = repo.get_account_secrets(account_id)
+            client = BitstampClient(
+                username=account["venue_account_ref"],
+                api_key=secrets["api_key"],
+                api_secret=secrets["api_secret"],
+            )
+            _BITSTAMP_CLIENTS[account_id] = client
+        return client
+
+
+def clear_bitstamp_trading_client(account_id: str) -> None:
+    with _BITSTAMP_CLIENTS_GUARD:
+        _BITSTAMP_CLIENTS.pop(account_id, None)
+
+
 def _normalize_account_balances_payload(payload: list[dict], account_id: str) -> list[dict]:
     balances: list[dict] = []
     for item in payload:
@@ -108,8 +131,8 @@ def fetch_account_balance(account_id: str) -> dict:
             return cached
 
         try:
-            client = _build_trading_client(account_id)
-            payload = client._post("account_balances/", return_json=True, version=2)
+            client = get_bitstamp_client(account_id)
+            payload = client.trading._post("account_balances/", return_json=True, version=2)
         except Exception as exc:
             _cache_error(cache_key, str(exc))
             raise

+ 20 - 10
src/exec_mcp/services_orders.py

@@ -11,6 +11,7 @@ from uuid import uuid4
 from fastapi import HTTPException
 
 from .bitstamp import BitstampClient, BitstampError
+from .services_bitstamp import get_bitstamp_client, clear_bitstamp_trading_client
 from .bitstamp_metadata import load_market_by_symbol
 from .storage import get_connection
 
@@ -48,14 +49,16 @@ def _utc_now() -> str:
 
 
 def _get_client(account_id: str) -> BitstampClient:
-    from .repo import get_account, get_account_secrets
-    account = get_account(account_id)
-    secrets = get_account_secrets(account_id)
-    return BitstampClient(
-        username=account["venue_account_ref"],
-        api_key=secrets["api_key"],
-        api_secret=secrets["api_secret"],
-    )
+    return BitstampClientWrapper(account_id)
+
+
+class BitstampClientWrapper:
+    def __init__(self, account_id: str):
+        self.trading = get_bitstamp_client(account_id).trading
+
+
+def _invalidate_client(account_id: str) -> None:
+    clear_bitstamp_trading_client(account_id)
 
 
 def _format_decimal(value, decimals: int) -> str:
@@ -149,7 +152,10 @@ def place_order(*, account_id: str, market: str, side: str, order_type: str, amo
         else:
             raise HTTPException(status_code=400, detail="invalid side")
     except BitstampError as exc:
-        raise HTTPException(status_code=400, detail=str(exc)) from exc
+        detail = str(exc)
+        if _looks_like_auth_failure(detail):
+            _invalidate_client(account_id)
+        raise HTTPException(status_code=400, detail=detail) from exc
 
     bitstamp_order_id = str(result.get("id") or result.get("order_id") or "")
     if not bitstamp_order_id:
@@ -266,7 +272,10 @@ def query_order(*, account_id: str, order_id, client_order_id: str | None = None
     try:
         result = client.trading.order_status_v2(order_id=order_id, client_order_id=client_order_id, omit_transactions=omit_transactions)
     except BitstampError as exc:
-        raise HTTPException(status_code=400, detail=str(exc)) from exc
+        detail = str(exc)
+        if _looks_like_auth_failure(detail):
+            _invalidate_client(account_id)
+        raise HTTPException(status_code=400, detail=detail) from exc
 
     with get_connection() as conn:
         conn.execute(
@@ -289,6 +298,7 @@ def cancel_order(*, account_id: str, order_id) -> dict:
         detail = str(exc)
         if _looks_like_auth_failure(detail):
             _cancel_breaker_trip(account_id)
+            _invalidate_client(account_id)
         raise HTTPException(status_code=400, detail=detail) from exc
 
     status = "cancelled" if result else "cancel_failed"

+ 61 - 0
test_stress.sh

@@ -0,0 +1,61 @@
+#!/usr/bin/env bash
+set -euo pipefail
+
+ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
+cd "$ROOT_DIR"
+
+ACCOUNT_ID="${ACCOUNT_ID:-qndd8o9ppop6}"
+ITERATIONS="${ITERATIONS:-10}"
+
+CONFIG="${CONFIG:-/home/lucky/.openclaw/workspace/config/mcporter.json}"
+
+if [[ -f .venv/bin/activate ]]; then
+  # shellcheck disable=SC1091
+  source .venv/bin/activate
+fi
+
+ok_count=0
+fail_count=0
+total_ms=0
+
+for i in $(seq 1 "$ITERATIONS"); do
+  start_ns=$(date +%s%N)
+  if ! output="$(mcporter --config "$CONFIG" call exec.get_account_info account_id="$ACCOUNT_ID" 2>&1)"; then
+    end_ns=$(date +%s%N)
+    elapsed_ms=$(( (end_ns - start_ns) / 1000000 ))
+    fail_count=$((fail_count + 1))
+    total_ms=$((total_ms + elapsed_ms))
+    echo "[$i/$ITERATIONS] FAIL ${elapsed_ms}ms"
+    echo "$output"
+    continue
+  fi
+  end_ns=$(date +%s%N)
+  elapsed_ms=$(( (end_ns - start_ns) / 1000000 ))
+
+  if python3 - "$output" <<'PY'
+import json, sys
+payload = sys.argv[1]
+json.loads(payload)
+print("ok")
+PY
+  then
+    ok_count=$((ok_count + 1))
+  else
+    fail_count=$((fail_count + 1))
+    echo "[$i/$ITERATIONS] INVALID_JSON ${elapsed_ms}ms"
+    echo "$output"
+    continue
+  fi
+
+  total_ms=$((total_ms + elapsed_ms))
+  echo "[$i/$ITERATIONS] ${elapsed_ms}ms"
+done
+
+if [[ "$ITERATIONS" -gt 0 ]]; then
+  avg_ms=$((total_ms / ITERATIONS))
+else
+  avg_ms=0
+fi
+
+echo "---"
+echo "ok=$ok_count fail=$fail_count total=$ITERATIONS avg_ms=$avg_ms"