[kimi] Refactor request_logging.py::dispatch (#616) #765

Merged
kimi merged 1 commits from kimi/issue-616 into main 2026-03-21 18:06:34 +00:00

View File

@@ -42,39 +42,63 @@ class RequestLoggingMiddleware(BaseHTTPMiddleware):
self.skip_paths = set(skip_paths or [])
self.log_level = log_level
async def dispatch(self, request: Request, call_next) -> Response:
"""Log the request and response details.
def _should_skip_path(self, path: str) -> bool:
"""Check if the request path should be skipped from logging.
Args:
path: The request URL path.
Returns:
True if the path should be skipped, False otherwise.
"""
return path in self.skip_paths
def _prepare_request_context(self, request: Request) -> tuple[str, float]:
"""Prepare context for request processing.
Generates a correlation ID and records the start time.
Args:
request: The incoming request.
call_next: Callable to get the response from downstream.
Returns:
The response from downstream.
Tuple of (correlation_id, start_time).
"""
# Check if we should skip logging this path
if request.url.path in self.skip_paths:
return await call_next(request)
# Generate correlation ID
correlation_id = str(uuid.uuid4())[:8]
request.state.correlation_id = correlation_id
# Record start time
start_time = time.time()
return correlation_id, start_time
# Get client info
client_ip = self._get_client_ip(request)
user_agent = request.headers.get("user-agent", "-")
def _get_duration_ms(self, start_time: float) -> float:
"""Calculate the request duration in milliseconds.
try:
# Process the request
response = await call_next(request)
Args:
start_time: The start time from time.time().
# Calculate duration
duration_ms = (time.time() - start_time) * 1000
Returns:
Duration in milliseconds.
"""
return (time.time() - start_time) * 1000
# Log the request
def _log_success(
self,
request: Request,
response: Response,
correlation_id: str,
duration_ms: float,
client_ip: str,
user_agent: str,
) -> None:
"""Log a successful request.
Args:
request: The incoming request.
response: The response from downstream.
correlation_id: The request correlation ID.
duration_ms: Request duration in milliseconds.
client_ip: Client IP address.
user_agent: User-Agent header value.
"""
self._log_request(
method=request.method,
path=request.url.path,
@@ -85,16 +109,23 @@ class RequestLoggingMiddleware(BaseHTTPMiddleware):
correlation_id=correlation_id,
)
# Add correlation ID to response headers
response.headers["X-Correlation-ID"] = correlation_id
def _log_error(
self,
request: Request,
exc: Exception,
correlation_id: str,
duration_ms: float,
client_ip: str,
) -> None:
"""Log a failed request and capture the error.
return response
except Exception as exc:
# Calculate duration even for failed requests
duration_ms = (time.time() - start_time) * 1000
# Log the error
Args:
request: The incoming request.
exc: The exception that was raised.
correlation_id: The request correlation ID.
duration_ms: Request duration in milliseconds.
client_ip: Client IP address.
"""
logger.error(
f"[{correlation_id}] {request.method} {request.url.path} "
f"- ERROR - {duration_ms:.2f}ms - {client_ip} - {str(exc)}"
@@ -115,11 +146,37 @@ class RequestLoggingMiddleware(BaseHTTPMiddleware):
"duration_ms": f"{duration_ms:.0f}",
},
)
except Exception as exc:
logger.warning("Escalation logging error: %s", exc)
pass # never let escalation break the request
except Exception:
logger.warning("Escalation logging error: capture failed")
# never let escalation break the request
# Re-raise the exception
async def dispatch(self, request: Request, call_next) -> Response:
"""Log the request and response details.
Args:
request: The incoming request.
call_next: Callable to get the response from downstream.
Returns:
The response from downstream.
"""
if self._should_skip_path(request.url.path):
return await call_next(request)
correlation_id, start_time = self._prepare_request_context(request)
client_ip = self._get_client_ip(request)
user_agent = request.headers.get("user-agent", "-")
try:
response = await call_next(request)
duration_ms = self._get_duration_ms(start_time)
self._log_success(request, response, correlation_id, duration_ms, client_ip, user_agent)
response.headers["X-Correlation-ID"] = correlation_id
return response
except Exception as exc:
duration_ms = self._get_duration_ms(start_time)
self._log_error(request, exc, correlation_id, duration_ms, client_ip)
raise
def _get_client_ip(self, request: Request) -> str: