forked from Rockachopa/Timmy-time-dashboard
This commit is contained in:
@@ -175,15 +175,10 @@ class CSRFMiddleware(BaseHTTPMiddleware):
|
||||
return await call_next(request)
|
||||
|
||||
# Token validation failed and path is not exempt
|
||||
# We still need to call the app to check if the endpoint is decorated
|
||||
# with @csrf_exempt, so we'll let it through and check after routing
|
||||
response = await call_next(request)
|
||||
|
||||
# After routing, check if the endpoint is marked as exempt
|
||||
endpoint = request.scope.get("endpoint")
|
||||
# Resolve the endpoint WITHOUT executing it to check @csrf_exempt
|
||||
endpoint = self._resolve_endpoint(request)
|
||||
if endpoint and is_csrf_exempt(endpoint):
|
||||
# Endpoint is marked as exempt, allow the response
|
||||
return response
|
||||
return await call_next(request)
|
||||
|
||||
# Endpoint is not exempt and token validation failed
|
||||
# Return 403 error
|
||||
@@ -196,6 +191,41 @@ class CSRFMiddleware(BaseHTTPMiddleware):
|
||||
},
|
||||
)
|
||||
|
||||
def _resolve_endpoint(self, request: Request) -> Callable | None:
|
||||
"""Resolve the route endpoint without executing it.
|
||||
|
||||
Walks the Starlette/FastAPI router to find which endpoint function
|
||||
handles this request, so we can check @csrf_exempt before any
|
||||
side effects occur.
|
||||
|
||||
Returns:
|
||||
The endpoint callable, or None if no route matched.
|
||||
"""
|
||||
# If routing already happened (endpoint in scope), use it
|
||||
endpoint = request.scope.get("endpoint")
|
||||
if endpoint:
|
||||
return endpoint
|
||||
|
||||
# Walk the middleware/app chain to find something with routes
|
||||
from starlette.routing import Match
|
||||
|
||||
app = self.app
|
||||
while app is not None:
|
||||
if hasattr(app, "routes"):
|
||||
for route in app.routes:
|
||||
match, _ = route.matches(request.scope)
|
||||
if match == Match.FULL:
|
||||
return getattr(route, "endpoint", None)
|
||||
# Try .router (FastAPI stores routes on app.router)
|
||||
if hasattr(app, "router") and hasattr(app.router, "routes"):
|
||||
for route in app.router.routes:
|
||||
match, _ = route.matches(request.scope)
|
||||
if match == Match.FULL:
|
||||
return getattr(route, "endpoint", None)
|
||||
app = getattr(app, "app", None)
|
||||
|
||||
return None
|
||||
|
||||
def _is_likely_exempt(self, path: str) -> bool:
|
||||
"""Check if a path is likely to be CSRF exempt.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user