248 lines
7.6 KiB
Python
248 lines
7.6 KiB
Python
"""Request logging middleware for FastAPI.
|
|
|
|
Logs HTTP requests with timing, status codes, and client information
|
|
for monitoring and debugging purposes.
|
|
"""
|
|
|
|
import logging
|
|
import time
|
|
import uuid
|
|
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
from starlette.requests import Request
|
|
from starlette.responses import Response
|
|
|
|
logger = logging.getLogger("timmy.requests")
|
|
|
|
|
|
class RequestLoggingMiddleware(BaseHTTPMiddleware):
|
|
"""Middleware to log all HTTP requests.
|
|
|
|
Logs the following information for each request:
|
|
- HTTP method and path
|
|
- Response status code
|
|
- Request processing time
|
|
- Client IP address
|
|
- User-Agent header
|
|
- Correlation ID for tracing
|
|
|
|
Usage:
|
|
app.add_middleware(RequestLoggingMiddleware)
|
|
|
|
# Skip certain paths:
|
|
app.add_middleware(RequestLoggingMiddleware, skip_paths=["/health", "/metrics"])
|
|
|
|
Attributes:
|
|
skip_paths: List of URL paths to skip logging.
|
|
log_level: Logging level for successful requests.
|
|
"""
|
|
|
|
def __init__(self, app, skip_paths: list[str] | None = None, log_level: int = logging.INFO):
|
|
super().__init__(app)
|
|
self.skip_paths = set(skip_paths or [])
|
|
self.log_level = log_level
|
|
|
|
def _should_skip_path(self, path: str) -> bool:
|
|
"""Check if the request path should be skipped from logging.
|
|
|
|
Args:
|
|
path: The request URL path.
|
|
|
|
Returns:
|
|
True if the path should be skipped, False otherwise.
|
|
"""
|
|
return path in self.skip_paths
|
|
|
|
def _prepare_request_context(self, request: Request) -> tuple[str, float]:
|
|
"""Prepare context for request processing.
|
|
|
|
Generates a correlation ID and records the start time.
|
|
|
|
Args:
|
|
request: The incoming request.
|
|
|
|
Returns:
|
|
Tuple of (correlation_id, start_time).
|
|
"""
|
|
correlation_id = str(uuid.uuid4())[:8]
|
|
request.state.correlation_id = correlation_id
|
|
start_time = time.time()
|
|
return correlation_id, start_time
|
|
|
|
def _get_duration_ms(self, start_time: float) -> float:
|
|
"""Calculate the request duration in milliseconds.
|
|
|
|
Args:
|
|
start_time: The start time from time.time().
|
|
|
|
Returns:
|
|
Duration in milliseconds.
|
|
"""
|
|
return (time.time() - start_time) * 1000
|
|
|
|
def _log_success(
|
|
self,
|
|
request: Request,
|
|
response: Response,
|
|
correlation_id: str,
|
|
duration_ms: float,
|
|
client_ip: str,
|
|
user_agent: str,
|
|
) -> None:
|
|
"""Log a successful request.
|
|
|
|
Args:
|
|
request: The incoming request.
|
|
response: The response from downstream.
|
|
correlation_id: The request correlation ID.
|
|
duration_ms: Request duration in milliseconds.
|
|
client_ip: Client IP address.
|
|
user_agent: User-Agent header value.
|
|
"""
|
|
self._log_request(
|
|
method=request.method,
|
|
path=request.url.path,
|
|
status_code=response.status_code,
|
|
duration_ms=duration_ms,
|
|
client_ip=client_ip,
|
|
user_agent=user_agent,
|
|
correlation_id=correlation_id,
|
|
)
|
|
|
|
def _log_error(
|
|
self,
|
|
request: Request,
|
|
exc: Exception,
|
|
correlation_id: str,
|
|
duration_ms: float,
|
|
client_ip: str,
|
|
) -> None:
|
|
"""Log a failed request and capture the error.
|
|
|
|
Args:
|
|
request: The incoming request.
|
|
exc: The exception that was raised.
|
|
correlation_id: The request correlation ID.
|
|
duration_ms: Request duration in milliseconds.
|
|
client_ip: Client IP address.
|
|
"""
|
|
logger.error(
|
|
f"[{correlation_id}] {request.method} {request.url.path} "
|
|
f"- ERROR - {duration_ms:.2f}ms - {client_ip} - {str(exc)}"
|
|
)
|
|
|
|
# Auto-escalate: create bug report task from unhandled exception
|
|
try:
|
|
from infrastructure.error_capture import capture_error
|
|
|
|
capture_error(
|
|
exc,
|
|
source="http",
|
|
context={
|
|
"method": request.method,
|
|
"path": request.url.path,
|
|
"correlation_id": correlation_id,
|
|
"client_ip": client_ip,
|
|
"duration_ms": f"{duration_ms:.0f}",
|
|
},
|
|
)
|
|
except Exception:
|
|
logger.warning("Escalation logging error: capture failed")
|
|
# never let escalation break the request
|
|
|
|
async def dispatch(self, request: Request, call_next) -> Response:
|
|
"""Log the request and response details.
|
|
|
|
Args:
|
|
request: The incoming request.
|
|
call_next: Callable to get the response from downstream.
|
|
|
|
Returns:
|
|
The response from downstream.
|
|
"""
|
|
if self._should_skip_path(request.url.path):
|
|
return await call_next(request)
|
|
|
|
correlation_id, start_time = self._prepare_request_context(request)
|
|
client_ip = self._get_client_ip(request)
|
|
user_agent = request.headers.get("user-agent", "-")
|
|
|
|
try:
|
|
response = await call_next(request)
|
|
duration_ms = self._get_duration_ms(start_time)
|
|
self._log_success(request, response, correlation_id, duration_ms, client_ip, user_agent)
|
|
response.headers["X-Correlation-ID"] = correlation_id
|
|
return response
|
|
|
|
except Exception as exc:
|
|
duration_ms = self._get_duration_ms(start_time)
|
|
self._log_error(request, exc, correlation_id, duration_ms, client_ip)
|
|
raise
|
|
|
|
def _get_client_ip(self, request: Request) -> str:
|
|
"""Extract the client IP address from the request.
|
|
|
|
Checks X-Forwarded-For and X-Real-IP headers first for proxied requests,
|
|
falls back to the direct client IP.
|
|
|
|
Args:
|
|
request: The incoming request.
|
|
|
|
Returns:
|
|
Client IP address string.
|
|
"""
|
|
# Check for forwarded IP (behind proxy/load balancer)
|
|
forwarded_for = request.headers.get("x-forwarded-for")
|
|
if forwarded_for:
|
|
# X-Forwarded-For can contain multiple IPs, take the first one
|
|
return forwarded_for.split(",")[0].strip()
|
|
|
|
real_ip = request.headers.get("x-real-ip")
|
|
if real_ip:
|
|
return real_ip
|
|
|
|
# Fall back to direct connection
|
|
if request.client:
|
|
return request.client.host
|
|
|
|
return "-"
|
|
|
|
def _log_request(
|
|
self,
|
|
method: str,
|
|
path: str,
|
|
status_code: int,
|
|
duration_ms: float,
|
|
client_ip: str,
|
|
user_agent: str,
|
|
correlation_id: str,
|
|
) -> None:
|
|
"""Format and log the request details.
|
|
|
|
Args:
|
|
method: HTTP method.
|
|
path: Request path.
|
|
status_code: HTTP status code.
|
|
duration_ms: Request duration in milliseconds.
|
|
client_ip: Client IP address.
|
|
user_agent: User-Agent header value.
|
|
correlation_id: Request correlation ID.
|
|
"""
|
|
# Determine log level based on status code
|
|
level = self.log_level
|
|
if status_code >= 500:
|
|
level = logging.ERROR
|
|
elif status_code >= 400:
|
|
level = logging.WARNING
|
|
|
|
message = (
|
|
f"[{correlation_id}] {method} {path} - {status_code} "
|
|
f"- {duration_ms:.2f}ms - {client_ip}"
|
|
)
|
|
|
|
# Add user agent for non-health requests
|
|
if path not in self.skip_paths:
|
|
message += f" - {user_agent[:50]}"
|
|
|
|
logger.log(level, message)
|