This repository has been archived on 2026-03-24. You can view files and clone it. You cannot open issues or pull requests or push a commit.
Files
Timmy-time-dashboard/src/dashboard/middleware/rate_limit.py

210 lines
6.5 KiB
Python

"""Rate limiting middleware for FastAPI.
Simple in-memory rate limiter for API endpoints. Tracks requests per IP
with configurable limits and automatic cleanup of stale entries.
"""
import logging
import time
from collections import deque
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import JSONResponse, Response
logger = logging.getLogger(__name__)
class RateLimiter:
"""In-memory rate limiter for tracking requests per IP.
Stores request timestamps in a dict keyed by client IP.
Automatically cleans up stale entries every 60 seconds.
Attributes:
requests_per_minute: Maximum requests allowed per minute per IP.
cleanup_interval_seconds: How often to clean stale entries.
"""
def __init__(
self,
requests_per_minute: int = 30,
cleanup_interval_seconds: int = 60,
):
self.requests_per_minute = requests_per_minute
self.cleanup_interval_seconds = cleanup_interval_seconds
self._storage: dict[str, deque[float]] = {}
self._last_cleanup: float = time.time()
self._window_seconds: float = 60.0 # 1 minute window
def _get_client_ip(self, request: Request) -> str:
"""Extract client IP from request, respecting X-Forwarded-For header.
Args:
request: The incoming request.
Returns:
Client IP address string.
"""
# Check for forwarded IP (when behind proxy/load balancer)
forwarded = request.headers.get("x-forwarded-for")
if forwarded:
# Take the first IP in the chain
return forwarded.split(",")[0].strip()
real_ip = request.headers.get("x-real-ip")
if real_ip:
return real_ip
# Fall back to direct connection
if request.client:
return request.client.host
return "unknown"
def _cleanup_if_needed(self) -> None:
"""Remove stale entries older than the cleanup interval."""
now = time.time()
if now - self._last_cleanup < self.cleanup_interval_seconds:
return
cutoff = now - self._window_seconds
stale_ips: list[str] = []
for ip, timestamps in self._storage.items():
# Remove timestamps older than the window
while timestamps and timestamps[0] < cutoff:
timestamps.popleft()
# Mark IP for removal if no recent requests
if not timestamps:
stale_ips.append(ip)
# Remove stale IP entries
for ip in stale_ips:
del self._storage[ip]
self._last_cleanup = now
if stale_ips:
logger.debug("Rate limiter cleanup: removed %d stale IPs", len(stale_ips))
def is_allowed(self, client_ip: str) -> tuple[bool, float]:
"""Check if a request from the given IP is allowed.
Args:
client_ip: The client's IP address.
Returns:
Tuple of (allowed: bool, retry_after: float).
retry_after is seconds until next allowed request, 0 if allowed now.
"""
now = time.time()
cutoff = now - self._window_seconds
# Get or create timestamp deque for this IP
if client_ip not in self._storage:
self._storage[client_ip] = deque()
timestamps = self._storage[client_ip]
# Remove timestamps outside the window
while timestamps and timestamps[0] < cutoff:
timestamps.popleft()
# Check if limit exceeded
if len(timestamps) >= self.requests_per_minute:
# Calculate retry after time
oldest = timestamps[0]
retry_after = self._window_seconds - (now - oldest)
return False, max(0.0, retry_after)
# Record this request
timestamps.append(now)
return True, 0.0
def check_request(self, request: Request) -> tuple[bool, float]:
"""Check if the request is allowed under rate limits.
Args:
request: The incoming request.
Returns:
Tuple of (allowed: bool, retry_after: float).
"""
self._cleanup_if_needed()
client_ip = self._get_client_ip(request)
return self.is_allowed(client_ip)
class RateLimitMiddleware(BaseHTTPMiddleware):
"""Middleware to apply rate limiting to specific routes.
Usage:
# Apply to all routes (not recommended for public static files)
app.add_middleware(RateLimitMiddleware)
# Apply only to specific paths
app.add_middleware(
RateLimitMiddleware,
path_prefixes=["/api/matrix/"],
requests_per_minute=30,
)
Attributes:
path_prefixes: List of URL path prefixes to rate limit.
If empty, applies to all paths.
requests_per_minute: Maximum requests per minute per IP.
"""
def __init__(
self,
app,
path_prefixes: list[str] | None = None,
requests_per_minute: int = 30,
):
super().__init__(app)
self.path_prefixes = path_prefixes or []
self.limiter = RateLimiter(requests_per_minute=requests_per_minute)
def _should_rate_limit(self, path: str) -> bool:
"""Check if the given path should be rate limited.
Args:
path: The request URL path.
Returns:
True if path matches any configured prefix.
"""
if not self.path_prefixes:
return True
return any(path.startswith(prefix) for prefix in self.path_prefixes)
async def dispatch(self, request: Request, call_next) -> Response:
"""Apply rate limiting to configured paths.
Args:
request: The incoming request.
call_next: Callable to get the response from downstream.
Returns:
Response from downstream, or 429 if rate limited.
"""
# Skip if path doesn't match configured prefixes
if not self._should_rate_limit(request.url.path):
return await call_next(request)
# Check rate limit
allowed, retry_after = self.limiter.check_request(request)
if not allowed:
return JSONResponse(
status_code=429,
content={
"error": "Rate limit exceeded. Try again later.",
"retry_after": int(retry_after) + 1,
},
headers={"Retry-After": str(int(retry_after) + 1)},
)
# Process the request
return await call_next(request)