Some checks failed
Nix / nix (ubuntu-latest) (pull_request) Failing after 15s
Supply Chain Audit / Scan PR for supply chain risks (pull_request) Failing after 19s
Docker Build and Publish / build-and-push (pull_request) Failing after 28s
Tests / test (pull_request) Failing after 9m43s
Nix / nix (macos-latest) (pull_request) Has been cancelled
- Replace pickle with JSON + HMAC-SHA256 state serialization - Add constant-time signature verification - Implement replay attack protection with nonce expiration - Add comprehensive security test suite (54 tests) - Harden token storage with integrity verification Resolves: V-006 (CVSS 8.8)
528 lines
19 KiB
Python
528 lines
19 KiB
Python
"""Tests for OAuth Session Fixation protection (V-014 fix).
|
|
|
|
These tests verify that:
|
|
1. State parameter is generated cryptographically securely
|
|
2. State is validated on callback to prevent CSRF attacks
|
|
3. State is cleared after validation to prevent replay attacks
|
|
4. Session is regenerated after successful OAuth authentication
|
|
"""
|
|
|
|
import asyncio
|
|
import json
|
|
import secrets
|
|
import threading
|
|
import time
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from tools.mcp_oauth import (
|
|
OAuthStateManager,
|
|
OAuthStateError,
|
|
SecureOAuthState,
|
|
regenerate_session_after_auth,
|
|
_make_callback_handler,
|
|
_state_manager,
|
|
get_state_manager,
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# OAuthStateManager Tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestOAuthStateManager:
|
|
"""Test the OAuth state manager for session fixation protection."""
|
|
|
|
def setup_method(self):
|
|
"""Reset state manager before each test."""
|
|
_state_manager.invalidate()
|
|
|
|
def test_generate_state_creates_secure_token(self):
|
|
"""State should be a cryptographically secure signed token."""
|
|
state = _state_manager.generate_state()
|
|
|
|
# Should be a non-empty string
|
|
assert isinstance(state, str)
|
|
assert len(state) > 0
|
|
|
|
# Should be URL-safe (contains data.signature format)
|
|
assert "." in state # Format: <base64-data>.<base64-signature>
|
|
|
|
def test_generate_state_unique_each_time(self):
|
|
"""Each generated state should be unique."""
|
|
states = [_state_manager.generate_state() for _ in range(10)]
|
|
|
|
# All states should be different
|
|
assert len(set(states)) == 10
|
|
|
|
def test_validate_and_extract_success(self):
|
|
"""Validating correct state should succeed."""
|
|
state = _state_manager.generate_state()
|
|
|
|
is_valid, data = _state_manager.validate_and_extract(state)
|
|
assert is_valid is True
|
|
assert data is not None
|
|
|
|
def test_validate_and_extract_wrong_state_fails(self):
|
|
"""Validating wrong state should fail (CSRF protection)."""
|
|
_state_manager.generate_state()
|
|
|
|
# Try to validate with a different state
|
|
wrong_state = "invalid_state_data"
|
|
is_valid, data = _state_manager.validate_and_extract(wrong_state)
|
|
assert is_valid is False
|
|
assert data is None
|
|
|
|
def test_validate_and_extract_none_fails(self):
|
|
"""Validating None state should fail."""
|
|
_state_manager.generate_state()
|
|
|
|
is_valid, data = _state_manager.validate_and_extract(None)
|
|
assert is_valid is False
|
|
assert data is None
|
|
|
|
def test_validate_and_extract_no_generation_fails(self):
|
|
"""Validating when no state was generated should fail."""
|
|
# Don't generate state first
|
|
is_valid, data = _state_manager.validate_and_extract("some_state")
|
|
assert is_valid is False
|
|
assert data is None
|
|
|
|
def test_validate_and_extract_prevents_replay(self):
|
|
"""State should be cleared after validation to prevent replay."""
|
|
state = _state_manager.generate_state()
|
|
|
|
# First validation should succeed
|
|
is_valid, data = _state_manager.validate_and_extract(state)
|
|
assert is_valid is True
|
|
|
|
# Second validation with same state should fail (replay attack)
|
|
is_valid, data = _state_manager.validate_and_extract(state)
|
|
assert is_valid is False
|
|
|
|
def test_invalidate_clears_state(self):
|
|
"""Explicit invalidation should clear state."""
|
|
state = _state_manager.generate_state()
|
|
_state_manager.invalidate()
|
|
|
|
# Validation should fail after invalidation
|
|
is_valid, data = _state_manager.validate_and_extract(state)
|
|
assert is_valid is False
|
|
|
|
def test_thread_safety(self):
|
|
"""State manager should be thread-safe."""
|
|
results = []
|
|
|
|
def generate_and_validate():
|
|
state = _state_manager.generate_state()
|
|
time.sleep(0.01) # Small delay to encourage race conditions
|
|
is_valid, _ = _state_manager.validate_and_extract(state)
|
|
results.append(is_valid)
|
|
|
|
# Run multiple threads concurrently
|
|
threads = [threading.Thread(target=generate_and_validate) for _ in range(5)]
|
|
for t in threads:
|
|
t.start()
|
|
for t in threads:
|
|
t.join()
|
|
|
|
# At least one should succeed (the last one to validate)
|
|
# Others might fail due to state being cleared
|
|
assert any(results)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# SecureOAuthState Tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestSecureOAuthState:
|
|
"""Test the secure OAuth state container."""
|
|
|
|
def test_serialize_deserialize_roundtrip(self):
|
|
"""Serialization and deserialization should preserve data."""
|
|
state = SecureOAuthState(data={"server_name": "test"})
|
|
serialized = state.serialize()
|
|
|
|
# Deserialize
|
|
restored = SecureOAuthState.deserialize(serialized)
|
|
|
|
assert restored.token == state.token
|
|
assert restored.nonce == state.nonce
|
|
assert restored.data == state.data
|
|
|
|
def test_deserialize_invalid_signature_fails(self):
|
|
"""Deserialization with tampered signature should fail."""
|
|
state = SecureOAuthState(data={"server_name": "test"})
|
|
serialized = state.serialize()
|
|
|
|
# Tamper with the serialized data
|
|
tampered = serialized[:-5] + "xxxxx"
|
|
|
|
with pytest.raises(OAuthStateError) as exc_info:
|
|
SecureOAuthState.deserialize(tampered)
|
|
|
|
assert "signature" in str(exc_info.value).lower() or "tampering" in str(exc_info.value).lower()
|
|
|
|
def test_deserialize_expired_state_fails(self):
|
|
"""Deserialization of expired state should fail."""
|
|
# Create state with old timestamp
|
|
old_time = time.time() - 700 # 700 seconds ago (> 600 max age)
|
|
state = SecureOAuthState.__new__(SecureOAuthState)
|
|
state.token = secrets.token_urlsafe(32)
|
|
state.timestamp = old_time
|
|
state.nonce = secrets.token_urlsafe(16)
|
|
state.data = {}
|
|
|
|
serialized = state.serialize()
|
|
|
|
with pytest.raises(OAuthStateError) as exc_info:
|
|
SecureOAuthState.deserialize(serialized)
|
|
|
|
assert "expired" in str(exc_info.value).lower()
|
|
|
|
def test_state_entropy(self):
|
|
"""State should have sufficient entropy."""
|
|
state = SecureOAuthState()
|
|
|
|
# Token should be at least 32 characters
|
|
assert len(state.token) >= 32
|
|
|
|
# Nonce should be present
|
|
assert len(state.nonce) >= 16
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Callback Handler Tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestCallbackHandler:
|
|
"""Test the OAuth callback handler for session fixation protection."""
|
|
|
|
def setup_method(self):
|
|
"""Reset state manager before each test."""
|
|
_state_manager.invalidate()
|
|
|
|
def test_handler_rejects_missing_state(self):
|
|
"""Handler should reject callbacks without state parameter."""
|
|
HandlerClass, result = _make_callback_handler()
|
|
|
|
# Create mock handler
|
|
handler = HandlerClass.__new__(HandlerClass)
|
|
handler.path = "/callback?code=test123" # No state
|
|
handler.wfile = MagicMock()
|
|
handler.send_response = MagicMock()
|
|
handler.send_header = MagicMock()
|
|
handler.end_headers = MagicMock()
|
|
|
|
handler.do_GET()
|
|
|
|
# Should send 400 error
|
|
handler.send_response.assert_called_once_with(400)
|
|
# Code is captured but not processed (state validation failed)
|
|
|
|
def test_handler_rejects_invalid_state(self):
|
|
"""Handler should reject callbacks with invalid state."""
|
|
HandlerClass, result = _make_callback_handler()
|
|
|
|
# Create mock handler with wrong state
|
|
handler = HandlerClass.__new__(HandlerClass)
|
|
handler.path = f"/callback?code=test123&state=invalid_state_12345"
|
|
handler.wfile = MagicMock()
|
|
handler.send_response = MagicMock()
|
|
handler.send_header = MagicMock()
|
|
handler.end_headers = MagicMock()
|
|
|
|
handler.do_GET()
|
|
|
|
# Should send 403 error (CSRF protection)
|
|
handler.send_response.assert_called_once_with(403)
|
|
|
|
def test_handler_accepts_valid_state(self):
|
|
"""Handler should accept callbacks with valid state."""
|
|
# Generate a valid state first
|
|
valid_state = _state_manager.generate_state()
|
|
|
|
HandlerClass, result = _make_callback_handler()
|
|
|
|
# Create mock handler with correct state
|
|
handler = HandlerClass.__new__(HandlerClass)
|
|
handler.path = f"/callback?code=test123&state={valid_state}"
|
|
handler.wfile = MagicMock()
|
|
handler.send_response = MagicMock()
|
|
handler.send_header = MagicMock()
|
|
handler.end_headers = MagicMock()
|
|
|
|
handler.do_GET()
|
|
|
|
# Should send 200 success
|
|
handler.send_response.assert_called_once_with(200)
|
|
assert result["auth_code"] == "test123"
|
|
|
|
def test_handler_handles_oauth_errors(self):
|
|
"""Handler should handle OAuth error responses."""
|
|
# Generate a valid state first
|
|
valid_state = _state_manager.generate_state()
|
|
|
|
HandlerClass, result = _make_callback_handler()
|
|
|
|
# Create mock handler with OAuth error
|
|
handler = HandlerClass.__new__(HandlerClass)
|
|
handler.path = f"/callback?error=access_denied&state={valid_state}"
|
|
handler.wfile = MagicMock()
|
|
handler.send_response = MagicMock()
|
|
handler.send_header = MagicMock()
|
|
handler.end_headers = MagicMock()
|
|
|
|
handler.do_GET()
|
|
|
|
# Should send 400 error
|
|
handler.send_response.assert_called_once_with(400)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Session Regeneration Tests (V-014 Fix)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestSessionRegeneration:
|
|
"""Test session regeneration after OAuth authentication (V-014)."""
|
|
|
|
def setup_method(self):
|
|
"""Reset state manager before each test."""
|
|
_state_manager.invalidate()
|
|
|
|
def test_regenerate_session_invalidates_state(self):
|
|
"""V-014: Session regeneration should invalidate OAuth state."""
|
|
# Generate a state
|
|
state = _state_manager.generate_state()
|
|
|
|
# Regenerate session
|
|
regenerate_session_after_auth()
|
|
|
|
# State should be invalidated
|
|
is_valid, _ = _state_manager.validate_and_extract(state)
|
|
assert is_valid is False
|
|
|
|
def test_regenerate_session_logs_debug(self, caplog):
|
|
"""V-014: Session regeneration should log debug message."""
|
|
import logging
|
|
with caplog.at_level(logging.DEBUG):
|
|
regenerate_session_after_auth()
|
|
|
|
assert "Session regenerated" in caplog.text
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Integration Tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestOAuthFlowIntegration:
|
|
"""Integration tests for the complete OAuth flow with session fixation protection."""
|
|
|
|
def setup_method(self):
|
|
"""Reset state manager before each test."""
|
|
_state_manager.invalidate()
|
|
|
|
def test_complete_flow_valid_state(self):
|
|
"""Complete flow should succeed with valid state."""
|
|
# Step 1: Generate state (as would happen in build_oauth_auth)
|
|
state = _state_manager.generate_state()
|
|
|
|
# Step 2: Simulate callback with valid state
|
|
HandlerClass, result = _make_callback_handler()
|
|
handler = HandlerClass.__new__(HandlerClass)
|
|
handler.path = f"/callback?code=auth_code_123&state={state}"
|
|
handler.wfile = MagicMock()
|
|
handler.send_response = MagicMock()
|
|
handler.send_header = MagicMock()
|
|
handler.end_headers = MagicMock()
|
|
|
|
handler.do_GET()
|
|
|
|
# Should succeed
|
|
assert result["auth_code"] == "auth_code_123"
|
|
handler.send_response.assert_called_once_with(200)
|
|
|
|
def test_csrf_attack_blocked(self):
|
|
"""CSRF attack with stolen code but no state should be blocked."""
|
|
HandlerClass, result = _make_callback_handler()
|
|
handler = HandlerClass.__new__(HandlerClass)
|
|
|
|
# Attacker tries to use stolen code without valid state
|
|
handler.path = f"/callback?code=stolen_code&state=invalid"
|
|
handler.wfile = MagicMock()
|
|
handler.send_response = MagicMock()
|
|
handler.send_header = MagicMock()
|
|
handler.end_headers = MagicMock()
|
|
|
|
handler.do_GET()
|
|
|
|
# Should be blocked with 403
|
|
handler.send_response.assert_called_once_with(403)
|
|
|
|
def test_session_fixation_attack_blocked(self):
|
|
"""Session fixation attack should be blocked by state validation."""
|
|
# Attacker obtains a valid auth code
|
|
stolen_code = "stolen_auth_code"
|
|
|
|
# Legitimate user generates state
|
|
legitimate_state = _state_manager.generate_state()
|
|
|
|
# Attacker tries to use stolen code without knowing the state
|
|
# This would be a session fixation attack
|
|
|
|
HandlerClass, result = _make_callback_handler()
|
|
handler = HandlerClass.__new__(HandlerClass)
|
|
handler.path = f"/callback?code={stolen_code}&state=wrong_state"
|
|
handler.wfile = MagicMock()
|
|
handler.send_response = MagicMock()
|
|
handler.send_header = MagicMock()
|
|
handler.end_headers = MagicMock()
|
|
|
|
handler.do_GET()
|
|
|
|
# Should be blocked - attacker doesn't know the valid state
|
|
assert handler.send_response.call_args[0][0] == 403
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Security Property Tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestSecurityProperties:
|
|
"""Test that security properties are maintained."""
|
|
|
|
def test_state_has_sufficient_entropy(self):
|
|
"""State should have sufficient entropy (> 256 bits)."""
|
|
state = _state_manager.generate_state()
|
|
|
|
# Should be at least 40 characters (sufficient entropy for base64)
|
|
assert len(state) >= 40
|
|
|
|
def test_no_state_reuse(self):
|
|
"""Same state should never be generated twice in sequence."""
|
|
states = []
|
|
for _ in range(100):
|
|
state = _state_manager.generate_state()
|
|
states.append(state)
|
|
_state_manager.invalidate() # Clear for next iteration
|
|
|
|
# All states should be unique
|
|
assert len(set(states)) == 100
|
|
|
|
def test_hmac_signature_verification(self):
|
|
"""State should be protected by HMAC signature."""
|
|
state = SecureOAuthState(data={"test": "data"})
|
|
serialized = state.serialize()
|
|
|
|
# Should have format: data.signature
|
|
parts = serialized.split(".")
|
|
assert len(parts) == 2
|
|
|
|
# Both parts should be non-empty
|
|
assert len(parts[0]) > 0
|
|
assert len(parts[1]) > 0
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Error Handling Tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestErrorHandling:
|
|
"""Test error handling in OAuth flow."""
|
|
|
|
def test_oauth_state_error_raised(self):
|
|
"""OAuthStateError should be raised for state validation failures."""
|
|
error = OAuthStateError("Test error")
|
|
assert str(error) == "Test error"
|
|
assert isinstance(error, Exception)
|
|
|
|
def test_invalid_state_logged(self, caplog):
|
|
"""Invalid state should be logged as error."""
|
|
import logging
|
|
|
|
with caplog.at_level(logging.ERROR):
|
|
_state_manager.generate_state()
|
|
_state_manager.validate_and_extract("wrong_state")
|
|
|
|
assert "validation failed" in caplog.text.lower()
|
|
|
|
def test_missing_state_logged(self, caplog):
|
|
"""Missing state should be logged as error."""
|
|
import logging
|
|
|
|
with caplog.at_level(logging.ERROR):
|
|
_state_manager.validate_and_extract(None)
|
|
|
|
assert "no state returned" in caplog.text.lower()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# V-014 Specific Tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestV014SessionFixationFix:
|
|
"""Specific tests for V-014 Session Fixation vulnerability fix."""
|
|
|
|
def setup_method(self):
|
|
"""Reset state manager before each test."""
|
|
_state_manager.invalidate()
|
|
|
|
def test_v014_session_regeneration_after_successful_auth(self):
|
|
"""
|
|
V-014 Fix: After successful OAuth authentication, the session
|
|
context should be regenerated to prevent session fixation attacks.
|
|
"""
|
|
# Simulate successful OAuth flow
|
|
state = _state_manager.generate_state()
|
|
|
|
# Before regeneration, state should exist
|
|
assert _state_manager._state is not None
|
|
|
|
# Simulate successful auth completion
|
|
is_valid, _ = _state_manager.validate_and_extract(state)
|
|
assert is_valid is True
|
|
|
|
# State should be cleared after successful validation
|
|
# (preventing session fixation via replay)
|
|
assert _state_manager._state is None
|
|
|
|
def test_v014_state_invalidation_on_auth_failure(self):
|
|
"""
|
|
V-014 Fix: On authentication failure, state should be invalidated
|
|
to prevent fixation attempts.
|
|
"""
|
|
# Generate state
|
|
_state_manager.generate_state()
|
|
|
|
# State exists
|
|
assert _state_manager._state is not None
|
|
|
|
# Simulate failed auth (e.g., error from OAuth provider)
|
|
_state_manager.invalidate()
|
|
|
|
# State should be cleared
|
|
assert _state_manager._state is None
|
|
|
|
def test_v014_callback_includes_state_validation(self):
|
|
"""
|
|
V-014 Fix: The OAuth callback handler must validate the state
|
|
parameter to prevent session fixation attacks.
|
|
"""
|
|
# Generate valid state
|
|
valid_state = _state_manager.generate_state()
|
|
|
|
HandlerClass, result = _make_callback_handler()
|
|
handler = HandlerClass.__new__(HandlerClass)
|
|
handler.path = f"/callback?code=test&state={valid_state}"
|
|
handler.wfile = MagicMock()
|
|
handler.send_response = MagicMock()
|
|
handler.send_header = MagicMock()
|
|
handler.end_headers = MagicMock()
|
|
|
|
handler.do_GET()
|
|
|
|
# Should succeed with valid state (state validation prevents fixation)
|
|
assert result["auth_code"] == "test"
|
|
assert handler.send_response.call_args[0][0] == 200
|