瀏覽代碼

Harden Bitstamp cancel and balance flows

Lukas Goldschmidt 1 月之前
父節點
當前提交
58d514a36d
共有 3 個文件被更改,包括 173 次插入47 次删除
  1. 28 1
      src/exec_mcp/bitstamp.py
  2. 113 45
      src/exec_mcp/services_bitstamp.py
  3. 32 1
      src/exec_mcp/services_orders.py

+ 28 - 1
src/exec_mcp/bitstamp.py

@@ -2,6 +2,8 @@ from __future__ import annotations
 
 import inspect
 import json
+import time
+import threading
 from dataclasses import dataclass
 
 from .bitstamp_rate_limit import throttle_bitstamp_request
@@ -24,13 +26,38 @@ class AccountInfo:
     metadata: dict | None = None
 
 
+_AUTH_BREAKER_LOCK = threading.Lock()
+_AUTH_BREAKER_NEXT_ALLOWED: dict[str, float] = {}
+_AUTH_BREAKER_SECONDS = 5.0
+
+
 class LG_Trading(Trading):
     def __init__(self, username, key, secret, *args, **kwargs):
+        self._username = str(username)
         super(LG_Trading, self).__init__(username=username, key=key, secret=secret, *args, **kwargs)
 
+    def _breaker_is_open(self) -> bool:
+        with _AUTH_BREAKER_LOCK:
+            return _AUTH_BREAKER_NEXT_ALLOWED.get(self._username, 0.0) > time.monotonic()
+
+    def _breaker_trip(self) -> None:
+        with _AUTH_BREAKER_LOCK:
+            _AUTH_BREAKER_NEXT_ALLOWED[self._username] = time.monotonic() + _AUTH_BREAKER_SECONDS
+
     def _post(self, url, data=None, return_json=True, version=2, **kwargs):
+        if self._breaker_is_open():
+            raise RuntimeError("Bitstamp auth breaker active, retry later")
         throttle_bitstamp_request()
-        return super()._post(url, data=data, return_json=return_json, version=version, **kwargs)
+        try:
+            return super()._post(url, data=data, return_json=return_json, version=version, **kwargs)
+        except Exception as exc:
+            msg = str(exc).lower()
+            if "403" in msg and ("authentication failed" in msg or "nonce" in msg or "timestamp" in msg):
+                self._breaker_trip()
+                time.sleep(0.75)
+                throttle_bitstamp_request()
+                return super()._post(url, data=data, return_json=return_json, version=version, **kwargs)
+            raise
 
     def order_status_v2(self, order_id, client_order_id=None, omit_transactions=None):
         data = {'id': order_id}

+ 113 - 45
src/exec_mcp/services_bitstamp.py

@@ -1,5 +1,7 @@
 from __future__ import annotations
 
+import threading
+
 try:
     import bitstamp.client
 except ModuleNotFoundError:  # allows test runs without the optional dependency
@@ -10,6 +12,9 @@ from .bitstamp_fx import load_eur_usd
 
 BALANCE_CACHE_TTL_SECONDS = 20
 ACCOUNT_INFO_CACHE_TTL_SECONDS = 30
+STALE_CACHE_TTL_SECONDS = 10 * 60
+_CACHE_LOCKS: dict[str, threading.Lock] = {}
+_CACHE_LOCKS_GUARD = threading.Lock()
 
 
 def _ttl_from_env(name: str, default_seconds: int) -> int:
@@ -19,6 +24,28 @@ def _ttl_from_env(name: str, default_seconds: int) -> int:
         return default_seconds
 
 
+def _cache_lock(cache_key: str) -> threading.Lock:
+    with _CACHE_LOCKS_GUARD:
+        lock = _CACHE_LOCKS.get(cache_key)
+        if lock is None:
+            lock = threading.Lock()
+            _CACHE_LOCKS[cache_key] = lock
+        return lock
+
+
+def _cache_error(cache_key: str, detail: str, ttl_seconds: int = 15) -> None:
+    repo.cache_put(cache_key, {"_cached_error": detail}, ttl_seconds)
+
+
+def _raise_cached_error(payload: dict) -> None:
+    detail = str(payload.get("_cached_error") or "Bitstamp request failed")
+    raise RuntimeError(detail)
+
+
+def _stale_key(cache_key: str) -> str:
+    return f"{cache_key}:stale"
+
+
 def _require_client() -> None:
     if bitstamp is None:
         raise RuntimeError("bitstamp-python-client dependency is not installed")
@@ -63,59 +90,100 @@ def fetch_account_balance(account_id: str) -> dict:
     cache_key = f"bitstamp:account_balance:{account_id}"
     cached = repo.cache_get(cache_key)
     if cached is not None:
