forked from Rockachopa/Timmy-time-dashboard
feat: code quality audit + autoresearch integration + infra hardening (#150)
This commit is contained in:
committed by
GitHub
parent
fd0ede0d51
commit
ae3bb1cc21
@@ -13,6 +13,7 @@ class TestCSRFMiddleware:
|
||||
def enable_csrf(self):
|
||||
"""Re-enable CSRF for these tests."""
|
||||
from config import settings
|
||||
|
||||
original = settings.timmy_disable_csrf
|
||||
settings.timmy_disable_csrf = False
|
||||
yield
|
||||
@@ -21,29 +22,29 @@ class TestCSRFMiddleware:
|
||||
def test_csrf_token_generation(self):
|
||||
"""CSRF token should be generated and stored in session/state."""
|
||||
from dashboard.middleware.csrf import generate_csrf_token
|
||||
|
||||
|
||||
token1 = generate_csrf_token()
|
||||
token2 = generate_csrf_token()
|
||||
|
||||
|
||||
# Tokens should be non-empty strings
|
||||
assert isinstance(token1, str)
|
||||
assert len(token1) > 0
|
||||
|
||||
|
||||
# Each token should be unique
|
||||
assert token1 != token2
|
||||
|
||||
def test_csrf_token_validation(self):
|
||||
"""Valid CSRF tokens should pass validation."""
|
||||
from dashboard.middleware.csrf import generate_csrf_token, validate_csrf_token
|
||||
|
||||
|
||||
token = generate_csrf_token()
|
||||
|
||||
|
||||
# Same token should validate
|
||||
assert validate_csrf_token(token, token) is True
|
||||
|
||||
|
||||
# Different tokens should not validate
|
||||
assert validate_csrf_token(token, "different-token") is False
|
||||
|
||||
|
||||
# Empty tokens should not validate
|
||||
assert validate_csrf_token(token, "") is False
|
||||
assert validate_csrf_token("", token) is False
|
||||
@@ -51,16 +52,16 @@ class TestCSRFMiddleware:
|
||||
def test_csrf_middleware_allows_safe_methods(self):
|
||||
"""GET, HEAD, OPTIONS requests should not require CSRF token."""
|
||||
from dashboard.middleware.csrf import CSRFMiddleware
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
app.add_middleware(CSRFMiddleware, secret="test-secret")
|
||||
|
||||
|
||||
@app.get("/test")
|
||||
def test_endpoint():
|
||||
return {"message": "success"}
|
||||
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
# GET should work without CSRF token
|
||||
response = client.get("/test")
|
||||
assert response.status_code == 200
|
||||
@@ -69,16 +70,16 @@ class TestCSRFMiddleware:
|
||||
def test_csrf_middleware_blocks_unsafe_methods_without_token(self):
|
||||
"""POST, PUT, DELETE should require CSRF token."""
|
||||
from dashboard.middleware.csrf import CSRFMiddleware
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
app.add_middleware(CSRFMiddleware, secret="test-secret")
|
||||
|
||||
|
||||
@app.post("/test")
|
||||
def test_endpoint():
|
||||
return {"message": "success"}
|
||||
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
# POST without CSRF token should fail
|
||||
response = client.post("/test")
|
||||
assert response.status_code == 403
|
||||
@@ -87,24 +88,22 @@ class TestCSRFMiddleware:
|
||||
def test_csrf_middleware_allows_with_valid_token(self):
|
||||
"""POST with valid CSRF token should succeed."""
|
||||
from dashboard.middleware.csrf import CSRFMiddleware, generate_csrf_token
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
app.add_middleware(CSRFMiddleware, secret="test-secret")
|
||||
|
||||
|
||||
@app.post("/test")
|
||||
def test_endpoint():
|
||||
return {"message": "success"}
|
||||
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
# Get CSRF token from cookie or header
|
||||
token = generate_csrf_token()
|
||||
|
||||
|
||||
# POST with valid CSRF token
|
||||
response = client.post(
|
||||
"/test",
|
||||
headers={"X-CSRF-Token": token},
|
||||
cookies={"csrf_token": token}
|
||||
"/test", headers={"X-CSRF-Token": token}, cookies={"csrf_token": token}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"message": "success"}
|
||||
@@ -112,16 +111,16 @@ class TestCSRFMiddleware:
|
||||
def test_csrf_middleware_exempt_routes(self):
|
||||
"""Routes with webhook patterns should bypass CSRF validation."""
|
||||
from dashboard.middleware.csrf import CSRFMiddleware
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
app.add_middleware(CSRFMiddleware, secret="test-secret")
|
||||
|
||||
|
||||
@app.post("/webhook")
|
||||
def webhook_endpoint():
|
||||
return {"message": "webhook received"}
|
||||
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
# POST to exempt route without CSRF token should work
|
||||
response = client.post("/webhook")
|
||||
assert response.status_code == 200
|
||||
@@ -130,16 +129,16 @@ class TestCSRFMiddleware:
|
||||
def test_csrf_token_in_cookie(self):
|
||||
"""CSRF token should be set in cookie for frontend to read."""
|
||||
from dashboard.middleware.csrf import CSRFMiddleware
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
app.add_middleware(CSRFMiddleware, secret="test-secret")
|
||||
|
||||
|
||||
@app.get("/test")
|
||||
def test_endpoint():
|
||||
return {"message": "success"}
|
||||
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
# GET should set CSRF cookie
|
||||
response = client.get("/test")
|
||||
assert response.status_code == 200
|
||||
@@ -148,22 +147,20 @@ class TestCSRFMiddleware:
|
||||
def test_csrf_middleware_allows_with_form_field(self):
|
||||
"""POST with valid CSRF token in form field should succeed."""
|
||||
from dashboard.middleware.csrf import CSRFMiddleware, generate_csrf_token
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
app.add_middleware(CSRFMiddleware)
|
||||
|
||||
|
||||
@app.post("/test-form")
|
||||
async def test_endpoint(request: Request):
|
||||
return {"message": "success"}
|
||||
|
||||
|
||||
client = TestClient(app)
|
||||
token = generate_csrf_token()
|
||||
|
||||
|
||||
# POST with valid CSRF token in form field
|
||||
response = client.post(
|
||||
"/test-form",
|
||||
data={"csrf_token": token, "other": "data"},
|
||||
cookies={"csrf_token": token}
|
||||
"/test-form", data={"csrf_token": token, "other": "data"}, cookies={"csrf_token": token}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"message": "success"}
|
||||
@@ -171,23 +168,21 @@ class TestCSRFMiddleware:
|
||||
def test_csrf_middleware_blocks_mismatched_token(self):
|
||||
"""POST with mismatched token should fail."""
|
||||
from dashboard.middleware.csrf import CSRFMiddleware, generate_csrf_token
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
app.add_middleware(CSRFMiddleware)
|
||||
|
||||
|
||||
@app.post("/test")
|
||||
async def test_endpoint():
|
||||
return {"message": "success"}
|
||||
|
||||
|
||||
client = TestClient(app)
|
||||
token1 = generate_csrf_token()
|
||||
token2 = generate_csrf_token()
|
||||
|
||||
|
||||
# POST with token from one session and cookie from another
|
||||
response = client.post(
|
||||
"/test",
|
||||
headers={"X-CSRF-Token": token1},
|
||||
cookies={"csrf_token": token2}
|
||||
"/test", headers={"X-CSRF-Token": token1}, cookies={"csrf_token": token2}
|
||||
)
|
||||
assert response.status_code == 403
|
||||
assert "CSRF" in response.json().get("error", "")
|
||||
@@ -195,20 +190,17 @@ class TestCSRFMiddleware:
|
||||
def test_csrf_middleware_blocks_missing_cookie(self):
|
||||
"""POST with header token but missing cookie should fail."""
|
||||
from dashboard.middleware.csrf import CSRFMiddleware, generate_csrf_token
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
app.add_middleware(CSRFMiddleware)
|
||||
|
||||
|
||||
@app.post("/test")
|
||||
async def test_endpoint():
|
||||
return {"message": "success"}
|
||||
|
||||
|
||||
client = TestClient(app)
|
||||
token = generate_csrf_token()
|
||||
|
||||
|
||||
# POST with header token but no cookie
|
||||
response = client.post(
|
||||
"/test",
|
||||
headers={"X-CSRF-Token": token}
|
||||
)
|
||||
response = client.post("/test", headers={"X-CSRF-Token": token})
|
||||
assert response.status_code == 403
|
||||
|
||||
@@ -3,8 +3,10 @@
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from dashboard.middleware.csrf import CSRFMiddleware
|
||||
|
||||
|
||||
class TestCSRFBypass:
|
||||
"""Test potential CSRF bypasses."""
|
||||
|
||||
@@ -12,6 +14,7 @@ class TestCSRFBypass:
|
||||
def enable_csrf(self):
|
||||
"""Re-enable CSRF for these tests."""
|
||||
from config import settings
|
||||
|
||||
original = settings.timmy_disable_csrf
|
||||
settings.timmy_disable_csrf = False
|
||||
yield
|
||||
@@ -21,19 +24,16 @@ class TestCSRFBypass:
|
||||
"""POST should require CSRF token even with AJAX headers (if not explicitly allowed)."""
|
||||
app = FastAPI()
|
||||
app.add_middleware(CSRFMiddleware)
|
||||
|
||||
|
||||
@app.post("/test")
|
||||
def test_endpoint():
|
||||
return {"message": "success"}
|
||||
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
# POST with X-Requested-With should STILL fail if it's not a valid CSRF token
|
||||
# Some older middlewares used to trust this header blindly.
|
||||
response = client.post(
|
||||
"/test",
|
||||
headers={"X-Requested-With": "XMLHttpRequest"}
|
||||
)
|
||||
response = client.post("/test", headers={"X-Requested-With": "XMLHttpRequest"})
|
||||
# This should fail with 403 because no CSRF token is provided
|
||||
assert response.status_code == 403
|
||||
|
||||
@@ -41,32 +41,32 @@ class TestCSRFBypass:
|
||||
"""Test if path traversal can bypass CSRF exempt patterns."""
|
||||
app = FastAPI()
|
||||
app.add_middleware(CSRFMiddleware)
|
||||
|
||||
|
||||
@app.post("/test")
|
||||
def test_endpoint():
|
||||
return {"message": "success"}
|
||||
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
# If the middleware checks path starts with /webhook,
|
||||
|
||||
# If the middleware checks path starts with /webhook,
|
||||
# can we use /webhook/../test to bypass?
|
||||
# Note: TestClient/FastAPI might normalize this, but we should check the logic.
|
||||
response = client.post("/webhook/../test")
|
||||
|
||||
|
||||
# If it bypassed, it would return 200 (if normalized to /test) or 404 (if not).
|
||||
# But it should definitely not return 200 success without CSRF.
|
||||
if response.status_code == 200:
|
||||
assert response.json() != {"message": "success"}
|
||||
|
||||
|
||||
def test_csrf_middleware_null_byte_bypass(self):
|
||||
"""Test if null byte in path can bypass CSRF exempt patterns."""
|
||||
app = FastAPI()
|
||||
middleware = CSRFMiddleware(app)
|
||||
|
||||
|
||||
# Test directly since TestClient blocks null bytes
|
||||
path = "/webhook\0/test"
|
||||
is_exempt = middleware._is_likely_exempt(path)
|
||||
|
||||
|
||||
# It should either be not exempt or the null byte should be handled
|
||||
# In our current implementation, it might still be exempt if normalized to /webhook\0/test
|
||||
# But it's better than /webhook/../test
|
||||
|
||||
@@ -3,8 +3,10 @@
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from dashboard.middleware.csrf import CSRFMiddleware
|
||||
|
||||
|
||||
class TestCSRFBypassVulnerability:
|
||||
"""Test CSRF bypass via path normalization and suffix matching."""
|
||||
|
||||
@@ -12,6 +14,7 @@ class TestCSRFBypassVulnerability:
|
||||
def enable_csrf(self):
|
||||
"""Re-enable CSRF for these tests."""
|
||||
from config import settings
|
||||
|
||||
original = settings.timmy_disable_csrf
|
||||
settings.timmy_disable_csrf = False
|
||||
yield
|
||||
@@ -19,25 +22,25 @@ class TestCSRFBypassVulnerability:
|
||||
|
||||
def test_csrf_bypass_via_traversal_to_exempt_pattern(self):
|
||||
"""Test if a non-exempt route can be accessed by traversing to an exempt pattern.
|
||||
|
||||
The middleware uses os.path.normpath() on the request path and then checks
|
||||
|
||||
The middleware uses os.path.normpath() on the request path and then checks
|
||||
if it starts with an exempt pattern. If the request is to '/webhook/../api/chat',
|
||||
normpath makes it '/api/chat', which DOES NOT start with '/webhook'.
|
||||
|
||||
|
||||
Wait, the vulnerability is actually the OTHER way around:
|
||||
If I want to access '/api/chat' (protected) but I use '/webhook/../api/chat',
|
||||
normpath makes it '/api/chat', which is NOT exempt.
|
||||
|
||||
|
||||
HOWEVER, if the middleware DOES NOT use normpath, then '/webhook/../api/chat'
|
||||
WOULD start with '/webhook' and be exempt.
|
||||
|
||||
|
||||
The current code DOES use normpath:
|
||||
```python
|
||||
normalized_path = os.path.normpath(path)
|
||||
if not normalized_path.startswith("/"):
|
||||
normalized_path = "/" + normalized_path
|
||||
```
|
||||
|
||||
|
||||
Let's look at the exempt patterns again:
|
||||
exempt_patterns = [
|
||||
"/webhook",
|
||||
@@ -45,23 +48,23 @@ class TestCSRFBypassVulnerability:
|
||||
"/lightning/webhook",
|
||||
"/_internal/",
|
||||
]
|
||||
|
||||
If I have a route '/webhook_attacker' that is NOT exempt,
|
||||
|
||||
If I have a route '/webhook_attacker' that is NOT exempt,
|
||||
but it starts with '/webhook', it WILL be exempt.
|
||||
"""
|
||||
app = FastAPI()
|
||||
app.add_middleware(CSRFMiddleware)
|
||||
|
||||
|
||||
@app.post("/webhook_attacker")
|
||||
def sensitive_endpoint():
|
||||
return {"message": "sensitive data accessed"}
|
||||
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
# This route should NOT be exempt, but it starts with '/webhook'
|
||||
# CSRF validation should fail (403) because we provide no token.
|
||||
response = client.post("/webhook_attacker")
|
||||
|
||||
|
||||
# If it's 200, it's a bypass!
|
||||
assert response.status_code == 403, "Route /webhook_attacker should be protected by CSRF"
|
||||
|
||||
@@ -76,13 +79,13 @@ class TestCSRFBypassVulnerability:
|
||||
"""Test if /webhook_secret is exempt because it starts with /webhook."""
|
||||
app = FastAPI()
|
||||
app.add_middleware(CSRFMiddleware)
|
||||
|
||||
|
||||
@app.post("/webhook_secret")
|
||||
def sensitive_endpoint():
|
||||
return {"message": "sensitive data accessed"}
|
||||
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
# Should be 403
|
||||
response = client.post("/webhook_secret")
|
||||
assert response.status_code == 403, "Route /webhook_secret should be protected by CSRF"
|
||||
@@ -91,21 +94,21 @@ class TestCSRFBypassVulnerability:
|
||||
"""Test that legitimate exempt paths still work."""
|
||||
app = FastAPI()
|
||||
app.add_middleware(CSRFMiddleware)
|
||||
|
||||
|
||||
@app.post("/webhook")
|
||||
def webhook():
|
||||
return {"message": "webhook received"}
|
||||
|
||||
|
||||
@app.post("/api/v1/chat")
|
||||
def api_chat():
|
||||
return {"message": "api chat"}
|
||||
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
# Legitimate /webhook (exact match)
|
||||
response = client.post("/webhook")
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
# Legitimate /api/v1/chat (prefix match)
|
||||
response = client.post("/api/v1/chat")
|
||||
assert response.status_code == 200
|
||||
|
||||
@@ -3,8 +3,10 @@
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from dashboard.middleware.csrf import CSRFMiddleware
|
||||
|
||||
|
||||
class TestCSRFTraversal:
|
||||
"""Test path traversal CSRF bypass."""
|
||||
|
||||
@@ -12,6 +14,7 @@ class TestCSRFTraversal:
|
||||
def enable_csrf(self):
|
||||
"""Re-enable CSRF for these tests."""
|
||||
from config import settings
|
||||
|
||||
original = settings.timmy_disable_csrf
|
||||
settings.timmy_disable_csrf = False
|
||||
yield
|
||||
@@ -21,21 +24,21 @@ class TestCSRFTraversal:
|
||||
"""Test if path traversal can bypass CSRF exempt patterns."""
|
||||
app = FastAPI()
|
||||
app.add_middleware(CSRFMiddleware)
|
||||
|
||||
|
||||
@app.post("/test")
|
||||
def test_endpoint():
|
||||
return {"message": "success"}
|
||||
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
# We want to check if the middleware logic is flawed.
|
||||
# Since TestClient might normalize, we can test the _is_likely_exempt method directly.
|
||||
middleware = CSRFMiddleware(app)
|
||||
|
||||
|
||||
# This path starts with /webhook, but resolves to /test
|
||||
traversal_path = "/webhook/../test"
|
||||
|
||||
|
||||
# If this returns True, it's a vulnerability because /test is not supposed to be exempt.
|
||||
is_exempt = middleware._is_likely_exempt(traversal_path)
|
||||
|
||||
|
||||
assert is_exempt is False, f"Path {traversal_path} should not be exempt"
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
"""Tests for request logging middleware."""
|
||||
|
||||
import pytest
|
||||
import time
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.testclient import TestClient
|
||||
@@ -15,23 +16,23 @@ class TestRequestLoggingMiddleware:
|
||||
def app_with_logging(self):
|
||||
"""Create app with request logging middleware."""
|
||||
from dashboard.middleware.request_logging import RequestLoggingMiddleware
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
app.add_middleware(RequestLoggingMiddleware)
|
||||
|
||||
|
||||
@app.get("/test")
|
||||
def test_endpoint():
|
||||
return {"message": "success"}
|
||||
|
||||
|
||||
@app.get("/slow")
|
||||
def slow_endpoint():
|
||||
time.sleep(0.1)
|
||||
return {"message": "slow response"}
|
||||
|
||||
|
||||
@app.get("/error")
|
||||
def error_endpoint():
|
||||
raise ValueError("Test error")
|
||||
|
||||
|
||||
return app
|
||||
|
||||
def test_logs_request_method_and_path(self, app_with_logging, caplog):
|
||||
@@ -40,17 +41,18 @@ class TestRequestLoggingMiddleware:
|
||||
client = TestClient(app_with_logging)
|
||||
response = client.get("/test")
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
# Check log contains method and path
|
||||
assert any("GET" in record.message and "/test" in record.message
|
||||
for record in caplog.records)
|
||||
assert any(
|
||||
"GET" in record.message and "/test" in record.message for record in caplog.records
|
||||
)
|
||||
|
||||
def test_logs_response_status_code(self, app_with_logging, caplog):
|
||||
"""Log should include response status code."""
|
||||
with caplog.at_level("INFO"):
|
||||
client = TestClient(app_with_logging)
|
||||
response = client.get("/test")
|
||||
|
||||
|
||||
# Check log contains status code
|
||||
assert any("200" in record.message for record in caplog.records)
|
||||
|
||||
@@ -59,27 +61,30 @@ class TestRequestLoggingMiddleware:
|
||||
with caplog.at_level("INFO"):
|
||||
client = TestClient(app_with_logging)
|
||||
response = client.get("/slow")
|
||||
|
||||
|
||||
# Check log contains duration (e.g., "0.1" or "100ms")
|
||||
assert any(record.message for record in caplog.records
|
||||
if any(c.isdigit() for c in record.message))
|
||||
assert any(
|
||||
record.message for record in caplog.records if any(c.isdigit() for c in record.message)
|
||||
)
|
||||
|
||||
def test_logs_client_ip(self, app_with_logging, caplog):
|
||||
"""Log should include client IP address."""
|
||||
with caplog.at_level("INFO"):
|
||||
client = TestClient(app_with_logging)
|
||||
response = client.get("/test", headers={"X-Forwarded-For": "192.168.1.1"})
|
||||
|
||||
|
||||
# Check log contains IP
|
||||
assert any("192.168.1.1" in record.message or "127.0.0.1" in record.message
|
||||
for record in caplog.records)
|
||||
assert any(
|
||||
"192.168.1.1" in record.message or "127.0.0.1" in record.message
|
||||
for record in caplog.records
|
||||
)
|
||||
|
||||
def test_logs_user_agent(self, app_with_logging, caplog):
|
||||
"""Log should include User-Agent header."""
|
||||
with caplog.at_level("INFO"):
|
||||
client = TestClient(app_with_logging)
|
||||
response = client.get("/test", headers={"User-Agent": "TestAgent/1.0"})
|
||||
|
||||
|
||||
# Check log contains user agent
|
||||
assert any("TestAgent" in record.message for record in caplog.records)
|
||||
|
||||
@@ -88,7 +93,7 @@ class TestRequestLoggingMiddleware:
|
||||
with caplog.at_level("ERROR"):
|
||||
client = TestClient(app_with_logging, raise_server_exceptions=False)
|
||||
response = client.get("/error")
|
||||
|
||||
|
||||
assert response.status_code == 500
|
||||
# Should have error log
|
||||
assert any(record.levelname == "ERROR" for record in caplog.records)
|
||||
@@ -96,18 +101,18 @@ class TestRequestLoggingMiddleware:
|
||||
def test_skips_health_check_logging(self, caplog):
|
||||
"""Health check endpoints should not be logged (to reduce noise)."""
|
||||
from dashboard.middleware.request_logging import RequestLoggingMiddleware
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
app.add_middleware(RequestLoggingMiddleware, skip_paths=["/health"])
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
def health_endpoint():
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
with caplog.at_level("INFO", logger="timmy.requests"):
|
||||
client = TestClient(app)
|
||||
response = client.get("/health")
|
||||
|
||||
|
||||
# Should not log health check (only check our logger's records)
|
||||
timmy_records = [r for r in caplog.records if r.name == "timmy.requests"]
|
||||
assert not any("/health" in record.message for record in timmy_records)
|
||||
@@ -117,7 +122,7 @@ class TestRequestLoggingMiddleware:
|
||||
with caplog.at_level("INFO"):
|
||||
client = TestClient(app_with_logging)
|
||||
response = client.get("/test")
|
||||
|
||||
|
||||
# Check for correlation ID format (UUID or similar)
|
||||
log_messages = [record.message for record in caplog.records]
|
||||
assert any(len(record.message) > 20 for record in caplog.records) # Rough check for ID
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.responses import JSONResponse, HTMLResponse
|
||||
from fastapi.responses import HTMLResponse, JSONResponse
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
@@ -13,18 +13,18 @@ class TestSecurityHeadersMiddleware:
|
||||
def client_with_headers(self):
|
||||
"""Create a test client with security headers middleware."""
|
||||
from dashboard.middleware.security_headers import SecurityHeadersMiddleware
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
app.add_middleware(SecurityHeadersMiddleware)
|
||||
|
||||
|
||||
@app.get("/test")
|
||||
def test_endpoint():
|
||||
return {"message": "success"}
|
||||
|
||||
|
||||
@app.get("/html")
|
||||
def html_endpoint():
|
||||
return HTMLResponse(content="<html><body>Test</body></html>")
|
||||
|
||||
|
||||
return TestClient(app)
|
||||
|
||||
def test_x_content_type_options_header(self, client_with_headers):
|
||||
@@ -66,17 +66,17 @@ class TestSecurityHeadersMiddleware:
|
||||
def test_strict_transport_security_in_production(self):
|
||||
"""HSTS header should be set in production mode."""
|
||||
from dashboard.middleware.security_headers import SecurityHeadersMiddleware
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
app.add_middleware(SecurityHeadersMiddleware, production=True)
|
||||
|
||||
|
||||
@app.get("/test")
|
||||
def test_endpoint():
|
||||
return {"message": "success"}
|
||||
|
||||
|
||||
client = TestClient(app)
|
||||
response = client.get("/test")
|
||||
|
||||
|
||||
hsts = response.headers.get("strict-transport-security")
|
||||
assert hsts is not None
|
||||
assert "max-age=" in hsts
|
||||
@@ -89,18 +89,18 @@ class TestSecurityHeadersMiddleware:
|
||||
def test_headers_on_error_response(self):
|
||||
"""Security headers should be set even on error responses."""
|
||||
from dashboard.middleware.security_headers import SecurityHeadersMiddleware
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
app.add_middleware(SecurityHeadersMiddleware)
|
||||
|
||||
|
||||
@app.get("/error")
|
||||
def error_endpoint():
|
||||
raise ValueError("Test error")
|
||||
|
||||
|
||||
# Use raise_server_exceptions=False to get the error response
|
||||
client = TestClient(app, raise_server_exceptions=False)
|
||||
response = client.get("/error")
|
||||
|
||||
|
||||
# Even on 500 error, security headers should be present
|
||||
assert response.status_code == 500
|
||||
assert response.headers.get("x-content-type-options") == "nosniff"
|
||||
|
||||
Reference in New Issue
Block a user