diff --git a/gateway/platforms/api_server.py b/gateway/platforms/api_server.py index 76a39b99a..2997e80de 100644 --- a/gateway/platforms/api_server.py +++ b/gateway/platforms/api_server.py @@ -241,6 +241,43 @@ else: security_headers_middleware = None # type: ignore[assignment] +# SECURITY FIX (V-016): Rate limiting middleware +if AIOHTTP_AVAILABLE: + @web.middleware + async def rate_limit_middleware(request, handler): + """Apply rate limiting per client IP. + + Returns 429 Too Many Requests if rate limit exceeded. + Configurable via API_SERVER_RATE_LIMIT env var (requests per minute). + """ + # Skip rate limiting for health checks + if request.path == "/health": + return await handler(request) + + # Get client IP (respecting X-Forwarded-For if behind proxy) + client_ip = request.headers.get("X-Forwarded-For", request.remote) + if client_ip and "," in client_ip: + client_ip = client_ip.split(",")[0].strip() + + limiter = _get_rate_limiter() + if not limiter.acquire(client_ip): + retry_after = limiter.get_retry_after(client_ip) + logger.warning(f"Rate limit exceeded for {client_ip}") + return web.json_response( + _openai_error( + f"Rate limit exceeded. Try again in {retry_after} seconds.", + err_type="rate_limit_error", + code="rate_limit_exceeded" + ), + status=429, + headers={"Retry-After": str(retry_after)} + ) + + return await handler(request) +else: + rate_limit_middleware = None # type: ignore[assignment] + + class _IdempotencyCache: """In-memory idempotency cache with TTL and basic LRU semantics.""" def __init__(self, max_items: int = 1000, ttl_seconds: int = 300): @@ -273,6 +310,59 @@ class _IdempotencyCache: _idem_cache = _IdempotencyCache() +# SECURITY FIX (V-016): Rate limiting +class _RateLimiter: + """Token bucket rate limiter per client IP. + + Default: 100 requests per minute per IP. + Configurable via API_SERVER_RATE_LIMIT env var (requests per minute). + """ + def __init__(self, requests_per_minute: int = 100): + from collections import defaultdict + self._buckets = defaultdict(lambda: {"tokens": requests_per_minute, "last": 0}) + self._rate = requests_per_minute / 60.0 # tokens per second + self._max_tokens = requests_per_minute + self._lock = threading.Lock() + + def _get_bucket(self, key: str) -> dict: + import time + with self._lock: + bucket = self._buckets[key] + now = time.time() + elapsed = now - bucket["last"] + bucket["last"] = now + # Add tokens based on elapsed time + bucket["tokens"] = min( + self._max_tokens, + bucket["tokens"] + elapsed * self._rate + ) + return bucket + + def acquire(self, key: str) -> bool: + """Try to acquire a token. Returns True if allowed, False if rate limited.""" + bucket = self._get_bucket(key) + with self._lock: + if bucket["tokens"] >= 1: + bucket["tokens"] -= 1 + return True + return False + + def get_retry_after(self, key: str) -> int: + """Get seconds until next token is available.""" + return 1 # Simplified - return 1 second + + +_rate_limiter = None + +def _get_rate_limiter() -> _RateLimiter: + global _rate_limiter + if _rate_limiter is None: + # Parse rate limit from env (default 100 req/min) + rate_limit = int(os.getenv("API_SERVER_RATE_LIMIT", "100")) + _rate_limiter = _RateLimiter(rate_limit) + return _rate_limiter + + def _make_request_fingerprint(body: Dict[str, Any], keys: List[str]) -> str: from hashlib import sha256 subset = {k: body.get(k) for k in keys} @@ -1282,7 +1372,8 @@ class APIServerAdapter(BasePlatformAdapter): return False try: - mws = [mw for mw in (cors_middleware, body_limit_middleware, security_headers_middleware) if mw is not None] + # SECURITY FIX (V-016): Add rate limiting middleware + mws = [mw for mw in (cors_middleware, body_limit_middleware, security_headers_middleware, rate_limit_middleware) if mw is not None] self._app = web.Application(middlewares=mws) self._app["api_server_adapter"] = self self._app.router.add_get("/health", self._handle_health)