[kimi] refactor: extract helpers from CSRFMiddleware.dispatch (#628) #691

Merged
kimi merged 1 commits from kimi/issue-628 into main 2026-03-21 03:41:10 +00:00
3 changed files with 40 additions and 31 deletions

View File

@@ -131,7 +131,6 @@ class CSRFMiddleware(BaseHTTPMiddleware):
For safe methods: Set a CSRF token cookie if not present.
For unsafe methods: Validate the CSRF token or check if exempt.
"""
# Bypass CSRF if explicitly disabled (e.g. in tests)
from config import settings
if settings.timmy_disable_csrf:
@@ -141,47 +140,55 @@ class CSRFMiddleware(BaseHTTPMiddleware):
if request.headers.get("upgrade", "").lower() == "websocket":
return await call_next(request)
# Get existing CSRF token from cookie
csrf_cookie = request.cookies.get(self.cookie_name)
# For safe methods, just ensure a token exists
if request.method in self.SAFE_METHODS:
response = await call_next(request)
return await self._handle_safe_method(request, call_next, csrf_cookie)
# Set CSRF token cookie if not present
if not csrf_cookie:
new_token = generate_csrf_token()
response.set_cookie(
key=self.cookie_name,
value=new_token,
httponly=False, # Must be readable by JavaScript
secure=settings.csrf_cookie_secure,
samesite="Lax",
max_age=86400, # 24 hours
)
return await self._handle_unsafe_method(request, call_next, csrf_cookie)
return response
async def _handle_safe_method(
self, request: Request, call_next, csrf_cookie: str | None
) -> Response:
"""Handle safe HTTP methods (GET, HEAD, OPTIONS, TRACE).
# For unsafe methods, we need to validate or check if exempt
# First, try to validate the CSRF token
Forwards the request and sets a CSRF token cookie if not present.
"""
from config import settings
response = await call_next(request)
if not csrf_cookie:
new_token = generate_csrf_token()
response.set_cookie(
key=self.cookie_name,
value=new_token,
httponly=False, # Must be readable by JavaScript
secure=settings.csrf_cookie_secure,
samesite="Lax",
max_age=86400, # 24 hours
)
return response
async def _handle_unsafe_method(
self, request: Request, call_next, csrf_cookie: str | None
) -> Response:
"""Handle unsafe HTTP methods (POST, PUT, DELETE, PATCH).
Validates the CSRF token, checks path and endpoint exemptions,
or returns a 403 error.
"""
if await self._validate_request(request, csrf_cookie):
# Token is valid, allow the request
return await call_next(request)
# Token validation failed, check if the path is exempt
path = request.url.path
if self._is_likely_exempt(path):
# Path is exempt, allow the request
if self._is_likely_exempt(request.url.path):
return await call_next(request)
# Token validation failed and path is not exempt
# Resolve the endpoint WITHOUT executing it to check @csrf_exempt
endpoint = self._resolve_endpoint(request)
if endpoint and is_csrf_exempt(endpoint):
return await call_next(request)
# Endpoint is not exempt and token validation failed
# Return 403 error
return JSONResponse(
status_code=403,
content={

View File

@@ -99,11 +99,11 @@ class GrokBackend:
def _get_client(self):
"""Create OpenAI client configured for xAI endpoint."""
from config import settings
import httpx
from openai import OpenAI
from config import settings
return OpenAI(
api_key=self._api_key,
base_url=settings.xai_base_url,
@@ -112,11 +112,11 @@ class GrokBackend:
async def _get_async_client(self):
"""Create async OpenAI client configured for xAI endpoint."""
from config import settings
import httpx
from openai import AsyncOpenAI
from config import settings
return AsyncOpenAI(
api_key=self._api_key,
base_url=settings.xai_base_url,

View File

@@ -31,6 +31,8 @@ for _mod in [
"pyzbar.pyzbar",
"pyttsx3",
"sentence_transformers",
"swarm",
"swarm.event_log",
]:
sys.modules.setdefault(_mod, MagicMock())