1
0

feat: code quality audit + autoresearch integration + infra hardening (#150)

This commit is contained in:
Alexander Whitestone
2026-03-08 12:50:44 -04:00
committed by GitHub
parent fd0ede0d51
commit ae3bb1cc21
186 changed files with 5129 additions and 3289 deletions

View File

@@ -1,8 +1,8 @@
"""Dashboard middleware package."""
from .csrf import CSRFMiddleware, csrf_exempt, generate_csrf_token, validate_csrf_token
from .security_headers import SecurityHeadersMiddleware
from .request_logging import RequestLoggingMiddleware
from .security_headers import SecurityHeadersMiddleware
__all__ = [
"CSRFMiddleware",

View File

@@ -4,16 +4,15 @@ Provides CSRF token generation, validation, and middleware integration
to protect state-changing endpoints from cross-site request attacks.
"""
import secrets
import hmac
import hashlib
from typing import Callable, Optional
import hmac
import secrets
from functools import wraps
from typing import Callable, Optional
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response, JSONResponse
from starlette.responses import JSONResponse, Response
# Module-level set to track exempt routes
_exempt_routes: set[str] = set()
@@ -21,26 +20,27 @@ _exempt_routes: set[str] = set()
def csrf_exempt(endpoint: Callable) -> Callable:
"""Decorator to mark an endpoint as exempt from CSRF validation.
Usage:
@app.post("/webhook")
@csrf_exempt
def webhook_endpoint():
...
"""
@wraps(endpoint)
async def async_wrapper(*args, **kwargs):
return await endpoint(*args, **kwargs)
@wraps(endpoint)
def sync_wrapper(*args, **kwargs):
return endpoint(*args, **kwargs)
# Mark the original function as exempt
endpoint._csrf_exempt = True # type: ignore
# Also mark the wrapper
if hasattr(endpoint, '__code__') and endpoint.__code__.co_flags & 0x80:
if hasattr(endpoint, "__code__") and endpoint.__code__.co_flags & 0x80:
async_wrapper._csrf_exempt = True # type: ignore
return async_wrapper
else:
@@ -50,12 +50,12 @@ def csrf_exempt(endpoint: Callable) -> Callable:
def is_csrf_exempt(endpoint: Callable) -> bool:
"""Check if an endpoint is marked as CSRF exempt."""
return getattr(endpoint, '_csrf_exempt', False)
return getattr(endpoint, "_csrf_exempt", False)
def generate_csrf_token() -> str:
"""Generate a cryptographically secure CSRF token.
Returns:
A secure random token string.
"""
@@ -64,77 +64,78 @@ def generate_csrf_token() -> str:
def validate_csrf_token(token: str, expected_token: str) -> bool:
"""Validate a CSRF token against the expected token.
Uses constant-time comparison to prevent timing attacks.
Args:
token: The token provided by the client.
expected_token: The expected token (from cookie/session).
Returns:
True if the token is valid, False otherwise.
"""
if not token or not expected_token:
return False
return hmac.compare_digest(token, expected_token)
class CSRFMiddleware(BaseHTTPMiddleware):
"""Middleware to enforce CSRF protection on state-changing requests.
Safe methods (GET, HEAD, OPTIONS, TRACE) are allowed without CSRF tokens.
State-changing methods (POST, PUT, DELETE, PATCH) require a valid CSRF token.
The token is expected to be:
- In the X-CSRF-Token header, or
- In the request body as 'csrf_token', or
- Matching the token in the csrf_token cookie
Usage:
app.add_middleware(CSRFMiddleware, secret="your-secret-key")
Attributes:
secret: Secret key for token signing (optional, for future use).
cookie_name: Name of the CSRF cookie.
header_name: Name of the CSRF header.
safe_methods: HTTP methods that don't require CSRF tokens.
"""
SAFE_METHODS = {"GET", "HEAD", "OPTIONS", "TRACE"}
def __init__(
self,
app,
secret: Optional[str] = None,
cookie_name: str = "csrf_token",
header_name: str = "X-CSRF-Token",
form_field: str = "csrf_token"
form_field: str = "csrf_token",
):
super().__init__(app)
self.secret = secret
self.cookie_name = cookie_name
self.header_name = header_name
self.form_field = form_field
async def dispatch(self, request: Request, call_next) -> Response:
"""Process the request and enforce CSRF protection.
For safe methods: Set a CSRF token cookie if not present.
For unsafe methods: Validate the CSRF token.
"""
# Bypass CSRF if explicitly disabled (e.g. in tests)
from config import settings
if settings.timmy_disable_csrf:
return await call_next(request)
# Get existing CSRF token from cookie
csrf_cookie = request.cookies.get(self.cookie_name)
# For safe methods, just ensure a token exists
if request.method in self.SAFE_METHODS:
response = await call_next(request)
# Set CSRF token cookie if not present
if not csrf_cookie:
new_token = generate_csrf_token()
@@ -144,15 +145,15 @@ class CSRFMiddleware(BaseHTTPMiddleware):
httponly=False, # Must be readable by JavaScript
secure=settings.csrf_cookie_secure,
samesite="Lax",
max_age=86400 # 24 hours
max_age=86400, # 24 hours
)
return response
# For unsafe methods, check if route is exempt first
# Note: We need to let the request proceed and check at response time
# since FastAPI routes are resolved after middleware
# Try to validate token early
if not await self._validate_request(request, csrf_cookie):
# Check if this might be an exempt route by checking path patterns
@@ -164,33 +165,34 @@ class CSRFMiddleware(BaseHTTPMiddleware):
content={
"error": "CSRF validation failed",
"code": "CSRF_INVALID",
"message": "Missing or invalid CSRF token. Include the token from the csrf_token cookie in the X-CSRF-Token header or as a form field."
}
"message": "Missing or invalid CSRF token. Include the token from the csrf_token cookie in the X-CSRF-Token header or as a form field.",
},
)
return await call_next(request)
def _is_likely_exempt(self, path: str) -> bool:
"""Check if a path is likely to be CSRF exempt.
Common patterns like webhooks, API endpoints, etc.
Uses path normalization and exact/prefix matching to prevent bypasses.
Args:
path: The request path.
Returns:
True if the path is likely exempt.
"""
# 1. Normalize path to prevent /webhook/../ bypasses
# Use posixpath for consistent behavior on all platforms
import posixpath
normalized_path = posixpath.normpath(path)
# Ensure it starts with / for comparison
if not normalized_path.startswith("/"):
normalized_path = "/" + normalized_path
# Add back trailing slash if it was present in original path
# to ensure prefix matching behaves as expected
if path.endswith("/") and not normalized_path.endswith("/"):
@@ -200,15 +202,15 @@ class CSRFMiddleware(BaseHTTPMiddleware):
# Patterns ending with / are prefix-matched
# Patterns NOT ending with / are exact-matched
exempt_patterns = [
"/webhook/", # Prefix match (e.g., /webhook/stripe)
"/webhook", # Exact match
"/api/v1/", # Prefix match
"/lightning/webhook/", # Prefix match
"/webhook/", # Prefix match (e.g., /webhook/stripe)
"/webhook", # Exact match
"/api/v1/", # Prefix match
"/lightning/webhook/", # Prefix match
"/lightning/webhook", # Exact match
"/_internal/", # Prefix match
"/_internal", # Exact match
"/_internal/", # Prefix match
"/_internal", # Exact match
]
for pattern in exempt_patterns:
if pattern.endswith("/"):
if normalized_path.startswith(pattern):
@@ -216,20 +218,20 @@ class CSRFMiddleware(BaseHTTPMiddleware):
else:
if normalized_path == pattern:
return True
return False
async def _validate_request(self, request: Request, csrf_cookie: Optional[str]) -> bool:
"""Validate the CSRF token in the request.
Checks for token in:
1. X-CSRF-Token header
2. csrf_token form field
Args:
request: The incoming request.
csrf_cookie: The expected token from the cookie.
Returns:
True if the token is valid, False otherwise.
"""
@@ -241,11 +243,14 @@ class CSRFMiddleware(BaseHTTPMiddleware):
header_token = request.headers.get(self.header_name)
if header_token and validate_csrf_token(header_token, csrf_cookie):
return True
# If no header token, try form data (for non-JSON POSTs)
# Check Content-Type to avoid hanging on non-form requests
content_type = request.headers.get("Content-Type", "")
if "application/x-www-form-urlencoded" in content_type or "multipart/form-data" in content_type:
if (
"application/x-www-form-urlencoded" in content_type
or "multipart/form-data" in content_type
):
try:
form_data = await request.form()
form_token = form_data.get(self.form_field)
@@ -254,5 +259,5 @@ class CSRFMiddleware(BaseHTTPMiddleware):
except Exception:
# Error parsing form data, treat as invalid
pass
return False

View File

@@ -4,22 +4,21 @@ Logs HTTP requests with timing, status codes, and client information
for monitoring and debugging purposes.
"""
import logging
import time
import uuid
import logging
from typing import Optional, List
from typing import List, Optional
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response
logger = logging.getLogger("timmy.requests")
class RequestLoggingMiddleware(BaseHTTPMiddleware):
"""Middleware to log all HTTP requests.
Logs the following information for each request:
- HTTP method and path
- Response status code
@@ -27,60 +26,55 @@ class RequestLoggingMiddleware(BaseHTTPMiddleware):
- Client IP address
- User-Agent header
- Correlation ID for tracing
Usage:
app.add_middleware(RequestLoggingMiddleware)
# Skip certain paths:
app.add_middleware(RequestLoggingMiddleware, skip_paths=["/health", "/metrics"])
Attributes:
skip_paths: List of URL paths to skip logging.
log_level: Logging level for successful requests.
"""
def __init__(
self,
app,
skip_paths: Optional[List[str]] = None,
log_level: int = logging.INFO
):
def __init__(self, app, skip_paths: Optional[List[str]] = None, log_level: int = logging.INFO):
super().__init__(app)
self.skip_paths = set(skip_paths or [])
self.log_level = log_level
async def dispatch(self, request: Request, call_next) -> Response:
"""Log the request and response details.
Args:
request: The incoming request.
call_next: Callable to get the response from downstream.
Returns:
The response from downstream.
"""
# Check if we should skip logging this path
if request.url.path in self.skip_paths:
return await call_next(request)
# Generate correlation ID
correlation_id = str(uuid.uuid4())[:8]
request.state.correlation_id = correlation_id
# Record start time
start_time = time.time()
# Get client info
client_ip = self._get_client_ip(request)
user_agent = request.headers.get("user-agent", "-")
try:
# Process the request
response = await call_next(request)
# Calculate duration
duration_ms = (time.time() - start_time) * 1000
# Log the request
self._log_request(
method=request.method,
@@ -89,14 +83,14 @@ class RequestLoggingMiddleware(BaseHTTPMiddleware):
duration_ms=duration_ms,
client_ip=client_ip,
user_agent=user_agent,
correlation_id=correlation_id
correlation_id=correlation_id,
)
# Add correlation ID to response headers
response.headers["X-Correlation-ID"] = correlation_id
return response
except Exception as exc:
# Calculate duration even for failed requests
duration_ms = (time.time() - start_time) * 1000
@@ -110,6 +104,7 @@ class RequestLoggingMiddleware(BaseHTTPMiddleware):
# Auto-escalate: create bug report task from unhandled exception
try:
from infrastructure.error_capture import capture_error
capture_error(
exc,
source="http",
@@ -126,16 +121,16 @@ class RequestLoggingMiddleware(BaseHTTPMiddleware):
# Re-raise the exception
raise
def _get_client_ip(self, request: Request) -> str:
"""Extract the client IP address from the request.
Checks X-Forwarded-For and X-Real-IP headers first for proxied requests,
falls back to the direct client IP.
Args:
request: The incoming request.
Returns:
Client IP address string.
"""
@@ -144,17 +139,17 @@ class RequestLoggingMiddleware(BaseHTTPMiddleware):
if forwarded_for:
# X-Forwarded-For can contain multiple IPs, take the first one
return forwarded_for.split(",")[0].strip()
real_ip = request.headers.get("x-real-ip")
if real_ip:
return real_ip
# Fall back to direct connection
if request.client:
return request.client.host
return "-"
def _log_request(
self,
method: str,
@@ -163,10 +158,10 @@ class RequestLoggingMiddleware(BaseHTTPMiddleware):
duration_ms: float,
client_ip: str,
user_agent: str,
correlation_id: str
correlation_id: str,
) -> None:
"""Format and log the request details.
Args:
method: HTTP method.
path: Request path.
@@ -182,14 +177,14 @@ class RequestLoggingMiddleware(BaseHTTPMiddleware):
level = logging.ERROR
elif status_code >= 400:
level = logging.WARNING
message = (
f"[{correlation_id}] {method} {path} - {status_code} "
f"- {duration_ms:.2f}ms - {client_ip}"
)
# Add user agent for non-health requests
if path not in self.skip_paths:
message += f" - {user_agent[:50]}"
logger.log(level, message)

View File

@@ -4,6 +4,8 @@ Adds common security headers to all HTTP responses to improve
application security posture against various attacks.
"""
from typing import Optional
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response
@@ -11,7 +13,7 @@ from starlette.responses import Response
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
"""Middleware to add security headers to all responses.
Adds the following headers:
- X-Content-Type-Options: Prevents MIME type sniffing
- X-Frame-Options: Prevents clickjacking
@@ -20,41 +22,41 @@ class SecurityHeadersMiddleware(BaseHTTPMiddleware):
- Permissions-Policy: Restricts feature access
- Content-Security-Policy: Mitigates XSS and data injection
- Strict-Transport-Security: Enforces HTTPS (production only)
Usage:
app.add_middleware(SecurityHeadersMiddleware)
# Or with production settings:
app.add_middleware(SecurityHeadersMiddleware, production=True)
Attributes:
production: If True, adds HSTS header for HTTPS enforcement.
csp_report_only: If True, sends CSP in report-only mode.
"""
def __init__(
self,
app,
production: bool = False,
csp_report_only: bool = False,
custom_csp: str = None
custom_csp: Optional[str] = None,
):
super().__init__(app)
self.production = production
self.csp_report_only = csp_report_only
# Build CSP directive
self.csp_directive = custom_csp or self._build_csp()
def _build_csp(self) -> str:
"""Build the Content-Security-Policy directive.
Creates a restrictive default policy that allows:
- Same-origin resources by default
- Inline scripts/styles (needed for HTMX/Bootstrap)
- Data URIs for images
- WebSocket connections
Returns:
CSP directive string.
"""
@@ -73,25 +75,25 @@ class SecurityHeadersMiddleware(BaseHTTPMiddleware):
"form-action 'self'",
]
return "; ".join(directives)
def _add_security_headers(self, response: Response) -> None:
"""Add security headers to a response.
Args:
response: The response to add headers to.
"""
# Prevent MIME type sniffing
response.headers["X-Content-Type-Options"] = "nosniff"
# Prevent clickjacking
response.headers["X-Frame-Options"] = "SAMEORIGIN"
# Enable XSS protection (legacy browsers)
response.headers["X-XSS-Protection"] = "1; mode=block"
# Control referrer information
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
# Restrict browser features
response.headers["Permissions-Policy"] = (
"camera=(), "
@@ -103,38 +105,41 @@ class SecurityHeadersMiddleware(BaseHTTPMiddleware):
"gyroscope=(), "
"accelerometer=()"
)
# Content Security Policy
csp_header = "Content-Security-Policy-Report-Only" if self.csp_report_only else "Content-Security-Policy"
csp_header = (
"Content-Security-Policy-Report-Only"
if self.csp_report_only
else "Content-Security-Policy"
)
response.headers[csp_header] = self.csp_directive
# HTTPS enforcement (production only)
if self.production:
response.headers["Strict-Transport-Security"] = (
"max-age=31536000; includeSubDomains; preload"
)
response.headers[
"Strict-Transport-Security"
] = "max-age=31536000; includeSubDomains; preload"
async def dispatch(self, request: Request, call_next) -> Response:
"""Add security headers to the response.
Args:
request: The incoming request.
call_next: Callable to get the response from downstream.
Returns:
Response with security headers added.
"""
try:
response = await call_next(request)
self._add_security_headers(response)
return response
except Exception:
# Create a response for the error with security headers
from starlette.responses import PlainTextResponse
response = PlainTextResponse(
content="Internal Server Error",
status_code=500
import logging
logging.getLogger(__name__).debug(
"Upstream error in security headers middleware", exc_info=True
)
self._add_security_headers(response)
# Return the error response with headers (don't re-raise)
return response
from starlette.responses import PlainTextResponse
response = PlainTextResponse("Internal Server Error", status_code=500)
self._add_security_headers(response)
return response