diff --git a/src/dashboard/middleware/csrf.py b/src/dashboard/middleware/csrf.py index e8e4a73a..d370c439 100644 --- a/src/dashboard/middleware/csrf.py +++ b/src/dashboard/middleware/csrf.py @@ -131,7 +131,6 @@ class CSRFMiddleware(BaseHTTPMiddleware): For safe methods: Set a CSRF token cookie if not present. For unsafe methods: Validate the CSRF token or check if exempt. """ - # Bypass CSRF if explicitly disabled (e.g. in tests) from config import settings if settings.timmy_disable_csrf: @@ -141,47 +140,55 @@ class CSRFMiddleware(BaseHTTPMiddleware): if request.headers.get("upgrade", "").lower() == "websocket": return await call_next(request) - # Get existing CSRF token from cookie csrf_cookie = request.cookies.get(self.cookie_name) - # For safe methods, just ensure a token exists if request.method in self.SAFE_METHODS: - response = await call_next(request) + return await self._handle_safe_method(request, call_next, csrf_cookie) - # Set CSRF token cookie if not present - if not csrf_cookie: - new_token = generate_csrf_token() - response.set_cookie( - key=self.cookie_name, - value=new_token, - httponly=False, # Must be readable by JavaScript - secure=settings.csrf_cookie_secure, - samesite="Lax", - max_age=86400, # 24 hours - ) + return await self._handle_unsafe_method(request, call_next, csrf_cookie) - return response + async def _handle_safe_method( + self, request: Request, call_next, csrf_cookie: str | None + ) -> Response: + """Handle safe HTTP methods (GET, HEAD, OPTIONS, TRACE). - # For unsafe methods, we need to validate or check if exempt - # First, try to validate the CSRF token + Forwards the request and sets a CSRF token cookie if not present. + """ + from config import settings + + response = await call_next(request) + + if not csrf_cookie: + new_token = generate_csrf_token() + response.set_cookie( + key=self.cookie_name, + value=new_token, + httponly=False, # Must be readable by JavaScript + secure=settings.csrf_cookie_secure, + samesite="Lax", + max_age=86400, # 24 hours + ) + + return response + + async def _handle_unsafe_method( + self, request: Request, call_next, csrf_cookie: str | None + ) -> Response: + """Handle unsafe HTTP methods (POST, PUT, DELETE, PATCH). + + Validates the CSRF token, checks path and endpoint exemptions, + or returns a 403 error. + """ if await self._validate_request(request, csrf_cookie): - # Token is valid, allow the request return await call_next(request) - # Token validation failed, check if the path is exempt - path = request.url.path - if self._is_likely_exempt(path): - # Path is exempt, allow the request + if self._is_likely_exempt(request.url.path): return await call_next(request) - # Token validation failed and path is not exempt - # Resolve the endpoint WITHOUT executing it to check @csrf_exempt endpoint = self._resolve_endpoint(request) if endpoint and is_csrf_exempt(endpoint): return await call_next(request) - # Endpoint is not exempt and token validation failed - # Return 403 error return JSONResponse( status_code=403, content={ diff --git a/src/timmy/backends.py b/src/timmy/backends.py index 2ec05ede..3102a6b5 100644 --- a/src/timmy/backends.py +++ b/src/timmy/backends.py @@ -99,11 +99,11 @@ class GrokBackend: def _get_client(self): """Create OpenAI client configured for xAI endpoint.""" - from config import settings - import httpx from openai import OpenAI + from config import settings + return OpenAI( api_key=self._api_key, base_url=settings.xai_base_url, @@ -112,11 +112,11 @@ class GrokBackend: async def _get_async_client(self): """Create async OpenAI client configured for xAI endpoint.""" - from config import settings - import httpx from openai import AsyncOpenAI + from config import settings + return AsyncOpenAI( api_key=self._api_key, base_url=settings.xai_base_url, diff --git a/tests/conftest.py b/tests/conftest.py index c503e643..3db5de56 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -31,6 +31,8 @@ for _mod in [ "pyzbar.pyzbar", "pyttsx3", "sentence_transformers", + "swarm", + "swarm.event_log", ]: sys.modules.setdefault(_mod, MagicMock())