[loop-cycle-2] fix: resolve endpoint before execution in CSRF middleware (#626) #656

Merged
Timmy merged 1 commits from fix/csrf-exempt-check-before-dispatch into main 2026-03-20 23:05:10 +00:00
2 changed files with 85 additions and 8 deletions

View File

@@ -175,15 +175,10 @@ class CSRFMiddleware(BaseHTTPMiddleware):
return await call_next(request)
# Token validation failed and path is not exempt
# We still need to call the app to check if the endpoint is decorated
# with @csrf_exempt, so we'll let it through and check after routing
response = await call_next(request)
# After routing, check if the endpoint is marked as exempt
endpoint = request.scope.get("endpoint")
# Resolve the endpoint WITHOUT executing it to check @csrf_exempt
endpoint = self._resolve_endpoint(request)
if endpoint and is_csrf_exempt(endpoint):
# Endpoint is marked as exempt, allow the response
return response
return await call_next(request)
# Endpoint is not exempt and token validation failed
# Return 403 error
@@ -196,6 +191,41 @@ class CSRFMiddleware(BaseHTTPMiddleware):
},
)
def _resolve_endpoint(self, request: Request) -> Callable | None:
"""Resolve the route endpoint without executing it.
Walks the Starlette/FastAPI router to find which endpoint function
handles this request, so we can check @csrf_exempt before any
side effects occur.
Returns:
The endpoint callable, or None if no route matched.
"""
# If routing already happened (endpoint in scope), use it
endpoint = request.scope.get("endpoint")
if endpoint:
return endpoint
# Walk the middleware/app chain to find something with routes
from starlette.routing import Match
app = self.app
while app is not None:
if hasattr(app, "routes"):
for route in app.routes:
match, _ = route.matches(request.scope)
if match == Match.FULL:
return getattr(route, "endpoint", None)
# Try .router (FastAPI stores routes on app.router)
if hasattr(app, "router") and hasattr(app.router, "routes"):
for route in app.router.routes:
match, _ = route.matches(request.scope)
if match == Match.FULL:
return getattr(route, "endpoint", None)
app = getattr(app, "app", None)
return None
def _is_likely_exempt(self, path: str) -> bool:
"""Check if a path is likely to be CSRF exempt.

View File

@@ -120,3 +120,50 @@ class TestCSRFDecoratorSupport:
# Protected endpoint should be 403
response2 = client.post("/protected")
assert response2.status_code == 403
def test_csrf_exempt_endpoint_not_executed_before_check(self):
"""Regression test for #626: endpoint must NOT execute before CSRF check.
Previously the middleware called call_next() first, executing the endpoint
and its side effects, then checked @csrf_exempt afterward. This meant
non-exempt endpoints would execute even when CSRF validation failed.
"""
app = FastAPI()
app.add_middleware(CSRFMiddleware)
side_effect_log: list[str] = []
@app.post("/protected-with-side-effects")
def protected_with_side_effects():
side_effect_log.append("executed")
return {"message": "should not run"}
client = TestClient(app)
# POST without CSRF token — should be blocked with 403
response = client.post("/protected-with-side-effects")
assert response.status_code == 403
# The critical assertion: the endpoint must NOT have executed
assert side_effect_log == [], (
"Endpoint executed before CSRF validation! Side effects occurred "
"despite CSRF failure (see issue #626)."
)
def test_csrf_exempt_endpoint_does_execute(self):
"""Ensure @csrf_exempt endpoints still execute normally."""
app = FastAPI()
app.add_middleware(CSRFMiddleware)
side_effect_log: list[str] = []
@app.post("/exempt-webhook")
@csrf_exempt
def exempt_webhook():
side_effect_log.append("executed")
return {"message": "webhook ok"}
client = TestClient(app)
response = client.post("/exempt-webhook")
assert response.status_code == 200
assert side_effect_log == ["executed"]