"""Tests for rate limiting middleware. Tests the RateLimiter class and RateLimitMiddleware for correct rate limiting behavior, cleanup, and edge cases. """ import time from unittest.mock import Mock import pytest from starlette.requests import Request from starlette.responses import JSONResponse from dashboard.middleware.rate_limit import RateLimiter, RateLimitMiddleware class TestRateLimiter: """Tests for the RateLimiter class.""" def test_init_defaults(self): """RateLimiter initializes with default values.""" limiter = RateLimiter() assert limiter.requests_per_minute == 30 assert limiter.cleanup_interval_seconds == 60 assert limiter._storage == {} def test_init_custom_values(self): """RateLimiter accepts custom configuration.""" limiter = RateLimiter(requests_per_minute=60, cleanup_interval_seconds=120) assert limiter.requests_per_minute == 60 assert limiter.cleanup_interval_seconds == 120 def test_is_allowed_first_request(self): """First request from an IP is always allowed.""" limiter = RateLimiter(requests_per_minute=5) allowed, retry_after = limiter.is_allowed("192.168.1.1") assert allowed is True assert retry_after == 0.0 assert "192.168.1.1" in limiter._storage assert len(limiter._storage["192.168.1.1"]) == 1 def test_is_allowed_under_limit(self): """Requests under the limit are allowed.""" limiter = RateLimiter(requests_per_minute=5) # Make 4 requests (under limit of 5) for _ in range(4): allowed, _ = limiter.is_allowed("192.168.1.1") assert allowed is True assert len(limiter._storage["192.168.1.1"]) == 4 def test_is_allowed_at_limit(self): """Request at the limit is allowed.""" limiter = RateLimiter(requests_per_minute=5) # Make exactly 5 requests for _ in range(5): allowed, _ = limiter.is_allowed("192.168.1.1") assert allowed is True assert len(limiter._storage["192.168.1.1"]) == 5 def test_is_allowed_over_limit(self): """Request over the limit is denied.""" limiter = RateLimiter(requests_per_minute=5) # Make 5 requests to hit the limit for _ in range(5): limiter.is_allowed("192.168.1.1") # 6th request should be denied allowed, retry_after = limiter.is_allowed("192.168.1.1") assert allowed is False assert retry_after > 0 def test_is_allowed_different_ips(self): """Rate limiting is per-IP, not global.""" limiter = RateLimiter(requests_per_minute=5) # Hit limit for IP 1 for _ in range(5): limiter.is_allowed("192.168.1.1") # IP 1 is now rate limited allowed, _ = limiter.is_allowed("192.168.1.1") assert allowed is False # IP 2 should still be allowed allowed, _ = limiter.is_allowed("192.168.1.2") assert allowed is True def test_window_expiration_allows_new_requests(self): """After window expires, new requests are allowed.""" limiter = RateLimiter(requests_per_minute=5) # Hit the limit for _ in range(5): limiter.is_allowed("192.168.1.1") # Should be rate limited allowed, _ = limiter.is_allowed("192.168.1.1") assert allowed is False # Simulate time passing by clearing timestamps manually # (we can't wait 60 seconds in a test) limiter._storage["192.168.1.1"].clear() # Should now be allowed again allowed, _ = limiter.is_allowed("192.168.1.1") assert allowed is True def test_cleanup_removes_stale_entries(self): """Cleanup removes IPs with no recent requests.""" limiter = RateLimiter( requests_per_minute=5, cleanup_interval_seconds=1, # Short interval for testing ) # Add some requests limiter.is_allowed("192.168.1.1") limiter.is_allowed("192.168.1.2") # Both IPs should be in storage assert "192.168.1.1" in limiter._storage assert "192.168.1.2" in limiter._storage # Manually clear timestamps to simulate stale data limiter._storage["192.168.1.1"].clear() limiter._last_cleanup = time.time() - 2 # Force cleanup # Trigger cleanup via check_request with a mock mock_request = Mock() mock_request.headers = {} mock_request.client = Mock() mock_request.client.host = "192.168.1.3" mock_request.url.path = "/api/matrix/test" limiter.check_request(mock_request) # Stale IP should be removed assert "192.168.1.1" not in limiter._storage # IP with no requests (cleared) is also stale assert "192.168.1.2" in limiter._storage def test_get_client_ip_direct(self): """Extract client IP from direct connection.""" limiter = RateLimiter() mock_request = Mock() mock_request.headers = {} mock_request.client = Mock() mock_request.client.host = "192.168.1.100" ip = limiter._get_client_ip(mock_request) assert ip == "192.168.1.100" def test_get_client_ip_x_forwarded_for(self): """Extract client IP from X-Forwarded-For header.""" limiter = RateLimiter() mock_request = Mock() mock_request.headers = {"x-forwarded-for": "10.0.0.1, 192.168.1.1"} mock_request.client = Mock() mock_request.client.host = "192.168.1.100" ip = limiter._get_client_ip(mock_request) assert ip == "10.0.0.1" def test_get_client_ip_x_real_ip(self): """Extract client IP from X-Real-IP header.""" limiter = RateLimiter() mock_request = Mock() mock_request.headers = {"x-real-ip": "10.0.0.5"} mock_request.client = Mock() mock_request.client.host = "192.168.1.100" ip = limiter._get_client_ip(mock_request) assert ip == "10.0.0.5" def test_get_client_ip_no_client(self): """Return 'unknown' when no client info available.""" limiter = RateLimiter() mock_request = Mock() mock_request.headers = {} mock_request.client = None ip = limiter._get_client_ip(mock_request) assert ip == "unknown" class TestRateLimitMiddleware: """Tests for the RateLimitMiddleware class.""" @pytest.fixture def mock_app(self): """Create a mock ASGI app.""" async def app(scope, receive, send): response = JSONResponse({"status": "ok"}) await response(scope, receive, send) return app @pytest.fixture def mock_request(self): """Create a mock Request object.""" request = Mock(spec=Request) request.url.path = "/api/matrix/test" request.headers = {} request.client = Mock() request.client.host = "192.168.1.1" return request def test_init_defaults(self, mock_app): """Middleware initializes with default values.""" middleware = RateLimitMiddleware(mock_app) assert middleware.path_prefixes == [] assert middleware.limiter.requests_per_minute == 30 def test_init_custom_values(self, mock_app): """Middleware accepts custom configuration.""" middleware = RateLimitMiddleware( mock_app, path_prefixes=["/api/matrix/"], requests_per_minute=60, ) assert middleware.path_prefixes == ["/api/matrix/"] assert middleware.limiter.requests_per_minute == 60 def test_should_rate_limit_no_prefixes(self, mock_app): """With no prefixes, all paths are rate limited.""" middleware = RateLimitMiddleware(mock_app) assert middleware._should_rate_limit("/api/matrix/test") is True assert middleware._should_rate_limit("/api/other/test") is True assert middleware._should_rate_limit("/health") is True def test_should_rate_limit_with_prefixes(self, mock_app): """With prefixes, only matching paths are rate limited.""" middleware = RateLimitMiddleware( mock_app, path_prefixes=["/api/matrix/", "/api/public/"], ) assert middleware._should_rate_limit("/api/matrix/test") is True assert middleware._should_rate_limit("/api/matrix/") is True assert middleware._should_rate_limit("/api/public/data") is True assert middleware._should_rate_limit("/api/other/test") is False assert middleware._should_rate_limit("/health") is False @pytest.mark.asyncio async def test_dispatch_allows_matching_path_under_limit(self, mock_app): """Request to matching path under limit is allowed.""" middleware = RateLimitMiddleware( mock_app, path_prefixes=["/api/matrix/"], requests_per_minute=5, ) # Create a proper ASGI scope scope = { "type": "http", "method": "GET", "path": "/api/matrix/test", "headers": [], } async def receive(): return {"type": "http.request", "body": b""} response_body = [] async def send(message): response_body.append(message) await middleware(scope, receive, send) # Should have sent response messages assert len(response_body) > 0 # Check for 200 status in the response start message start_message = next( (m for m in response_body if m.get("type") == "http.response.start"), None ) assert start_message is not None assert start_message["status"] == 200 @pytest.mark.asyncio async def test_dispatch_skips_non_matching_path(self, mock_app): """Request to non-matching path bypasses rate limiting.""" middleware = RateLimitMiddleware( mock_app, path_prefixes=["/api/matrix/"], requests_per_minute=5, ) scope = { "type": "http", "method": "GET", "path": "/api/other/test", # Doesn't match /api/matrix/ "headers": [], } async def receive(): return {"type": "http.request", "body": b""} response_body = [] async def send(message): response_body.append(message) await middleware(scope, receive, send) # Should have sent response messages assert len(response_body) > 0 start_message = next( (m for m in response_body if m.get("type") == "http.response.start"), None ) assert start_message is not None assert start_message["status"] == 200 @pytest.mark.asyncio async def test_dispatch_returns_429_when_rate_limited(self, mock_app): """Request over limit returns 429 status.""" middleware = RateLimitMiddleware( mock_app, path_prefixes=["/api/matrix/"], requests_per_minute=2, # Low limit for testing ) # First request - allowed test_scope = { "type": "http", "method": "GET", "path": "/api/matrix/test", "headers": [], } async def receive(): return {"type": "http.request", "body": b""} # Helper to capture response def make_send(captured): async def send(message): captured.append(message) return send # Make requests to hit the limit for _ in range(2): response_body = [] await middleware(test_scope, receive, make_send(response_body)) start_message = next( (m for m in response_body if m.get("type") == "http.response.start"), None, ) assert start_message["status"] == 200 # 3rd request should be rate limited response_body = [] await middleware(test_scope, receive, make_send(response_body)) start_message = next( (m for m in response_body if m.get("type") == "http.response.start"), None ) assert start_message["status"] == 429 # Check for Retry-After header headers = dict(start_message.get("headers", [])) assert b"retry-after" in headers or b"Retry-After" in headers class TestRateLimiterIntegration: """Integration-style tests for rate limiter behavior.""" def test_multiple_ips_independent_limits(self): """Each IP has its own independent rate limit.""" limiter = RateLimiter(requests_per_minute=3) # Use up limit for IP 1 for _ in range(3): limiter.is_allowed("10.0.0.1") # Use up limit for IP 2 for _ in range(3): limiter.is_allowed("10.0.0.2") # Both should now be rate limited assert limiter.is_allowed("10.0.0.1")[0] is False assert limiter.is_allowed("10.0.0.2")[0] is False # IP 3 should still be allowed assert limiter.is_allowed("10.0.0.3")[0] is True def test_timestamp_window_sliding(self): """Rate limit window slides correctly as time passes.""" from collections import deque limiter = RateLimiter(requests_per_minute=3) # Add 3 timestamps manually (simulating old requests) now = time.time() limiter._storage["test-ip"] = deque( [ now - 100, # 100 seconds ago (outside 60s window) now - 50, # 50 seconds ago (inside window) now - 10, # 10 seconds ago (inside window) ] ) # Currently have 2 requests in window, so 1 more allowed allowed, _ = limiter.is_allowed("test-ip") assert allowed is True # Now 3 in window, should be rate limited allowed, _ = limiter.is_allowed("test-ip") assert allowed is False def test_cleanup_preserves_active_ips(self): """Cleanup only removes IPs with no recent requests.""" from collections import deque limiter = RateLimiter( requests_per_minute=3, cleanup_interval_seconds=1, ) now = time.time() # IP 1: active recently limiter._storage["active-ip"] = deque([now - 10]) # IP 2: no timestamps (stale) limiter._storage["stale-ip"] = deque() # IP 3: old timestamps only limiter._storage["old-ip"] = deque([now - 100]) limiter._last_cleanup = now - 2 # Force cleanup # Run cleanup limiter._cleanup_if_needed() # Active IP should remain assert "active-ip" in limiter._storage # Stale IPs should be removed assert "stale-ip" not in limiter._storage assert "old-ip" not in limiter._storage