forked from Rockachopa/Timmy-time-dashboard
447 lines
15 KiB
Python
447 lines
15 KiB
Python
|
|
"""Tests for rate limiting middleware.
|
||
|
|
|
||
|
|
Tests the RateLimiter class and RateLimitMiddleware for correct
|
||
|
|
rate limiting behavior, cleanup, and edge cases.
|
||
|
|
"""
|
||
|
|
|
||
|
|
import time
|
||
|
|
from unittest.mock import Mock
|
||
|
|
|
||
|
|
import pytest
|
||
|
|
from starlette.requests import Request
|
||
|
|
from starlette.responses import JSONResponse
|
||
|
|
|
||
|
|
from dashboard.middleware.rate_limit import RateLimiter, RateLimitMiddleware
|
||
|
|
|
||
|
|
|
||
|
|
class TestRateLimiter:
|
||
|
|
"""Tests for the RateLimiter class."""
|
||
|
|
|
||
|
|
def test_init_defaults(self):
|
||
|
|
"""RateLimiter initializes with default values."""
|
||
|
|
limiter = RateLimiter()
|
||
|
|
assert limiter.requests_per_minute == 30
|
||
|
|
assert limiter.cleanup_interval_seconds == 60
|
||
|
|
assert limiter._storage == {}
|
||
|
|
|
||
|
|
def test_init_custom_values(self):
|
||
|
|
"""RateLimiter accepts custom configuration."""
|
||
|
|
limiter = RateLimiter(requests_per_minute=60, cleanup_interval_seconds=120)
|
||
|
|
assert limiter.requests_per_minute == 60
|
||
|
|
assert limiter.cleanup_interval_seconds == 120
|
||
|
|
|
||
|
|
def test_is_allowed_first_request(self):
|
||
|
|
"""First request from an IP is always allowed."""
|
||
|
|
limiter = RateLimiter(requests_per_minute=5)
|
||
|
|
allowed, retry_after = limiter.is_allowed("192.168.1.1")
|
||
|
|
assert allowed is True
|
||
|
|
assert retry_after == 0.0
|
||
|
|
assert "192.168.1.1" in limiter._storage
|
||
|
|
assert len(limiter._storage["192.168.1.1"]) == 1
|
||
|
|
|
||
|
|
def test_is_allowed_under_limit(self):
|
||
|
|
"""Requests under the limit are allowed."""
|
||
|
|
limiter = RateLimiter(requests_per_minute=5)
|
||
|
|
|
||
|
|
# Make 4 requests (under limit of 5)
|
||
|
|
for _ in range(4):
|
||
|
|
allowed, _ = limiter.is_allowed("192.168.1.1")
|
||
|
|
assert allowed is True
|
||
|
|
|
||
|
|
assert len(limiter._storage["192.168.1.1"]) == 4
|
||
|
|
|
||
|
|
def test_is_allowed_at_limit(self):
|
||
|
|
"""Request at the limit is allowed."""
|
||
|
|
limiter = RateLimiter(requests_per_minute=5)
|
||
|
|
|
||
|
|
# Make exactly 5 requests
|
||
|
|
for _ in range(5):
|
||
|
|
allowed, _ = limiter.is_allowed("192.168.1.1")
|
||
|
|
assert allowed is True
|
||
|
|
|
||
|
|
assert len(limiter._storage["192.168.1.1"]) == 5
|
||
|
|
|
||
|
|
def test_is_allowed_over_limit(self):
|
||
|
|
"""Request over the limit is denied."""
|
||
|
|
limiter = RateLimiter(requests_per_minute=5)
|
||
|
|
|
||
|
|
# Make 5 requests to hit the limit
|
||
|
|
for _ in range(5):
|
||
|
|
limiter.is_allowed("192.168.1.1")
|
||
|
|
|
||
|
|
# 6th request should be denied
|
||
|
|
allowed, retry_after = limiter.is_allowed("192.168.1.1")
|
||
|
|
assert allowed is False
|
||
|
|
assert retry_after > 0
|
||
|
|
|
||
|
|
def test_is_allowed_different_ips(self):
|
||
|
|
"""Rate limiting is per-IP, not global."""
|
||
|
|
limiter = RateLimiter(requests_per_minute=5)
|
||
|
|
|
||
|
|
# Hit limit for IP 1
|
||
|
|
for _ in range(5):
|
||
|
|
limiter.is_allowed("192.168.1.1")
|
||
|
|
|
||
|
|
# IP 1 is now rate limited
|
||
|
|
allowed, _ = limiter.is_allowed("192.168.1.1")
|
||
|
|
assert allowed is False
|
||
|
|
|
||
|
|
# IP 2 should still be allowed
|
||
|
|
allowed, _ = limiter.is_allowed("192.168.1.2")
|
||
|
|
assert allowed is True
|
||
|
|
|
||
|
|
def test_window_expiration_allows_new_requests(self):
|
||
|
|
"""After window expires, new requests are allowed."""
|
||
|
|
limiter = RateLimiter(requests_per_minute=5)
|
||
|
|
|
||
|
|
# Hit the limit
|
||
|
|
for _ in range(5):
|
||
|
|
limiter.is_allowed("192.168.1.1")
|
||
|
|
|
||
|
|
# Should be rate limited
|
||
|
|
allowed, _ = limiter.is_allowed("192.168.1.1")
|
||
|
|
assert allowed is False
|
||
|
|
|
||
|
|
# Simulate time passing by clearing timestamps manually
|
||
|
|
# (we can't wait 60 seconds in a test)
|
||
|
|
limiter._storage["192.168.1.1"].clear()
|
||
|
|
|
||
|
|
# Should now be allowed again
|
||
|
|
allowed, _ = limiter.is_allowed("192.168.1.1")
|
||
|
|
assert allowed is True
|
||
|
|
|
||
|
|
def test_cleanup_removes_stale_entries(self):
|
||
|
|
"""Cleanup removes IPs with no recent requests."""
|
||
|
|
limiter = RateLimiter(
|
||
|
|
requests_per_minute=5,
|
||
|
|
cleanup_interval_seconds=1, # Short interval for testing
|
||
|
|
)
|
||
|
|
|
||
|
|
# Add some requests
|
||
|
|
limiter.is_allowed("192.168.1.1")
|
||
|
|
limiter.is_allowed("192.168.1.2")
|
||
|
|
|
||
|
|
# Both IPs should be in storage
|
||
|
|
assert "192.168.1.1" in limiter._storage
|
||
|
|
assert "192.168.1.2" in limiter._storage
|
||
|
|
|
||
|
|
# Manually clear timestamps to simulate stale data
|
||
|
|
limiter._storage["192.168.1.1"].clear()
|
||
|
|
limiter._last_cleanup = time.time() - 2 # Force cleanup
|
||
|
|
|
||
|
|
# Trigger cleanup via check_request with a mock
|
||
|
|
mock_request = Mock()
|
||
|
|
mock_request.headers = {}
|
||
|
|
mock_request.client = Mock()
|
||
|
|
mock_request.client.host = "192.168.1.3"
|
||
|
|
mock_request.url.path = "/api/matrix/test"
|
||
|
|
|
||
|
|
limiter.check_request(mock_request)
|
||
|
|
|
||
|
|
# Stale IP should be removed
|
||
|
|
assert "192.168.1.1" not in limiter._storage
|
||
|
|
# IP with no requests (cleared) is also stale
|
||
|
|
assert "192.168.1.2" in limiter._storage
|
||
|
|
|
||
|
|
def test_get_client_ip_direct(self):
|
||
|
|
"""Extract client IP from direct connection."""
|
||
|
|
limiter = RateLimiter()
|
||
|
|
|
||
|
|
mock_request = Mock()
|
||
|
|
mock_request.headers = {}
|
||
|
|
mock_request.client = Mock()
|
||
|
|
mock_request.client.host = "192.168.1.100"
|
||
|
|
|
||
|
|
ip = limiter._get_client_ip(mock_request)
|
||
|
|
assert ip == "192.168.1.100"
|
||
|
|
|
||
|
|
def test_get_client_ip_x_forwarded_for(self):
|
||
|
|
"""Extract client IP from X-Forwarded-For header."""
|
||
|
|
limiter = RateLimiter()
|
||
|
|
|
||
|
|
mock_request = Mock()
|
||
|
|
mock_request.headers = {"x-forwarded-for": "10.0.0.1, 192.168.1.1"}
|
||
|
|
mock_request.client = Mock()
|
||
|
|
mock_request.client.host = "192.168.1.100"
|
||
|
|
|
||
|
|
ip = limiter._get_client_ip(mock_request)
|
||
|
|
assert ip == "10.0.0.1"
|
||
|
|
|
||
|
|
def test_get_client_ip_x_real_ip(self):
|
||
|
|
"""Extract client IP from X-Real-IP header."""
|
||
|
|
limiter = RateLimiter()
|
||
|
|
|
||
|
|
mock_request = Mock()
|
||
|
|
mock_request.headers = {"x-real-ip": "10.0.0.5"}
|
||
|
|
mock_request.client = Mock()
|
||
|
|
mock_request.client.host = "192.168.1.100"
|
||
|
|
|
||
|
|
ip = limiter._get_client_ip(mock_request)
|
||
|
|
assert ip == "10.0.0.5"
|
||
|
|
|
||
|
|
def test_get_client_ip_no_client(self):
|
||
|
|
"""Return 'unknown' when no client info available."""
|
||
|
|
limiter = RateLimiter()
|
||
|
|
|
||
|
|
mock_request = Mock()
|
||
|
|
mock_request.headers = {}
|
||
|
|
mock_request.client = None
|
||
|
|
|
||
|
|
ip = limiter._get_client_ip(mock_request)
|
||
|
|
assert ip == "unknown"
|
||
|
|
|
||
|
|
|
||
|
|
class TestRateLimitMiddleware:
|
||
|
|
"""Tests for the RateLimitMiddleware class."""
|
||
|
|
|
||
|
|
@pytest.fixture
|
||
|
|
def mock_app(self):
|
||
|
|
"""Create a mock ASGI app."""
|
||
|
|
|
||
|
|
async def app(scope, receive, send):
|
||
|
|
response = JSONResponse({"status": "ok"})
|
||
|
|
await response(scope, receive, send)
|
||
|
|
|
||
|
|
return app
|
||
|
|
|
||
|
|
@pytest.fixture
|
||
|
|
def mock_request(self):
|
||
|
|
"""Create a mock Request object."""
|
||
|
|
request = Mock(spec=Request)
|
||
|
|
request.url.path = "/api/matrix/test"
|
||
|
|
request.headers = {}
|
||
|
|
request.client = Mock()
|
||
|
|
request.client.host = "192.168.1.1"
|
||
|
|
return request
|
||
|
|
|
||
|
|
def test_init_defaults(self, mock_app):
|
||
|
|
"""Middleware initializes with default values."""
|
||
|
|
middleware = RateLimitMiddleware(mock_app)
|
||
|
|
assert middleware.path_prefixes == []
|
||
|
|
assert middleware.limiter.requests_per_minute == 30
|
||
|
|
|
||
|
|
def test_init_custom_values(self, mock_app):
|
||
|
|
"""Middleware accepts custom configuration."""
|
||
|
|
middleware = RateLimitMiddleware(
|
||
|
|
mock_app,
|
||
|
|
path_prefixes=["/api/matrix/"],
|
||
|
|
requests_per_minute=60,
|
||
|
|
)
|
||
|
|
assert middleware.path_prefixes == ["/api/matrix/"]
|
||
|
|
assert middleware.limiter.requests_per_minute == 60
|
||
|
|
|
||
|
|
def test_should_rate_limit_no_prefixes(self, mock_app):
|
||
|
|
"""With no prefixes, all paths are rate limited."""
|
||
|
|
middleware = RateLimitMiddleware(mock_app)
|
||
|
|
assert middleware._should_rate_limit("/api/matrix/test") is True
|
||
|
|
assert middleware._should_rate_limit("/api/other/test") is True
|
||
|
|
assert middleware._should_rate_limit("/health") is True
|
||
|
|
|
||
|
|
def test_should_rate_limit_with_prefixes(self, mock_app):
|
||
|
|
"""With prefixes, only matching paths are rate limited."""
|
||
|
|
middleware = RateLimitMiddleware(
|
||
|
|
mock_app,
|
||
|
|
path_prefixes=["/api/matrix/", "/api/public/"],
|
||
|
|
)
|
||
|
|
assert middleware._should_rate_limit("/api/matrix/test") is True
|
||
|
|
assert middleware._should_rate_limit("/api/matrix/") is True
|
||
|
|
assert middleware._should_rate_limit("/api/public/data") is True
|
||
|
|
assert middleware._should_rate_limit("/api/other/test") is False
|
||
|
|
assert middleware._should_rate_limit("/health") is False
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_dispatch_allows_matching_path_under_limit(self, mock_app):
|
||
|
|
"""Request to matching path under limit is allowed."""
|
||
|
|
middleware = RateLimitMiddleware(
|
||
|
|
mock_app,
|
||
|
|
path_prefixes=["/api/matrix/"],
|
||
|
|
requests_per_minute=5,
|
||
|
|
)
|
||
|
|
|
||
|
|
# Create a proper ASGI scope
|
||
|
|
scope = {
|
||
|
|
"type": "http",
|
||
|
|
"method": "GET",
|
||
|
|
"path": "/api/matrix/test",
|
||
|
|
"headers": [],
|
||
|
|
}
|
||
|
|
|
||
|
|
async def receive():
|
||
|
|
return {"type": "http.request", "body": b""}
|
||
|
|
|
||
|
|
response_body = []
|
||
|
|
|
||
|
|
async def send(message):
|
||
|
|
response_body.append(message)
|
||
|
|
|
||
|
|
await middleware(scope, receive, send)
|
||
|
|
|
||
|
|
# Should have sent response messages
|
||
|
|
assert len(response_body) > 0
|
||
|
|
# Check for 200 status in the response start message
|
||
|
|
start_message = next(
|
||
|
|
(m for m in response_body if m.get("type") == "http.response.start"), None
|
||
|
|
)
|
||
|
|
assert start_message is not None
|
||
|
|
assert start_message["status"] == 200
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_dispatch_skips_non_matching_path(self, mock_app):
|
||
|
|
"""Request to non-matching path bypasses rate limiting."""
|
||
|
|
middleware = RateLimitMiddleware(
|
||
|
|
mock_app,
|
||
|
|
path_prefixes=["/api/matrix/"],
|
||
|
|
requests_per_minute=5,
|
||
|
|
)
|
||
|
|
|
||
|
|
scope = {
|
||
|
|
"type": "http",
|
||
|
|
"method": "GET",
|
||
|
|
"path": "/api/other/test", # Doesn't match /api/matrix/
|
||
|
|
"headers": [],
|
||
|
|
}
|
||
|
|
|
||
|
|
async def receive():
|
||
|
|
return {"type": "http.request", "body": b""}
|
||
|
|
|
||
|
|
response_body = []
|
||
|
|
|
||
|
|
async def send(message):
|
||
|
|
response_body.append(message)
|
||
|
|
|
||
|
|
await middleware(scope, receive, send)
|
||
|
|
|
||
|
|
# Should have sent response messages
|
||
|
|
assert len(response_body) > 0
|
||
|
|
start_message = next(
|
||
|
|
(m for m in response_body if m.get("type") == "http.response.start"), None
|
||
|
|
)
|
||
|
|
assert start_message is not None
|
||
|
|
assert start_message["status"] == 200
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_dispatch_returns_429_when_rate_limited(self, mock_app):
|
||
|
|
"""Request over limit returns 429 status."""
|
||
|
|
middleware = RateLimitMiddleware(
|
||
|
|
mock_app,
|
||
|
|
path_prefixes=["/api/matrix/"],
|
||
|
|
requests_per_minute=2, # Low limit for testing
|
||
|
|
)
|
||
|
|
|
||
|
|
# First request - allowed
|
||
|
|
test_scope = {
|
||
|
|
"type": "http",
|
||
|
|
"method": "GET",
|
||
|
|
"path": "/api/matrix/test",
|
||
|
|
"headers": [],
|
||
|
|
}
|
||
|
|
|
||
|
|
async def receive():
|
||
|
|
return {"type": "http.request", "body": b""}
|
||
|
|
|
||
|
|
# Helper to capture response
|
||
|
|
def make_send(captured):
|
||
|
|
async def send(message):
|
||
|
|
captured.append(message)
|
||
|
|
|
||
|
|
return send
|
||
|
|
|
||
|
|
# Make requests to hit the limit
|
||
|
|
for _ in range(2):
|
||
|
|
response_body = []
|
||
|
|
await middleware(test_scope, receive, make_send(response_body))
|
||
|
|
|
||
|
|
start_message = next(
|
||
|
|
(m for m in response_body if m.get("type") == "http.response.start"),
|
||
|
|
None,
|
||
|
|
)
|
||
|
|
assert start_message["status"] == 200
|
||
|
|
|
||
|
|
# 3rd request should be rate limited
|
||
|
|
response_body = []
|
||
|
|
await middleware(test_scope, receive, make_send(response_body))
|
||
|
|
|
||
|
|
start_message = next(
|
||
|
|
(m for m in response_body if m.get("type") == "http.response.start"), None
|
||
|
|
)
|
||
|
|
assert start_message["status"] == 429
|
||
|
|
|
||
|
|
# Check for Retry-After header
|
||
|
|
headers = dict(start_message.get("headers", []))
|
||
|
|
assert b"retry-after" in headers or b"Retry-After" in headers
|
||
|
|
|
||
|
|
|
||
|
|
class TestRateLimiterIntegration:
|
||
|
|
"""Integration-style tests for rate limiter behavior."""
|
||
|
|
|
||
|
|
def test_multiple_ips_independent_limits(self):
|
||
|
|
"""Each IP has its own independent rate limit."""
|
||
|
|
limiter = RateLimiter(requests_per_minute=3)
|
||
|
|
|
||
|
|
# Use up limit for IP 1
|
||
|
|
for _ in range(3):
|
||
|
|
limiter.is_allowed("10.0.0.1")
|
||
|
|
|
||
|
|
# Use up limit for IP 2
|
||
|
|
for _ in range(3):
|
||
|
|
limiter.is_allowed("10.0.0.2")
|
||
|
|
|
||
|
|
# Both should now be rate limited
|
||
|
|
assert limiter.is_allowed("10.0.0.1")[0] is False
|
||
|
|
assert limiter.is_allowed("10.0.0.2")[0] is False
|
||
|
|
|
||
|
|
# IP 3 should still be allowed
|
||
|
|
assert limiter.is_allowed("10.0.0.3")[0] is True
|
||
|
|
|
||
|
|
def test_timestamp_window_sliding(self):
|
||
|
|
"""Rate limit window slides correctly as time passes."""
|
||
|
|
from collections import deque
|
||
|
|
|
||
|
|
limiter = RateLimiter(requests_per_minute=3)
|
||
|
|
|
||
|
|
# Add 3 timestamps manually (simulating old requests)
|
||
|
|
now = time.time()
|
||
|
|
limiter._storage["test-ip"] = deque(
|
||
|
|
[
|
||
|
|
now - 100, # 100 seconds ago (outside 60s window)
|
||
|
|
now - 50, # 50 seconds ago (inside window)
|
||
|
|
now - 10, # 10 seconds ago (inside window)
|
||
|
|
]
|
||
|
|
)
|
||
|
|
|
||
|
|
# Currently have 2 requests in window, so 1 more allowed
|
||
|
|
allowed, _ = limiter.is_allowed("test-ip")
|
||
|
|
assert allowed is True
|
||
|
|
|
||
|
|
# Now 3 in window, should be rate limited
|
||
|
|
allowed, _ = limiter.is_allowed("test-ip")
|
||
|
|
assert allowed is False
|
||
|
|
|
||
|
|
def test_cleanup_preserves_active_ips(self):
|
||
|
|
"""Cleanup only removes IPs with no recent requests."""
|
||
|
|
from collections import deque
|
||
|
|
|
||
|
|
limiter = RateLimiter(
|
||
|
|
requests_per_minute=3,
|
||
|
|
cleanup_interval_seconds=1,
|
||
|
|
)
|
||
|
|
|
||
|
|
now = time.time()
|
||
|
|
# IP 1: active recently
|
||
|
|
limiter._storage["active-ip"] = deque([now - 10])
|
||
|
|
# IP 2: no timestamps (stale)
|
||
|
|
limiter._storage["stale-ip"] = deque()
|
||
|
|
# IP 3: old timestamps only
|
||
|
|
limiter._storage["old-ip"] = deque([now - 100])
|
||
|
|
|
||
|
|
limiter._last_cleanup = now - 2 # Force cleanup
|
||
|
|
|
||
|
|
# Run cleanup
|
||
|
|
limiter._cleanup_if_needed()
|
||
|
|
|
||
|
|
# Active IP should remain
|
||
|
|
assert "active-ip" in limiter._storage
|
||
|
|
# Stale IPs should be removed
|
||
|
|
assert "stale-ip" not in limiter._storage
|
||
|
|
assert "old-ip" not in limiter._storage
|