forked from Rockachopa/Timmy-time-dashboard
This commit is contained in:
@@ -24,6 +24,7 @@ from config import settings
|
||||
|
||||
# Import dedicated middleware
|
||||
from dashboard.middleware.csrf import CSRFMiddleware
|
||||
from dashboard.middleware.rate_limit import RateLimitMiddleware
|
||||
from dashboard.middleware.request_logging import RequestLoggingMiddleware
|
||||
from dashboard.middleware.security_headers import SecurityHeadersMiddleware
|
||||
from dashboard.routes.agents import router as agents_router
|
||||
@@ -559,10 +560,17 @@ def _is_tailscale_origin(origin: str) -> bool:
|
||||
# 1. Logging (outermost to capture everything)
|
||||
app.add_middleware(RequestLoggingMiddleware, skip_paths=["/health"])
|
||||
|
||||
# 2. Security Headers
|
||||
# 2. Rate Limiting (before security to prevent abuse early)
|
||||
app.add_middleware(
|
||||
RateLimitMiddleware,
|
||||
path_prefixes=["/api/matrix/"],
|
||||
requests_per_minute=30,
|
||||
)
|
||||
|
||||
# 3. Security Headers
|
||||
app.add_middleware(SecurityHeadersMiddleware, production=not settings.debug)
|
||||
|
||||
# 3. CSRF Protection
|
||||
# 4. CSRF Protection
|
||||
app.add_middleware(CSRFMiddleware)
|
||||
|
||||
# 4. Standard FastAPI middleware
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Dashboard middleware package."""
|
||||
|
||||
from .csrf import CSRFMiddleware, csrf_exempt, generate_csrf_token, validate_csrf_token
|
||||
from .rate_limit import RateLimiter, RateLimitMiddleware
|
||||
from .request_logging import RequestLoggingMiddleware
|
||||
from .security_headers import SecurityHeadersMiddleware
|
||||
|
||||
@@ -9,6 +10,8 @@ __all__ = [
|
||||
"csrf_exempt",
|
||||
"generate_csrf_token",
|
||||
"validate_csrf_token",
|
||||
"RateLimiter",
|
||||
"RateLimitMiddleware",
|
||||
"SecurityHeadersMiddleware",
|
||||
"RequestLoggingMiddleware",
|
||||
]
|
||||
|
||||
209
src/dashboard/middleware/rate_limit.py
Normal file
209
src/dashboard/middleware/rate_limit.py
Normal file
@@ -0,0 +1,209 @@
|
||||
"""Rate limiting middleware for FastAPI.
|
||||
|
||||
Simple in-memory rate limiter for API endpoints. Tracks requests per IP
|
||||
with configurable limits and automatic cleanup of stale entries.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from collections import deque
|
||||
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse, Response
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""In-memory rate limiter for tracking requests per IP.
|
||||
|
||||
Stores request timestamps in a dict keyed by client IP.
|
||||
Automatically cleans up stale entries every 60 seconds.
|
||||
|
||||
Attributes:
|
||||
requests_per_minute: Maximum requests allowed per minute per IP.
|
||||
cleanup_interval_seconds: How often to clean stale entries.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
requests_per_minute: int = 30,
|
||||
cleanup_interval_seconds: int = 60,
|
||||
):
|
||||
self.requests_per_minute = requests_per_minute
|
||||
self.cleanup_interval_seconds = cleanup_interval_seconds
|
||||
self._storage: dict[str, deque[float]] = {}
|
||||
self._last_cleanup: float = time.time()
|
||||
self._window_seconds: float = 60.0 # 1 minute window
|
||||
|
||||
def _get_client_ip(self, request: Request) -> str:
|
||||
"""Extract client IP from request, respecting X-Forwarded-For header.
|
||||
|
||||
Args:
|
||||
request: The incoming request.
|
||||
|
||||
Returns:
|
||||
Client IP address string.
|
||||
"""
|
||||
# Check for forwarded IP (when behind proxy/load balancer)
|
||||
forwarded = request.headers.get("x-forwarded-for")
|
||||
if forwarded:
|
||||
# Take the first IP in the chain
|
||||
return forwarded.split(",")[0].strip()
|
||||
|
||||
real_ip = request.headers.get("x-real-ip")
|
||||
if real_ip:
|
||||
return real_ip
|
||||
|
||||
# Fall back to direct connection
|
||||
if request.client:
|
||||
return request.client.host
|
||||
|
||||
return "unknown"
|
||||
|
||||
def _cleanup_if_needed(self) -> None:
|
||||
"""Remove stale entries older than the cleanup interval."""
|
||||
now = time.time()
|
||||
if now - self._last_cleanup < self.cleanup_interval_seconds:
|
||||
return
|
||||
|
||||
cutoff = now - self._window_seconds
|
||||
stale_ips: list[str] = []
|
||||
|
||||
for ip, timestamps in self._storage.items():
|
||||
# Remove timestamps older than the window
|
||||
while timestamps and timestamps[0] < cutoff:
|
||||
timestamps.popleft()
|
||||
# Mark IP for removal if no recent requests
|
||||
if not timestamps:
|
||||
stale_ips.append(ip)
|
||||
|
||||
# Remove stale IP entries
|
||||
for ip in stale_ips:
|
||||
del self._storage[ip]
|
||||
|
||||
self._last_cleanup = now
|
||||
if stale_ips:
|
||||
logger.debug("Rate limiter cleanup: removed %d stale IPs", len(stale_ips))
|
||||
|
||||
def is_allowed(self, client_ip: str) -> tuple[bool, float]:
|
||||
"""Check if a request from the given IP is allowed.
|
||||
|
||||
Args:
|
||||
client_ip: The client's IP address.
|
||||
|
||||
Returns:
|
||||
Tuple of (allowed: bool, retry_after: float).
|
||||
retry_after is seconds until next allowed request, 0 if allowed now.
|
||||
"""
|
||||
now = time.time()
|
||||
cutoff = now - self._window_seconds
|
||||
|
||||
# Get or create timestamp deque for this IP
|
||||
if client_ip not in self._storage:
|
||||
self._storage[client_ip] = deque()
|
||||
|
||||
timestamps = self._storage[client_ip]
|
||||
|
||||
# Remove timestamps outside the window
|
||||
while timestamps and timestamps[0] < cutoff:
|
||||
timestamps.popleft()
|
||||
|
||||
# Check if limit exceeded
|
||||
if len(timestamps) >= self.requests_per_minute:
|
||||
# Calculate retry after time
|
||||
oldest = timestamps[0]
|
||||
retry_after = self._window_seconds - (now - oldest)
|
||||
return False, max(0.0, retry_after)
|
||||
|
||||
# Record this request
|
||||
timestamps.append(now)
|
||||
return True, 0.0
|
||||
|
||||
def check_request(self, request: Request) -> tuple[bool, float]:
|
||||
"""Check if the request is allowed under rate limits.
|
||||
|
||||
Args:
|
||||
request: The incoming request.
|
||||
|
||||
Returns:
|
||||
Tuple of (allowed: bool, retry_after: float).
|
||||
"""
|
||||
self._cleanup_if_needed()
|
||||
client_ip = self._get_client_ip(request)
|
||||
return self.is_allowed(client_ip)
|
||||
|
||||
|
||||
class RateLimitMiddleware(BaseHTTPMiddleware):
|
||||
"""Middleware to apply rate limiting to specific routes.
|
||||
|
||||
Usage:
|
||||
# Apply to all routes (not recommended for public static files)
|
||||
app.add_middleware(RateLimitMiddleware)
|
||||
|
||||
# Apply only to specific paths
|
||||
app.add_middleware(
|
||||
RateLimitMiddleware,
|
||||
path_prefixes=["/api/matrix/"],
|
||||
requests_per_minute=30,
|
||||
)
|
||||
|
||||
Attributes:
|
||||
path_prefixes: List of URL path prefixes to rate limit.
|
||||
If empty, applies to all paths.
|
||||
requests_per_minute: Maximum requests per minute per IP.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app,
|
||||
path_prefixes: list[str] | None = None,
|
||||
requests_per_minute: int = 30,
|
||||
):
|
||||
super().__init__(app)
|
||||
self.path_prefixes = path_prefixes or []
|
||||
self.limiter = RateLimiter(requests_per_minute=requests_per_minute)
|
||||
|
||||
def _should_rate_limit(self, path: str) -> bool:
|
||||
"""Check if the given path should be rate limited.
|
||||
|
||||
Args:
|
||||
path: The request URL path.
|
||||
|
||||
Returns:
|
||||
True if path matches any configured prefix.
|
||||
"""
|
||||
if not self.path_prefixes:
|
||||
return True
|
||||
return any(path.startswith(prefix) for prefix in self.path_prefixes)
|
||||
|
||||
async def dispatch(self, request: Request, call_next) -> Response:
|
||||
"""Apply rate limiting to configured paths.
|
||||
|
||||
Args:
|
||||
request: The incoming request.
|
||||
call_next: Callable to get the response from downstream.
|
||||
|
||||
Returns:
|
||||
Response from downstream, or 429 if rate limited.
|
||||
"""
|
||||
# Skip if path doesn't match configured prefixes
|
||||
if not self._should_rate_limit(request.url.path):
|
||||
return await call_next(request)
|
||||
|
||||
# Check rate limit
|
||||
allowed, retry_after = self.limiter.check_request(request)
|
||||
|
||||
if not allowed:
|
||||
return JSONResponse(
|
||||
status_code=429,
|
||||
content={
|
||||
"error": "Rate limit exceeded. Try again later.",
|
||||
"retry_after": int(retry_after) + 1,
|
||||
},
|
||||
headers={"Retry-After": str(int(retry_after) + 1)},
|
||||
)
|
||||
|
||||
# Process the request
|
||||
return await call_next(request)
|
||||
446
tests/unit/test_rate_limit.py
Normal file
446
tests/unit/test_rate_limit.py
Normal file
@@ -0,0 +1,446 @@
|
||||
"""Tests for rate limiting middleware.
|
||||
|
||||
Tests the RateLimiter class and RateLimitMiddleware for correct
|
||||
rate limiting behavior, cleanup, and edge cases.
|
||||
"""
|
||||
|
||||
import time
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse
|
||||
|
||||
from dashboard.middleware.rate_limit import RateLimiter, RateLimitMiddleware
|
||||
|
||||
|
||||
class TestRateLimiter:
|
||||
"""Tests for the RateLimiter class."""
|
||||
|
||||
def test_init_defaults(self):
|
||||
"""RateLimiter initializes with default values."""
|
||||
limiter = RateLimiter()
|
||||
assert limiter.requests_per_minute == 30
|
||||
assert limiter.cleanup_interval_seconds == 60
|
||||
assert limiter._storage == {}
|
||||
|
||||
def test_init_custom_values(self):
|
||||
"""RateLimiter accepts custom configuration."""
|
||||
limiter = RateLimiter(requests_per_minute=60, cleanup_interval_seconds=120)
|
||||
assert limiter.requests_per_minute == 60
|
||||
assert limiter.cleanup_interval_seconds == 120
|
||||
|
||||
def test_is_allowed_first_request(self):
|
||||
"""First request from an IP is always allowed."""
|
||||
limiter = RateLimiter(requests_per_minute=5)
|
||||
allowed, retry_after = limiter.is_allowed("192.168.1.1")
|
||||
assert allowed is True
|
||||
assert retry_after == 0.0
|
||||
assert "192.168.1.1" in limiter._storage
|
||||
assert len(limiter._storage["192.168.1.1"]) == 1
|
||||
|
||||
def test_is_allowed_under_limit(self):
|
||||
"""Requests under the limit are allowed."""
|
||||
limiter = RateLimiter(requests_per_minute=5)
|
||||
|
||||
# Make 4 requests (under limit of 5)
|
||||
for _ in range(4):
|
||||
allowed, _ = limiter.is_allowed("192.168.1.1")
|
||||
assert allowed is True
|
||||
|
||||
assert len(limiter._storage["192.168.1.1"]) == 4
|
||||
|
||||
def test_is_allowed_at_limit(self):
|
||||
"""Request at the limit is allowed."""
|
||||
limiter = RateLimiter(requests_per_minute=5)
|
||||
|
||||
# Make exactly 5 requests
|
||||
for _ in range(5):
|
||||
allowed, _ = limiter.is_allowed("192.168.1.1")
|
||||
assert allowed is True
|
||||
|
||||
assert len(limiter._storage["192.168.1.1"]) == 5
|
||||
|
||||
def test_is_allowed_over_limit(self):
|
||||
"""Request over the limit is denied."""
|
||||
limiter = RateLimiter(requests_per_minute=5)
|
||||
|
||||
# Make 5 requests to hit the limit
|
||||
for _ in range(5):
|
||||
limiter.is_allowed("192.168.1.1")
|
||||
|
||||
# 6th request should be denied
|
||||
allowed, retry_after = limiter.is_allowed("192.168.1.1")
|
||||
assert allowed is False
|
||||
assert retry_after > 0
|
||||
|
||||
def test_is_allowed_different_ips(self):
|
||||
"""Rate limiting is per-IP, not global."""
|
||||
limiter = RateLimiter(requests_per_minute=5)
|
||||
|
||||
# Hit limit for IP 1
|
||||
for _ in range(5):
|
||||
limiter.is_allowed("192.168.1.1")
|
||||
|
||||
# IP 1 is now rate limited
|
||||
allowed, _ = limiter.is_allowed("192.168.1.1")
|
||||
assert allowed is False
|
||||
|
||||
# IP 2 should still be allowed
|
||||
allowed, _ = limiter.is_allowed("192.168.1.2")
|
||||
assert allowed is True
|
||||
|
||||
def test_window_expiration_allows_new_requests(self):
|
||||
"""After window expires, new requests are allowed."""
|
||||
limiter = RateLimiter(requests_per_minute=5)
|
||||
|
||||
# Hit the limit
|
||||
for _ in range(5):
|
||||
limiter.is_allowed("192.168.1.1")
|
||||
|
||||
# Should be rate limited
|
||||
allowed, _ = limiter.is_allowed("192.168.1.1")
|
||||
assert allowed is False
|
||||
|
||||
# Simulate time passing by clearing timestamps manually
|
||||
# (we can't wait 60 seconds in a test)
|
||||
limiter._storage["192.168.1.1"].clear()
|
||||
|
||||
# Should now be allowed again
|
||||
allowed, _ = limiter.is_allowed("192.168.1.1")
|
||||
assert allowed is True
|
||||
|
||||
def test_cleanup_removes_stale_entries(self):
|
||||
"""Cleanup removes IPs with no recent requests."""
|
||||
limiter = RateLimiter(
|
||||
requests_per_minute=5,
|
||||
cleanup_interval_seconds=1, # Short interval for testing
|
||||
)
|
||||
|
||||
# Add some requests
|
||||
limiter.is_allowed("192.168.1.1")
|
||||
limiter.is_allowed("192.168.1.2")
|
||||
|
||||
# Both IPs should be in storage
|
||||
assert "192.168.1.1" in limiter._storage
|
||||
assert "192.168.1.2" in limiter._storage
|
||||
|
||||
# Manually clear timestamps to simulate stale data
|
||||
limiter._storage["192.168.1.1"].clear()
|
||||
limiter._last_cleanup = time.time() - 2 # Force cleanup
|
||||
|
||||
# Trigger cleanup via check_request with a mock
|
||||
mock_request = Mock()
|
||||
mock_request.headers = {}
|
||||
mock_request.client = Mock()
|
||||
mock_request.client.host = "192.168.1.3"
|
||||
mock_request.url.path = "/api/matrix/test"
|
||||
|
||||
limiter.check_request(mock_request)
|
||||
|
||||
# Stale IP should be removed
|
||||
assert "192.168.1.1" not in limiter._storage
|
||||
# IP with no requests (cleared) is also stale
|
||||
assert "192.168.1.2" in limiter._storage
|
||||
|
||||
def test_get_client_ip_direct(self):
|
||||
"""Extract client IP from direct connection."""
|
||||
limiter = RateLimiter()
|
||||
|
||||
mock_request = Mock()
|
||||
mock_request.headers = {}
|
||||
mock_request.client = Mock()
|
||||
mock_request.client.host = "192.168.1.100"
|
||||
|
||||
ip = limiter._get_client_ip(mock_request)
|
||||
assert ip == "192.168.1.100"
|
||||
|
||||
def test_get_client_ip_x_forwarded_for(self):
|
||||
"""Extract client IP from X-Forwarded-For header."""
|
||||
limiter = RateLimiter()
|
||||
|
||||
mock_request = Mock()
|
||||
mock_request.headers = {"x-forwarded-for": "10.0.0.1, 192.168.1.1"}
|
||||
mock_request.client = Mock()
|
||||
mock_request.client.host = "192.168.1.100"
|
||||
|
||||
ip = limiter._get_client_ip(mock_request)
|
||||
assert ip == "10.0.0.1"
|
||||
|
||||
def test_get_client_ip_x_real_ip(self):
|
||||
"""Extract client IP from X-Real-IP header."""
|
||||
limiter = RateLimiter()
|
||||
|
||||
mock_request = Mock()
|
||||
mock_request.headers = {"x-real-ip": "10.0.0.5"}
|
||||
mock_request.client = Mock()
|
||||
mock_request.client.host = "192.168.1.100"
|
||||
|
||||
ip = limiter._get_client_ip(mock_request)
|
||||
assert ip == "10.0.0.5"
|
||||
|
||||
def test_get_client_ip_no_client(self):
|
||||
"""Return 'unknown' when no client info available."""
|
||||
limiter = RateLimiter()
|
||||
|
||||
mock_request = Mock()
|
||||
mock_request.headers = {}
|
||||
mock_request.client = None
|
||||
|
||||
ip = limiter._get_client_ip(mock_request)
|
||||
assert ip == "unknown"
|
||||
|
||||
|
||||
class TestRateLimitMiddleware:
|
||||
"""Tests for the RateLimitMiddleware class."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app(self):
|
||||
"""Create a mock ASGI app."""
|
||||
|
||||
async def app(scope, receive, send):
|
||||
response = JSONResponse({"status": "ok"})
|
||||
await response(scope, receive, send)
|
||||
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
def mock_request(self):
|
||||
"""Create a mock Request object."""
|
||||
request = Mock(spec=Request)
|
||||
request.url.path = "/api/matrix/test"
|
||||
request.headers = {}
|
||||
request.client = Mock()
|
||||
request.client.host = "192.168.1.1"
|
||||
return request
|
||||
|
||||
def test_init_defaults(self, mock_app):
|
||||
"""Middleware initializes with default values."""
|
||||
middleware = RateLimitMiddleware(mock_app)
|
||||
assert middleware.path_prefixes == []
|
||||
assert middleware.limiter.requests_per_minute == 30
|
||||
|
||||
def test_init_custom_values(self, mock_app):
|
||||
"""Middleware accepts custom configuration."""
|
||||
middleware = RateLimitMiddleware(
|
||||
mock_app,
|
||||
path_prefixes=["/api/matrix/"],
|
||||
requests_per_minute=60,
|
||||
)
|
||||
assert middleware.path_prefixes == ["/api/matrix/"]
|
||||
assert middleware.limiter.requests_per_minute == 60
|
||||
|
||||
def test_should_rate_limit_no_prefixes(self, mock_app):
|
||||
"""With no prefixes, all paths are rate limited."""
|
||||
middleware = RateLimitMiddleware(mock_app)
|
||||
assert middleware._should_rate_limit("/api/matrix/test") is True
|
||||
assert middleware._should_rate_limit("/api/other/test") is True
|
||||
assert middleware._should_rate_limit("/health") is True
|
||||
|
||||
def test_should_rate_limit_with_prefixes(self, mock_app):
|
||||
"""With prefixes, only matching paths are rate limited."""
|
||||
middleware = RateLimitMiddleware(
|
||||
mock_app,
|
||||
path_prefixes=["/api/matrix/", "/api/public/"],
|
||||
)
|
||||
assert middleware._should_rate_limit("/api/matrix/test") is True
|
||||
assert middleware._should_rate_limit("/api/matrix/") is True
|
||||
assert middleware._should_rate_limit("/api/public/data") is True
|
||||
assert middleware._should_rate_limit("/api/other/test") is False
|
||||
assert middleware._should_rate_limit("/health") is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatch_allows_matching_path_under_limit(self, mock_app):
|
||||
"""Request to matching path under limit is allowed."""
|
||||
middleware = RateLimitMiddleware(
|
||||
mock_app,
|
||||
path_prefixes=["/api/matrix/"],
|
||||
requests_per_minute=5,
|
||||
)
|
||||
|
||||
# Create a proper ASGI scope
|
||||
scope = {
|
||||
"type": "http",
|
||||
"method": "GET",
|
||||
"path": "/api/matrix/test",
|
||||
"headers": [],
|
||||
}
|
||||
|
||||
async def receive():
|
||||
return {"type": "http.request", "body": b""}
|
||||
|
||||
response_body = []
|
||||
|
||||
async def send(message):
|
||||
response_body.append(message)
|
||||
|
||||
await middleware(scope, receive, send)
|
||||
|
||||
# Should have sent response messages
|
||||
assert len(response_body) > 0
|
||||
# Check for 200 status in the response start message
|
||||
start_message = next(
|
||||
(m for m in response_body if m.get("type") == "http.response.start"), None
|
||||
)
|
||||
assert start_message is not None
|
||||
assert start_message["status"] == 200
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatch_skips_non_matching_path(self, mock_app):
|
||||
"""Request to non-matching path bypasses rate limiting."""
|
||||
middleware = RateLimitMiddleware(
|
||||
mock_app,
|
||||
path_prefixes=["/api/matrix/"],
|
||||
requests_per_minute=5,
|
||||
)
|
||||
|
||||
scope = {
|
||||
"type": "http",
|
||||
"method": "GET",
|
||||
"path": "/api/other/test", # Doesn't match /api/matrix/
|
||||
"headers": [],
|
||||
}
|
||||
|
||||
async def receive():
|
||||
return {"type": "http.request", "body": b""}
|
||||
|
||||
response_body = []
|
||||
|
||||
async def send(message):
|
||||
response_body.append(message)
|
||||
|
||||
await middleware(scope, receive, send)
|
||||
|
||||
# Should have sent response messages
|
||||
assert len(response_body) > 0
|
||||
start_message = next(
|
||||
(m for m in response_body if m.get("type") == "http.response.start"), None
|
||||
)
|
||||
assert start_message is not None
|
||||
assert start_message["status"] == 200
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatch_returns_429_when_rate_limited(self, mock_app):
|
||||
"""Request over limit returns 429 status."""
|
||||
middleware = RateLimitMiddleware(
|
||||
mock_app,
|
||||
path_prefixes=["/api/matrix/"],
|
||||
requests_per_minute=2, # Low limit for testing
|
||||
)
|
||||
|
||||
# First request - allowed
|
||||
test_scope = {
|
||||
"type": "http",
|
||||
"method": "GET",
|
||||
"path": "/api/matrix/test",
|
||||
"headers": [],
|
||||
}
|
||||
|
||||
async def receive():
|
||||
return {"type": "http.request", "body": b""}
|
||||
|
||||
# Helper to capture response
|
||||
def make_send(captured):
|
||||
async def send(message):
|
||||
captured.append(message)
|
||||
|
||||
return send
|
||||
|
||||
# Make requests to hit the limit
|
||||
for _ in range(2):
|
||||
response_body = []
|
||||
await middleware(test_scope, receive, make_send(response_body))
|
||||
|
||||
start_message = next(
|
||||
(m for m in response_body if m.get("type") == "http.response.start"),
|
||||
None,
|
||||
)
|
||||
assert start_message["status"] == 200
|
||||
|
||||
# 3rd request should be rate limited
|
||||
response_body = []
|
||||
await middleware(test_scope, receive, make_send(response_body))
|
||||
|
||||
start_message = next(
|
||||
(m for m in response_body if m.get("type") == "http.response.start"), None
|
||||
)
|
||||
assert start_message["status"] == 429
|
||||
|
||||
# Check for Retry-After header
|
||||
headers = dict(start_message.get("headers", []))
|
||||
assert b"retry-after" in headers or b"Retry-After" in headers
|
||||
|
||||
|
||||
class TestRateLimiterIntegration:
|
||||
"""Integration-style tests for rate limiter behavior."""
|
||||
|
||||
def test_multiple_ips_independent_limits(self):
|
||||
"""Each IP has its own independent rate limit."""
|
||||
limiter = RateLimiter(requests_per_minute=3)
|
||||
|
||||
# Use up limit for IP 1
|
||||
for _ in range(3):
|
||||
limiter.is_allowed("10.0.0.1")
|
||||
|
||||
# Use up limit for IP 2
|
||||
for _ in range(3):
|
||||
limiter.is_allowed("10.0.0.2")
|
||||
|
||||
# Both should now be rate limited
|
||||
assert limiter.is_allowed("10.0.0.1")[0] is False
|
||||
assert limiter.is_allowed("10.0.0.2")[0] is False
|
||||
|
||||
# IP 3 should still be allowed
|
||||
assert limiter.is_allowed("10.0.0.3")[0] is True
|
||||
|
||||
def test_timestamp_window_sliding(self):
|
||||
"""Rate limit window slides correctly as time passes."""
|
||||
from collections import deque
|
||||
|
||||
limiter = RateLimiter(requests_per_minute=3)
|
||||
|
||||
# Add 3 timestamps manually (simulating old requests)
|
||||
now = time.time()
|
||||
limiter._storage["test-ip"] = deque(
|
||||
[
|
||||
now - 100, # 100 seconds ago (outside 60s window)
|
||||
now - 50, # 50 seconds ago (inside window)
|
||||
now - 10, # 10 seconds ago (inside window)
|
||||
]
|
||||
)
|
||||
|
||||
# Currently have 2 requests in window, so 1 more allowed
|
||||
allowed, _ = limiter.is_allowed("test-ip")
|
||||
assert allowed is True
|
||||
|
||||
# Now 3 in window, should be rate limited
|
||||
allowed, _ = limiter.is_allowed("test-ip")
|
||||
assert allowed is False
|
||||
|
||||
def test_cleanup_preserves_active_ips(self):
|
||||
"""Cleanup only removes IPs with no recent requests."""
|
||||
from collections import deque
|
||||
|
||||
limiter = RateLimiter(
|
||||
requests_per_minute=3,
|
||||
cleanup_interval_seconds=1,
|
||||
)
|
||||
|
||||
now = time.time()
|
||||
# IP 1: active recently
|
||||
limiter._storage["active-ip"] = deque([now - 10])
|
||||
# IP 2: no timestamps (stale)
|
||||
limiter._storage["stale-ip"] = deque()
|
||||
# IP 3: old timestamps only
|
||||
limiter._storage["old-ip"] = deque([now - 100])
|
||||
|
||||
limiter._last_cleanup = now - 2 # Force cleanup
|
||||
|
||||
# Run cleanup
|
||||
limiter._cleanup_if_needed()
|
||||
|
||||
# Active IP should remain
|
||||
assert "active-ip" in limiter._storage
|
||||
# Stale IPs should be removed
|
||||
assert "stale-ip" not in limiter._storage
|
||||
assert "old-ip" not in limiter._storage
|
||||
Reference in New Issue
Block a user