This repository has been archived on 2026-03-24. You can view files and clone it. You cannot open issues or pull requests or push a commit.
Files
Timmy-time-dashboard/tests/unit/test_rate_limit.py

447 lines
15 KiB
Python

"""Tests for rate limiting middleware.
Tests the RateLimiter class and RateLimitMiddleware for correct
rate limiting behavior, cleanup, and edge cases.
"""
import time
from unittest.mock import Mock
import pytest
from starlette.requests import Request
from starlette.responses import JSONResponse
from dashboard.middleware.rate_limit import RateLimiter, RateLimitMiddleware
class TestRateLimiter:
"""Tests for the RateLimiter class."""
def test_init_defaults(self):
"""RateLimiter initializes with default values."""
limiter = RateLimiter()
assert limiter.requests_per_minute == 30
assert limiter.cleanup_interval_seconds == 60
assert limiter._storage == {}
def test_init_custom_values(self):
"""RateLimiter accepts custom configuration."""
limiter = RateLimiter(requests_per_minute=60, cleanup_interval_seconds=120)
assert limiter.requests_per_minute == 60
assert limiter.cleanup_interval_seconds == 120
def test_is_allowed_first_request(self):
"""First request from an IP is always allowed."""
limiter = RateLimiter(requests_per_minute=5)
allowed, retry_after = limiter.is_allowed("192.168.1.1")
assert allowed is True
assert retry_after == 0.0
assert "192.168.1.1" in limiter._storage
assert len(limiter._storage["192.168.1.1"]) == 1
def test_is_allowed_under_limit(self):
"""Requests under the limit are allowed."""
limiter = RateLimiter(requests_per_minute=5)
# Make 4 requests (under limit of 5)
for _ in range(4):
allowed, _ = limiter.is_allowed("192.168.1.1")
assert allowed is True
assert len(limiter._storage["192.168.1.1"]) == 4
def test_is_allowed_at_limit(self):
"""Request at the limit is allowed."""
limiter = RateLimiter(requests_per_minute=5)
# Make exactly 5 requests
for _ in range(5):
allowed, _ = limiter.is_allowed("192.168.1.1")
assert allowed is True
assert len(limiter._storage["192.168.1.1"]) == 5
def test_is_allowed_over_limit(self):
"""Request over the limit is denied."""
limiter = RateLimiter(requests_per_minute=5)
# Make 5 requests to hit the limit
for _ in range(5):
limiter.is_allowed("192.168.1.1")
# 6th request should be denied
allowed, retry_after = limiter.is_allowed("192.168.1.1")
assert allowed is False
assert retry_after > 0
def test_is_allowed_different_ips(self):
"""Rate limiting is per-IP, not global."""
limiter = RateLimiter(requests_per_minute=5)
# Hit limit for IP 1
for _ in range(5):
limiter.is_allowed("192.168.1.1")
# IP 1 is now rate limited
allowed, _ = limiter.is_allowed("192.168.1.1")
assert allowed is False
# IP 2 should still be allowed
allowed, _ = limiter.is_allowed("192.168.1.2")
assert allowed is True
def test_window_expiration_allows_new_requests(self):
"""After window expires, new requests are allowed."""
limiter = RateLimiter(requests_per_minute=5)
# Hit the limit
for _ in range(5):
limiter.is_allowed("192.168.1.1")
# Should be rate limited
allowed, _ = limiter.is_allowed("192.168.1.1")
assert allowed is False
# Simulate time passing by clearing timestamps manually
# (we can't wait 60 seconds in a test)
limiter._storage["192.168.1.1"].clear()
# Should now be allowed again
allowed, _ = limiter.is_allowed("192.168.1.1")
assert allowed is True
def test_cleanup_removes_stale_entries(self):
"""Cleanup removes IPs with no recent requests."""
limiter = RateLimiter(
requests_per_minute=5,
cleanup_interval_seconds=1, # Short interval for testing
)
# Add some requests
limiter.is_allowed("192.168.1.1")
limiter.is_allowed("192.168.1.2")
# Both IPs should be in storage
assert "192.168.1.1" in limiter._storage
assert "192.168.1.2" in limiter._storage
# Manually clear timestamps to simulate stale data
limiter._storage["192.168.1.1"].clear()
limiter._last_cleanup = time.time() - 2 # Force cleanup
# Trigger cleanup via check_request with a mock
mock_request = Mock()
mock_request.headers = {}
mock_request.client = Mock()
mock_request.client.host = "192.168.1.3"
mock_request.url.path = "/api/matrix/test"
limiter.check_request(mock_request)
# Stale IP should be removed
assert "192.168.1.1" not in limiter._storage
# IP with no requests (cleared) is also stale
assert "192.168.1.2" in limiter._storage
def test_get_client_ip_direct(self):
"""Extract client IP from direct connection."""
limiter = RateLimiter()
mock_request = Mock()
mock_request.headers = {}
mock_request.client = Mock()
mock_request.client.host = "192.168.1.100"
ip = limiter._get_client_ip(mock_request)
assert ip == "192.168.1.100"
def test_get_client_ip_x_forwarded_for(self):
"""Extract client IP from X-Forwarded-For header."""
limiter = RateLimiter()
mock_request = Mock()
mock_request.headers = {"x-forwarded-for": "10.0.0.1, 192.168.1.1"}
mock_request.client = Mock()
mock_request.client.host = "192.168.1.100"
ip = limiter._get_client_ip(mock_request)
assert ip == "10.0.0.1"
def test_get_client_ip_x_real_ip(self):
"""Extract client IP from X-Real-IP header."""
limiter = RateLimiter()
mock_request = Mock()
mock_request.headers = {"x-real-ip": "10.0.0.5"}
mock_request.client = Mock()
mock_request.client.host = "192.168.1.100"
ip = limiter._get_client_ip(mock_request)
assert ip == "10.0.0.5"
def test_get_client_ip_no_client(self):
"""Return 'unknown' when no client info available."""
limiter = RateLimiter()
mock_request = Mock()
mock_request.headers = {}
mock_request.client = None
ip = limiter._get_client_ip(mock_request)
assert ip == "unknown"
class TestRateLimitMiddleware:
"""Tests for the RateLimitMiddleware class."""
@pytest.fixture
def mock_app(self):
"""Create a mock ASGI app."""
async def app(scope, receive, send):
response = JSONResponse({"status": "ok"})
await response(scope, receive, send)
return app
@pytest.fixture
def mock_request(self):
"""Create a mock Request object."""
request = Mock(spec=Request)
request.url.path = "/api/matrix/test"
request.headers = {}
request.client = Mock()
request.client.host = "192.168.1.1"
return request
def test_init_defaults(self, mock_app):
"""Middleware initializes with default values."""
middleware = RateLimitMiddleware(mock_app)
assert middleware.path_prefixes == []
assert middleware.limiter.requests_per_minute == 30
def test_init_custom_values(self, mock_app):
"""Middleware accepts custom configuration."""
middleware = RateLimitMiddleware(
mock_app,
path_prefixes=["/api/matrix/"],
requests_per_minute=60,
)
assert middleware.path_prefixes == ["/api/matrix/"]
assert middleware.limiter.requests_per_minute == 60
def test_should_rate_limit_no_prefixes(self, mock_app):
"""With no prefixes, all paths are rate limited."""
middleware = RateLimitMiddleware(mock_app)
assert middleware._should_rate_limit("/api/matrix/test") is True
assert middleware._should_rate_limit("/api/other/test") is True
assert middleware._should_rate_limit("/health") is True
def test_should_rate_limit_with_prefixes(self, mock_app):
"""With prefixes, only matching paths are rate limited."""
middleware = RateLimitMiddleware(
mock_app,
path_prefixes=["/api/matrix/", "/api/public/"],
)
assert middleware._should_rate_limit("/api/matrix/test") is True
assert middleware._should_rate_limit("/api/matrix/") is True
assert middleware._should_rate_limit("/api/public/data") is True
assert middleware._should_rate_limit("/api/other/test") is False
assert middleware._should_rate_limit("/health") is False
@pytest.mark.asyncio
async def test_dispatch_allows_matching_path_under_limit(self, mock_app):
"""Request to matching path under limit is allowed."""
middleware = RateLimitMiddleware(
mock_app,
path_prefixes=["/api/matrix/"],
requests_per_minute=5,
)
# Create a proper ASGI scope
scope = {
"type": "http",
"method": "GET",
"path": "/api/matrix/test",
"headers": [],
}
async def receive():
return {"type": "http.request", "body": b""}
response_body = []
async def send(message):
response_body.append(message)
await middleware(scope, receive, send)
# Should have sent response messages
assert len(response_body) > 0
# Check for 200 status in the response start message
start_message = next(
(m for m in response_body if m.get("type") == "http.response.start"), None
)
assert start_message is not None
assert start_message["status"] == 200
@pytest.mark.asyncio
async def test_dispatch_skips_non_matching_path(self, mock_app):
"""Request to non-matching path bypasses rate limiting."""
middleware = RateLimitMiddleware(
mock_app,
path_prefixes=["/api/matrix/"],
requests_per_minute=5,
)
scope = {
"type": "http",
"method": "GET",
"path": "/api/other/test", # Doesn't match /api/matrix/
"headers": [],
}
async def receive():
return {"type": "http.request", "body": b""}
response_body = []
async def send(message):
response_body.append(message)
await middleware(scope, receive, send)
# Should have sent response messages
assert len(response_body) > 0
start_message = next(
(m for m in response_body if m.get("type") == "http.response.start"), None
)
assert start_message is not None
assert start_message["status"] == 200
@pytest.mark.asyncio
async def test_dispatch_returns_429_when_rate_limited(self, mock_app):
"""Request over limit returns 429 status."""
middleware = RateLimitMiddleware(
mock_app,
path_prefixes=["/api/matrix/"],
requests_per_minute=2, # Low limit for testing
)
# First request - allowed
test_scope = {
"type": "http",
"method": "GET",
"path": "/api/matrix/test",
"headers": [],
}
async def receive():
return {"type": "http.request", "body": b""}
# Helper to capture response
def make_send(captured):
async def send(message):
captured.append(message)
return send
# Make requests to hit the limit
for _ in range(2):
response_body = []
await middleware(test_scope, receive, make_send(response_body))
start_message = next(
(m for m in response_body if m.get("type") == "http.response.start"),
None,
)
assert start_message["status"] == 200
# 3rd request should be rate limited
response_body = []
await middleware(test_scope, receive, make_send(response_body))
start_message = next(
(m for m in response_body if m.get("type") == "http.response.start"), None
)
assert start_message["status"] == 429
# Check for Retry-After header
headers = dict(start_message.get("headers", []))
assert b"retry-after" in headers or b"Retry-After" in headers
class TestRateLimiterIntegration:
"""Integration-style tests for rate limiter behavior."""
def test_multiple_ips_independent_limits(self):
"""Each IP has its own independent rate limit."""
limiter = RateLimiter(requests_per_minute=3)
# Use up limit for IP 1
for _ in range(3):
limiter.is_allowed("10.0.0.1")
# Use up limit for IP 2
for _ in range(3):
limiter.is_allowed("10.0.0.2")
# Both should now be rate limited
assert limiter.is_allowed("10.0.0.1")[0] is False
assert limiter.is_allowed("10.0.0.2")[0] is False
# IP 3 should still be allowed
assert limiter.is_allowed("10.0.0.3")[0] is True
def test_timestamp_window_sliding(self):
"""Rate limit window slides correctly as time passes."""
from collections import deque
limiter = RateLimiter(requests_per_minute=3)
# Add 3 timestamps manually (simulating old requests)
now = time.time()
limiter._storage["test-ip"] = deque(
[
now - 100, # 100 seconds ago (outside 60s window)
now - 50, # 50 seconds ago (inside window)
now - 10, # 10 seconds ago (inside window)
]
)
# Currently have 2 requests in window, so 1 more allowed
allowed, _ = limiter.is_allowed("test-ip")
assert allowed is True
# Now 3 in window, should be rate limited
allowed, _ = limiter.is_allowed("test-ip")
assert allowed is False
def test_cleanup_preserves_active_ips(self):
"""Cleanup only removes IPs with no recent requests."""
from collections import deque
limiter = RateLimiter(
requests_per_minute=3,
cleanup_interval_seconds=1,
)
now = time.time()
# IP 1: active recently
limiter._storage["active-ip"] = deque([now - 10])
# IP 2: no timestamps (stale)
limiter._storage["stale-ip"] = deque()
# IP 3: old timestamps only
limiter._storage["old-ip"] = deque([now - 100])
limiter._last_cleanup = now - 2 # Force cleanup
# Run cleanup
limiter._cleanup_if_needed()
# Active IP should remain
assert "active-ip" in limiter._storage
# Stale IPs should be removed
assert "stale-ip" not in limiter._storage
assert "old-ip" not in limiter._storage