diff --git a/src/dashboard/middleware/csrf.py b/src/dashboard/middleware/csrf.py index 826ea461..e8e4a73a 100644 --- a/src/dashboard/middleware/csrf.py +++ b/src/dashboard/middleware/csrf.py @@ -175,15 +175,10 @@ class CSRFMiddleware(BaseHTTPMiddleware): return await call_next(request) # Token validation failed and path is not exempt - # We still need to call the app to check if the endpoint is decorated - # with @csrf_exempt, so we'll let it through and check after routing - response = await call_next(request) - - # After routing, check if the endpoint is marked as exempt - endpoint = request.scope.get("endpoint") + # Resolve the endpoint WITHOUT executing it to check @csrf_exempt + endpoint = self._resolve_endpoint(request) if endpoint and is_csrf_exempt(endpoint): - # Endpoint is marked as exempt, allow the response - return response + return await call_next(request) # Endpoint is not exempt and token validation failed # Return 403 error @@ -196,6 +191,41 @@ class CSRFMiddleware(BaseHTTPMiddleware): }, ) + def _resolve_endpoint(self, request: Request) -> Callable | None: + """Resolve the route endpoint without executing it. + + Walks the Starlette/FastAPI router to find which endpoint function + handles this request, so we can check @csrf_exempt before any + side effects occur. + + Returns: + The endpoint callable, or None if no route matched. + """ + # If routing already happened (endpoint in scope), use it + endpoint = request.scope.get("endpoint") + if endpoint: + return endpoint + + # Walk the middleware/app chain to find something with routes + from starlette.routing import Match + + app = self.app + while app is not None: + if hasattr(app, "routes"): + for route in app.routes: + match, _ = route.matches(request.scope) + if match == Match.FULL: + return getattr(route, "endpoint", None) + # Try .router (FastAPI stores routes on app.router) + if hasattr(app, "router") and hasattr(app.router, "routes"): + for route in app.router.routes: + match, _ = route.matches(request.scope) + if match == Match.FULL: + return getattr(route, "endpoint", None) + app = getattr(app, "app", None) + + return None + def _is_likely_exempt(self, path: str) -> bool: """Check if a path is likely to be CSRF exempt. diff --git a/tests/dashboard/middleware/test_csrf_decorator_support.py b/tests/dashboard/middleware/test_csrf_decorator_support.py index ddb042e4..77094ac6 100644 --- a/tests/dashboard/middleware/test_csrf_decorator_support.py +++ b/tests/dashboard/middleware/test_csrf_decorator_support.py @@ -120,3 +120,50 @@ class TestCSRFDecoratorSupport: # Protected endpoint should be 403 response2 = client.post("/protected") assert response2.status_code == 403 + + def test_csrf_exempt_endpoint_not_executed_before_check(self): + """Regression test for #626: endpoint must NOT execute before CSRF check. + + Previously the middleware called call_next() first, executing the endpoint + and its side effects, then checked @csrf_exempt afterward. This meant + non-exempt endpoints would execute even when CSRF validation failed. + """ + app = FastAPI() + app.add_middleware(CSRFMiddleware) + + side_effect_log: list[str] = [] + + @app.post("/protected-with-side-effects") + def protected_with_side_effects(): + side_effect_log.append("executed") + return {"message": "should not run"} + + client = TestClient(app) + + # POST without CSRF token — should be blocked with 403 + response = client.post("/protected-with-side-effects") + assert response.status_code == 403 + # The critical assertion: the endpoint must NOT have executed + assert side_effect_log == [], ( + "Endpoint executed before CSRF validation! Side effects occurred " + "despite CSRF failure (see issue #626)." + ) + + def test_csrf_exempt_endpoint_does_execute(self): + """Ensure @csrf_exempt endpoints still execute normally.""" + app = FastAPI() + app.add_middleware(CSRFMiddleware) + + side_effect_log: list[str] = [] + + @app.post("/exempt-webhook") + @csrf_exempt + def exempt_webhook(): + side_effect_log.append("executed") + return {"message": "webhook ok"} + + client = TestClient(app) + + response = client.post("/exempt-webhook") + assert response.status_code == 200 + assert side_effect_log == ["executed"]