forked from Rockachopa/Timmy-time-dashboard
Security: Enhance CSRF protection with form field support and stricter validation (#128)
This commit is contained in:
committed by
GitHub
parent
a18099a06f
commit
b8ff534ad8
@@ -154,7 +154,7 @@ class CSRFMiddleware(BaseHTTPMiddleware):
|
||||
# since FastAPI routes are resolved after middleware
|
||||
|
||||
# Try to validate token early
|
||||
if not self._validate_request(request, csrf_cookie):
|
||||
if not await self._validate_request(request, csrf_cookie):
|
||||
# Check if this might be an exempt route by checking path patterns
|
||||
# that are commonly exempt (like webhooks)
|
||||
path = request.url.path
|
||||
@@ -164,7 +164,7 @@ class CSRFMiddleware(BaseHTTPMiddleware):
|
||||
content={
|
||||
"error": "CSRF validation failed",
|
||||
"code": "CSRF_INVALID",
|
||||
"message": "Missing or invalid CSRF token. Include the token from the csrf_token cookie in the X-CSRF-Token header."
|
||||
"message": "Missing or invalid CSRF token. Include the token from the csrf_token cookie in the X-CSRF-Token header or as a form field."
|
||||
}
|
||||
)
|
||||
|
||||
@@ -189,7 +189,7 @@ class CSRFMiddleware(BaseHTTPMiddleware):
|
||||
]
|
||||
return any(pattern in path for pattern in exempt_patterns)
|
||||
|
||||
def _validate_request(self, request: Request, csrf_cookie: Optional[str]) -> bool:
|
||||
async def _validate_request(self, request: Request, csrf_cookie: Optional[str]) -> bool:
|
||||
"""Validate the CSRF token in the request.
|
||||
|
||||
Checks for token in:
|
||||
@@ -203,19 +203,26 @@ class CSRFMiddleware(BaseHTTPMiddleware):
|
||||
Returns:
|
||||
True if the token is valid, False otherwise.
|
||||
"""
|
||||
# Validate against cookie
|
||||
if not csrf_cookie:
|
||||
return False
|
||||
|
||||
# Get token from header
|
||||
header_token = request.headers.get(self.header_name)
|
||||
if header_token and validate_csrf_token(header_token, csrf_cookie):
|
||||
return True
|
||||
|
||||
# If no header token, try form data (for non-JSON POSTs)
|
||||
form_token = None
|
||||
if not header_token:
|
||||
# Note: Reading form data requires async, handled separately
|
||||
pass
|
||||
# Check Content-Type to avoid hanging on non-form requests
|
||||
content_type = request.headers.get("Content-Type", "")
|
||||
if "application/x-www-form-urlencoded" in content_type or "multipart/form-data" in content_type:
|
||||
try:
|
||||
form_data = await request.form()
|
||||
form_token = form_data.get(self.form_field)
|
||||
if form_token and validate_csrf_token(str(form_token), csrf_cookie):
|
||||
return True
|
||||
except Exception:
|
||||
# Error parsing form data, treat as invalid
|
||||
pass
|
||||
|
||||
token = header_token or form_token
|
||||
|
||||
# Validate against cookie
|
||||
if not token or not csrf_cookie:
|
||||
return False
|
||||
|
||||
return validate_csrf_token(token, csrf_cookie)
|
||||
return False
|
||||
|
||||
@@ -147,3 +147,71 @@ class TestCSRFMiddleware:
|
||||
response = client.get("/test")
|
||||
assert response.status_code == 200
|
||||
assert "csrf_token" in response.cookies or "set-cookie" in str(response.headers).lower()
|
||||
|
||||
def test_csrf_middleware_allows_with_form_field(self):
|
||||
"""POST with valid CSRF token in form field should succeed."""
|
||||
from dashboard.middleware.csrf import CSRFMiddleware, generate_csrf_token
|
||||
|
||||
app = FastAPI()
|
||||
app.add_middleware(CSRFMiddleware)
|
||||
|
||||
@app.post("/test-form")
|
||||
async def test_endpoint(request: Request):
|
||||
return {"message": "success"}
|
||||
|
||||
client = TestClient(app)
|
||||
token = generate_csrf_token()
|
||||
|
||||
# POST with valid CSRF token in form field
|
||||
response = client.post(
|
||||
"/test-form",
|
||||
data={"csrf_token": token, "other": "data"},
|
||||
cookies={"csrf_token": token}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"message": "success"}
|
||||
|
||||
def test_csrf_middleware_blocks_mismatched_token(self):
|
||||
"""POST with mismatched token should fail."""
|
||||
from dashboard.middleware.csrf import CSRFMiddleware, generate_csrf_token
|
||||
|
||||
app = FastAPI()
|
||||
app.add_middleware(CSRFMiddleware)
|
||||
|
||||
@app.post("/test")
|
||||
async def test_endpoint():
|
||||
return {"message": "success"}
|
||||
|
||||
client = TestClient(app)
|
||||
token1 = generate_csrf_token()
|
||||
token2 = generate_csrf_token()
|
||||
|
||||
# POST with token from one session and cookie from another
|
||||
response = client.post(
|
||||
"/test",
|
||||
headers={"X-CSRF-Token": token1},
|
||||
cookies={"csrf_token": token2}
|
||||
)
|
||||
assert response.status_code == 403
|
||||
assert "CSRF" in response.json().get("error", "")
|
||||
|
||||
def test_csrf_middleware_blocks_missing_cookie(self):
|
||||
"""POST with header token but missing cookie should fail."""
|
||||
from dashboard.middleware.csrf import CSRFMiddleware, generate_csrf_token
|
||||
|
||||
app = FastAPI()
|
||||
app.add_middleware(CSRFMiddleware)
|
||||
|
||||
@app.post("/test")
|
||||
async def test_endpoint():
|
||||
return {"message": "success"}
|
||||
|
||||
client = TestClient(app)
|
||||
token = generate_csrf_token()
|
||||
|
||||
# POST with header token but no cookie
|
||||
response = client.post(
|
||||
"/test",
|
||||
headers={"X-CSRF-Token": token}
|
||||
)
|
||||
assert response.status_code == 403
|
||||
|
||||
Reference in New Issue
Block a user