+        if isinstance(cached, dict) and cached.get("_cached_error"):
+            stale = repo.cache_get(_stale_key(cache_key))
+            if stale is not None:
+                return stale
+            _raise_cached_error(cached)
         return cached
 
-    client = _build_trading_client(account_id)
-    payload = client._post("account_balances/", return_json=True, version=2)
-    normalized = _normalize_account_balances_payload(payload, account_id)
+    with _cache_lock(cache_key):
+        cached = repo.cache_get(cache_key)
+        if cached is not None:
+            if isinstance(cached, dict) and cached.get("_cached_error"):
+                stale = repo.cache_get(_stale_key(cache_key))
+                if stale is not None:
+                    return stale
+                _raise_cached_error(cached)
+            return cached
+
+        try:
+            client = _build_trading_client(account_id)
+            payload = client._post("account_balances/", return_json=True, version=2)
+        except Exception as exc:
+            _cache_error(cache_key, str(exc))
+            raise
+
+        normalized = _normalize_account_balances_payload(payload, account_id)
 
-    result = {"source": "bitstamp", "cached": False, "balances": normalized, "payload": payload}
-    repo.cache_put(cache_key, result, _ttl_from_env("BITSTAMP_BALANCE_CACHE_TTL_SECONDS", BALANCE_CACHE_TTL_SECONDS))
-    return result
+        result = {"source": "bitstamp", "cached": False, "balances": normalized, "payload": payload}
+        repo.cache_put(cache_key, result, _ttl_from_env("BITSTAMP_BALANCE_CACHE_TTL_SECONDS", BALANCE_CACHE_TTL_SECONDS))
+        repo.cache_put(_stale_key(cache_key), result, STALE_CACHE_TTL_SECONDS)
+        return result
 
 
 def fetch_account_info(account_id: str) -> dict:
     cache_key = f"bitstamp:account_info:{account_id}"
     cached = repo.cache_get(cache_key)
     if cached is not None:
+        if isinstance(cached, dict) and cached.get("_cached_error"):
+            stale = repo.cache_get(_stale_key(cache_key))
+            if stale is not None:
+                return stale
+            _raise_cached_error(cached)
         return cached
 
-    account = repo.get_account(account_id)
-    balance = fetch_account_balance(account_id)
-
-    valued_balances = []
-    total_value_usd = 0.0
-    for item in balance["balances"]:
-        asset = item["asset_code"].lower()
-        total = float(item["total"])
-        if asset == "usd":
-            value_usd = total
-        else:
-            value_usd = None
-            market = f"{asset}usd"
-            price = repo.get_latest_price(market)
-            if price is not None:
-                value_usd = total * price
-            elif asset == "eur":
-                fx = load_eur_usd()
-                if fx and fx.get("sell") is not None:
-                    value_usd = total * float(fx["sell"])
-        if value_usd is not None:
-            total_value_usd += value_usd
-        valued_balances.append({**item, "value_currency": "USD", "value_usd": value_usd})
-
-    result = {
-        "id": account["id"],
-        "display_name": account["display_name"],
-        "venue": account["venue"],
-        "venue_account_ref": account["venue_account_ref"],
-        "description": account["description"],
-        "enabled": account["enabled"],
-        "metadata": account["metadata"],
-        "balances": valued_balances,
-        "total_value_usd": total_value_usd,
-        "raw_balance": balance["payload"],
-    }
-
-    repo.cache_put(cache_key, result, _ttl_from_env("BITSTAMP_ACCOUNT_INFO_CACHE_TTL_SECONDS", ACCOUNT_INFO_CACHE_TTL_SECONDS))
-    return result
+    with _cache_lock(cache_key):
+        cached = repo.cache_get(cache_key)
+        if cached is not None:
+            if isinstance(cached, dict) and cached.get("_cached_error"):
+                stale = repo.cache_get(_stale_key(cache_key))
+                if stale is not None:
+                    return stale
+                _raise_cached_error(cached)
+            return cached
+
+        try:
+            account = repo.get_account(account_id)
+            balance = fetch_account_balance(account_id)
+        except Exception as exc:
+            _cache_error(cache_key, str(exc))
+            raise
+
+        valued_balances = []
+        total_value_usd = 0.0
+        for item in balance["balances"]:
+            asset = item["asset_code"].lower()
+            total = float(item["total"])
+            if asset == "usd":
+                value_usd = total
+            else:
+                value_usd = None
+                market = f"{asset}usd"
+                price = repo.get_latest_price(market)
+                if price is not None:
+                    value_usd = total * price
+                elif asset == "eur":
+                    fx = load_eur_usd()
+                    if fx and fx.get("sell") is not None:
+                        value_usd = total * float(fx["sell"])
+            if value_usd is not None:
+                total_value_usd += value_usd
+            valued_balances.append({**item, "value_currency": "USD", "value_usd": value_usd})
+
+        result = {
+            "id": account["id"],
+            "display_name": account["display_name"],
+            "venue": account["venue"],
+            "venue_account_ref": account["venue_account_ref"],
+            "description": account["description"],
+            "enabled": account["enabled"],
+            "metadata": account["metadata"],
+            "balances": valued_balances,
+            "total_value_usd": total_value_usd,
+            "raw_balance": balance["payload"],
+        }
+
+        repo.cache_put(cache_key, result, _ttl_from_env("BITSTAMP_ACCOUNT_INFO_CACHE_TTL_SECONDS", ACCOUNT_INFO_CACHE_TTL_SECONDS))
+        repo.cache_put(_stale_key(cache_key), result, STALE_CACHE_TTL_SECONDS)
+        return result

