1
0

fix: implement @csrf_exempt decorator support in CSRFMiddleware (#159)

This commit is contained in:
Alexander Whitestone
2026-03-10 15:26:40 -04:00
committed by GitHub
parent 4a4c9be1eb
commit 2a5f317a12
2 changed files with 217 additions and 19 deletions

View File

@@ -0,0 +1,179 @@
"""Tests for CSRF middleware @csrf_exempt decorator support.
This test suite ensures that the @csrf_exempt decorator works correctly
to mark endpoints as exempt from CSRF validation.
"""
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from dashboard.middleware.csrf import CSRFMiddleware, csrf_exempt
class TestCSRFDecoratorSupport:
"""Test @csrf_exempt decorator functionality."""
@pytest.fixture(autouse=True)
def enable_csrf(self):
"""Re-enable CSRF for these tests."""
from config import settings
original = settings.timmy_disable_csrf
settings.timmy_disable_csrf = False
yield
settings.timmy_disable_csrf = original
def test_csrf_exempt_decorator_allows_post_without_token(self):
"""Test that @csrf_exempt decorator allows POST without CSRF token."""
app = FastAPI()
app.add_middleware(CSRFMiddleware)
@app.post("/webhook")
@csrf_exempt
def webhook_endpoint():
return {"message": "webhook received"}
client = TestClient(app)
# Should be 200 because endpoint is decorated with @csrf_exempt
response = client.post("/webhook")
assert response.status_code == 200
assert response.json() == {"message": "webhook received"}
def test_csrf_exempt_decorator_with_async_endpoint(self):
"""Test that @csrf_exempt decorator works with async endpoints."""
app = FastAPI()
app.add_middleware(CSRFMiddleware)
@app.post("/async-webhook")
@csrf_exempt
async def async_webhook_endpoint():
return {"message": "async webhook received"}
client = TestClient(app)
# Should be 200 because endpoint is decorated with @csrf_exempt
response = client.post("/async-webhook")
assert response.status_code == 200
assert response.json() == {"message": "async webhook received"}
def test_non_exempt_endpoint_requires_csrf_token(self):
"""Test that endpoints without @csrf_exempt require CSRF token."""
app = FastAPI()
app.add_middleware(CSRFMiddleware)
@app.post("/protected")
def protected_endpoint():
return {"message": "protected"}
client = TestClient(app)
# Should be 403 because endpoint is not exempt and no CSRF token provided
response = client.post("/protected")
assert response.status_code == 403
def test_csrf_exempt_endpoint_ignores_invalid_token(self):
"""Test that @csrf_exempt endpoints ignore invalid CSRF tokens."""
app = FastAPI()
app.add_middleware(CSRFMiddleware)
@app.post("/webhook")
@csrf_exempt
def webhook_endpoint():
return {"message": "webhook received"}
client = TestClient(app)
# Should be 200 even with invalid token
response = client.post(
"/webhook",
headers={"X-CSRF-Token": "invalid_token"},
)
assert response.status_code == 200
def test_exempt_endpoint_with_form_data(self):
"""Test that @csrf_exempt works with form data."""
app = FastAPI()
app.add_middleware(CSRFMiddleware)
@app.post("/webhook")
@csrf_exempt
def webhook_endpoint():
return {"message": "webhook received"}
client = TestClient(app)
# Should be 200 even with form data and no CSRF token
response = client.post(
"/webhook",
data={"key": "value"},
)
assert response.status_code == 200
def test_exempt_endpoint_with_json_data(self):
"""Test that @csrf_exempt works with JSON data."""
app = FastAPI()
app.add_middleware(CSRFMiddleware)
@app.post("/webhook")
@csrf_exempt
def webhook_endpoint():
return {"message": "webhook received"}
client = TestClient(app)
# Should be 200 even with JSON data and no CSRF token
response = client.post(
"/webhook",
json={"key": "value"},
)
assert response.status_code == 200
def test_multiple_exempt_endpoints(self):
"""Test multiple @csrf_exempt endpoints."""
app = FastAPI()
app.add_middleware(CSRFMiddleware)
@app.post("/webhook1")
@csrf_exempt
def webhook1():
return {"message": "webhook1"}
@app.post("/webhook2")
@csrf_exempt
def webhook2():
return {"message": "webhook2"}
client = TestClient(app)
# Both should be 200
response1 = client.post("/webhook1")
assert response1.status_code == 200
response2 = client.post("/webhook2")
assert response2.status_code == 200
def test_exempt_and_protected_endpoints_coexist(self):
"""Test that exempt and protected endpoints can coexist."""
app = FastAPI()
app.add_middleware(CSRFMiddleware)
@app.post("/webhook")
@csrf_exempt
def webhook_endpoint():
return {"message": "webhook"}
@app.post("/protected")
def protected_endpoint():
return {"message": "protected"}
client = TestClient(app)
# Exempt endpoint should be 200
response1 = client.post("/webhook")
assert response1.status_code == 200
# Protected endpoint should be 403
response2 = client.post("/protected")
assert response2.status_code == 403