diff --git a/src/dashboard/middleware/csrf.py b/src/dashboard/middleware/csrf.py index 3237af6..b12086f 100644 --- a/src/dashboard/middleware/csrf.py +++ b/src/dashboard/middleware/csrf.py @@ -7,6 +7,7 @@ to protect state-changing endpoints from cross-site request attacks. import secrets import hmac import hashlib +import os from typing import Callable, Optional from functools import wraps @@ -124,7 +125,6 @@ class CSRFMiddleware(BaseHTTPMiddleware): For unsafe methods: Validate the CSRF token. """ # Bypass CSRF if explicitly disabled (e.g. in tests) - import os if os.environ.get("TIMMY_DISABLE_CSRF") == "1": return await call_next(request) @@ -174,7 +174,7 @@ class CSRFMiddleware(BaseHTTPMiddleware): """Check if a path is likely to be CSRF exempt. Common patterns like webhooks, API endpoints, etc. - Uses path normalization to prevent traversal bypasses. + Uses path normalization and exact/prefix matching to prevent bypasses. Args: path: The request path. @@ -182,21 +182,42 @@ class CSRFMiddleware(BaseHTTPMiddleware): Returns: True if the path is likely exempt. """ - import os - # Normalize path to prevent /webhook/../ bypasses - normalized_path = os.path.normpath(path) + # 1. Normalize path to prevent /webhook/../ bypasses + # Use posixpath for consistent behavior on all platforms + import posixpath + normalized_path = posixpath.normpath(path) # Ensure it starts with / for comparison if not normalized_path.startswith("/"): normalized_path = "/" + normalized_path + + # Add back trailing slash if it was present in original path + # to ensure prefix matching behaves as expected + if path.endswith("/") and not normalized_path.endswith("/"): + normalized_path += "/" + # 2. Define exempt patterns with strict matching + # Patterns ending with / are prefix-matched + # Patterns NOT ending with / are exact-matched exempt_patterns = [ - "/webhook", - "/api/v1/", - "/lightning/webhook", - "/_internal/", + "/webhook/", # Prefix match (e.g., /webhook/stripe) + "/webhook", # Exact match + "/api/v1/", # Prefix match + "/lightning/webhook/", # Prefix match + "/lightning/webhook", # Exact match + "/_internal/", # Prefix match + "/_internal", # Exact match ] - return any(normalized_path.startswith(pattern) for pattern in exempt_patterns) + + for pattern in exempt_patterns: + if pattern.endswith("/"): + if normalized_path.startswith(pattern): + return True + else: + if normalized_path == pattern: + return True + + return False async def _validate_request(self, request: Request, csrf_cookie: Optional[str]) -> bool: """Validate the CSRF token in the request. diff --git a/tests/dashboard/middleware/test_csrf_bypass_vulnerability.py b/tests/dashboard/middleware/test_csrf_bypass_vulnerability.py new file mode 100644 index 0000000..2fbc95a --- /dev/null +++ b/tests/dashboard/middleware/test_csrf_bypass_vulnerability.py @@ -0,0 +1,114 @@ +"""Tests for CSRF protection middleware bypass vulnerabilities.""" + +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient +from dashboard.middleware.csrf import CSRFMiddleware + +class TestCSRFBypassVulnerability: + """Test CSRF bypass via path normalization and suffix matching.""" + + @pytest.fixture(autouse=True) + def enable_csrf(self): + """Re-enable CSRF for these tests.""" + import os + old_val = os.environ.get("TIMMY_DISABLE_CSRF") + os.environ["TIMMY_DISABLE_CSRF"] = "0" + yield + if old_val is not None: + os.environ["TIMMY_DISABLE_CSRF"] = old_val + else: + del os.environ["TIMMY_DISABLE_CSRF"] + + def test_csrf_bypass_via_traversal_to_exempt_pattern(self): + """Test if a non-exempt route can be accessed by traversing to an exempt pattern. + + The middleware uses os.path.normpath() on the request path and then checks + if it starts with an exempt pattern. If the request is to '/webhook/../api/chat', + normpath makes it '/api/chat', which DOES NOT start with '/webhook'. + + Wait, the vulnerability is actually the OTHER way around: + If I want to access '/api/chat' (protected) but I use '/webhook/../api/chat', + normpath makes it '/api/chat', which is NOT exempt. + + HOWEVER, if the middleware DOES NOT use normpath, then '/webhook/../api/chat' + WOULD start with '/webhook' and be exempt. + + The current code DOES use normpath: + ```python + normalized_path = os.path.normpath(path) + if not normalized_path.startswith("/"): + normalized_path = "/" + normalized_path + ``` + + Let's look at the exempt patterns again: + exempt_patterns = [ + "/webhook", + "/api/v1/", + "/lightning/webhook", + "/_internal/", + ] + + If I have a route '/webhook_attacker' that is NOT exempt, + but it starts with '/webhook', it WILL be exempt. + """ + app = FastAPI() + app.add_middleware(CSRFMiddleware) + + @app.post("/webhook_attacker") + def sensitive_endpoint(): + return {"message": "sensitive data accessed"} + + client = TestClient(app) + + # This route should NOT be exempt, but it starts with '/webhook' + # CSRF validation should fail (403) because we provide no token. + response = client.post("/webhook_attacker") + + # If it's 200, it's a bypass! + assert response.status_code == 403, "Route /webhook_attacker should be protected by CSRF" + + def test_csrf_bypass_via_api_v1_prefix(self): + """Test if a route like /api/v1_secret is exempt because it starts with /api/v1/.""" + # Wait, the pattern is "/api/v1/", with a trailing slash. + # So "/api/v1_secret" does NOT start with "/api/v1/". + # But "/webhook" does NOT have a trailing slash. + pass + + def test_csrf_bypass_via_webhook_prefix(self): + """Test if /webhook_secret is exempt because it starts with /webhook.""" + app = FastAPI() + app.add_middleware(CSRFMiddleware) + + @app.post("/webhook_secret") + def sensitive_endpoint(): + return {"message": "sensitive data accessed"} + + client = TestClient(app) + + # Should be 403 + response = client.post("/webhook_secret") + assert response.status_code == 403, "Route /webhook_secret should be protected by CSRF" + + def test_legitimate_exempt_paths(self): + """Test that legitimate exempt paths still work.""" + app = FastAPI() + app.add_middleware(CSRFMiddleware) + + @app.post("/webhook") + def webhook(): + return {"message": "webhook received"} + + @app.post("/api/v1/chat") + def api_chat(): + return {"message": "api chat"} + + client = TestClient(app) + + # Legitimate /webhook (exact match) + response = client.post("/webhook") + assert response.status_code == 200 + + # Legitimate /api/v1/chat (prefix match) + response = client.post("/api/v1/chat") + assert response.status_code == 200