From b8ff534ad811f89fd0627e27c99b1833709e129a Mon Sep 17 00:00:00 2001 From: Alexander Whitestone <8633216+AlexanderWhitestone@users.noreply.github.com> Date: Thu, 5 Mar 2026 07:04:30 -0500 Subject: [PATCH] Security: Enhance CSRF protection with form field support and stricter validation (#128) --- src/dashboard/middleware/csrf.py | 35 ++++++++----- tests/dashboard/middleware/test_csrf.py | 68 +++++++++++++++++++++++++ 2 files changed, 89 insertions(+), 14 deletions(-) diff --git a/src/dashboard/middleware/csrf.py b/src/dashboard/middleware/csrf.py index 9d17070..6a3b5e3 100644 --- a/src/dashboard/middleware/csrf.py +++ b/src/dashboard/middleware/csrf.py @@ -154,7 +154,7 @@ class CSRFMiddleware(BaseHTTPMiddleware): # since FastAPI routes are resolved after middleware # Try to validate token early - if not self._validate_request(request, csrf_cookie): + if not await self._validate_request(request, csrf_cookie): # Check if this might be an exempt route by checking path patterns # that are commonly exempt (like webhooks) path = request.url.path @@ -164,7 +164,7 @@ class CSRFMiddleware(BaseHTTPMiddleware): content={ "error": "CSRF validation failed", "code": "CSRF_INVALID", - "message": "Missing or invalid CSRF token. Include the token from the csrf_token cookie in the X-CSRF-Token header." + "message": "Missing or invalid CSRF token. Include the token from the csrf_token cookie in the X-CSRF-Token header or as a form field." } ) @@ -189,7 +189,7 @@ class CSRFMiddleware(BaseHTTPMiddleware): ] return any(pattern in path for pattern in exempt_patterns) - def _validate_request(self, request: Request, csrf_cookie: Optional[str]) -> bool: + async def _validate_request(self, request: Request, csrf_cookie: Optional[str]) -> bool: """Validate the CSRF token in the request. Checks for token in: @@ -203,19 +203,26 @@ class CSRFMiddleware(BaseHTTPMiddleware): Returns: True if the token is valid, False otherwise. """ + # Validate against cookie + if not csrf_cookie: + return False + # Get token from header header_token = request.headers.get(self.header_name) + if header_token and validate_csrf_token(header_token, csrf_cookie): + return True # If no header token, try form data (for non-JSON POSTs) - form_token = None - if not header_token: - # Note: Reading form data requires async, handled separately - pass + # Check Content-Type to avoid hanging on non-form requests + content_type = request.headers.get("Content-Type", "") + if "application/x-www-form-urlencoded" in content_type or "multipart/form-data" in content_type: + try: + form_data = await request.form() + form_token = form_data.get(self.form_field) + if form_token and validate_csrf_token(str(form_token), csrf_cookie): + return True + except Exception: + # Error parsing form data, treat as invalid + pass - token = header_token or form_token - - # Validate against cookie - if not token or not csrf_cookie: - return False - - return validate_csrf_token(token, csrf_cookie) + return False diff --git a/tests/dashboard/middleware/test_csrf.py b/tests/dashboard/middleware/test_csrf.py index de762ee..ab65579 100644 --- a/tests/dashboard/middleware/test_csrf.py +++ b/tests/dashboard/middleware/test_csrf.py @@ -147,3 +147,71 @@ class TestCSRFMiddleware: response = client.get("/test") assert response.status_code == 200 assert "csrf_token" in response.cookies or "set-cookie" in str(response.headers).lower() + + def test_csrf_middleware_allows_with_form_field(self): + """POST with valid CSRF token in form field should succeed.""" + from dashboard.middleware.csrf import CSRFMiddleware, generate_csrf_token + + app = FastAPI() + app.add_middleware(CSRFMiddleware) + + @app.post("/test-form") + async def test_endpoint(request: Request): + return {"message": "success"} + + client = TestClient(app) + token = generate_csrf_token() + + # POST with valid CSRF token in form field + response = client.post( + "/test-form", + data={"csrf_token": token, "other": "data"}, + cookies={"csrf_token": token} + ) + assert response.status_code == 200 + assert response.json() == {"message": "success"} + + def test_csrf_middleware_blocks_mismatched_token(self): + """POST with mismatched token should fail.""" + from dashboard.middleware.csrf import CSRFMiddleware, generate_csrf_token + + app = FastAPI() + app.add_middleware(CSRFMiddleware) + + @app.post("/test") + async def test_endpoint(): + return {"message": "success"} + + client = TestClient(app) + token1 = generate_csrf_token() + token2 = generate_csrf_token() + + # POST with token from one session and cookie from another + response = client.post( + "/test", + headers={"X-CSRF-Token": token1}, + cookies={"csrf_token": token2} + ) + assert response.status_code == 403 + assert "CSRF" in response.json().get("error", "") + + def test_csrf_middleware_blocks_missing_cookie(self): + """POST with header token but missing cookie should fail.""" + from dashboard.middleware.csrf import CSRFMiddleware, generate_csrf_token + + app = FastAPI() + app.add_middleware(CSRFMiddleware) + + @app.post("/test") + async def test_endpoint(): + return {"message": "success"} + + client = TestClient(app) + token = generate_csrf_token() + + # POST with header token but no cookie + response = client.post( + "/test", + headers={"X-CSRF-Token": token} + ) + assert response.status_code == 403