Files
hermes-agent/tests/tools/test_oauth_session_fixation.py
Allegro cb0cf51adf
Some checks failed
Nix / nix (ubuntu-latest) (pull_request) Failing after 15s
Supply Chain Audit / Scan PR for supply chain risks (pull_request) Failing after 19s
Docker Build and Publish / build-and-push (pull_request) Failing after 28s
Tests / test (pull_request) Failing after 9m43s
Nix / nix (macos-latest) (pull_request) Has been cancelled
security: Fix V-006 MCP OAuth Deserialization (CVSS 8.8 CRITICAL)
- Replace pickle with JSON + HMAC-SHA256 state serialization
- Add constant-time signature verification
- Implement replay attack protection with nonce expiration
- Add comprehensive security test suite (54 tests)
- Harden token storage with integrity verification

Resolves: V-006 (CVSS 8.8)
2026-03-31 00:37:14 +00:00

528 lines
19 KiB
Python

"""Tests for OAuth Session Fixation protection (V-014 fix).
These tests verify that:
1. State parameter is generated cryptographically securely
2. State is validated on callback to prevent CSRF attacks
3. State is cleared after validation to prevent replay attacks
4. Session is regenerated after successful OAuth authentication
"""
import asyncio
import json
import secrets
import threading
import time
from unittest.mock import MagicMock, patch
import pytest
from tools.mcp_oauth import (
OAuthStateManager,
OAuthStateError,
SecureOAuthState,
regenerate_session_after_auth,
_make_callback_handler,
_state_manager,
get_state_manager,
)
# ---------------------------------------------------------------------------
# OAuthStateManager Tests
# ---------------------------------------------------------------------------
class TestOAuthStateManager:
"""Test the OAuth state manager for session fixation protection."""
def setup_method(self):
"""Reset state manager before each test."""
_state_manager.invalidate()
def test_generate_state_creates_secure_token(self):
"""State should be a cryptographically secure signed token."""
state = _state_manager.generate_state()
# Should be a non-empty string
assert isinstance(state, str)
assert len(state) > 0
# Should be URL-safe (contains data.signature format)
assert "." in state # Format: <base64-data>.<base64-signature>
def test_generate_state_unique_each_time(self):
"""Each generated state should be unique."""
states = [_state_manager.generate_state() for _ in range(10)]
# All states should be different
assert len(set(states)) == 10
def test_validate_and_extract_success(self):
"""Validating correct state should succeed."""
state = _state_manager.generate_state()
is_valid, data = _state_manager.validate_and_extract(state)
assert is_valid is True
assert data is not None
def test_validate_and_extract_wrong_state_fails(self):
"""Validating wrong state should fail (CSRF protection)."""
_state_manager.generate_state()
# Try to validate with a different state
wrong_state = "invalid_state_data"
is_valid, data = _state_manager.validate_and_extract(wrong_state)
assert is_valid is False
assert data is None
def test_validate_and_extract_none_fails(self):
"""Validating None state should fail."""
_state_manager.generate_state()
is_valid, data = _state_manager.validate_and_extract(None)
assert is_valid is False
assert data is None
def test_validate_and_extract_no_generation_fails(self):
"""Validating when no state was generated should fail."""
# Don't generate state first
is_valid, data = _state_manager.validate_and_extract("some_state")
assert is_valid is False
assert data is None
def test_validate_and_extract_prevents_replay(self):
"""State should be cleared after validation to prevent replay."""
state = _state_manager.generate_state()
# First validation should succeed
is_valid, data = _state_manager.validate_and_extract(state)
assert is_valid is True
# Second validation with same state should fail (replay attack)
is_valid, data = _state_manager.validate_and_extract(state)
assert is_valid is False
def test_invalidate_clears_state(self):
"""Explicit invalidation should clear state."""
state = _state_manager.generate_state()
_state_manager.invalidate()
# Validation should fail after invalidation
is_valid, data = _state_manager.validate_and_extract(state)
assert is_valid is False
def test_thread_safety(self):
"""State manager should be thread-safe."""
results = []
def generate_and_validate():
state = _state_manager.generate_state()
time.sleep(0.01) # Small delay to encourage race conditions
is_valid, _ = _state_manager.validate_and_extract(state)
results.append(is_valid)
# Run multiple threads concurrently
threads = [threading.Thread(target=generate_and_validate) for _ in range(5)]
for t in threads:
t.start()
for t in threads:
t.join()
# At least one should succeed (the last one to validate)
# Others might fail due to state being cleared
assert any(results)
# ---------------------------------------------------------------------------
# SecureOAuthState Tests
# ---------------------------------------------------------------------------
class TestSecureOAuthState:
"""Test the secure OAuth state container."""
def test_serialize_deserialize_roundtrip(self):
"""Serialization and deserialization should preserve data."""
state = SecureOAuthState(data={"server_name": "test"})
serialized = state.serialize()
# Deserialize
restored = SecureOAuthState.deserialize(serialized)
assert restored.token == state.token
assert restored.nonce == state.nonce
assert restored.data == state.data
def test_deserialize_invalid_signature_fails(self):
"""Deserialization with tampered signature should fail."""
state = SecureOAuthState(data={"server_name": "test"})
serialized = state.serialize()
# Tamper with the serialized data
tampered = serialized[:-5] + "xxxxx"
with pytest.raises(OAuthStateError) as exc_info:
SecureOAuthState.deserialize(tampered)
assert "signature" in str(exc_info.value).lower() or "tampering" in str(exc_info.value).lower()
def test_deserialize_expired_state_fails(self):
"""Deserialization of expired state should fail."""
# Create state with old timestamp
old_time = time.time() - 700 # 700 seconds ago (> 600 max age)
state = SecureOAuthState.__new__(SecureOAuthState)
state.token = secrets.token_urlsafe(32)
state.timestamp = old_time
state.nonce = secrets.token_urlsafe(16)
state.data = {}
serialized = state.serialize()
with pytest.raises(OAuthStateError) as exc_info:
SecureOAuthState.deserialize(serialized)
assert "expired" in str(exc_info.value).lower()
def test_state_entropy(self):
"""State should have sufficient entropy."""
state = SecureOAuthState()
# Token should be at least 32 characters
assert len(state.token) >= 32
# Nonce should be present
assert len(state.nonce) >= 16
# ---------------------------------------------------------------------------
# Callback Handler Tests
# ---------------------------------------------------------------------------
class TestCallbackHandler:
"""Test the OAuth callback handler for session fixation protection."""
def setup_method(self):
"""Reset state manager before each test."""
_state_manager.invalidate()
def test_handler_rejects_missing_state(self):
"""Handler should reject callbacks without state parameter."""
HandlerClass, result = _make_callback_handler()
# Create mock handler
handler = HandlerClass.__new__(HandlerClass)
handler.path = "/callback?code=test123" # No state
handler.wfile = MagicMock()
handler.send_response = MagicMock()
handler.send_header = MagicMock()
handler.end_headers = MagicMock()
handler.do_GET()
# Should send 400 error
handler.send_response.assert_called_once_with(400)
# Code is captured but not processed (state validation failed)
def test_handler_rejects_invalid_state(self):
"""Handler should reject callbacks with invalid state."""
HandlerClass, result = _make_callback_handler()
# Create mock handler with wrong state
handler = HandlerClass.__new__(HandlerClass)
handler.path = f"/callback?code=test123&state=invalid_state_12345"
handler.wfile = MagicMock()
handler.send_response = MagicMock()
handler.send_header = MagicMock()
handler.end_headers = MagicMock()
handler.do_GET()
# Should send 403 error (CSRF protection)
handler.send_response.assert_called_once_with(403)
def test_handler_accepts_valid_state(self):
"""Handler should accept callbacks with valid state."""
# Generate a valid state first
valid_state = _state_manager.generate_state()
HandlerClass, result = _make_callback_handler()
# Create mock handler with correct state
handler = HandlerClass.__new__(HandlerClass)
handler.path = f"/callback?code=test123&state={valid_state}"
handler.wfile = MagicMock()
handler.send_response = MagicMock()
handler.send_header = MagicMock()
handler.end_headers = MagicMock()
handler.do_GET()
# Should send 200 success
handler.send_response.assert_called_once_with(200)
assert result["auth_code"] == "test123"
def test_handler_handles_oauth_errors(self):
"""Handler should handle OAuth error responses."""
# Generate a valid state first
valid_state = _state_manager.generate_state()
HandlerClass, result = _make_callback_handler()
# Create mock handler with OAuth error
handler = HandlerClass.__new__(HandlerClass)
handler.path = f"/callback?error=access_denied&state={valid_state}"
handler.wfile = MagicMock()
handler.send_response = MagicMock()
handler.send_header = MagicMock()
handler.end_headers = MagicMock()
handler.do_GET()
# Should send 400 error
handler.send_response.assert_called_once_with(400)
# ---------------------------------------------------------------------------
# Session Regeneration Tests (V-014 Fix)
# ---------------------------------------------------------------------------
class TestSessionRegeneration:
"""Test session regeneration after OAuth authentication (V-014)."""
def setup_method(self):
"""Reset state manager before each test."""
_state_manager.invalidate()
def test_regenerate_session_invalidates_state(self):
"""V-014: Session regeneration should invalidate OAuth state."""
# Generate a state
state = _state_manager.generate_state()
# Regenerate session
regenerate_session_after_auth()
# State should be invalidated
is_valid, _ = _state_manager.validate_and_extract(state)
assert is_valid is False
def test_regenerate_session_logs_debug(self, caplog):
"""V-014: Session regeneration should log debug message."""
import logging
with caplog.at_level(logging.DEBUG):
regenerate_session_after_auth()
assert "Session regenerated" in caplog.text
# ---------------------------------------------------------------------------
# Integration Tests
# ---------------------------------------------------------------------------
class TestOAuthFlowIntegration:
"""Integration tests for the complete OAuth flow with session fixation protection."""
def setup_method(self):
"""Reset state manager before each test."""
_state_manager.invalidate()
def test_complete_flow_valid_state(self):
"""Complete flow should succeed with valid state."""
# Step 1: Generate state (as would happen in build_oauth_auth)
state = _state_manager.generate_state()
# Step 2: Simulate callback with valid state
HandlerClass, result = _make_callback_handler()
handler = HandlerClass.__new__(HandlerClass)
handler.path = f"/callback?code=auth_code_123&state={state}"
handler.wfile = MagicMock()
handler.send_response = MagicMock()
handler.send_header = MagicMock()
handler.end_headers = MagicMock()
handler.do_GET()
# Should succeed
assert result["auth_code"] == "auth_code_123"
handler.send_response.assert_called_once_with(200)
def test_csrf_attack_blocked(self):
"""CSRF attack with stolen code but no state should be blocked."""
HandlerClass, result = _make_callback_handler()
handler = HandlerClass.__new__(HandlerClass)
# Attacker tries to use stolen code without valid state
handler.path = f"/callback?code=stolen_code&state=invalid"
handler.wfile = MagicMock()
handler.send_response = MagicMock()
handler.send_header = MagicMock()
handler.end_headers = MagicMock()
handler.do_GET()
# Should be blocked with 403
handler.send_response.assert_called_once_with(403)
def test_session_fixation_attack_blocked(self):
"""Session fixation attack should be blocked by state validation."""
# Attacker obtains a valid auth code
stolen_code = "stolen_auth_code"
# Legitimate user generates state
legitimate_state = _state_manager.generate_state()
# Attacker tries to use stolen code without knowing the state
# This would be a session fixation attack
HandlerClass, result = _make_callback_handler()
handler = HandlerClass.__new__(HandlerClass)
handler.path = f"/callback?code={stolen_code}&state=wrong_state"
handler.wfile = MagicMock()
handler.send_response = MagicMock()
handler.send_header = MagicMock()
handler.end_headers = MagicMock()
handler.do_GET()
# Should be blocked - attacker doesn't know the valid state
assert handler.send_response.call_args[0][0] == 403
# ---------------------------------------------------------------------------
# Security Property Tests
# ---------------------------------------------------------------------------
class TestSecurityProperties:
"""Test that security properties are maintained."""
def test_state_has_sufficient_entropy(self):
"""State should have sufficient entropy (> 256 bits)."""
state = _state_manager.generate_state()
# Should be at least 40 characters (sufficient entropy for base64)
assert len(state) >= 40
def test_no_state_reuse(self):
"""Same state should never be generated twice in sequence."""
states = []
for _ in range(100):
state = _state_manager.generate_state()
states.append(state)
_state_manager.invalidate() # Clear for next iteration
# All states should be unique
assert len(set(states)) == 100
def test_hmac_signature_verification(self):
"""State should be protected by HMAC signature."""
state = SecureOAuthState(data={"test": "data"})
serialized = state.serialize()
# Should have format: data.signature
parts = serialized.split(".")
assert len(parts) == 2
# Both parts should be non-empty
assert len(parts[0]) > 0
assert len(parts[1]) > 0
# ---------------------------------------------------------------------------
# Error Handling Tests
# ---------------------------------------------------------------------------
class TestErrorHandling:
"""Test error handling in OAuth flow."""
def test_oauth_state_error_raised(self):
"""OAuthStateError should be raised for state validation failures."""
error = OAuthStateError("Test error")
assert str(error) == "Test error"
assert isinstance(error, Exception)
def test_invalid_state_logged(self, caplog):
"""Invalid state should be logged as error."""
import logging
with caplog.at_level(logging.ERROR):
_state_manager.generate_state()
_state_manager.validate_and_extract("wrong_state")
assert "validation failed" in caplog.text.lower()
def test_missing_state_logged(self, caplog):
"""Missing state should be logged as error."""
import logging
with caplog.at_level(logging.ERROR):
_state_manager.validate_and_extract(None)
assert "no state returned" in caplog.text.lower()
# ---------------------------------------------------------------------------
# V-014 Specific Tests
# ---------------------------------------------------------------------------
class TestV014SessionFixationFix:
"""Specific tests for V-014 Session Fixation vulnerability fix."""
def setup_method(self):
"""Reset state manager before each test."""
_state_manager.invalidate()
def test_v014_session_regeneration_after_successful_auth(self):
"""
V-014 Fix: After successful OAuth authentication, the session
context should be regenerated to prevent session fixation attacks.
"""
# Simulate successful OAuth flow
state = _state_manager.generate_state()
# Before regeneration, state should exist
assert _state_manager._state is not None
# Simulate successful auth completion
is_valid, _ = _state_manager.validate_and_extract(state)
assert is_valid is True
# State should be cleared after successful validation
# (preventing session fixation via replay)
assert _state_manager._state is None
def test_v014_state_invalidation_on_auth_failure(self):
"""
V-014 Fix: On authentication failure, state should be invalidated
to prevent fixation attempts.
"""
# Generate state
_state_manager.generate_state()
# State exists
assert _state_manager._state is not None
# Simulate failed auth (e.g., error from OAuth provider)
_state_manager.invalidate()
# State should be cleared
assert _state_manager._state is None
def test_v014_callback_includes_state_validation(self):
"""
V-014 Fix: The OAuth callback handler must validate the state
parameter to prevent session fixation attacks.
"""
# Generate valid state
valid_state = _state_manager.generate_state()
HandlerClass, result = _make_callback_handler()
handler = HandlerClass.__new__(HandlerClass)
handler.path = f"/callback?code=test&state={valid_state}"
handler.wfile = MagicMock()
handler.send_response = MagicMock()
handler.send_header = MagicMock()
handler.end_headers = MagicMock()
handler.do_GET()
# Should succeed with valid state (state validation prevents fixation)
assert result["auth_code"] == "test"
assert handler.send_response.call_args[0][0] == 200