fix: resolve CSRF exempt endpoint before execution (#626)

This commit is contained in:
hermes
2026-03-20 18:57:51 -04:00
parent d2a5866650
commit 59e789e2da
2 changed files with 140 additions and 10 deletions

View File

@@ -175,18 +175,12 @@ class CSRFMiddleware(BaseHTTPMiddleware):
return await call_next(request)
# Token validation failed and path is not exempt
# We still need to call the app to check if the endpoint is decorated
# with @csrf_exempt, so we'll let it through and check after routing
response = await call_next(request)
# After routing, check if the endpoint is marked as exempt
endpoint = request.scope.get("endpoint")
# Resolve the endpoint from routes BEFORE executing to avoid side effects
endpoint = self._resolve_endpoint(request)
if endpoint and is_csrf_exempt(endpoint):
# Endpoint is marked as exempt, allow the response
return response
return await call_next(request)
# Endpoint is not exempt and token validation failed
# Return 403 error
# Endpoint is not exempt and token validation failed — reject without executing
return JSONResponse(
status_code=403,
content={
@@ -196,6 +190,42 @@ class CSRFMiddleware(BaseHTTPMiddleware):
},
)
def _resolve_endpoint(self, request: Request) -> Callable | None:
"""Resolve the endpoint for a request without executing it.
Walks the app chain to find routes, then matches against the request
scope. This allows checking @csrf_exempt before the handler runs
(avoiding side effects on CSRF rejection).
Returns:
The endpoint callable if found, None otherwise.
"""
try:
from starlette.routing import Match
# Walk the middleware/app chain to find something with routes
routes = None
current = self.app
for _ in range(10): # Safety limit
routes = getattr(current, "routes", None)
if routes:
break
current = getattr(current, "app", None)
if current is None:
break
if not routes:
return None
scope = dict(request.scope)
for route in routes:
match, child_scope = route.matches(scope)
if match == Match.FULL:
return child_scope.get("endpoint")
except Exception:
logger.debug("Failed to resolve endpoint for CSRF check")
return None
def _is_likely_exempt(self, path: str) -> bool:
"""Check if a path is likely to be CSRF exempt.