[loop-cycle-2] fix: resolve endpoint before execution in CSRF middleware (#626) (#656)

This commit is contained in:
2026-03-20 23:05:09 +00:00
parent d2a5866650
commit 9d0f5c778e
2 changed files with 85 additions and 8 deletions

View File

@@ -175,15 +175,10 @@ class CSRFMiddleware(BaseHTTPMiddleware):
return await call_next(request)
# Token validation failed and path is not exempt
# We still need to call the app to check if the endpoint is decorated
# with @csrf_exempt, so we'll let it through and check after routing
response = await call_next(request)
# After routing, check if the endpoint is marked as exempt
endpoint = request.scope.get("endpoint")
# Resolve the endpoint WITHOUT executing it to check @csrf_exempt
endpoint = self._resolve_endpoint(request)
if endpoint and is_csrf_exempt(endpoint):
# Endpoint is marked as exempt, allow the response
return response
return await call_next(request)
# Endpoint is not exempt and token validation failed
# Return 403 error
@@ -196,6 +191,41 @@ class CSRFMiddleware(BaseHTTPMiddleware):
},
)
def _resolve_endpoint(self, request: Request) -> Callable | None:
"""Resolve the route endpoint without executing it.
Walks the Starlette/FastAPI router to find which endpoint function
handles this request, so we can check @csrf_exempt before any
side effects occur.
Returns:
The endpoint callable, or None if no route matched.
"""
# If routing already happened (endpoint in scope), use it
endpoint = request.scope.get("endpoint")
if endpoint:
return endpoint
# Walk the middleware/app chain to find something with routes
from starlette.routing import Match
app = self.app
while app is not None:
if hasattr(app, "routes"):
for route in app.routes:
match, _ = route.matches(request.scope)
if match == Match.FULL:
return getattr(route, "endpoint", None)
# Try .router (FastAPI stores routes on app.router)
if hasattr(app, "router") and hasattr(app.router, "routes"):
for route in app.router.routes:
match, _ = route.matches(request.scope)
if match == Match.FULL:
return getattr(route, "endpoint", None)
app = getattr(app, "app", None)
return None
def _is_likely_exempt(self, path: str) -> bool:
"""Check if a path is likely to be CSRF exempt.