diff --git a/pytest.ini b/pytest.ini index d89aa6e..8f9060e 100644 --- a/pytest.ini +++ b/pytest.ini @@ -29,7 +29,7 @@ addopts = --tb=short --strict-markers --disable-warnings - -n auto +# -n auto # Coverage configuration [coverage:run] diff --git a/src/dashboard/app.py b/src/dashboard/app.py index 4889bf4..eef3422 100644 --- a/src/dashboard/app.py +++ b/src/dashboard/app.py @@ -4,6 +4,7 @@ Key improvements: 1. Background tasks use asyncio.create_task() to avoid blocking startup 2. Chat integrations start in background 3. All startup operations complete quickly +4. Security and logging handled by dedicated middleware """ import asyncio @@ -40,6 +41,11 @@ from dashboard.routes.thinking import router as thinking_router from dashboard.routes.calm import router as calm_router from infrastructure.router.api import router as cascade_router +# Import dedicated middleware +from dashboard.middleware.csrf import CSRFMiddleware +from dashboard.middleware.request_logging import RequestLoggingMiddleware +from dashboard.middleware.security_headers import SecurityHeadersMiddleware + def _configure_logging() -> None: """Configure logging with console and optional rotating file handler.""" @@ -241,29 +247,20 @@ def _get_cors_origins() -> list[str]: return origins -async def add_security_headers(request: Request, call_next): - """Add security headers to all responses.""" - response = await call_next(request) - response.headers["X-Frame-Options"] = "SAMEORIGIN" - response.headers["X-Content-Type-Options"] = "nosniff" - response.headers["X-XSS-Protection"] = "1; mode=block" - response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin" - response.headers["Content-Security-Policy"] = ( - "default-src 'self'; " - "script-src 'self' 'unsafe-inline' cdn.jsdelivr.net; " - "style-src 'self' 'unsafe-inline' fonts.googleapis.com cdn.jsdelivr.net; " - "font-src 'self' fonts.gstatic.com; " - "img-src 'self' data: https:; " - "connect-src 'self' ws: wss:; " - "frame-ancestors 'self'; " - "base-uri 'self'; " - "form-action 'self'" - ) - return response +# Add dedicated middleware in correct order +# 1. Logging (outermost to capture everything) +app.add_middleware(RequestLoggingMiddleware, skip_paths=["/health"]) +# 2. Security Headers +app.add_middleware( + SecurityHeadersMiddleware, + production=not settings.debug +) -app.middleware("http")(add_security_headers) +# 3. CSRF Protection +app.add_middleware(CSRFMiddleware) +# 4. Standard FastAPI middleware app.add_middleware( TrustedHostMiddleware, allowed_hosts=["localhost", "127.0.0.1", "*.local", "testserver"], diff --git a/tests/dashboard/test_middleware_migration.py b/tests/dashboard/test_middleware_migration.py new file mode 100644 index 0000000..9de333d --- /dev/null +++ b/tests/dashboard/test_middleware_migration.py @@ -0,0 +1,34 @@ +import pytest +from fastapi.testclient import TestClient +from dashboard.app import app + +@pytest.fixture +def client(): + return TestClient(app) + +def test_security_headers_middleware_is_used(client): + """Verify that SecurityHeadersMiddleware is used instead of the inline function.""" + response = client.get("/") + # SecurityHeadersMiddleware sets X-Frame-Options to 'DENY' by default + # The inline function in app.py sets it to 'SAMEORIGIN' + assert response.headers["X-Frame-Options"] == "DENY" + # SecurityHeadersMiddleware also sets Permissions-Policy + assert "Permissions-Policy" in response.headers + +def test_request_logging_middleware_is_used(client): + """Verify that RequestLoggingMiddleware is used.""" + response = client.get("/") + # RequestLoggingMiddleware adds X-Correlation-ID to the response + assert "X-Correlation-ID" in response.headers + +def test_csrf_middleware_is_used(client): + """Verify that CSRFMiddleware is used.""" + # GET request should set a csrf_token cookie if not present + response = client.get("/") + assert "csrf_token" in response.cookies + + # POST request without token should be blocked (403) + # Use a path that isn't likely to be exempt + response = client.post("/agents/create") + assert response.status_code == 403 + assert response.json()["code"] == "CSRF_INVALID"