[kimi] Refactor request_logging.py::dispatch (#616) #765
@@ -42,6 +42,114 @@ class RequestLoggingMiddleware(BaseHTTPMiddleware):
|
||||
self.skip_paths = set(skip_paths or [])
|
||||
self.log_level = log_level
|
||||
|
||||
def _should_skip_path(self, path: str) -> bool:
|
||||
"""Check if the request path should be skipped from logging.
|
||||
|
||||
Args:
|
||||
path: The request URL path.
|
||||
|
||||
Returns:
|
||||
True if the path should be skipped, False otherwise.
|
||||
"""
|
||||
return path in self.skip_paths
|
||||
|
||||
def _prepare_request_context(self, request: Request) -> tuple[str, float]:
|
||||
"""Prepare context for request processing.
|
||||
|
||||
Generates a correlation ID and records the start time.
|
||||
|
||||
Args:
|
||||
request: The incoming request.
|
||||
|
||||
Returns:
|
||||
Tuple of (correlation_id, start_time).
|
||||
"""
|
||||
correlation_id = str(uuid.uuid4())[:8]
|
||||
request.state.correlation_id = correlation_id
|
||||
start_time = time.time()
|
||||
return correlation_id, start_time
|
||||
|
||||
def _get_duration_ms(self, start_time: float) -> float:
|
||||
"""Calculate the request duration in milliseconds.
|
||||
|
||||
Args:
|
||||
start_time: The start time from time.time().
|
||||
|
||||
Returns:
|
||||
Duration in milliseconds.
|
||||
"""
|
||||
return (time.time() - start_time) * 1000
|
||||
|
||||
def _log_success(
|
||||
self,
|
||||
request: Request,
|
||||
response: Response,
|
||||
correlation_id: str,
|
||||
duration_ms: float,
|
||||
client_ip: str,
|
||||
user_agent: str,
|
||||
) -> None:
|
||||
"""Log a successful request.
|
||||
|
||||
Args:
|
||||
request: The incoming request.
|
||||
response: The response from downstream.
|
||||
correlation_id: The request correlation ID.
|
||||
duration_ms: Request duration in milliseconds.
|
||||
client_ip: Client IP address.
|
||||
user_agent: User-Agent header value.
|
||||
"""
|
||||
self._log_request(
|
||||
method=request.method,
|
||||
path=request.url.path,
|
||||
status_code=response.status_code,
|
||||
duration_ms=duration_ms,
|
||||
client_ip=client_ip,
|
||||
user_agent=user_agent,
|
||||
correlation_id=correlation_id,
|
||||
)
|
||||
|
||||
def _log_error(
|
||||
self,
|
||||
request: Request,
|
||||
exc: Exception,
|
||||
correlation_id: str,
|
||||
duration_ms: float,
|
||||
client_ip: str,
|
||||
) -> None:
|
||||
"""Log a failed request and capture the error.
|
||||
|
||||
Args:
|
||||
request: The incoming request.
|
||||
exc: The exception that was raised.
|
||||
correlation_id: The request correlation ID.
|
||||
duration_ms: Request duration in milliseconds.
|
||||
client_ip: Client IP address.
|
||||
"""
|
||||
logger.error(
|
||||
f"[{correlation_id}] {request.method} {request.url.path} "
|
||||
f"- ERROR - {duration_ms:.2f}ms - {client_ip} - {str(exc)}"
|
||||
)
|
||||
|
||||
# Auto-escalate: create bug report task from unhandled exception
|
||||
try:
|
||||
from infrastructure.error_capture import capture_error
|
||||
|
||||
capture_error(
|
||||
exc,
|
||||
source="http",
|
||||
context={
|
||||
"method": request.method,
|
||||
"path": request.url.path,
|
||||
"correlation_id": correlation_id,
|
||||
"client_ip": client_ip,
|
||||
"duration_ms": f"{duration_ms:.0f}",
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
logger.warning("Escalation logging error: capture failed")
|
||||
# never let escalation break the request
|
||||
|
||||
async def dispatch(self, request: Request, call_next) -> Response:
|
||||
"""Log the request and response details.
|
||||
|
||||
@@ -52,74 +160,23 @@ class RequestLoggingMiddleware(BaseHTTPMiddleware):
|
||||
Returns:
|
||||
The response from downstream.
|
||||
"""
|
||||
# Check if we should skip logging this path
|
||||
if request.url.path in self.skip_paths:
|
||||
if self._should_skip_path(request.url.path):
|
||||
return await call_next(request)
|
||||
|
||||
# Generate correlation ID
|
||||
correlation_id = str(uuid.uuid4())[:8]
|
||||
request.state.correlation_id = correlation_id
|
||||
|
||||
# Record start time
|
||||
start_time = time.time()
|
||||
|
||||
# Get client info
|
||||
correlation_id, start_time = self._prepare_request_context(request)
|
||||
client_ip = self._get_client_ip(request)
|
||||
user_agent = request.headers.get("user-agent", "-")
|
||||
|
||||
try:
|
||||
# Process the request
|
||||
response = await call_next(request)
|
||||
|
||||
# Calculate duration
|
||||
duration_ms = (time.time() - start_time) * 1000
|
||||
|
||||
# Log the request
|
||||
self._log_request(
|
||||
method=request.method,
|
||||
path=request.url.path,
|
||||
status_code=response.status_code,
|
||||
duration_ms=duration_ms,
|
||||
client_ip=client_ip,
|
||||
user_agent=user_agent,
|
||||
correlation_id=correlation_id,
|
||||
)
|
||||
|
||||
# Add correlation ID to response headers
|
||||
duration_ms = self._get_duration_ms(start_time)
|
||||
self._log_success(request, response, correlation_id, duration_ms, client_ip, user_agent)
|
||||
response.headers["X-Correlation-ID"] = correlation_id
|
||||
|
||||
return response
|
||||
|
||||
except Exception as exc:
|
||||
# Calculate duration even for failed requests
|
||||
duration_ms = (time.time() - start_time) * 1000
|
||||
|
||||
# Log the error
|
||||
logger.error(
|
||||
f"[{correlation_id}] {request.method} {request.url.path} "
|
||||
f"- ERROR - {duration_ms:.2f}ms - {client_ip} - {str(exc)}"
|
||||
)
|
||||
|
||||
# Auto-escalate: create bug report task from unhandled exception
|
||||
try:
|
||||
from infrastructure.error_capture import capture_error
|
||||
|
||||
capture_error(
|
||||
exc,
|
||||
source="http",
|
||||
context={
|
||||
"method": request.method,
|
||||
"path": request.url.path,
|
||||
"correlation_id": correlation_id,
|
||||
"client_ip": client_ip,
|
||||
"duration_ms": f"{duration_ms:.0f}",
|
||||
},
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("Escalation logging error: %s", exc)
|
||||
pass # never let escalation break the request
|
||||
|
||||
# Re-raise the exception
|
||||
duration_ms = self._get_duration_ms(start_time)
|
||||
self._log_error(request, exc, correlation_id, duration_ms, client_ip)
|
||||
raise
|
||||
|
||||
def _get_client_ip(self, request: Request) -> str:
|
||||
|
||||
Reference in New Issue
Block a user