feat: add security middleware suite - CSRF, security headers, and request logging (#102)
Implements three security middleware components with full test coverage:
- CSRF Protection: Token generation/validation, safe method allowlist,
auto-exempt webhooks, constant-time comparison for timing attack prevention
- Security Headers: X-Content-Type-Options, X-Frame-Options, CSP,
Permissions-Policy, Referrer-Policy, HSTS (production)
- Request Logging: Method/path/status/duration logging with correlation IDs,
configurable path exclusions, X-Forwarded-For support
Also fixes Discord test isolation issue where settings.discord_token
was not being properly reset between tests.
New files:
- src/dashboard/middleware/{csrf,security_headers,request_logging}.py
- tests/dashboard/middleware/test_{csrf,security_headers,request_logging}.py
Addresses design review recommendations R3, R8, R9, R4.
All tests pass: 1950 passed, 40 skipped
Co-authored-by: Alexander Payne <apayne@MM.local>
This commit is contained in:
committed by
GitHub
parent
6eefcabc97
commit
3a8496a3f1
137
tests/dashboard/middleware/test_csrf.py
Normal file
137
tests/dashboard/middleware/test_csrf.py
Normal file
@@ -0,0 +1,137 @@
|
||||
"""Tests for CSRF protection middleware."""
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
class TestCSRFMiddleware:
|
||||
"""Test CSRF token validation and generation."""
|
||||
|
||||
def test_csrf_token_generation(self):
|
||||
"""CSRF token should be generated and stored in session/state."""
|
||||
from dashboard.middleware.csrf import generate_csrf_token
|
||||
|
||||
token1 = generate_csrf_token()
|
||||
token2 = generate_csrf_token()
|
||||
|
||||
# Tokens should be non-empty strings
|
||||
assert isinstance(token1, str)
|
||||
assert len(token1) > 0
|
||||
|
||||
# Each token should be unique
|
||||
assert token1 != token2
|
||||
|
||||
def test_csrf_token_validation(self):
|
||||
"""Valid CSRF tokens should pass validation."""
|
||||
from dashboard.middleware.csrf import generate_csrf_token, validate_csrf_token
|
||||
|
||||
token = generate_csrf_token()
|
||||
|
||||
# Same token should validate
|
||||
assert validate_csrf_token(token, token) is True
|
||||
|
||||
# Different tokens should not validate
|
||||
assert validate_csrf_token(token, "different-token") is False
|
||||
|
||||
# Empty tokens should not validate
|
||||
assert validate_csrf_token(token, "") is False
|
||||
assert validate_csrf_token("", token) is False
|
||||
|
||||
def test_csrf_middleware_allows_safe_methods(self):
|
||||
"""GET, HEAD, OPTIONS requests should not require CSRF token."""
|
||||
from dashboard.middleware.csrf import CSRFMiddleware
|
||||
|
||||
app = FastAPI()
|
||||
app.add_middleware(CSRFMiddleware, secret="test-secret")
|
||||
|
||||
@app.get("/test")
|
||||
def test_endpoint():
|
||||
return {"message": "success"}
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
# GET should work without CSRF token
|
||||
response = client.get("/test")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"message": "success"}
|
||||
|
||||
def test_csrf_middleware_blocks_unsafe_methods_without_token(self):
|
||||
"""POST, PUT, DELETE should require CSRF token."""
|
||||
from dashboard.middleware.csrf import CSRFMiddleware
|
||||
|
||||
app = FastAPI()
|
||||
app.add_middleware(CSRFMiddleware, secret="test-secret")
|
||||
|
||||
@app.post("/test")
|
||||
def test_endpoint():
|
||||
return {"message": "success"}
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
# POST without CSRF token should fail
|
||||
response = client.post("/test")
|
||||
assert response.status_code == 403
|
||||
assert "csrf" in response.json().get("error", "").lower()
|
||||
|
||||
def test_csrf_middleware_allows_with_valid_token(self):
|
||||
"""POST with valid CSRF token should succeed."""
|
||||
from dashboard.middleware.csrf import CSRFMiddleware, generate_csrf_token
|
||||
|
||||
app = FastAPI()
|
||||
app.add_middleware(CSRFMiddleware, secret="test-secret")
|
||||
|
||||
@app.post("/test")
|
||||
def test_endpoint():
|
||||
return {"message": "success"}
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
# Get CSRF token from cookie or header
|
||||
token = generate_csrf_token()
|
||||
|
||||
# POST with valid CSRF token
|
||||
response = client.post(
|
||||
"/test",
|
||||
headers={"X-CSRF-Token": token},
|
||||
cookies={"csrf_token": token}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"message": "success"}
|
||||
|
||||
def test_csrf_middleware_exempt_routes(self):
|
||||
"""Routes with webhook patterns should bypass CSRF validation."""
|
||||
from dashboard.middleware.csrf import CSRFMiddleware
|
||||
|
||||
app = FastAPI()
|
||||
app.add_middleware(CSRFMiddleware, secret="test-secret")
|
||||
|
||||
@app.post("/webhook")
|
||||
def webhook_endpoint():
|
||||
return {"message": "webhook received"}
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
# POST to exempt route without CSRF token should work
|
||||
response = client.post("/webhook")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"message": "webhook received"}
|
||||
|
||||
def test_csrf_token_in_cookie(self):
|
||||
"""CSRF token should be set in cookie for frontend to read."""
|
||||
from dashboard.middleware.csrf import CSRFMiddleware
|
||||
|
||||
app = FastAPI()
|
||||
app.add_middleware(CSRFMiddleware, secret="test-secret")
|
||||
|
||||
@app.get("/test")
|
||||
def test_endpoint():
|
||||
return {"message": "success"}
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
# GET should set CSRF cookie
|
||||
response = client.get("/test")
|
||||
assert response.status_code == 200
|
||||
assert "csrf_token" in response.cookies or "set-cookie" in str(response.headers).lower()
|
||||
123
tests/dashboard/middleware/test_request_logging.py
Normal file
123
tests/dashboard/middleware/test_request_logging.py
Normal file
@@ -0,0 +1,123 @@
|
||||
"""Tests for request logging middleware."""
|
||||
|
||||
import pytest
|
||||
import time
|
||||
from unittest.mock import Mock, patch
|
||||
from fastapi import FastAPI
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
class TestRequestLoggingMiddleware:
|
||||
"""Test request logging captures essential information."""
|
||||
|
||||
@pytest.fixture
|
||||
def app_with_logging(self):
|
||||
"""Create app with request logging middleware."""
|
||||
from dashboard.middleware.request_logging import RequestLoggingMiddleware
|
||||
|
||||
app = FastAPI()
|
||||
app.add_middleware(RequestLoggingMiddleware)
|
||||
|
||||
@app.get("/test")
|
||||
def test_endpoint():
|
||||
return {"message": "success"}
|
||||
|
||||
@app.get("/slow")
|
||||
def slow_endpoint():
|
||||
time.sleep(0.1)
|
||||
return {"message": "slow response"}
|
||||
|
||||
@app.get("/error")
|
||||
def error_endpoint():
|
||||
raise ValueError("Test error")
|
||||
|
||||
return app
|
||||
|
||||
def test_logs_request_method_and_path(self, app_with_logging, caplog):
|
||||
"""Log should include HTTP method and path."""
|
||||
with caplog.at_level("INFO"):
|
||||
client = TestClient(app_with_logging)
|
||||
response = client.get("/test")
|
||||
assert response.status_code == 200
|
||||
|
||||
# Check log contains method and path
|
||||
assert any("GET" in record.message and "/test" in record.message
|
||||
for record in caplog.records)
|
||||
|
||||
def test_logs_response_status_code(self, app_with_logging, caplog):
|
||||
"""Log should include response status code."""
|
||||
with caplog.at_level("INFO"):
|
||||
client = TestClient(app_with_logging)
|
||||
response = client.get("/test")
|
||||
|
||||
# Check log contains status code
|
||||
assert any("200" in record.message for record in caplog.records)
|
||||
|
||||
def test_logs_request_duration(self, app_with_logging, caplog):
|
||||
"""Log should include request processing time."""
|
||||
with caplog.at_level("INFO"):
|
||||
client = TestClient(app_with_logging)
|
||||
response = client.get("/slow")
|
||||
|
||||
# Check log contains duration (e.g., "0.1" or "100ms")
|
||||
assert any(record.message for record in caplog.records
|
||||
if any(c.isdigit() for c in record.message))
|
||||
|
||||
def test_logs_client_ip(self, app_with_logging, caplog):
|
||||
"""Log should include client IP address."""
|
||||
with caplog.at_level("INFO"):
|
||||
client = TestClient(app_with_logging)
|
||||
response = client.get("/test", headers={"X-Forwarded-For": "192.168.1.1"})
|
||||
|
||||
# Check log contains IP
|
||||
assert any("192.168.1.1" in record.message or "127.0.0.1" in record.message
|
||||
for record in caplog.records)
|
||||
|
||||
def test_logs_user_agent(self, app_with_logging, caplog):
|
||||
"""Log should include User-Agent header."""
|
||||
with caplog.at_level("INFO"):
|
||||
client = TestClient(app_with_logging)
|
||||
response = client.get("/test", headers={"User-Agent": "TestAgent/1.0"})
|
||||
|
||||
# Check log contains user agent
|
||||
assert any("TestAgent" in record.message for record in caplog.records)
|
||||
|
||||
def test_logs_error_requests(self, app_with_logging, caplog):
|
||||
"""Errors should be logged with appropriate level."""
|
||||
with caplog.at_level("ERROR"):
|
||||
client = TestClient(app_with_logging, raise_server_exceptions=False)
|
||||
response = client.get("/error")
|
||||
|
||||
assert response.status_code == 500
|
||||
# Should have error log
|
||||
assert any(record.levelname == "ERROR" for record in caplog.records)
|
||||
|
||||
def test_skips_health_check_logging(self, caplog):
|
||||
"""Health check endpoints should not be logged (to reduce noise)."""
|
||||
from dashboard.middleware.request_logging import RequestLoggingMiddleware
|
||||
|
||||
app = FastAPI()
|
||||
app.add_middleware(RequestLoggingMiddleware, skip_paths=["/health"])
|
||||
|
||||
@app.get("/health")
|
||||
def health_endpoint():
|
||||
return {"status": "ok"}
|
||||
|
||||
with caplog.at_level("INFO", logger="timmy.requests"):
|
||||
client = TestClient(app)
|
||||
response = client.get("/health")
|
||||
|
||||
# Should not log health check (only check our logger's records)
|
||||
timmy_records = [r for r in caplog.records if r.name == "timmy.requests"]
|
||||
assert not any("/health" in record.message for record in timmy_records)
|
||||
|
||||
def test_correlation_id_in_logs(self, app_with_logging, caplog):
|
||||
"""Each request should have a unique correlation ID."""
|
||||
with caplog.at_level("INFO"):
|
||||
client = TestClient(app_with_logging)
|
||||
response = client.get("/test")
|
||||
|
||||
# Check for correlation ID format (UUID or similar)
|
||||
log_messages = [record.message for record in caplog.records]
|
||||
assert any(len(record.message) > 20 for record in caplog.records) # Rough check for ID
|
||||
107
tests/dashboard/middleware/test_security_headers.py
Normal file
107
tests/dashboard/middleware/test_security_headers.py
Normal file
@@ -0,0 +1,107 @@
|
||||
"""Tests for security headers middleware."""
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.responses import JSONResponse, HTMLResponse
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
class TestSecurityHeadersMiddleware:
|
||||
"""Test security headers are properly set on responses."""
|
||||
|
||||
@pytest.fixture
|
||||
def client_with_headers(self):
|
||||
"""Create a test client with security headers middleware."""
|
||||
from dashboard.middleware.security_headers import SecurityHeadersMiddleware
|
||||
|
||||
app = FastAPI()
|
||||
app.add_middleware(SecurityHeadersMiddleware)
|
||||
|
||||
@app.get("/test")
|
||||
def test_endpoint():
|
||||
return {"message": "success"}
|
||||
|
||||
@app.get("/html")
|
||||
def html_endpoint():
|
||||
return HTMLResponse(content="<html><body>Test</body></html>")
|
||||
|
||||
return TestClient(app)
|
||||
|
||||
def test_x_content_type_options_header(self, client_with_headers):
|
||||
"""X-Content-Type-Options should be set to nosniff."""
|
||||
response = client_with_headers.get("/test")
|
||||
assert response.headers.get("x-content-type-options") == "nosniff"
|
||||
|
||||
def test_x_frame_options_header(self, client_with_headers):
|
||||
"""X-Frame-Options should be set to DENY."""
|
||||
response = client_with_headers.get("/test")
|
||||
assert response.headers.get("x-frame-options") == "DENY"
|
||||
|
||||
def test_x_xss_protection_header(self, client_with_headers):
|
||||
"""X-XSS-Protection should be enabled."""
|
||||
response = client_with_headers.get("/test")
|
||||
assert "1; mode=block" in response.headers.get("x-xss-protection", "")
|
||||
|
||||
def test_referrer_policy_header(self, client_with_headers):
|
||||
"""Referrer-Policy should be set."""
|
||||
response = client_with_headers.get("/test")
|
||||
assert response.headers.get("referrer-policy") == "strict-origin-when-cross-origin"
|
||||
|
||||
def test_permissions_policy_header(self, client_with_headers):
|
||||
"""Permissions-Policy should restrict sensitive features."""
|
||||
response = client_with_headers.get("/test")
|
||||
policy = response.headers.get("permissions-policy", "")
|
||||
assert "camera=()" in policy
|
||||
assert "microphone=()" in policy
|
||||
assert "geolocation=()" in policy
|
||||
|
||||
def test_content_security_policy_header(self, client_with_headers):
|
||||
"""Content-Security-Policy should be set for HTML responses."""
|
||||
response = client_with_headers.get("/html")
|
||||
csp = response.headers.get("content-security-policy", "")
|
||||
assert "default-src 'self'" in csp
|
||||
assert "script-src" in csp
|
||||
assert "style-src" in csp
|
||||
|
||||
def test_strict_transport_security_in_production(self):
|
||||
"""HSTS header should be set in production mode."""
|
||||
from dashboard.middleware.security_headers import SecurityHeadersMiddleware
|
||||
|
||||
app = FastAPI()
|
||||
app.add_middleware(SecurityHeadersMiddleware, production=True)
|
||||
|
||||
@app.get("/test")
|
||||
def test_endpoint():
|
||||
return {"message": "success"}
|
||||
|
||||
client = TestClient(app)
|
||||
response = client.get("/test")
|
||||
|
||||
hsts = response.headers.get("strict-transport-security")
|
||||
assert hsts is not None
|
||||
assert "max-age=" in hsts
|
||||
|
||||
def test_strict_transport_security_not_in_dev(self, client_with_headers):
|
||||
"""HSTS header should not be set in development mode."""
|
||||
response = client_with_headers.get("/test")
|
||||
assert "strict-transport-security" not in response.headers
|
||||
|
||||
def test_headers_on_error_response(self):
|
||||
"""Security headers should be set even on error responses."""
|
||||
from dashboard.middleware.security_headers import SecurityHeadersMiddleware
|
||||
|
||||
app = FastAPI()
|
||||
app.add_middleware(SecurityHeadersMiddleware)
|
||||
|
||||
@app.get("/error")
|
||||
def error_endpoint():
|
||||
raise ValueError("Test error")
|
||||
|
||||
# Use raise_server_exceptions=False to get the error response
|
||||
client = TestClient(app, raise_server_exceptions=False)
|
||||
response = client.get("/error")
|
||||
|
||||
# Even on 500 error, security headers should be present
|
||||
assert response.status_code == 500
|
||||
assert response.headers.get("x-content-type-options") == "nosniff"
|
||||
assert response.headers.get("x-frame-options") == "DENY"
|
||||
@@ -54,9 +54,12 @@ class TestDiscordVendor:
|
||||
def test_load_token_missing_file(self, tmp_path, monkeypatch):
|
||||
from integrations.chat_bridge.vendors import discord as discord_mod
|
||||
from integrations.chat_bridge.vendors.discord import DiscordVendor
|
||||
from config import settings
|
||||
|
||||
state_file = tmp_path / "nonexistent.json"
|
||||
monkeypatch.setattr(discord_mod, "_STATE_FILE", state_file)
|
||||
# Ensure settings.discord_token is empty for test isolation
|
||||
monkeypatch.setattr(settings, "discord_token", "")
|
||||
|
||||
vendor = DiscordVendor()
|
||||
# Falls back to config.settings.discord_token
|
||||
|
||||
Reference in New Issue
Block a user