"""Tests for OAuth Session Fixation protection (V-014 fix). These tests verify that: 1. State parameter is generated cryptographically securely 2. State is validated on callback to prevent CSRF attacks 3. State is cleared after validation to prevent replay attacks 4. Session is regenerated after successful OAuth authentication """ import asyncio import json import secrets import threading import time from unittest.mock import MagicMock, patch import pytest from tools.mcp_oauth import ( OAuthStateManager, OAuthStateError, SecureOAuthState, regenerate_session_after_auth, _make_callback_handler, _state_manager, get_state_manager, ) # --------------------------------------------------------------------------- # OAuthStateManager Tests # --------------------------------------------------------------------------- class TestOAuthStateManager: """Test the OAuth state manager for session fixation protection.""" def setup_method(self): """Reset state manager before each test.""" _state_manager.invalidate() def test_generate_state_creates_secure_token(self): """State should be a cryptographically secure signed token.""" state = _state_manager.generate_state() # Should be a non-empty string assert isinstance(state, str) assert len(state) > 0 # Should be URL-safe (contains data.signature format) assert "." in state # Format: . def test_generate_state_unique_each_time(self): """Each generated state should be unique.""" states = [_state_manager.generate_state() for _ in range(10)] # All states should be different assert len(set(states)) == 10 def test_validate_and_extract_success(self): """Validating correct state should succeed.""" state = _state_manager.generate_state() is_valid, data = _state_manager.validate_and_extract(state) assert is_valid is True assert data is not None def test_validate_and_extract_wrong_state_fails(self): """Validating wrong state should fail (CSRF protection).""" _state_manager.generate_state() # Try to validate with a different state wrong_state = "invalid_state_data" is_valid, data = _state_manager.validate_and_extract(wrong_state) assert is_valid is False assert data is None def test_validate_and_extract_none_fails(self): """Validating None state should fail.""" _state_manager.generate_state() is_valid, data = _state_manager.validate_and_extract(None) assert is_valid is False assert data is None def test_validate_and_extract_no_generation_fails(self): """Validating when no state was generated should fail.""" # Don't generate state first is_valid, data = _state_manager.validate_and_extract("some_state") assert is_valid is False assert data is None def test_validate_and_extract_prevents_replay(self): """State should be cleared after validation to prevent replay.""" state = _state_manager.generate_state() # First validation should succeed is_valid, data = _state_manager.validate_and_extract(state) assert is_valid is True # Second validation with same state should fail (replay attack) is_valid, data = _state_manager.validate_and_extract(state) assert is_valid is False def test_invalidate_clears_state(self): """Explicit invalidation should clear state.""" state = _state_manager.generate_state() _state_manager.invalidate() # Validation should fail after invalidation is_valid, data = _state_manager.validate_and_extract(state) assert is_valid is False def test_thread_safety(self): """State manager should be thread-safe.""" results = [] def generate_and_validate(): state = _state_manager.generate_state() time.sleep(0.01) # Small delay to encourage race conditions is_valid, _ = _state_manager.validate_and_extract(state) results.append(is_valid) # Run multiple threads concurrently threads = [threading.Thread(target=generate_and_validate) for _ in range(5)] for t in threads: t.start() for t in threads: t.join() # At least one should succeed (the last one to validate) # Others might fail due to state being cleared assert any(results) # --------------------------------------------------------------------------- # SecureOAuthState Tests # --------------------------------------------------------------------------- class TestSecureOAuthState: """Test the secure OAuth state container.""" def test_serialize_deserialize_roundtrip(self): """Serialization and deserialization should preserve data.""" state = SecureOAuthState(data={"server_name": "test"}) serialized = state.serialize() # Deserialize restored = SecureOAuthState.deserialize(serialized) assert restored.token == state.token assert restored.nonce == state.nonce assert restored.data == state.data def test_deserialize_invalid_signature_fails(self): """Deserialization with tampered signature should fail.""" state = SecureOAuthState(data={"server_name": "test"}) serialized = state.serialize() # Tamper with the serialized data tampered = serialized[:-5] + "xxxxx" with pytest.raises(OAuthStateError) as exc_info: SecureOAuthState.deserialize(tampered) assert "signature" in str(exc_info.value).lower() or "tampering" in str(exc_info.value).lower() def test_deserialize_expired_state_fails(self): """Deserialization of expired state should fail.""" # Create state with old timestamp old_time = time.time() - 700 # 700 seconds ago (> 600 max age) state = SecureOAuthState.__new__(SecureOAuthState) state.token = secrets.token_urlsafe(32) state.timestamp = old_time state.nonce = secrets.token_urlsafe(16) state.data = {} serialized = state.serialize() with pytest.raises(OAuthStateError) as exc_info: SecureOAuthState.deserialize(serialized) assert "expired" in str(exc_info.value).lower() def test_state_entropy(self): """State should have sufficient entropy.""" state = SecureOAuthState() # Token should be at least 32 characters assert len(state.token) >= 32 # Nonce should be present assert len(state.nonce) >= 16 # --------------------------------------------------------------------------- # Callback Handler Tests # --------------------------------------------------------------------------- class TestCallbackHandler: """Test the OAuth callback handler for session fixation protection.""" def setup_method(self): """Reset state manager before each test.""" _state_manager.invalidate() def test_handler_rejects_missing_state(self): """Handler should reject callbacks without state parameter.""" HandlerClass, result = _make_callback_handler() # Create mock handler handler = HandlerClass.__new__(HandlerClass) handler.path = "/callback?code=test123" # No state handler.wfile = MagicMock() handler.send_response = MagicMock() handler.send_header = MagicMock() handler.end_headers = MagicMock() handler.do_GET() # Should send 400 error handler.send_response.assert_called_once_with(400) # Code is captured but not processed (state validation failed) def test_handler_rejects_invalid_state(self): """Handler should reject callbacks with invalid state.""" HandlerClass, result = _make_callback_handler() # Create mock handler with wrong state handler = HandlerClass.__new__(HandlerClass) handler.path = f"/callback?code=test123&state=invalid_state_12345" handler.wfile = MagicMock() handler.send_response = MagicMock() handler.send_header = MagicMock() handler.end_headers = MagicMock() handler.do_GET() # Should send 403 error (CSRF protection) handler.send_response.assert_called_once_with(403) def test_handler_accepts_valid_state(self): """Handler should accept callbacks with valid state.""" # Generate a valid state first valid_state = _state_manager.generate_state() HandlerClass, result = _make_callback_handler() # Create mock handler with correct state handler = HandlerClass.__new__(HandlerClass) handler.path = f"/callback?code=test123&state={valid_state}" handler.wfile = MagicMock() handler.send_response = MagicMock() handler.send_header = MagicMock() handler.end_headers = MagicMock() handler.do_GET() # Should send 200 success handler.send_response.assert_called_once_with(200) assert result["auth_code"] == "test123" def test_handler_handles_oauth_errors(self): """Handler should handle OAuth error responses.""" # Generate a valid state first valid_state = _state_manager.generate_state() HandlerClass, result = _make_callback_handler() # Create mock handler with OAuth error handler = HandlerClass.__new__(HandlerClass) handler.path = f"/callback?error=access_denied&state={valid_state}" handler.wfile = MagicMock() handler.send_response = MagicMock() handler.send_header = MagicMock() handler.end_headers = MagicMock() handler.do_GET() # Should send 400 error handler.send_response.assert_called_once_with(400) # --------------------------------------------------------------------------- # Session Regeneration Tests (V-014 Fix) # --------------------------------------------------------------------------- class TestSessionRegeneration: """Test session regeneration after OAuth authentication (V-014).""" def setup_method(self): """Reset state manager before each test.""" _state_manager.invalidate() def test_regenerate_session_invalidates_state(self): """V-014: Session regeneration should invalidate OAuth state.""" # Generate a state state = _state_manager.generate_state() # Regenerate session regenerate_session_after_auth() # State should be invalidated is_valid, _ = _state_manager.validate_and_extract(state) assert is_valid is False def test_regenerate_session_logs_debug(self, caplog): """V-014: Session regeneration should log debug message.""" import logging with caplog.at_level(logging.DEBUG): regenerate_session_after_auth() assert "Session regenerated" in caplog.text # --------------------------------------------------------------------------- # Integration Tests # --------------------------------------------------------------------------- class TestOAuthFlowIntegration: """Integration tests for the complete OAuth flow with session fixation protection.""" def setup_method(self): """Reset state manager before each test.""" _state_manager.invalidate() def test_complete_flow_valid_state(self): """Complete flow should succeed with valid state.""" # Step 1: Generate state (as would happen in build_oauth_auth) state = _state_manager.generate_state() # Step 2: Simulate callback with valid state HandlerClass, result = _make_callback_handler() handler = HandlerClass.__new__(HandlerClass) handler.path = f"/callback?code=auth_code_123&state={state}" handler.wfile = MagicMock() handler.send_response = MagicMock() handler.send_header = MagicMock() handler.end_headers = MagicMock() handler.do_GET() # Should succeed assert result["auth_code"] == "auth_code_123" handler.send_response.assert_called_once_with(200) def test_csrf_attack_blocked(self): """CSRF attack with stolen code but no state should be blocked.""" HandlerClass, result = _make_callback_handler() handler = HandlerClass.__new__(HandlerClass) # Attacker tries to use stolen code without valid state handler.path = f"/callback?code=stolen_code&state=invalid" handler.wfile = MagicMock() handler.send_response = MagicMock() handler.send_header = MagicMock() handler.end_headers = MagicMock() handler.do_GET() # Should be blocked with 403 handler.send_response.assert_called_once_with(403) def test_session_fixation_attack_blocked(self): """Session fixation attack should be blocked by state validation.""" # Attacker obtains a valid auth code stolen_code = "stolen_auth_code" # Legitimate user generates state legitimate_state = _state_manager.generate_state() # Attacker tries to use stolen code without knowing the state # This would be a session fixation attack HandlerClass, result = _make_callback_handler() handler = HandlerClass.__new__(HandlerClass) handler.path = f"/callback?code={stolen_code}&state=wrong_state" handler.wfile = MagicMock() handler.send_response = MagicMock() handler.send_header = MagicMock() handler.end_headers = MagicMock() handler.do_GET() # Should be blocked - attacker doesn't know the valid state assert handler.send_response.call_args[0][0] == 403 # --------------------------------------------------------------------------- # Security Property Tests # --------------------------------------------------------------------------- class TestSecurityProperties: """Test that security properties are maintained.""" def test_state_has_sufficient_entropy(self): """State should have sufficient entropy (> 256 bits).""" state = _state_manager.generate_state() # Should be at least 40 characters (sufficient entropy for base64) assert len(state) >= 40 def test_no_state_reuse(self): """Same state should never be generated twice in sequence.""" states = [] for _ in range(100): state = _state_manager.generate_state() states.append(state) _state_manager.invalidate() # Clear for next iteration # All states should be unique assert len(set(states)) == 100 def test_hmac_signature_verification(self): """State should be protected by HMAC signature.""" state = SecureOAuthState(data={"test": "data"}) serialized = state.serialize() # Should have format: data.signature parts = serialized.split(".") assert len(parts) == 2 # Both parts should be non-empty assert len(parts[0]) > 0 assert len(parts[1]) > 0 # --------------------------------------------------------------------------- # Error Handling Tests # --------------------------------------------------------------------------- class TestErrorHandling: """Test error handling in OAuth flow.""" def test_oauth_state_error_raised(self): """OAuthStateError should be raised for state validation failures.""" error = OAuthStateError("Test error") assert str(error) == "Test error" assert isinstance(error, Exception) def test_invalid_state_logged(self, caplog): """Invalid state should be logged as error.""" import logging with caplog.at_level(logging.ERROR): _state_manager.generate_state() _state_manager.validate_and_extract("wrong_state") assert "validation failed" in caplog.text.lower() def test_missing_state_logged(self, caplog): """Missing state should be logged as error.""" import logging with caplog.at_level(logging.ERROR): _state_manager.validate_and_extract(None) assert "no state returned" in caplog.text.lower() # --------------------------------------------------------------------------- # V-014 Specific Tests # --------------------------------------------------------------------------- class TestV014SessionFixationFix: """Specific tests for V-014 Session Fixation vulnerability fix.""" def setup_method(self): """Reset state manager before each test.""" _state_manager.invalidate() def test_v014_session_regeneration_after_successful_auth(self): """ V-014 Fix: After successful OAuth authentication, the session context should be regenerated to prevent session fixation attacks. """ # Simulate successful OAuth flow state = _state_manager.generate_state() # Before regeneration, state should exist assert _state_manager._state is not None # Simulate successful auth completion is_valid, _ = _state_manager.validate_and_extract(state) assert is_valid is True # State should be cleared after successful validation # (preventing session fixation via replay) assert _state_manager._state is None def test_v014_state_invalidation_on_auth_failure(self): """ V-014 Fix: On authentication failure, state should be invalidated to prevent fixation attempts. """ # Generate state _state_manager.generate_state() # State exists assert _state_manager._state is not None # Simulate failed auth (e.g., error from OAuth provider) _state_manager.invalidate() # State should be cleared assert _state_manager._state is None def test_v014_callback_includes_state_validation(self): """ V-014 Fix: The OAuth callback handler must validate the state parameter to prevent session fixation attacks. """ # Generate valid state valid_state = _state_manager.generate_state() HandlerClass, result = _make_callback_handler() handler = HandlerClass.__new__(HandlerClass) handler.path = f"/callback?code=test&state={valid_state}" handler.wfile = MagicMock() handler.send_response = MagicMock() handler.send_header = MagicMock() handler.end_headers = MagicMock() handler.do_GET() # Should succeed with valid state (state validation prevents fixation) assert result["auth_code"] == "test" assert handler.send_response.call_args[0][0] == 200