forked from Rockachopa/Timmy-time-dashboard
210 lines
6.5 KiB
Python
210 lines
6.5 KiB
Python
"""Rate limiting middleware for FastAPI.
|
|
|
|
Simple in-memory rate limiter for API endpoints. Tracks requests per IP
|
|
with configurable limits and automatic cleanup of stale entries.
|
|
"""
|
|
|
|
import logging
|
|
import time
|
|
from collections import deque
|
|
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
from starlette.requests import Request
|
|
from starlette.responses import JSONResponse, Response
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class RateLimiter:
|
|
"""In-memory rate limiter for tracking requests per IP.
|
|
|
|
Stores request timestamps in a dict keyed by client IP.
|
|
Automatically cleans up stale entries every 60 seconds.
|
|
|
|
Attributes:
|
|
requests_per_minute: Maximum requests allowed per minute per IP.
|
|
cleanup_interval_seconds: How often to clean stale entries.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
requests_per_minute: int = 30,
|
|
cleanup_interval_seconds: int = 60,
|
|
):
|
|
self.requests_per_minute = requests_per_minute
|
|
self.cleanup_interval_seconds = cleanup_interval_seconds
|
|
self._storage: dict[str, deque[float]] = {}
|
|
self._last_cleanup: float = time.time()
|
|
self._window_seconds: float = 60.0 # 1 minute window
|
|
|
|
def _get_client_ip(self, request: Request) -> str:
|
|
"""Extract client IP from request, respecting X-Forwarded-For header.
|
|
|
|
Args:
|
|
request: The incoming request.
|
|
|
|
Returns:
|
|
Client IP address string.
|
|
"""
|
|
# Check for forwarded IP (when behind proxy/load balancer)
|
|
forwarded = request.headers.get("x-forwarded-for")
|
|
if forwarded:
|
|
# Take the first IP in the chain
|
|
return forwarded.split(",")[0].strip()
|
|
|
|
real_ip = request.headers.get("x-real-ip")
|
|
if real_ip:
|
|
return real_ip
|
|
|
|
# Fall back to direct connection
|
|
if request.client:
|
|
return request.client.host
|
|
|
|
return "unknown"
|
|
|
|
def _cleanup_if_needed(self) -> None:
|
|
"""Remove stale entries older than the cleanup interval."""
|
|
now = time.time()
|
|
if now - self._last_cleanup < self.cleanup_interval_seconds:
|
|
return
|
|
|
|
cutoff = now - self._window_seconds
|
|
stale_ips: list[str] = []
|
|
|
|
for ip, timestamps in self._storage.items():
|
|
# Remove timestamps older than the window
|
|
while timestamps and timestamps[0] < cutoff:
|
|
timestamps.popleft()
|
|
# Mark IP for removal if no recent requests
|
|
if not timestamps:
|
|
stale_ips.append(ip)
|
|
|
|
# Remove stale IP entries
|
|
for ip in stale_ips:
|
|
del self._storage[ip]
|
|
|
|
self._last_cleanup = now
|
|
if stale_ips:
|
|
logger.debug("Rate limiter cleanup: removed %d stale IPs", len(stale_ips))
|
|
|
|
def is_allowed(self, client_ip: str) -> tuple[bool, float]:
|
|
"""Check if a request from the given IP is allowed.
|
|
|
|
Args:
|
|
client_ip: The client's IP address.
|
|
|
|
Returns:
|
|
Tuple of (allowed: bool, retry_after: float).
|
|
retry_after is seconds until next allowed request, 0 if allowed now.
|
|
"""
|
|
now = time.time()
|
|
cutoff = now - self._window_seconds
|
|
|
|
# Get or create timestamp deque for this IP
|
|
if client_ip not in self._storage:
|
|
self._storage[client_ip] = deque()
|
|
|
|
timestamps = self._storage[client_ip]
|
|
|
|
# Remove timestamps outside the window
|
|
while timestamps and timestamps[0] < cutoff:
|
|
timestamps.popleft()
|
|
|
|
# Check if limit exceeded
|
|
if len(timestamps) >= self.requests_per_minute:
|
|
# Calculate retry after time
|
|
oldest = timestamps[0]
|
|
retry_after = self._window_seconds - (now - oldest)
|
|
return False, max(0.0, retry_after)
|
|
|
|
# Record this request
|
|
timestamps.append(now)
|
|
return True, 0.0
|
|
|
|
def check_request(self, request: Request) -> tuple[bool, float]:
|
|
"""Check if the request is allowed under rate limits.
|
|
|
|
Args:
|
|
request: The incoming request.
|
|
|
|
Returns:
|
|
Tuple of (allowed: bool, retry_after: float).
|
|
"""
|
|
self._cleanup_if_needed()
|
|
client_ip = self._get_client_ip(request)
|
|
return self.is_allowed(client_ip)
|
|
|
|
|
|
class RateLimitMiddleware(BaseHTTPMiddleware):
|
|
"""Middleware to apply rate limiting to specific routes.
|
|
|
|
Usage:
|
|
# Apply to all routes (not recommended for public static files)
|
|
app.add_middleware(RateLimitMiddleware)
|
|
|
|
# Apply only to specific paths
|
|
app.add_middleware(
|
|
RateLimitMiddleware,
|
|
path_prefixes=["/api/matrix/"],
|
|
requests_per_minute=30,
|
|
)
|
|
|
|
Attributes:
|
|
path_prefixes: List of URL path prefixes to rate limit.
|
|
If empty, applies to all paths.
|
|
requests_per_minute: Maximum requests per minute per IP.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
app,
|
|
path_prefixes: list[str] | None = None,
|
|
requests_per_minute: int = 30,
|
|
):
|
|
super().__init__(app)
|
|
self.path_prefixes = path_prefixes or []
|
|
self.limiter = RateLimiter(requests_per_minute=requests_per_minute)
|
|
|
|
def _should_rate_limit(self, path: str) -> bool:
|
|
"""Check if the given path should be rate limited.
|
|
|
|
Args:
|
|
path: The request URL path.
|
|
|
|
Returns:
|
|
True if path matches any configured prefix.
|
|
"""
|
|
if not self.path_prefixes:
|
|
return True
|
|
return any(path.startswith(prefix) for prefix in self.path_prefixes)
|
|
|
|
async def dispatch(self, request: Request, call_next) -> Response:
|
|
"""Apply rate limiting to configured paths.
|
|
|
|
Args:
|
|
request: The incoming request.
|
|
call_next: Callable to get the response from downstream.
|
|
|
|
Returns:
|
|
Response from downstream, or 429 if rate limited.
|
|
"""
|
|
# Skip if path doesn't match configured prefixes
|
|
if not self._should_rate_limit(request.url.path):
|
|
return await call_next(request)
|
|
|
|
# Check rate limit
|
|
allowed, retry_after = self.limiter.check_request(request)
|
|
|
|
if not allowed:
|
|
return JSONResponse(
|
|
status_code=429,
|
|
content={
|
|
"error": "Rate limit exceeded. Try again later.",
|
|
"retry_after": int(retry_after) + 1,
|
|
},
|
|
headers={"Retry-After": str(int(retry_after) + 1)},
|
|
)
|
|
|
|
# Process the request
|
|
return await call_next(request)
|