From dc9f0c04eb20e9711a52b4013990f668dd66bde5 Mon Sep 17 00:00:00 2001 From: Kimi Agent Date: Sat, 21 Mar 2026 16:23:16 +0000 Subject: [PATCH] [kimi] Add rate limiting middleware for Matrix API endpoints (#683) (#746) --- src/dashboard/app.py | 12 +- src/dashboard/middleware/__init__.py | 3 + src/dashboard/middleware/rate_limit.py | 209 ++++++++++++ tests/unit/test_rate_limit.py | 446 +++++++++++++++++++++++++ 4 files changed, 668 insertions(+), 2 deletions(-) create mode 100644 src/dashboard/middleware/rate_limit.py create mode 100644 tests/unit/test_rate_limit.py diff --git a/src/dashboard/app.py b/src/dashboard/app.py index 3bf6ac3..ebb0421 100644 --- a/src/dashboard/app.py +++ b/src/dashboard/app.py @@ -24,6 +24,7 @@ from config import settings # Import dedicated middleware from dashboard.middleware.csrf import CSRFMiddleware +from dashboard.middleware.rate_limit import RateLimitMiddleware from dashboard.middleware.request_logging import RequestLoggingMiddleware from dashboard.middleware.security_headers import SecurityHeadersMiddleware from dashboard.routes.agents import router as agents_router @@ -559,10 +560,17 @@ def _is_tailscale_origin(origin: str) -> bool: # 1. Logging (outermost to capture everything) app.add_middleware(RequestLoggingMiddleware, skip_paths=["/health"]) -# 2. Security Headers +# 2. Rate Limiting (before security to prevent abuse early) +app.add_middleware( + RateLimitMiddleware, + path_prefixes=["/api/matrix/"], + requests_per_minute=30, +) + +# 3. Security Headers app.add_middleware(SecurityHeadersMiddleware, production=not settings.debug) -# 3. CSRF Protection +# 4. CSRF Protection app.add_middleware(CSRFMiddleware) # 4. Standard FastAPI middleware diff --git a/src/dashboard/middleware/__init__.py b/src/dashboard/middleware/__init__.py index 24a85ff..98c2988 100644 --- a/src/dashboard/middleware/__init__.py +++ b/src/dashboard/middleware/__init__.py @@ -1,6 +1,7 @@ """Dashboard middleware package.""" from .csrf import CSRFMiddleware, csrf_exempt, generate_csrf_token, validate_csrf_token +from .rate_limit import RateLimiter, RateLimitMiddleware from .request_logging import RequestLoggingMiddleware from .security_headers import SecurityHeadersMiddleware @@ -9,6 +10,8 @@ __all__ = [ "csrf_exempt", "generate_csrf_token", "validate_csrf_token", + "RateLimiter", + "RateLimitMiddleware", "SecurityHeadersMiddleware", "RequestLoggingMiddleware", ] diff --git a/src/dashboard/middleware/rate_limit.py b/src/dashboard/middleware/rate_limit.py new file mode 100644 index 0000000..a6edf6d --- /dev/null +++ b/src/dashboard/middleware/rate_limit.py @@ -0,0 +1,209 @@ +"""Rate limiting middleware for FastAPI. + +Simple in-memory rate limiter for API endpoints. Tracks requests per IP +with configurable limits and automatic cleanup of stale entries. +""" + +import logging +import time +from collections import deque + +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.requests import Request +from starlette.responses import JSONResponse, Response + +logger = logging.getLogger(__name__) + + +class RateLimiter: + """In-memory rate limiter for tracking requests per IP. + + Stores request timestamps in a dict keyed by client IP. + Automatically cleans up stale entries every 60 seconds. + + Attributes: + requests_per_minute: Maximum requests allowed per minute per IP. + cleanup_interval_seconds: How often to clean stale entries. + """ + + def __init__( + self, + requests_per_minute: int = 30, + cleanup_interval_seconds: int = 60, + ): + self.requests_per_minute = requests_per_minute + self.cleanup_interval_seconds = cleanup_interval_seconds + self._storage: dict[str, deque[float]] = {} + self._last_cleanup: float = time.time() + self._window_seconds: float = 60.0 # 1 minute window + + def _get_client_ip(self, request: Request) -> str: + """Extract client IP from request, respecting X-Forwarded-For header. + + Args: + request: The incoming request. + + Returns: + Client IP address string. + """ + # Check for forwarded IP (when behind proxy/load balancer) + forwarded = request.headers.get("x-forwarded-for") + if forwarded: + # Take the first IP in the chain + return forwarded.split(",")[0].strip() + + real_ip = request.headers.get("x-real-ip") + if real_ip: + return real_ip + + # Fall back to direct connection + if request.client: + return request.client.host + + return "unknown" + + def _cleanup_if_needed(self) -> None: + """Remove stale entries older than the cleanup interval.""" + now = time.time() + if now - self._last_cleanup < self.cleanup_interval_seconds: + return + + cutoff = now - self._window_seconds + stale_ips: list[str] = [] + + for ip, timestamps in self._storage.items(): + # Remove timestamps older than the window + while timestamps and timestamps[0] < cutoff: + timestamps.popleft() + # Mark IP for removal if no recent requests + if not timestamps: + stale_ips.append(ip) + + # Remove stale IP entries + for ip in stale_ips: + del self._storage[ip] + + self._last_cleanup = now + if stale_ips: + logger.debug("Rate limiter cleanup: removed %d stale IPs", len(stale_ips)) + + def is_allowed(self, client_ip: str) -> tuple[bool, float]: + """Check if a request from the given IP is allowed. + + Args: + client_ip: The client's IP address. + + Returns: + Tuple of (allowed: bool, retry_after: float). + retry_after is seconds until next allowed request, 0 if allowed now. + """ + now = time.time() + cutoff = now - self._window_seconds + + # Get or create timestamp deque for this IP + if client_ip not in self._storage: + self._storage[client_ip] = deque() + + timestamps = self._storage[client_ip] + + # Remove timestamps outside the window + while timestamps and timestamps[0] < cutoff: + timestamps.popleft() + + # Check if limit exceeded + if len(timestamps) >= self.requests_per_minute: + # Calculate retry after time + oldest = timestamps[0] + retry_after = self._window_seconds - (now - oldest) + return False, max(0.0, retry_after) + + # Record this request + timestamps.append(now) + return True, 0.0 + + def check_request(self, request: Request) -> tuple[bool, float]: + """Check if the request is allowed under rate limits. + + Args: + request: The incoming request. + + Returns: + Tuple of (allowed: bool, retry_after: float). + """ + self._cleanup_if_needed() + client_ip = self._get_client_ip(request) + return self.is_allowed(client_ip) + + +class RateLimitMiddleware(BaseHTTPMiddleware): + """Middleware to apply rate limiting to specific routes. + + Usage: + # Apply to all routes (not recommended for public static files) + app.add_middleware(RateLimitMiddleware) + + # Apply only to specific paths + app.add_middleware( + RateLimitMiddleware, + path_prefixes=["/api/matrix/"], + requests_per_minute=30, + ) + + Attributes: + path_prefixes: List of URL path prefixes to rate limit. + If empty, applies to all paths. + requests_per_minute: Maximum requests per minute per IP. + """ + + def __init__( + self, + app, + path_prefixes: list[str] | None = None, + requests_per_minute: int = 30, + ): + super().__init__(app) + self.path_prefixes = path_prefixes or [] + self.limiter = RateLimiter(requests_per_minute=requests_per_minute) + + def _should_rate_limit(self, path: str) -> bool: + """Check if the given path should be rate limited. + + Args: + path: The request URL path. + + Returns: + True if path matches any configured prefix. + """ + if not self.path_prefixes: + return True + return any(path.startswith(prefix) for prefix in self.path_prefixes) + + async def dispatch(self, request: Request, call_next) -> Response: + """Apply rate limiting to configured paths. + + Args: + request: The incoming request. + call_next: Callable to get the response from downstream. + + Returns: + Response from downstream, or 429 if rate limited. + """ + # Skip if path doesn't match configured prefixes + if not self._should_rate_limit(request.url.path): + return await call_next(request) + + # Check rate limit + allowed, retry_after = self.limiter.check_request(request) + + if not allowed: + return JSONResponse( + status_code=429, + content={ + "error": "Rate limit exceeded. Try again later.", + "retry_after": int(retry_after) + 1, + }, + headers={"Retry-After": str(int(retry_after) + 1)}, + ) + + # Process the request + return await call_next(request) diff --git a/tests/unit/test_rate_limit.py b/tests/unit/test_rate_limit.py new file mode 100644 index 0000000..d37e26e --- /dev/null +++ b/tests/unit/test_rate_limit.py @@ -0,0 +1,446 @@ +"""Tests for rate limiting middleware. + +Tests the RateLimiter class and RateLimitMiddleware for correct +rate limiting behavior, cleanup, and edge cases. +""" + +import time +from unittest.mock import Mock + +import pytest +from starlette.requests import Request +from starlette.responses import JSONResponse + +from dashboard.middleware.rate_limit import RateLimiter, RateLimitMiddleware + + +class TestRateLimiter: + """Tests for the RateLimiter class.""" + + def test_init_defaults(self): + """RateLimiter initializes with default values.""" + limiter = RateLimiter() + assert limiter.requests_per_minute == 30 + assert limiter.cleanup_interval_seconds == 60 + assert limiter._storage == {} + + def test_init_custom_values(self): + """RateLimiter accepts custom configuration.""" + limiter = RateLimiter(requests_per_minute=60, cleanup_interval_seconds=120) + assert limiter.requests_per_minute == 60 + assert limiter.cleanup_interval_seconds == 120 + + def test_is_allowed_first_request(self): + """First request from an IP is always allowed.""" + limiter = RateLimiter(requests_per_minute=5) + allowed, retry_after = limiter.is_allowed("192.168.1.1") + assert allowed is True + assert retry_after == 0.0 + assert "192.168.1.1" in limiter._storage + assert len(limiter._storage["192.168.1.1"]) == 1 + + def test_is_allowed_under_limit(self): + """Requests under the limit are allowed.""" + limiter = RateLimiter(requests_per_minute=5) + + # Make 4 requests (under limit of 5) + for _ in range(4): + allowed, _ = limiter.is_allowed("192.168.1.1") + assert allowed is True + + assert len(limiter._storage["192.168.1.1"]) == 4 + + def test_is_allowed_at_limit(self): + """Request at the limit is allowed.""" + limiter = RateLimiter(requests_per_minute=5) + + # Make exactly 5 requests + for _ in range(5): + allowed, _ = limiter.is_allowed("192.168.1.1") + assert allowed is True + + assert len(limiter._storage["192.168.1.1"]) == 5 + + def test_is_allowed_over_limit(self): + """Request over the limit is denied.""" + limiter = RateLimiter(requests_per_minute=5) + + # Make 5 requests to hit the limit + for _ in range(5): + limiter.is_allowed("192.168.1.1") + + # 6th request should be denied + allowed, retry_after = limiter.is_allowed("192.168.1.1") + assert allowed is False + assert retry_after > 0 + + def test_is_allowed_different_ips(self): + """Rate limiting is per-IP, not global.""" + limiter = RateLimiter(requests_per_minute=5) + + # Hit limit for IP 1 + for _ in range(5): + limiter.is_allowed("192.168.1.1") + + # IP 1 is now rate limited + allowed, _ = limiter.is_allowed("192.168.1.1") + assert allowed is False + + # IP 2 should still be allowed + allowed, _ = limiter.is_allowed("192.168.1.2") + assert allowed is True + + def test_window_expiration_allows_new_requests(self): + """After window expires, new requests are allowed.""" + limiter = RateLimiter(requests_per_minute=5) + + # Hit the limit + for _ in range(5): + limiter.is_allowed("192.168.1.1") + + # Should be rate limited + allowed, _ = limiter.is_allowed("192.168.1.1") + assert allowed is False + + # Simulate time passing by clearing timestamps manually + # (we can't wait 60 seconds in a test) + limiter._storage["192.168.1.1"].clear() + + # Should now be allowed again + allowed, _ = limiter.is_allowed("192.168.1.1") + assert allowed is True + + def test_cleanup_removes_stale_entries(self): + """Cleanup removes IPs with no recent requests.""" + limiter = RateLimiter( + requests_per_minute=5, + cleanup_interval_seconds=1, # Short interval for testing + ) + + # Add some requests + limiter.is_allowed("192.168.1.1") + limiter.is_allowed("192.168.1.2") + + # Both IPs should be in storage + assert "192.168.1.1" in limiter._storage + assert "192.168.1.2" in limiter._storage + + # Manually clear timestamps to simulate stale data + limiter._storage["192.168.1.1"].clear() + limiter._last_cleanup = time.time() - 2 # Force cleanup + + # Trigger cleanup via check_request with a mock + mock_request = Mock() + mock_request.headers = {} + mock_request.client = Mock() + mock_request.client.host = "192.168.1.3" + mock_request.url.path = "/api/matrix/test" + + limiter.check_request(mock_request) + + # Stale IP should be removed + assert "192.168.1.1" not in limiter._storage + # IP with no requests (cleared) is also stale + assert "192.168.1.2" in limiter._storage + + def test_get_client_ip_direct(self): + """Extract client IP from direct connection.""" + limiter = RateLimiter() + + mock_request = Mock() + mock_request.headers = {} + mock_request.client = Mock() + mock_request.client.host = "192.168.1.100" + + ip = limiter._get_client_ip(mock_request) + assert ip == "192.168.1.100" + + def test_get_client_ip_x_forwarded_for(self): + """Extract client IP from X-Forwarded-For header.""" + limiter = RateLimiter() + + mock_request = Mock() + mock_request.headers = {"x-forwarded-for": "10.0.0.1, 192.168.1.1"} + mock_request.client = Mock() + mock_request.client.host = "192.168.1.100" + + ip = limiter._get_client_ip(mock_request) + assert ip == "10.0.0.1" + + def test_get_client_ip_x_real_ip(self): + """Extract client IP from X-Real-IP header.""" + limiter = RateLimiter() + + mock_request = Mock() + mock_request.headers = {"x-real-ip": "10.0.0.5"} + mock_request.client = Mock() + mock_request.client.host = "192.168.1.100" + + ip = limiter._get_client_ip(mock_request) + assert ip == "10.0.0.5" + + def test_get_client_ip_no_client(self): + """Return 'unknown' when no client info available.""" + limiter = RateLimiter() + + mock_request = Mock() + mock_request.headers = {} + mock_request.client = None + + ip = limiter._get_client_ip(mock_request) + assert ip == "unknown" + + +class TestRateLimitMiddleware: + """Tests for the RateLimitMiddleware class.""" + + @pytest.fixture + def mock_app(self): + """Create a mock ASGI app.""" + + async def app(scope, receive, send): + response = JSONResponse({"status": "ok"}) + await response(scope, receive, send) + + return app + + @pytest.fixture + def mock_request(self): + """Create a mock Request object.""" + request = Mock(spec=Request) + request.url.path = "/api/matrix/test" + request.headers = {} + request.client = Mock() + request.client.host = "192.168.1.1" + return request + + def test_init_defaults(self, mock_app): + """Middleware initializes with default values.""" + middleware = RateLimitMiddleware(mock_app) + assert middleware.path_prefixes == [] + assert middleware.limiter.requests_per_minute == 30 + + def test_init_custom_values(self, mock_app): + """Middleware accepts custom configuration.""" + middleware = RateLimitMiddleware( + mock_app, + path_prefixes=["/api/matrix/"], + requests_per_minute=60, + ) + assert middleware.path_prefixes == ["/api/matrix/"] + assert middleware.limiter.requests_per_minute == 60 + + def test_should_rate_limit_no_prefixes(self, mock_app): + """With no prefixes, all paths are rate limited.""" + middleware = RateLimitMiddleware(mock_app) + assert middleware._should_rate_limit("/api/matrix/test") is True + assert middleware._should_rate_limit("/api/other/test") is True + assert middleware._should_rate_limit("/health") is True + + def test_should_rate_limit_with_prefixes(self, mock_app): + """With prefixes, only matching paths are rate limited.""" + middleware = RateLimitMiddleware( + mock_app, + path_prefixes=["/api/matrix/", "/api/public/"], + ) + assert middleware._should_rate_limit("/api/matrix/test") is True + assert middleware._should_rate_limit("/api/matrix/") is True + assert middleware._should_rate_limit("/api/public/data") is True + assert middleware._should_rate_limit("/api/other/test") is False + assert middleware._should_rate_limit("/health") is False + + @pytest.mark.asyncio + async def test_dispatch_allows_matching_path_under_limit(self, mock_app): + """Request to matching path under limit is allowed.""" + middleware = RateLimitMiddleware( + mock_app, + path_prefixes=["/api/matrix/"], + requests_per_minute=5, + ) + + # Create a proper ASGI scope + scope = { + "type": "http", + "method": "GET", + "path": "/api/matrix/test", + "headers": [], + } + + async def receive(): + return {"type": "http.request", "body": b""} + + response_body = [] + + async def send(message): + response_body.append(message) + + await middleware(scope, receive, send) + + # Should have sent response messages + assert len(response_body) > 0 + # Check for 200 status in the response start message + start_message = next( + (m for m in response_body if m.get("type") == "http.response.start"), None + ) + assert start_message is not None + assert start_message["status"] == 200 + + @pytest.mark.asyncio + async def test_dispatch_skips_non_matching_path(self, mock_app): + """Request to non-matching path bypasses rate limiting.""" + middleware = RateLimitMiddleware( + mock_app, + path_prefixes=["/api/matrix/"], + requests_per_minute=5, + ) + + scope = { + "type": "http", + "method": "GET", + "path": "/api/other/test", # Doesn't match /api/matrix/ + "headers": [], + } + + async def receive(): + return {"type": "http.request", "body": b""} + + response_body = [] + + async def send(message): + response_body.append(message) + + await middleware(scope, receive, send) + + # Should have sent response messages + assert len(response_body) > 0 + start_message = next( + (m for m in response_body if m.get("type") == "http.response.start"), None + ) + assert start_message is not None + assert start_message["status"] == 200 + + @pytest.mark.asyncio + async def test_dispatch_returns_429_when_rate_limited(self, mock_app): + """Request over limit returns 429 status.""" + middleware = RateLimitMiddleware( + mock_app, + path_prefixes=["/api/matrix/"], + requests_per_minute=2, # Low limit for testing + ) + + # First request - allowed + test_scope = { + "type": "http", + "method": "GET", + "path": "/api/matrix/test", + "headers": [], + } + + async def receive(): + return {"type": "http.request", "body": b""} + + # Helper to capture response + def make_send(captured): + async def send(message): + captured.append(message) + + return send + + # Make requests to hit the limit + for _ in range(2): + response_body = [] + await middleware(test_scope, receive, make_send(response_body)) + + start_message = next( + (m for m in response_body if m.get("type") == "http.response.start"), + None, + ) + assert start_message["status"] == 200 + + # 3rd request should be rate limited + response_body = [] + await middleware(test_scope, receive, make_send(response_body)) + + start_message = next( + (m for m in response_body if m.get("type") == "http.response.start"), None + ) + assert start_message["status"] == 429 + + # Check for Retry-After header + headers = dict(start_message.get("headers", [])) + assert b"retry-after" in headers or b"Retry-After" in headers + + +class TestRateLimiterIntegration: + """Integration-style tests for rate limiter behavior.""" + + def test_multiple_ips_independent_limits(self): + """Each IP has its own independent rate limit.""" + limiter = RateLimiter(requests_per_minute=3) + + # Use up limit for IP 1 + for _ in range(3): + limiter.is_allowed("10.0.0.1") + + # Use up limit for IP 2 + for _ in range(3): + limiter.is_allowed("10.0.0.2") + + # Both should now be rate limited + assert limiter.is_allowed("10.0.0.1")[0] is False + assert limiter.is_allowed("10.0.0.2")[0] is False + + # IP 3 should still be allowed + assert limiter.is_allowed("10.0.0.3")[0] is True + + def test_timestamp_window_sliding(self): + """Rate limit window slides correctly as time passes.""" + from collections import deque + + limiter = RateLimiter(requests_per_minute=3) + + # Add 3 timestamps manually (simulating old requests) + now = time.time() + limiter._storage["test-ip"] = deque( + [ + now - 100, # 100 seconds ago (outside 60s window) + now - 50, # 50 seconds ago (inside window) + now - 10, # 10 seconds ago (inside window) + ] + ) + + # Currently have 2 requests in window, so 1 more allowed + allowed, _ = limiter.is_allowed("test-ip") + assert allowed is True + + # Now 3 in window, should be rate limited + allowed, _ = limiter.is_allowed("test-ip") + assert allowed is False + + def test_cleanup_preserves_active_ips(self): + """Cleanup only removes IPs with no recent requests.""" + from collections import deque + + limiter = RateLimiter( + requests_per_minute=3, + cleanup_interval_seconds=1, + ) + + now = time.time() + # IP 1: active recently + limiter._storage["active-ip"] = deque([now - 10]) + # IP 2: no timestamps (stale) + limiter._storage["stale-ip"] = deque() + # IP 3: old timestamps only + limiter._storage["old-ip"] = deque([now - 100]) + + limiter._last_cleanup = now - 2 # Force cleanup + + # Run cleanup + limiter._cleanup_if_needed() + + # Active IP should remain + assert "active-ip" in limiter._storage + # Stale IPs should be removed + assert "stale-ip" not in limiter._storage + assert "old-ip" not in limiter._storage