""" Security tests for OAuth state handling and token storage (V-006 Fix). Tests verify that: 1. JSON serialization is used instead of pickle 2. HMAC signatures are properly verified for both state and tokens 3. State structure is validated 4. Token schema is validated 5. Tampering is detected 6. Replay attacks are prevented 7. Timing attacks are mitigated via constant-time comparison """ import base64 import hashlib import hmac import json import os import secrets import sys import tempfile import threading import time from pathlib import Path from unittest.mock import patch # Ensure tools directory is in path sys.path.insert(0, str(Path(__file__).parent.parent)) from tools.mcp_oauth import ( OAuthStateError, OAuthStateManager, SecureOAuthState, HermesTokenStorage, _validate_token_schema, _OAUTH_TOKEN_SCHEMA, _OAUTH_CLIENT_SCHEMA, _sign_token_data, _verify_token_signature, _get_token_storage_key, _state_manager, get_state_manager, ) # ============================================================================= # SecureOAuthState Tests # ============================================================================= class TestSecureOAuthState: """Tests for the SecureOAuthState class.""" def test_generate_creates_valid_state(self): """Test that generated state has all required fields.""" state = SecureOAuthState() assert state.token is not None assert len(state.token) >= 16 assert state.timestamp is not None assert isinstance(state.timestamp, float) assert state.nonce is not None assert len(state.nonce) >= 8 assert isinstance(state.data, dict) def test_generate_unique_tokens(self): """Test that generated tokens are unique.""" tokens = {SecureOAuthState._generate_token() for _ in range(100)} assert len(tokens) == 100 def test_serialization_format(self): """Test that serialized state has correct format.""" state = SecureOAuthState(data={"test": "value"}) serialized = state.serialize() # Should have format: data.signature parts = serialized.split(".") assert len(parts) == 2 # Both parts should be URL-safe base64 data_part, sig_part = parts assert all(c in "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_=" for c in data_part) assert all(c in "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_=" for c in sig_part) def test_serialize_deserialize_roundtrip(self): """Test that serialize/deserialize preserves state.""" original = SecureOAuthState(data={"server": "test123", "user": "alice"}) serialized = original.serialize() deserialized = SecureOAuthState.deserialize(serialized) assert deserialized.token == original.token assert deserialized.timestamp == original.timestamp assert deserialized.nonce == original.nonce assert deserialized.data == original.data def test_deserialize_empty_raises_error(self): """Test that deserializing empty state raises OAuthStateError.""" try: SecureOAuthState.deserialize("") assert False, "Should have raised OAuthStateError" except OAuthStateError as e: assert "empty or wrong type" in str(e) try: SecureOAuthState.deserialize(None) assert False, "Should have raised OAuthStateError" except OAuthStateError as e: assert "empty or wrong type" in str(e) def test_deserialize_missing_signature_raises_error(self): """Test that missing signature is detected.""" data = json.dumps({"test": "data"}) encoded = base64.urlsafe_b64encode(data.encode()).decode() try: SecureOAuthState.deserialize(encoded) assert False, "Should have raised OAuthStateError" except OAuthStateError as e: assert "missing signature" in str(e) def test_deserialize_invalid_base64_raises_error(self): """Test that invalid data is rejected (base64 or signature).""" # Invalid characters may be accepted by Python's base64 decoder # but signature verification should fail try: SecureOAuthState.deserialize("!!!invalid!!!.!!!data!!!") assert False, "Should have raised OAuthStateError" except OAuthStateError as e: # Error could be from encoding or signature verification assert "Invalid state" in str(e) def test_deserialize_tampered_signature_detected(self): """Test that tampered signature is detected.""" state = SecureOAuthState() serialized = state.serialize() # Tamper with the signature data_part, sig_part = serialized.split(".") tampered_sig = base64.urlsafe_b64encode(b"tampered").decode().rstrip("=") tampered = f"{data_part}.{tampered_sig}" try: SecureOAuthState.deserialize(tampered) assert False, "Should have raised OAuthStateError" except OAuthStateError as e: assert "tampering detected" in str(e) def test_deserialize_tampered_data_detected(self): """Test that tampered data is detected via HMAC verification.""" state = SecureOAuthState() serialized = state.serialize() # Tamper with the data but keep signature data_part, sig_part = serialized.split(".") tampered_data = json.dumps({"hacked": True}) tampered_encoded = base64.urlsafe_b64encode(tampered_data.encode()).decode().rstrip("=") tampered = f"{tampered_encoded}.{sig_part}" try: SecureOAuthState.deserialize(tampered) assert False, "Should have raised OAuthStateError" except OAuthStateError as e: assert "tampering detected" in str(e) def test_deserialize_expired_state_raises_error(self): """Test that expired states are rejected.""" # Create a state with old timestamp old_state = SecureOAuthState() old_state.timestamp = time.time() - 1000 # 1000 seconds ago serialized = old_state.serialize() try: SecureOAuthState.deserialize(serialized) assert False, "Should have raised OAuthStateError" except OAuthStateError as e: assert "expired" in str(e) def test_deserialize_invalid_json_raises_error(self): """Test that invalid JSON raises OAuthStateError.""" key = SecureOAuthState._get_secret_key() bad_data = b"not valid json {{{" sig = base64.urlsafe_b64encode(hmac.new(key, bad_data, hashlib.sha256).digest()) encoded_data = base64.urlsafe_b64encode(bad_data).decode().rstrip("=") encoded_sig = sig.decode().rstrip("=") try: SecureOAuthState.deserialize(f"{encoded_data}.{encoded_sig}") assert False, "Should have raised OAuthStateError" except OAuthStateError as e: assert "Invalid state JSON" in str(e) def test_deserialize_missing_fields_raises_error(self): """Test that missing required fields are detected.""" key = SecureOAuthState._get_secret_key() bad_data = json.dumps({"token": "test"}).encode() # missing timestamp, nonce sig = base64.urlsafe_b64encode(hmac.new(key, bad_data, hashlib.sha256).digest()) encoded_data = base64.urlsafe_b64encode(bad_data).decode().rstrip("=") encoded_sig = sig.decode().rstrip("=") try: SecureOAuthState.deserialize(f"{encoded_data}.{encoded_sig}") assert False, "Should have raised OAuthStateError" except OAuthStateError as e: assert "missing fields" in str(e) def test_deserialize_invalid_token_type_raises_error(self): """Test that non-string tokens are rejected.""" key = SecureOAuthState._get_secret_key() bad_data = json.dumps({ "token": 12345, # should be string "timestamp": time.time(), "nonce": "abc123" }).encode() sig = base64.urlsafe_b64encode(hmac.new(key, bad_data, hashlib.sha256).digest()) encoded_data = base64.urlsafe_b64encode(bad_data).decode().rstrip("=") encoded_sig = sig.decode().rstrip("=") try: SecureOAuthState.deserialize(f"{encoded_data}.{encoded_sig}") assert False, "Should have raised OAuthStateError" except OAuthStateError as e: assert "token must be a string" in str(e) def test_deserialize_short_token_raises_error(self): """Test that short tokens are rejected.""" key = SecureOAuthState._get_secret_key() bad_data = json.dumps({ "token": "short", # too short "timestamp": time.time(), "nonce": "abc123" }).encode() sig = base64.urlsafe_b64encode(hmac.new(key, bad_data, hashlib.sha256).digest()) encoded_data = base64.urlsafe_b64encode(bad_data).decode().rstrip("=") encoded_sig = sig.decode().rstrip("=") try: SecureOAuthState.deserialize(f"{encoded_data}.{encoded_sig}") assert False, "Should have raised OAuthStateError" except OAuthStateError as e: assert "token must be a string" in str(e) def test_deserialize_invalid_timestamp_raises_error(self): """Test that non-numeric timestamps are rejected.""" key = SecureOAuthState._get_secret_key() bad_data = json.dumps({ "token": "a" * 32, "timestamp": "not a number", "nonce": "abc123" }).encode() sig = base64.urlsafe_b64encode(hmac.new(key, bad_data, hashlib.sha256).digest()) encoded_data = base64.urlsafe_b64encode(bad_data).decode().rstrip("=") encoded_sig = sig.decode().rstrip("=") try: SecureOAuthState.deserialize(f"{encoded_data}.{encoded_sig}") assert False, "Should have raised OAuthStateError" except OAuthStateError as e: assert "timestamp must be numeric" in str(e) def test_validate_against_correct_token(self): """Test token validation with matching token.""" state = SecureOAuthState() assert state.validate_against(state.token) is True def test_validate_against_wrong_token(self): """Test token validation with non-matching token.""" state = SecureOAuthState() assert state.validate_against("wrong-token") is False def test_validate_against_non_string(self): """Test token validation with non-string input.""" state = SecureOAuthState() assert state.validate_against(None) is False assert state.validate_against(12345) is False def test_validate_uses_constant_time_comparison(self): """Test that validate_against uses constant-time comparison.""" state = SecureOAuthState(token="test-token-for-comparison") # This test verifies no early return on mismatch # In practice, secrets.compare_digest is used result1 = state.validate_against("wrong-token-for-comparison") result2 = state.validate_against("another-wrong-token-here") assert result1 is False assert result2 is False def test_to_dict_format(self): """Test that to_dict returns correct format.""" state = SecureOAuthState(data={"custom": "data"}) d = state.to_dict() assert set(d.keys()) == {"token", "timestamp", "nonce", "data"} assert d["token"] == state.token assert d["timestamp"] == state.timestamp assert d["nonce"] == state.nonce assert d["data"] == {"custom": "data"} # ============================================================================= # OAuthStateManager Tests # ============================================================================= class TestOAuthStateManager: """Tests for the OAuthStateManager class.""" def setup_method(self): """Reset state manager before each test.""" global _state_manager _state_manager.invalidate() _state_manager._used_nonces.clear() def test_generate_state_returns_serialized(self): """Test that generate_state returns a serialized state string.""" state_str = _state_manager.generate_state() # Should be a string with format: data.signature assert isinstance(state_str, str) assert "." in state_str parts = state_str.split(".") assert len(parts) == 2 def test_generate_state_with_data(self): """Test that extra data is included in state.""" extra = {"server_name": "test-server", "user_id": "123"} state_str = _state_manager.generate_state(extra_data=extra) # Validate and extract is_valid, data = _state_manager.validate_and_extract(state_str) assert is_valid is True assert data == extra def test_validate_and_extract_valid_state(self): """Test validation with a valid state.""" extra = {"test": "data"} state_str = _state_manager.generate_state(extra_data=extra) is_valid, data = _state_manager.validate_and_extract(state_str) assert is_valid is True assert data == extra def test_validate_and_extract_none_state(self): """Test validation with None state.""" is_valid, data = _state_manager.validate_and_extract(None) assert is_valid is False assert data is None def test_validate_and_extract_invalid_state(self): """Test validation with an invalid state.""" is_valid, data = _state_manager.validate_and_extract("invalid.state.here") assert is_valid is False assert data is None def test_state_cleared_after_validation(self): """Test that state is cleared after successful validation.""" state_str = _state_manager.generate_state() # First validation should succeed is_valid1, _ = _state_manager.validate_and_extract(state_str) assert is_valid1 is True # Second validation should fail (replay) is_valid2, _ = _state_manager.validate_and_extract(state_str) assert is_valid2 is False def test_nonce_tracking_prevents_replay(self): """Test that nonce tracking prevents replay attacks.""" state = SecureOAuthState() serialized = state.serialize() # Manually add to used nonces with _state_manager._lock: _state_manager._used_nonces.add(state.nonce) # Validation should fail due to nonce replay is_valid, _ = _state_manager.validate_and_extract(serialized) assert is_valid is False def test_invalidate_clears_state(self): """Test that invalidate clears the stored state.""" _state_manager.generate_state() assert _state_manager._state is not None _state_manager.invalidate() assert _state_manager._state is None def test_thread_safety(self): """Test thread safety of state manager.""" results = [] def generate(): state_str = _state_manager.generate_state(extra_data={"thread": threading.current_thread().name}) results.append(state_str) threads = [threading.Thread(target=generate) for _ in range(10)] for t in threads: t.start() for t in threads: t.join() # All states should be unique assert len(set(results)) == 10 def test_max_nonce_limit(self): """Test that nonce set is limited to prevent memory growth.""" manager = OAuthStateManager() manager._max_used_nonces = 5 # Generate more nonces than the limit for _ in range(10): state = SecureOAuthState() manager._used_nonces.add(state.nonce) # Set should have been cleared at some point # (implementation clears when limit is exceeded) # ============================================================================= # Schema Validation Tests # ============================================================================= class TestSchemaValidation: """Tests for JSON schema validation (V-006).""" def test_valid_token_schema_accepted(self): """Test that valid token data passes schema validation.""" valid_token = { "access_token": "secret_token_123", "token_type": "Bearer", "refresh_token": "refresh_456", "expires_in": 3600, "expires_at": 1234567890.0, "scope": "read write", "id_token": "id_token_789", } # Should not raise _validate_token_schema(valid_token, _OAUTH_TOKEN_SCHEMA, "token") def test_minimal_valid_token_schema(self): """Test that minimal valid token (only required fields) passes.""" minimal_token = { "access_token": "secret", "token_type": "Bearer", } _validate_token_schema(minimal_token, _OAUTH_TOKEN_SCHEMA, "token") def test_missing_required_field_rejected(self): """Test that missing required fields are detected.""" invalid_token = {"token_type": "Bearer"} # missing access_token try: _validate_token_schema(invalid_token, _OAUTH_TOKEN_SCHEMA, "token") assert False, "Should have raised OAuthStateError" except OAuthStateError as e: assert "missing required fields" in str(e) assert "access_token" in str(e) def test_wrong_type_rejected(self): """Test that fields with wrong types are rejected.""" invalid_token = { "access_token": 12345, # should be string "token_type": "Bearer", } try: _validate_token_schema(invalid_token, _OAUTH_TOKEN_SCHEMA, "token") assert False, "Should have raised OAuthStateError" except OAuthStateError as e: assert "has wrong type" in str(e) def test_null_values_accepted(self): """Test that null values for optional fields are accepted.""" token_with_nulls = { "access_token": "secret", "token_type": "Bearer", "refresh_token": None, "expires_in": None, } _validate_token_schema(token_with_nulls, _OAUTH_TOKEN_SCHEMA, "token") def test_non_dict_data_rejected(self): """Test that non-dictionary data is rejected.""" try: _validate_token_schema("not a dict", _OAUTH_TOKEN_SCHEMA, "token") assert False, "Should have raised OAuthStateError" except OAuthStateError as e: assert "must be a dictionary" in str(e) def test_valid_client_schema(self): """Test that valid client info passes schema validation.""" valid_client = { "client_id": "client_123", "client_secret": "secret_456", "client_name": "Test Client", "redirect_uris": ["http://localhost/callback"], } _validate_token_schema(valid_client, _OAUTH_CLIENT_SCHEMA, "client") def test_client_missing_required_rejected(self): """Test that client info missing client_id is rejected.""" invalid_client = {"client_name": "Test"} try: _validate_token_schema(invalid_client, _OAUTH_CLIENT_SCHEMA, "client") assert False, "Should have raised OAuthStateError" except OAuthStateError as e: assert "missing required fields" in str(e) # ============================================================================= # Token Storage Security Tests # ============================================================================= class TestTokenStorageSecurity: """Tests for token storage signing and validation (V-006).""" def test_sign_and_verify_token_data(self): """Test that token data can be signed and verified.""" data = {"access_token": "test123", "token_type": "Bearer"} sig = _sign_token_data(data) assert sig is not None assert len(sig) > 0 assert _verify_token_signature(data, sig) is True def test_tampered_token_data_rejected(self): """Test that tampered token data fails verification.""" data = {"access_token": "test123", "token_type": "Bearer"} sig = _sign_token_data(data) # Modify the data tampered_data = {"access_token": "hacked", "token_type": "Bearer"} assert _verify_token_signature(tampered_data, sig) is False def test_empty_signature_rejected(self): """Test that empty signature is rejected.""" data = {"access_token": "test", "token_type": "Bearer"} assert _verify_token_signature(data, "") is False def test_invalid_signature_rejected(self): """Test that invalid signature is rejected.""" data = {"access_token": "test", "token_type": "Bearer"} assert _verify_token_signature(data, "invalid") is False def test_signature_deterministic(self): """Test that signing the same data produces the same signature.""" data = {"access_token": "test123", "token_type": "Bearer"} sig1 = _sign_token_data(data) sig2 = _sign_token_data(data) assert sig1 == sig2 def test_different_data_different_signatures(self): """Test that different data produces different signatures.""" data1 = {"access_token": "test1", "token_type": "Bearer"} data2 = {"access_token": "test2", "token_type": "Bearer"} sig1 = _sign_token_data(data1) sig2 = _sign_token_data(data2) assert sig1 != sig2 # ============================================================================= # Pickle Security Tests # ============================================================================= class TestNoPickleUsage: """Tests to verify pickle is NOT used (V-006 regression tests).""" def test_serialization_does_not_use_pickle(self): """Verify that state serialization uses JSON, not pickle.""" state = SecureOAuthState(data={"malicious": "__import__('os').system('rm -rf /')"}) serialized = state.serialize() # Decode the data part data_part, _ = serialized.split(".") padding = 4 - (len(data_part) % 4) if len(data_part) % 4 else 0 decoded = base64.urlsafe_b64decode(data_part + ("=" * padding)) # Should be valid JSON, not pickle parsed = json.loads(decoded.decode('utf-8')) assert parsed["data"]["malicious"] == "__import__('os').system('rm -rf /')" # Should NOT start with pickle protocol markers assert not decoded.startswith(b'\x80') # Pickle protocol marker assert b'cos\n' not in decoded # Pickle module load pattern def test_deserialize_rejects_pickle_payload(self): """Test that pickle payloads are rejected during deserialization.""" import pickle # Create a pickle payload that would execute code malicious = pickle.dumps({"cmd": "whoami"}) key = SecureOAuthState._get_secret_key() sig = base64.urlsafe_b64encode( hmac.new(key, malicious, hashlib.sha256).digest() ).decode().rstrip("=") data = base64.urlsafe_b64encode(malicious).decode().rstrip("=") # Should fail because it's not valid JSON try: SecureOAuthState.deserialize(f"{data}.{sig}") assert False, "Should have raised OAuthStateError" except OAuthStateError as e: assert "Invalid state JSON" in str(e) # ============================================================================= # Key Management Tests # ============================================================================= class TestSecretKeyManagement: """Tests for HMAC secret key management.""" def test_get_secret_key_from_env(self): """Test that HERMES_OAUTH_SECRET environment variable is used.""" with patch.dict(os.environ, {"HERMES_OAUTH_SECRET": "test-secret-key-32bytes-long!!"}): key = SecureOAuthState._get_secret_key() assert key == b"test-secret-key-32bytes-long!!" def test_get_token_storage_key_from_env(self): """Test that HERMES_TOKEN_STORAGE_SECRET environment variable is used.""" with patch.dict(os.environ, {"HERMES_TOKEN_STORAGE_SECRET": "storage-secret-key-32bytes!!"}): key = _get_token_storage_key() assert key == b"storage-secret-key-32bytes!!" def test_get_secret_key_creates_file(self): """Test that secret key file is created if it doesn't exist.""" with tempfile.TemporaryDirectory() as tmpdir: home = Path(tmpdir) with patch('pathlib.Path.home', return_value=home): with patch.dict(os.environ, {}, clear=True): key = SecureOAuthState._get_secret_key() assert len(key) == 64 # ============================================================================= # Integration Tests # ============================================================================= class TestOAuthFlowIntegration: """Integration tests for the OAuth flow with secure state.""" def setup_method(self): """Reset state manager before each test.""" global _state_manager _state_manager.invalidate() _state_manager._used_nonces.clear() def test_full_oauth_state_flow(self): """Test the full OAuth state generation and validation flow.""" # Step 1: Generate state for OAuth request server_name = "test-mcp-server" state = _state_manager.generate_state(extra_data={"server_name": server_name}) # Step 2: Simulate OAuth callback with state # (In real flow, this comes back from OAuth provider) is_valid, data = _state_manager.validate_and_extract(state) # Step 3: Verify validation succeeded assert is_valid is True assert data["server_name"] == server_name # Step 4: Verify state cannot be replayed is_valid_replay, _ = _state_manager.validate_and_extract(state) assert is_valid_replay is False def test_csrf_attack_prevention(self): """Test that CSRF attacks using different states are detected.""" # Attacker generates their own state attacker_state = _state_manager.generate_state(extra_data={"malicious": True}) # Victim generates their state victim_manager = OAuthStateManager() victim_state = victim_manager.generate_state(extra_data={"legitimate": True}) # Attacker tries to use their state with victim's session # This would fail because the tokens don't match is_valid, _ = victim_manager.validate_and_extract(attacker_state) assert is_valid is False def test_mitm_attack_detection(self): """Test that tampered states from MITM attacks are detected.""" # Generate legitimate state state = _state_manager.generate_state() # Modify the state (simulating MITM tampering) parts = state.split(".") tampered_state = parts[0] + ".tampered-signature-here" # Validation should fail is_valid, _ = _state_manager.validate_and_extract(tampered_state) assert is_valid is False # ============================================================================= # Performance Tests # ============================================================================= class TestPerformance: """Performance tests for state operations.""" def test_serialize_performance(self): """Test that serialization is fast.""" state = SecureOAuthState(data={"key": "value" * 100}) start = time.time() for _ in range(1000): state.serialize() elapsed = time.time() - start # Should complete 1000 serializations in under 1 second assert elapsed < 1.0 def test_deserialize_performance(self): """Test that deserialization is fast.""" state = SecureOAuthState(data={"key": "value" * 100}) serialized = state.serialize() start = time.time() for _ in range(1000): SecureOAuthState.deserialize(serialized) elapsed = time.time() - start # Should complete 1000 deserializations in under 1 second assert elapsed < 1.0 def run_tests(): """Run all tests.""" import inspect test_classes = [ TestSecureOAuthState, TestOAuthStateManager, TestSchemaValidation, TestTokenStorageSecurity, TestNoPickleUsage, TestSecretKeyManagement, TestOAuthFlowIntegration, TestPerformance, ] total_tests = 0 passed_tests = 0 failed_tests = [] for cls in test_classes: print(f"\n{'='*60}") print(f"Running {cls.__name__}") print('='*60) instance = cls() # Run setup if exists if hasattr(instance, 'setup_method'): instance.setup_method() for name, method in inspect.getmembers(cls, predicate=inspect.isfunction): if name.startswith('test_'): total_tests += 1 try: method(instance) print(f" ✓ {name}") passed_tests += 1 except Exception as e: print(f" ✗ {name}: {e}") failed_tests.append((cls.__name__, name, str(e))) print(f"\n{'='*60}") print(f"Results: {passed_tests}/{total_tests} tests passed") print('='*60) if failed_tests: print("\nFailed tests:") for cls_name, test_name, error in failed_tests: print(f" - {cls_name}.{test_name}: {error}") return 1 else: print("\nAll tests passed!") return 0 if __name__ == "__main__": sys.exit(run_tests())