+ 32 - 1
src/exec_mcp/services_orders.py

@@ -3,6 +3,7 @@ from __future__ import annotations
 import json
 import os
 import time
+import threading
 from datetime import datetime, timezone, timedelta
 from decimal import Decimal, ROUND_DOWN
 from uuid import uuid4
@@ -15,6 +16,24 @@ from .storage import get_connection
 
 
 OPEN_ORDER_STATUSES = {"open", "new", "partially_filled"}
+_CANCEL_BREAKER_LOCK = threading.Lock()
+_CANCEL_BREAKER_NEXT_ALLOWED: dict[str, float] = {}
+_CANCEL_BREAKER_SECONDS = 3.0
+
+
+def _cancel_breaker_is_open(account_id: str) -> bool:
+    with _CANCEL_BREAKER_LOCK:
+        return _CANCEL_BREAKER_NEXT_ALLOWED.get(account_id, 0.0) > time.monotonic()
+
+
+def _cancel_breaker_trip(account_id: str) -> None:
+    with _CANCEL_BREAKER_LOCK:
+        _CANCEL_BREAKER_NEXT_ALLOWED[account_id] = time.monotonic() + _CANCEL_BREAKER_SECONDS
+
+
+def _looks_like_auth_failure(detail: str) -> bool:
+    text = str(detail or "").lower()
+    return "403" in text and ("authentication failed" in text or "nonce" in text or "timestamp" in text)
 
 
 def _bitstamp_call_delay_seconds() -> float:
@@ -216,6 +235,9 @@ def cancel_all_orders(*, account_id: str, client_id: str | None = None) -> dict:
     results = []
     delay = _bitstamp_call_delay_seconds()
     for order in orders:
+        if _cancel_breaker_is_open(account_id):
+            results.append({"ok": False, "order_id": order.get("bitstamp_order_id"), "error": "cancel breaker active", "status": "deferred"})
+            continue
         bitstamp_order_id = order.get("bitstamp_order_id")
         if not bitstamp_order_id:
             results.append({"order_id": None, "ok": False, "error": "missing bitstamp_order_id"})
@@ -224,6 +246,10 @@ def cancel_all_orders(*, account_id: str, client_id: str | None = None) -> dict:
             results.append(cancel_order(account_id=account_id, order_id=bitstamp_order_id))
         except HTTPException as exc:
             detail = str(exc.detail)
+            if _looks_like_auth_failure(detail):
+                _cancel_breaker_trip(account_id)
+                results.append({"ok": False, "order_id": bitstamp_order_id, "error": detail, "status": "deferred"})
+                break
             if "not found" in detail.lower():
                 _set_local_order_status(bitstamp_order_id=bitstamp_order_id, status="missing")
                 results.append({"ok": False, "order_id": bitstamp_order_id, "error": detail, "status": "missing"})
@@ -254,11 +280,16 @@ def query_order(*, account_id: str, order_id, client_order_id: str | None = None
 
 def cancel_order(*, account_id: str, order_id) -> dict:
     order_id = str(order_id)
+    if _cancel_breaker_is_open(account_id):
+        raise HTTPException(status_code=503, detail="cancel breaker active, retry later")
     client = _get_client(account_id)
     try:
         result = client.trading.cancel_order(order_id=order_id, version=2)
     except BitstampError as exc:
-        raise HTTPException(status_code=400, detail=str(exc)) from exc
+        detail = str(exc)
+        if _looks_like_auth_failure(detail):
+            _cancel_breaker_trip(account_id)
+        raise HTTPException(status_code=400, detail=detail) from exc
 
     status = "cancelled" if result else "cancel_failed"
     with get_connection() as conn: