[kimi] Add rate limiting middleware for Matrix API endpoints (#683) (#746)

This commit is contained in:
2026-03-21 16:23:16 +00:00
parent 815933953c
commit dc9f0c04eb
4 changed files with 668 additions and 2 deletions

View File

@@ -24,6 +24,7 @@ from config import settings
# Import dedicated middleware
from dashboard.middleware.csrf import CSRFMiddleware
from dashboard.middleware.rate_limit import RateLimitMiddleware
from dashboard.middleware.request_logging import RequestLoggingMiddleware
from dashboard.middleware.security_headers import SecurityHeadersMiddleware
from dashboard.routes.agents import router as agents_router
@@ -559,10 +560,17 @@ def _is_tailscale_origin(origin: str) -> bool:
# 1. Logging (outermost to capture everything)
app.add_middleware(RequestLoggingMiddleware, skip_paths=["/health"])
# 2. Security Headers
# 2. Rate Limiting (before security to prevent abuse early)
app.add_middleware(
RateLimitMiddleware,
path_prefixes=["/api/matrix/"],
requests_per_minute=30,
)
# 3. Security Headers
app.add_middleware(SecurityHeadersMiddleware, production=not settings.debug)
# 3. CSRF Protection
# 4. CSRF Protection
app.add_middleware(CSRFMiddleware)
# 4. Standard FastAPI middleware

View File

@@ -1,6 +1,7 @@
"""Dashboard middleware package."""
from .csrf import CSRFMiddleware, csrf_exempt, generate_csrf_token, validate_csrf_token
from .rate_limit import RateLimiter, RateLimitMiddleware
from .request_logging import RequestLoggingMiddleware
from .security_headers import SecurityHeadersMiddleware
@@ -9,6 +10,8 @@ __all__ = [
"csrf_exempt",
"generate_csrf_token",
"validate_csrf_token",
"RateLimiter",
"RateLimitMiddleware",
"SecurityHeadersMiddleware",
"RequestLoggingMiddleware",
]

View File

@@ -0,0 +1,209 @@
"""Rate limiting middleware for FastAPI.
Simple in-memory rate limiter for API endpoints. Tracks requests per IP
with configurable limits and automatic cleanup of stale entries.
"""
import logging
import time
from collections import deque
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import JSONResponse, Response
logger = logging.getLogger(__name__)
class RateLimiter:
"""In-memory rate limiter for tracking requests per IP.
Stores request timestamps in a dict keyed by client IP.
Automatically cleans up stale entries every 60 seconds.
Attributes:
requests_per_minute: Maximum requests allowed per minute per IP.
cleanup_interval_seconds: How often to clean stale entries.
"""
def __init__(
self,
requests_per_minute: int = 30,
cleanup_interval_seconds: int = 60,
):
self.requests_per_minute = requests_per_minute
self.cleanup_interval_seconds = cleanup_interval_seconds
self._storage: dict[str, deque[float]] = {}
self._last_cleanup: float = time.time()
self._window_seconds: float = 60.0 # 1 minute window
def _get_client_ip(self, request: Request) -> str:
"""Extract client IP from request, respecting X-Forwarded-For header.
Args:
request: The incoming request.
Returns:
Client IP address string.
"""
# Check for forwarded IP (when behind proxy/load balancer)
forwarded = request.headers.get("x-forwarded-for")
if forwarded:
# Take the first IP in the chain
return forwarded.split(",")[0].strip()
real_ip = request.headers.get("x-real-ip")
if real_ip:
return real_ip
# Fall back to direct connection
if request.client:
return request.client.host
return "unknown"
def _cleanup_if_needed(self) -> None:
"""Remove stale entries older than the cleanup interval."""
now = time.time()
if now - self._last_cleanup < self.cleanup_interval_seconds:
return
cutoff = now - self._window_seconds
stale_ips: list[str] = []
for ip, timestamps in self._storage.items():
# Remove timestamps older than the window
while timestamps and timestamps[0] < cutoff:
timestamps.popleft()
# Mark IP for removal if no recent requests
if not timestamps:
stale_ips.append(ip)
# Remove stale IP entries
for ip in stale_ips:
del self._storage[ip]
self._last_cleanup = now
if stale_ips:
logger.debug("Rate limiter cleanup: removed %d stale IPs", len(stale_ips))
def is_allowed(self, client_ip: str) -> tuple[bool, float]:
"""Check if a request from the given IP is allowed.
Args:
client_ip: The client's IP address.
Returns:
Tuple of (allowed: bool, retry_after: float).
retry_after is seconds until next allowed request, 0 if allowed now.
"""
now = time.time()
cutoff = now - self._window_seconds
# Get or create timestamp deque for this IP
if client_ip not in self._storage:
self._storage[client_ip] = deque()
timestamps = self._storage[client_ip]
# Remove timestamps outside the window
while timestamps and timestamps[0] < cutoff:
timestamps.popleft()
# Check if limit exceeded
if len(timestamps) >= self.requests_per_minute:
# Calculate retry after time
oldest = timestamps[0]
retry_after = self._window_seconds - (now - oldest)
return False, max(0.0, retry_after)
# Record this request
timestamps.append(now)
return True, 0.0
def check_request(self, request: Request) -> tuple[bool, float]:
"""Check if the request is allowed under rate limits.
Args:
request: The incoming request.
Returns:
Tuple of (allowed: bool, retry_after: float).
"""
self._cleanup_if_needed()
client_ip = self._get_client_ip(request)
return self.is_allowed(client_ip)
class RateLimitMiddleware(BaseHTTPMiddleware):
"""Middleware to apply rate limiting to specific routes.
Usage:
# Apply to all routes (not recommended for public static files)
app.add_middleware(RateLimitMiddleware)
# Apply only to specific paths
app.add_middleware(
RateLimitMiddleware,
path_prefixes=["/api/matrix/"],
requests_per_minute=30,
)
Attributes:
path_prefixes: List of URL path prefixes to rate limit.
If empty, applies to all paths.
requests_per_minute: Maximum requests per minute per IP.
"""
def __init__(
self,
app,
path_prefixes: list[str] | None = None,
requests_per_minute: int = 30,
):
super().__init__(app)
self.path_prefixes = path_prefixes or []
self.limiter = RateLimiter(requests_per_minute=requests_per_minute)
def _should_rate_limit(self, path: str) -> bool:
"""Check if the given path should be rate limited.
Args:
path: The request URL path.
Returns:
True if path matches any configured prefix.
"""
if not self.path_prefixes:
return True
return any(path.startswith(prefix) for prefix in self.path_prefixes)
async def dispatch(self, request: Request, call_next) -> Response:
"""Apply rate limiting to configured paths.
Args:
request: The incoming request.
call_next: Callable to get the response from downstream.
Returns:
Response from downstream, or 429 if rate limited.
"""
# Skip if path doesn't match configured prefixes
if not self._should_rate_limit(request.url.path):
return await call_next(request)
# Check rate limit
allowed, retry_after = self.limiter.check_request(request)
if not allowed:
return JSONResponse(
status_code=429,
content={
"error": "Rate limit exceeded. Try again later.",
"retry_after": int(retry_after) + 1,
},
headers={"Retry-After": str(int(retry_after) + 1)},
)
# Process the request
return await call_next(request)

View File

@@ -0,0 +1,446 @@
"""Tests for rate limiting middleware.
Tests the RateLimiter class and RateLimitMiddleware for correct
rate limiting behavior, cleanup, and edge cases.
"""
import time
from unittest.mock import Mock
import pytest
from starlette.requests import Request
from starlette.responses import JSONResponse
from dashboard.middleware.rate_limit import RateLimiter, RateLimitMiddleware
class TestRateLimiter:
"""Tests for the RateLimiter class."""
def test_init_defaults(self):
"""RateLimiter initializes with default values."""
limiter = RateLimiter()
assert limiter.requests_per_minute == 30
assert limiter.cleanup_interval_seconds == 60
assert limiter._storage == {}
def test_init_custom_values(self):
"""RateLimiter accepts custom configuration."""
limiter = RateLimiter(requests_per_minute=60, cleanup_interval_seconds=120)
assert limiter.requests_per_minute == 60
assert limiter.cleanup_interval_seconds == 120
def test_is_allowed_first_request(self):
"""First request from an IP is always allowed."""
limiter = RateLimiter(requests_per_minute=5)
allowed, retry_after = limiter.is_allowed("192.168.1.1")
assert allowed is True
assert retry_after == 0.0
assert "192.168.1.1" in limiter._storage
assert len(limiter._storage["192.168.1.1"]) == 1
def test_is_allowed_under_limit(self):
"""Requests under the limit are allowed."""
limiter = RateLimiter(requests_per_minute=5)
# Make 4 requests (under limit of 5)
for _ in range(4):
allowed, _ = limiter.is_allowed("192.168.1.1")
assert allowed is True
assert len(limiter._storage["192.168.1.1"]) == 4
def test_is_allowed_at_limit(self):
"""Request at the limit is allowed."""
limiter = RateLimiter(requests_per_minute=5)
# Make exactly 5 requests
for _ in range(5):
allowed, _ = limiter.is_allowed("192.168.1.1")
assert allowed is True
assert len(limiter._storage["192.168.1.1"]) == 5
def test_is_allowed_over_limit(self):
"""Request over the limit is denied."""
limiter = RateLimiter(requests_per_minute=5)
# Make 5 requests to hit the limit
for _ in range(5):
limiter.is_allowed("192.168.1.1")
# 6th request should be denied
allowed, retry_after = limiter.is_allowed("192.168.1.1")
assert allowed is False
assert retry_after > 0
def test_is_allowed_different_ips(self):
"""Rate limiting is per-IP, not global."""
limiter = RateLimiter(requests_per_minute=5)
# Hit limit for IP 1
for _ in range(5):
limiter.is_allowed("192.168.1.1")
# IP 1 is now rate limited
allowed, _ = limiter.is_allowed("192.168.1.1")
assert allowed is False
# IP 2 should still be allowed
allowed, _ = limiter.is_allowed("192.168.1.2")
assert allowed is True
def test_window_expiration_allows_new_requests(self):
"""After window expires, new requests are allowed."""
limiter = RateLimiter(requests_per_minute=5)
# Hit the limit
for _ in range(5):
limiter.is_allowed("192.168.1.1")
# Should be rate limited
allowed, _ = limiter.is_allowed("192.168.1.1")
assert allowed is False
# Simulate time passing by clearing timestamps manually
# (we can't wait 60 seconds in a test)
limiter._storage["192.168.1.1"].clear()
# Should now be allowed again
allowed, _ = limiter.is_allowed("192.168.1.1")
assert allowed is True
def test_cleanup_removes_stale_entries(self):
"""Cleanup removes IPs with no recent requests."""
limiter = RateLimiter(
requests_per_minute=5,
cleanup_interval_seconds=1, # Short interval for testing
)
# Add some requests
limiter.is_allowed("192.168.1.1")
limiter.is_allowed("192.168.1.2")
# Both IPs should be in storage
assert "192.168.1.1" in limiter._storage
assert "192.168.1.2" in limiter._storage
# Manually clear timestamps to simulate stale data
limiter._storage["192.168.1.1"].clear()
limiter._last_cleanup = time.time() - 2 # Force cleanup
# Trigger cleanup via check_request with a mock
mock_request = Mock()
mock_request.headers = {}
mock_request.client = Mock()
mock_request.client.host = "192.168.1.3"
mock_request.url.path = "/api/matrix/test"
limiter.check_request(mock_request)
# Stale IP should be removed
assert "192.168.1.1" not in limiter._storage
# IP with no requests (cleared) is also stale
assert "192.168.1.2" in limiter._storage
def test_get_client_ip_direct(self):
"""Extract client IP from direct connection."""
limiter = RateLimiter()
mock_request = Mock()
mock_request.headers = {}
mock_request.client = Mock()
mock_request.client.host = "192.168.1.100"
ip = limiter._get_client_ip(mock_request)
assert ip == "192.168.1.100"
def test_get_client_ip_x_forwarded_for(self):
"""Extract client IP from X-Forwarded-For header."""
limiter = RateLimiter()
mock_request = Mock()
mock_request.headers = {"x-forwarded-for": "10.0.0.1, 192.168.1.1"}
mock_request.client = Mock()
mock_request.client.host = "192.168.1.100"
ip = limiter._get_client_ip(mock_request)
assert ip == "10.0.0.1"
def test_get_client_ip_x_real_ip(self):
"""Extract client IP from X-Real-IP header."""
limiter = RateLimiter()
mock_request = Mock()
mock_request.headers = {"x-real-ip": "10.0.0.5"}
mock_request.client = Mock()
mock_request.client.host = "192.168.1.100"
ip = limiter._get_client_ip(mock_request)
assert ip == "10.0.0.5"
def test_get_client_ip_no_client(self):
"""Return 'unknown' when no client info available."""
limiter = RateLimiter()
mock_request = Mock()
mock_request.headers = {}
mock_request.client = None
ip = limiter._get_client_ip(mock_request)
assert ip == "unknown"
class TestRateLimitMiddleware:
"""Tests for the RateLimitMiddleware class."""
@pytest.fixture
def mock_app(self):
"""Create a mock ASGI app."""
async def app(scope, receive, send):
response = JSONResponse({"status": "ok"})
await response(scope, receive, send)
return app
@pytest.fixture
def mock_request(self):
"""Create a mock Request object."""
request = Mock(spec=Request)
request.url.path = "/api/matrix/test"
request.headers = {}
request.client = Mock()
request.client.host = "192.168.1.1"
return request
def test_init_defaults(self, mock_app):
"""Middleware initializes with default values."""
middleware = RateLimitMiddleware(mock_app)
assert middleware.path_prefixes == []
assert middleware.limiter.requests_per_minute == 30
def test_init_custom_values(self, mock_app):
"""Middleware accepts custom configuration."""
middleware = RateLimitMiddleware(
mock_app,
path_prefixes=["/api/matrix/"],
requests_per_minute=60,
)
assert middleware.path_prefixes == ["/api/matrix/"]
assert middleware.limiter.requests_per_minute == 60
def test_should_rate_limit_no_prefixes(self, mock_app):
"""With no prefixes, all paths are rate limited."""
middleware = RateLimitMiddleware(mock_app)
assert middleware._should_rate_limit("/api/matrix/test") is True
assert middleware._should_rate_limit("/api/other/test") is True
assert middleware._should_rate_limit("/health") is True
def test_should_rate_limit_with_prefixes(self, mock_app):
"""With prefixes, only matching paths are rate limited."""
middleware = RateLimitMiddleware(
mock_app,
path_prefixes=["/api/matrix/", "/api/public/"],
)
assert middleware._should_rate_limit("/api/matrix/test") is True
assert middleware._should_rate_limit("/api/matrix/") is True
assert middleware._should_rate_limit("/api/public/data") is True
assert middleware._should_rate_limit("/api/other/test") is False
assert middleware._should_rate_limit("/health") is False
@pytest.mark.asyncio
async def test_dispatch_allows_matching_path_under_limit(self, mock_app):
"""Request to matching path under limit is allowed."""
middleware = RateLimitMiddleware(
mock_app,
path_prefixes=["/api/matrix/"],
requests_per_minute=5,
)
# Create a proper ASGI scope
scope = {
"type": "http",
"method": "GET",
"path": "/api/matrix/test",
"headers": [],
}
async def receive():
return {"type": "http.request", "body": b""}
response_body = []
async def send(message):
response_body.append(message)
await middleware(scope, receive, send)
# Should have sent response messages
assert len(response_body) > 0
# Check for 200 status in the response start message
start_message = next(
(m for m in response_body if m.get("type") == "http.response.start"), None
)
assert start_message is not None
assert start_message["status"] == 200
@pytest.mark.asyncio
async def test_dispatch_skips_non_matching_path(self, mock_app):
"""Request to non-matching path bypasses rate limiting."""
middleware = RateLimitMiddleware(
mock_app,
path_prefixes=["/api/matrix/"],
requests_per_minute=5,
)
scope = {
"type": "http",
"method": "GET",
"path": "/api/other/test", # Doesn't match /api/matrix/
"headers": [],
}
async def receive():
return {"type": "http.request", "body": b""}
response_body = []
async def send(message):
response_body.append(message)
await middleware(scope, receive, send)
# Should have sent response messages
assert len(response_body) > 0
start_message = next(
(m for m in response_body if m.get("type") == "http.response.start"), None
)
assert start_message is not None
assert start_message["status"] == 200
@pytest.mark.asyncio
async def test_dispatch_returns_429_when_rate_limited(self, mock_app):
"""Request over limit returns 429 status."""
middleware = RateLimitMiddleware(
mock_app,
path_prefixes=["/api/matrix/"],
requests_per_minute=2, # Low limit for testing
)
# First request - allowed
test_scope = {
"type": "http",
"method": "GET",
"path": "/api/matrix/test",
"headers": [],
}
async def receive():
return {"type": "http.request", "body": b""}
# Helper to capture response
def make_send(captured):
async def send(message):
captured.append(message)
return send
# Make requests to hit the limit
for _ in range(2):
response_body = []
await middleware(test_scope, receive, make_send(response_body))
start_message = next(
(m for m in response_body if m.get("type") == "http.response.start"),
None,
)
assert start_message["status"] == 200
# 3rd request should be rate limited
response_body = []
await middleware(test_scope, receive, make_send(response_body))
start_message = next(
(m for m in response_body if m.get("type") == "http.response.start"), None
)
assert start_message["status"] == 429
# Check for Retry-After header
headers = dict(start_message.get("headers", []))
assert b"retry-after" in headers or b"Retry-After" in headers
class TestRateLimiterIntegration:
"""Integration-style tests for rate limiter behavior."""
def test_multiple_ips_independent_limits(self):
"""Each IP has its own independent rate limit."""
limiter = RateLimiter(requests_per_minute=3)
# Use up limit for IP 1
for _ in range(3):
limiter.is_allowed("10.0.0.1")
# Use up limit for IP 2
for _ in range(3):
limiter.is_allowed("10.0.0.2")
# Both should now be rate limited
assert limiter.is_allowed("10.0.0.1")[0] is False
assert limiter.is_allowed("10.0.0.2")[0] is False
# IP 3 should still be allowed
assert limiter.is_allowed("10.0.0.3")[0] is True
def test_timestamp_window_sliding(self):
"""Rate limit window slides correctly as time passes."""
from collections import deque
limiter = RateLimiter(requests_per_minute=3)
# Add 3 timestamps manually (simulating old requests)
now = time.time()
limiter._storage["test-ip"] = deque(
[
now - 100, # 100 seconds ago (outside 60s window)
now - 50, # 50 seconds ago (inside window)
now - 10, # 10 seconds ago (inside window)
]
)
# Currently have 2 requests in window, so 1 more allowed
allowed, _ = limiter.is_allowed("test-ip")
assert allowed is True
# Now 3 in window, should be rate limited
allowed, _ = limiter.is_allowed("test-ip")
assert allowed is False
def test_cleanup_preserves_active_ips(self):
"""Cleanup only removes IPs with no recent requests."""
from collections import deque
limiter = RateLimiter(
requests_per_minute=3,
cleanup_interval_seconds=1,
)
now = time.time()
# IP 1: active recently
limiter._storage["active-ip"] = deque([now - 10])
# IP 2: no timestamps (stale)
limiter._storage["stale-ip"] = deque()
# IP 3: old timestamps only
limiter._storage["old-ip"] = deque([now - 100])
limiter._last_cleanup = now - 2 # Force cleanup
# Run cleanup
limiter._cleanup_if_needed()
# Active IP should remain
assert "active-ip" in limiter._storage
# Stale IPs should be removed
assert "stale-ip" not in limiter._storage
assert "old-ip" not in limiter._storage