"""Request logging middleware for FastAPI. Logs HTTP requests with timing, status codes, and client information for monitoring and debugging purposes. """ import logging import time import uuid from starlette.middleware.base import BaseHTTPMiddleware from starlette.requests import Request from starlette.responses import Response logger = logging.getLogger("timmy.requests") class RequestLoggingMiddleware(BaseHTTPMiddleware): """Middleware to log all HTTP requests. Logs the following information for each request: - HTTP method and path - Response status code - Request processing time - Client IP address - User-Agent header - Correlation ID for tracing Usage: app.add_middleware(RequestLoggingMiddleware) # Skip certain paths: app.add_middleware(RequestLoggingMiddleware, skip_paths=["/health", "/metrics"]) Attributes: skip_paths: List of URL paths to skip logging. log_level: Logging level for successful requests. """ def __init__(self, app, skip_paths: list[str] | None = None, log_level: int = logging.INFO): super().__init__(app) self.skip_paths = set(skip_paths or []) self.log_level = log_level def _should_skip_path(self, path: str) -> bool: """Check if the request path should be skipped from logging. Args: path: The request URL path. Returns: True if the path should be skipped, False otherwise. """ return path in self.skip_paths def _prepare_request_context(self, request: Request) -> tuple[str, float]: """Prepare context for request processing. Generates a correlation ID and records the start time. Args: request: The incoming request. Returns: Tuple of (correlation_id, start_time). """ correlation_id = str(uuid.uuid4())[:8] request.state.correlation_id = correlation_id start_time = time.time() return correlation_id, start_time def _get_duration_ms(self, start_time: float) -> float: """Calculate the request duration in milliseconds. Args: start_time: The start time from time.time(). Returns: Duration in milliseconds. """ return (time.time() - start_time) * 1000 def _log_success( self, request: Request, response: Response, correlation_id: str, duration_ms: float, client_ip: str, user_agent: str, ) -> None: """Log a successful request. Args: request: The incoming request. response: The response from downstream. correlation_id: The request correlation ID. duration_ms: Request duration in milliseconds. client_ip: Client IP address. user_agent: User-Agent header value. """ self._log_request( method=request.method, path=request.url.path, status_code=response.status_code, duration_ms=duration_ms, client_ip=client_ip, user_agent=user_agent, correlation_id=correlation_id, ) def _log_error( self, request: Request, exc: Exception, correlation_id: str, duration_ms: float, client_ip: str, ) -> None: """Log a failed request and capture the error. Args: request: The incoming request. exc: The exception that was raised. correlation_id: The request correlation ID. duration_ms: Request duration in milliseconds. client_ip: Client IP address. """ logger.error( f"[{correlation_id}] {request.method} {request.url.path} " f"- ERROR - {duration_ms:.2f}ms - {client_ip} - {str(exc)}" ) # Auto-escalate: create bug report task from unhandled exception try: from infrastructure.error_capture import capture_error capture_error( exc, source="http", context={ "method": request.method, "path": request.url.path, "correlation_id": correlation_id, "client_ip": client_ip, "duration_ms": f"{duration_ms:.0f}", }, ) except Exception: logger.warning("Escalation logging error: capture failed") # never let escalation break the request async def dispatch(self, request: Request, call_next) -> Response: """Log the request and response details. Args: request: The incoming request. call_next: Callable to get the response from downstream. Returns: The response from downstream. """ if self._should_skip_path(request.url.path): return await call_next(request) correlation_id, start_time = self._prepare_request_context(request) client_ip = self._get_client_ip(request) user_agent = request.headers.get("user-agent", "-") try: response = await call_next(request) duration_ms = self._get_duration_ms(start_time) self._log_success(request, response, correlation_id, duration_ms, client_ip, user_agent) response.headers["X-Correlation-ID"] = correlation_id return response except Exception as exc: duration_ms = self._get_duration_ms(start_time) self._log_error(request, exc, correlation_id, duration_ms, client_ip) raise def _get_client_ip(self, request: Request) -> str: """Extract the client IP address from the request. Checks X-Forwarded-For and X-Real-IP headers first for proxied requests, falls back to the direct client IP. Args: request: The incoming request. Returns: Client IP address string. """ # Check for forwarded IP (behind proxy/load balancer) forwarded_for = request.headers.get("x-forwarded-for") if forwarded_for: # X-Forwarded-For can contain multiple IPs, take the first one return forwarded_for.split(",")[0].strip() real_ip = request.headers.get("x-real-ip") if real_ip: return real_ip # Fall back to direct connection if request.client: return request.client.host return "-" def _log_request( self, method: str, path: str, status_code: int, duration_ms: float, client_ip: str, user_agent: str, correlation_id: str, ) -> None: """Format and log the request details. Args: method: HTTP method. path: Request path. status_code: HTTP status code. duration_ms: Request duration in milliseconds. client_ip: Client IP address. user_agent: User-Agent header value. correlation_id: Request correlation ID. """ # Determine log level based on status code level = self.log_level if status_code >= 500: level = logging.ERROR elif status_code >= 400: level = logging.WARNING message = ( f"[{correlation_id}] {method} {path} - {status_code} " f"- {duration_ms:.2f}ms - {client_ip}" ) # Add user agent for non-health requests if path not in self.skip_paths: message += f" - {user_agent[:50]}" logger.log(level, message)