Compare commits
1 Commits
fix/test-l
...
fix/csrf-c
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
59e789e2da |
@@ -175,18 +175,12 @@ class CSRFMiddleware(BaseHTTPMiddleware):
|
||||
return await call_next(request)
|
||||
|
||||
# Token validation failed and path is not exempt
|
||||
# We still need to call the app to check if the endpoint is decorated
|
||||
# with @csrf_exempt, so we'll let it through and check after routing
|
||||
response = await call_next(request)
|
||||
|
||||
# After routing, check if the endpoint is marked as exempt
|
||||
endpoint = request.scope.get("endpoint")
|
||||
# Resolve the endpoint from routes BEFORE executing to avoid side effects
|
||||
endpoint = self._resolve_endpoint(request)
|
||||
if endpoint and is_csrf_exempt(endpoint):
|
||||
# Endpoint is marked as exempt, allow the response
|
||||
return response
|
||||
return await call_next(request)
|
||||
|
||||
# Endpoint is not exempt and token validation failed
|
||||
# Return 403 error
|
||||
# Endpoint is not exempt and token validation failed — reject without executing
|
||||
return JSONResponse(
|
||||
status_code=403,
|
||||
content={
|
||||
@@ -196,6 +190,42 @@ class CSRFMiddleware(BaseHTTPMiddleware):
|
||||
},
|
||||
)
|
||||
|
||||
def _resolve_endpoint(self, request: Request) -> Callable | None:
|
||||
"""Resolve the endpoint for a request without executing it.
|
||||
|
||||
Walks the app chain to find routes, then matches against the request
|
||||
scope. This allows checking @csrf_exempt before the handler runs
|
||||
(avoiding side effects on CSRF rejection).
|
||||
|
||||
Returns:
|
||||
The endpoint callable if found, None otherwise.
|
||||
"""
|
||||
try:
|
||||
from starlette.routing import Match
|
||||
|
||||
# Walk the middleware/app chain to find something with routes
|
||||
routes = None
|
||||
current = self.app
|
||||
for _ in range(10): # Safety limit
|
||||
routes = getattr(current, "routes", None)
|
||||
if routes:
|
||||
break
|
||||
current = getattr(current, "app", None)
|
||||
if current is None:
|
||||
break
|
||||
|
||||
if not routes:
|
||||
return None
|
||||
|
||||
scope = dict(request.scope)
|
||||
for route in routes:
|
||||
match, child_scope = route.matches(scope)
|
||||
if match == Match.FULL:
|
||||
return child_scope.get("endpoint")
|
||||
except Exception:
|
||||
logger.debug("Failed to resolve endpoint for CSRF check")
|
||||
return None
|
||||
|
||||
def _is_likely_exempt(self, path: str) -> bool:
|
||||
"""Check if a path is likely to be CSRF exempt.
|
||||
|
||||
|
||||
100
tests/dashboard/middleware/test_csrf_no_side_effects.py
Normal file
100
tests/dashboard/middleware/test_csrf_no_side_effects.py
Normal file
@@ -0,0 +1,100 @@
|
||||
"""Tests that CSRF rejection does NOT execute the endpoint handler.
|
||||
|
||||
Regression test for #626: the middleware was calling call_next() before
|
||||
checking @csrf_exempt, causing side effects even on CSRF-rejected requests.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from dashboard.middleware.csrf import CSRFMiddleware, csrf_exempt
|
||||
|
||||
|
||||
class TestCSRFNoSideEffects:
|
||||
"""Verify endpoints are NOT executed when CSRF validation fails."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def enable_csrf(self):
|
||||
"""Re-enable CSRF for these tests."""
|
||||
from config import settings
|
||||
|
||||
original = settings.timmy_disable_csrf
|
||||
settings.timmy_disable_csrf = False
|
||||
yield
|
||||
settings.timmy_disable_csrf = original
|
||||
|
||||
def test_protected_endpoint_not_executed_on_csrf_failure(self):
|
||||
"""A protected endpoint must NOT run when CSRF token is missing.
|
||||
|
||||
Before the fix, the middleware called call_next() to resolve the
|
||||
endpoint, executing its side effects before returning 403.
|
||||
"""
|
||||
app = FastAPI()
|
||||
app.add_middleware(CSRFMiddleware)
|
||||
|
||||
side_effect_log = []
|
||||
|
||||
@app.post("/transfer")
|
||||
def transfer_money():
|
||||
side_effect_log.append("money_transferred")
|
||||
return {"message": "transferred"}
|
||||
|
||||
client = TestClient(app)
|
||||
response = client.post("/transfer")
|
||||
|
||||
assert response.status_code == 403
|
||||
assert side_effect_log == [], (
|
||||
"Endpoint was executed despite CSRF failure — side effects occurred!"
|
||||
)
|
||||
|
||||
def test_csrf_exempt_endpoint_still_executes(self):
|
||||
"""A @csrf_exempt endpoint should still execute without a CSRF token."""
|
||||
app = FastAPI()
|
||||
app.add_middleware(CSRFMiddleware)
|
||||
|
||||
side_effect_log = []
|
||||
|
||||
@app.post("/webhook-handler")
|
||||
@csrf_exempt
|
||||
def webhook_handler():
|
||||
side_effect_log.append("webhook_processed")
|
||||
return {"message": "processed"}
|
||||
|
||||
client = TestClient(app)
|
||||
response = client.post("/webhook-handler")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert side_effect_log == ["webhook_processed"]
|
||||
|
||||
def test_exempt_and_protected_no_cross_contamination(self):
|
||||
"""Mixed exempt/protected: only exempt endpoints execute without tokens."""
|
||||
app = FastAPI()
|
||||
app.add_middleware(CSRFMiddleware)
|
||||
|
||||
execution_log = []
|
||||
|
||||
@app.post("/safe-webhook")
|
||||
@csrf_exempt
|
||||
def safe_webhook():
|
||||
execution_log.append("safe")
|
||||
return {"message": "safe"}
|
||||
|
||||
@app.post("/dangerous-action")
|
||||
def dangerous_action():
|
||||
execution_log.append("dangerous")
|
||||
return {"message": "danger"}
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
# Exempt endpoint runs
|
||||
resp1 = client.post("/safe-webhook")
|
||||
assert resp1.status_code == 200
|
||||
|
||||
# Protected endpoint blocked WITHOUT executing
|
||||
resp2 = client.post("/dangerous-action")
|
||||
assert resp2.status_code == 403
|
||||
|
||||
assert execution_log == ["safe"], (
|
||||
f"Expected only 'safe' execution, got: {execution_log}"
|
||||
)
|
||||
Reference in New Issue
Block a user