787 lines
30 KiB
Python
787 lines
30 KiB
Python
|
|
"""
|
||
|
|
Security tests for OAuth state handling and token storage (V-006 Fix).
|
||
|
|
|
||
|
|
Tests verify that:
|
||
|
|
1. JSON serialization is used instead of pickle
|
||
|
|
2. HMAC signatures are properly verified for both state and tokens
|
||
|
|
3. State structure is validated
|
||
|
|
4. Token schema is validated
|
||
|
|
5. Tampering is detected
|
||
|
|
6. Replay attacks are prevented
|
||
|
|
7. Timing attacks are mitigated via constant-time comparison
|
||
|
|
"""
|
||
|
|
|
||
|
|
import base64
|
||
|
|
import hashlib
|
||
|
|
import hmac
|
||
|
|
import json
|
||
|
|
import os
|
||
|
|
import secrets
|
||
|
|
import sys
|
||
|
|
import tempfile
|
||
|
|
import threading
|
||
|
|
import time
|
||
|
|
from pathlib import Path
|
||
|
|
from unittest.mock import patch
|
||
|
|
|
||
|
|
# Ensure tools directory is in path
|
||
|
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||
|
|
|
||
|
|
from tools.mcp_oauth import (
|
||
|
|
OAuthStateError,
|
||
|
|
OAuthStateManager,
|
||
|
|
SecureOAuthState,
|
||
|
|
HermesTokenStorage,
|
||
|
|
_validate_token_schema,
|
||
|
|
_OAUTH_TOKEN_SCHEMA,
|
||
|
|
_OAUTH_CLIENT_SCHEMA,
|
||
|
|
_sign_token_data,
|
||
|
|
_verify_token_signature,
|
||
|
|
_get_token_storage_key,
|
||
|
|
_state_manager,
|
||
|
|
get_state_manager,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
# =============================================================================
|
||
|
|
# SecureOAuthState Tests
|
||
|
|
# =============================================================================
|
||
|
|
|
||
|
|
class TestSecureOAuthState:
|
||
|
|
"""Tests for the SecureOAuthState class."""
|
||
|
|
|
||
|
|
def test_generate_creates_valid_state(self):
|
||
|
|
"""Test that generated state has all required fields."""
|
||
|
|
state = SecureOAuthState()
|
||
|
|
|
||
|
|
assert state.token is not None
|
||
|
|
assert len(state.token) >= 16
|
||
|
|
assert state.timestamp is not None
|
||
|
|
assert isinstance(state.timestamp, float)
|
||
|
|
assert state.nonce is not None
|
||
|
|
assert len(state.nonce) >= 8
|
||
|
|
assert isinstance(state.data, dict)
|
||
|
|
|
||
|
|
def test_generate_unique_tokens(self):
|
||
|
|
"""Test that generated tokens are unique."""
|
||
|
|
tokens = {SecureOAuthState._generate_token() for _ in range(100)}
|
||
|
|
assert len(tokens) == 100
|
||
|
|
|
||
|
|
def test_serialization_format(self):
|
||
|
|
"""Test that serialized state has correct format."""
|
||
|
|
state = SecureOAuthState(data={"test": "value"})
|
||
|
|
serialized = state.serialize()
|
||
|
|
|
||
|
|
# Should have format: data.signature
|
||
|
|
parts = serialized.split(".")
|
||
|
|
assert len(parts) == 2
|
||
|
|
|
||
|
|
# Both parts should be URL-safe base64
|
||
|
|
data_part, sig_part = parts
|
||
|
|
assert all(c in "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_="
|
||
|
|
for c in data_part)
|
||
|
|
assert all(c in "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_="
|
||
|
|
for c in sig_part)
|
||
|
|
|
||
|
|
def test_serialize_deserialize_roundtrip(self):
|
||
|
|
"""Test that serialize/deserialize preserves state."""
|
||
|
|
original = SecureOAuthState(data={"server": "test123", "user": "alice"})
|
||
|
|
serialized = original.serialize()
|
||
|
|
deserialized = SecureOAuthState.deserialize(serialized)
|
||
|
|
|
||
|
|
assert deserialized.token == original.token
|
||
|
|
assert deserialized.timestamp == original.timestamp
|
||
|
|
assert deserialized.nonce == original.nonce
|
||
|
|
assert deserialized.data == original.data
|
||
|
|
|
||
|
|
def test_deserialize_empty_raises_error(self):
|
||
|
|
"""Test that deserializing empty state raises OAuthStateError."""
|
||
|
|
try:
|
||
|
|
SecureOAuthState.deserialize("")
|
||
|
|
assert False, "Should have raised OAuthStateError"
|
||
|
|
except OAuthStateError as e:
|
||
|
|
assert "empty or wrong type" in str(e)
|
||
|
|
|
||
|
|
try:
|
||
|
|
SecureOAuthState.deserialize(None)
|
||
|
|
assert False, "Should have raised OAuthStateError"
|
||
|
|
except OAuthStateError as e:
|
||
|
|
assert "empty or wrong type" in str(e)
|
||
|
|
|
||
|
|
def test_deserialize_missing_signature_raises_error(self):
|
||
|
|
"""Test that missing signature is detected."""
|
||
|
|
data = json.dumps({"test": "data"})
|
||
|
|
encoded = base64.urlsafe_b64encode(data.encode()).decode()
|
||
|
|
|
||
|
|
try:
|
||
|
|
SecureOAuthState.deserialize(encoded)
|
||
|
|
assert False, "Should have raised OAuthStateError"
|
||
|
|
except OAuthStateError as e:
|
||
|
|
assert "missing signature" in str(e)
|
||
|
|
|
||
|
|
def test_deserialize_invalid_base64_raises_error(self):
|
||
|
|
"""Test that invalid data is rejected (base64 or signature)."""
|
||
|
|
# Invalid characters may be accepted by Python's base64 decoder
|
||
|
|
# but signature verification should fail
|
||
|
|
try:
|
||
|
|
SecureOAuthState.deserialize("!!!invalid!!!.!!!data!!!")
|
||
|
|
assert False, "Should have raised OAuthStateError"
|
||
|
|
except OAuthStateError as e:
|
||
|
|
# Error could be from encoding or signature verification
|
||
|
|
assert "Invalid state" in str(e)
|
||
|
|
|
||
|
|
def test_deserialize_tampered_signature_detected(self):
|
||
|
|
"""Test that tampered signature is detected."""
|
||
|
|
state = SecureOAuthState()
|
||
|
|
serialized = state.serialize()
|
||
|
|
|
||
|
|
# Tamper with the signature
|
||
|
|
data_part, sig_part = serialized.split(".")
|
||
|
|
tampered_sig = base64.urlsafe_b64encode(b"tampered").decode().rstrip("=")
|
||
|
|
tampered = f"{data_part}.{tampered_sig}"
|
||
|
|
|
||
|
|
try:
|
||
|
|
SecureOAuthState.deserialize(tampered)
|
||
|
|
assert False, "Should have raised OAuthStateError"
|
||
|
|
except OAuthStateError as e:
|
||
|
|
assert "tampering detected" in str(e)
|
||
|
|
|
||
|
|
def test_deserialize_tampered_data_detected(self):
|
||
|
|
"""Test that tampered data is detected via HMAC verification."""
|
||
|
|
state = SecureOAuthState()
|
||
|
|
serialized = state.serialize()
|
||
|
|
|
||
|
|
# Tamper with the data but keep signature
|
||
|
|
data_part, sig_part = serialized.split(".")
|
||
|
|
tampered_data = json.dumps({"hacked": True})
|
||
|
|
tampered_encoded = base64.urlsafe_b64encode(tampered_data.encode()).decode().rstrip("=")
|
||
|
|
tampered = f"{tampered_encoded}.{sig_part}"
|
||
|
|
|
||
|
|
try:
|
||
|
|
SecureOAuthState.deserialize(tampered)
|
||
|
|
assert False, "Should have raised OAuthStateError"
|
||
|
|
except OAuthStateError as e:
|
||
|
|
assert "tampering detected" in str(e)
|
||
|
|
|
||
|
|
def test_deserialize_expired_state_raises_error(self):
|
||
|
|
"""Test that expired states are rejected."""
|
||
|
|
# Create a state with old timestamp
|
||
|
|
old_state = SecureOAuthState()
|
||
|
|
old_state.timestamp = time.time() - 1000 # 1000 seconds ago
|
||
|
|
|
||
|
|
serialized = old_state.serialize()
|
||
|
|
|
||
|
|
try:
|
||
|
|
SecureOAuthState.deserialize(serialized)
|
||
|
|
assert False, "Should have raised OAuthStateError"
|
||
|
|
except OAuthStateError as e:
|
||
|
|
assert "expired" in str(e)
|
||
|
|
|
||
|
|
def test_deserialize_invalid_json_raises_error(self):
|
||
|
|
"""Test that invalid JSON raises OAuthStateError."""
|
||
|
|
key = SecureOAuthState._get_secret_key()
|
||
|
|
bad_data = b"not valid json {{{"
|
||
|
|
sig = base64.urlsafe_b64encode(hmac.new(key, bad_data, hashlib.sha256).digest())
|
||
|
|
encoded_data = base64.urlsafe_b64encode(bad_data).decode().rstrip("=")
|
||
|
|
encoded_sig = sig.decode().rstrip("=")
|
||
|
|
|
||
|
|
try:
|
||
|
|
SecureOAuthState.deserialize(f"{encoded_data}.{encoded_sig}")
|
||
|
|
assert False, "Should have raised OAuthStateError"
|
||
|
|
except OAuthStateError as e:
|
||
|
|
assert "Invalid state JSON" in str(e)
|
||
|
|
|
||
|
|
def test_deserialize_missing_fields_raises_error(self):
|
||
|
|
"""Test that missing required fields are detected."""
|
||
|
|
key = SecureOAuthState._get_secret_key()
|
||
|
|
bad_data = json.dumps({"token": "test"}).encode() # missing timestamp, nonce
|
||
|
|
sig = base64.urlsafe_b64encode(hmac.new(key, bad_data, hashlib.sha256).digest())
|
||
|
|
encoded_data = base64.urlsafe_b64encode(bad_data).decode().rstrip("=")
|
||
|
|
encoded_sig = sig.decode().rstrip("=")
|
||
|
|
|
||
|
|
try:
|
||
|
|
SecureOAuthState.deserialize(f"{encoded_data}.{encoded_sig}")
|
||
|
|
assert False, "Should have raised OAuthStateError"
|
||
|
|
except OAuthStateError as e:
|
||
|
|
assert "missing fields" in str(e)
|
||
|
|
|
||
|
|
def test_deserialize_invalid_token_type_raises_error(self):
|
||
|
|
"""Test that non-string tokens are rejected."""
|
||
|
|
key = SecureOAuthState._get_secret_key()
|
||
|
|
bad_data = json.dumps({
|
||
|
|
"token": 12345, # should be string
|
||
|
|
"timestamp": time.time(),
|
||
|
|
"nonce": "abc123"
|
||
|
|
}).encode()
|
||
|
|
sig = base64.urlsafe_b64encode(hmac.new(key, bad_data, hashlib.sha256).digest())
|
||
|
|
encoded_data = base64.urlsafe_b64encode(bad_data).decode().rstrip("=")
|
||
|
|
encoded_sig = sig.decode().rstrip("=")
|
||
|
|
|
||
|
|
try:
|
||
|
|
SecureOAuthState.deserialize(f"{encoded_data}.{encoded_sig}")
|
||
|
|
assert False, "Should have raised OAuthStateError"
|
||
|
|
except OAuthStateError as e:
|
||
|
|
assert "token must be a string" in str(e)
|
||
|
|
|
||
|
|
def test_deserialize_short_token_raises_error(self):
|
||
|
|
"""Test that short tokens are rejected."""
|
||
|
|
key = SecureOAuthState._get_secret_key()
|
||
|
|
bad_data = json.dumps({
|
||
|
|
"token": "short", # too short
|
||
|
|
"timestamp": time.time(),
|
||
|
|
"nonce": "abc123"
|
||
|
|
}).encode()
|
||
|
|
sig = base64.urlsafe_b64encode(hmac.new(key, bad_data, hashlib.sha256).digest())
|
||
|
|
encoded_data = base64.urlsafe_b64encode(bad_data).decode().rstrip("=")
|
||
|
|
encoded_sig = sig.decode().rstrip("=")
|
||
|
|
|
||
|
|
try:
|
||
|
|
SecureOAuthState.deserialize(f"{encoded_data}.{encoded_sig}")
|
||
|
|
assert False, "Should have raised OAuthStateError"
|
||
|
|
except OAuthStateError as e:
|
||
|
|
assert "token must be a string" in str(e)
|
||
|
|
|
||
|
|
def test_deserialize_invalid_timestamp_raises_error(self):
|
||
|
|
"""Test that non-numeric timestamps are rejected."""
|
||
|
|
key = SecureOAuthState._get_secret_key()
|
||
|
|
bad_data = json.dumps({
|
||
|
|
"token": "a" * 32,
|
||
|
|
"timestamp": "not a number",
|
||
|
|
"nonce": "abc123"
|
||
|
|
}).encode()
|
||
|
|
sig = base64.urlsafe_b64encode(hmac.new(key, bad_data, hashlib.sha256).digest())
|
||
|
|
encoded_data = base64.urlsafe_b64encode(bad_data).decode().rstrip("=")
|
||
|
|
encoded_sig = sig.decode().rstrip("=")
|
||
|
|
|
||
|
|
try:
|
||
|
|
SecureOAuthState.deserialize(f"{encoded_data}.{encoded_sig}")
|
||
|
|
assert False, "Should have raised OAuthStateError"
|
||
|
|
except OAuthStateError as e:
|
||
|
|
assert "timestamp must be numeric" in str(e)
|
||
|
|
|
||
|
|
def test_validate_against_correct_token(self):
|
||
|
|
"""Test token validation with matching token."""
|
||
|
|
state = SecureOAuthState()
|
||
|
|
assert state.validate_against(state.token) is True
|
||
|
|
|
||
|
|
def test_validate_against_wrong_token(self):
|
||
|
|
"""Test token validation with non-matching token."""
|
||
|
|
state = SecureOAuthState()
|
||
|
|
assert state.validate_against("wrong-token") is False
|
||
|
|
|
||
|
|
def test_validate_against_non_string(self):
|
||
|
|
"""Test token validation with non-string input."""
|
||
|
|
state = SecureOAuthState()
|
||
|
|
assert state.validate_against(None) is False
|
||
|
|
assert state.validate_against(12345) is False
|
||
|
|
|
||
|
|
def test_validate_uses_constant_time_comparison(self):
|
||
|
|
"""Test that validate_against uses constant-time comparison."""
|
||
|
|
state = SecureOAuthState(token="test-token-for-comparison")
|
||
|
|
|
||
|
|
# This test verifies no early return on mismatch
|
||
|
|
# In practice, secrets.compare_digest is used
|
||
|
|
result1 = state.validate_against("wrong-token-for-comparison")
|
||
|
|
result2 = state.validate_against("another-wrong-token-here")
|
||
|
|
|
||
|
|
assert result1 is False
|
||
|
|
assert result2 is False
|
||
|
|
|
||
|
|
def test_to_dict_format(self):
|
||
|
|
"""Test that to_dict returns correct format."""
|
||
|
|
state = SecureOAuthState(data={"custom": "data"})
|
||
|
|
d = state.to_dict()
|
||
|
|
|
||
|
|
assert set(d.keys()) == {"token", "timestamp", "nonce", "data"}
|
||
|
|
assert d["token"] == state.token
|
||
|
|
assert d["timestamp"] == state.timestamp
|
||
|
|
assert d["nonce"] == state.nonce
|
||
|
|
assert d["data"] == {"custom": "data"}
|
||
|
|
|
||
|
|
|
||
|
|
# =============================================================================
|
||
|
|
# OAuthStateManager Tests
|
||
|
|
# =============================================================================
|
||
|
|
|
||
|
|
class TestOAuthStateManager:
|
||
|
|
"""Tests for the OAuthStateManager class."""
|
||
|
|
|
||
|
|
def setup_method(self):
|
||
|
|
"""Reset state manager before each test."""
|
||
|
|
global _state_manager
|
||
|
|
_state_manager.invalidate()
|
||
|
|
_state_manager._used_nonces.clear()
|
||
|
|
|
||
|
|
def test_generate_state_returns_serialized(self):
|
||
|
|
"""Test that generate_state returns a serialized state string."""
|
||
|
|
state_str = _state_manager.generate_state()
|
||
|
|
|
||
|
|
# Should be a string with format: data.signature
|
||
|
|
assert isinstance(state_str, str)
|
||
|
|
assert "." in state_str
|
||
|
|
parts = state_str.split(".")
|
||
|
|
assert len(parts) == 2
|
||
|
|
|
||
|
|
def test_generate_state_with_data(self):
|
||
|
|
"""Test that extra data is included in state."""
|
||
|
|
extra = {"server_name": "test-server", "user_id": "123"}
|
||
|
|
state_str = _state_manager.generate_state(extra_data=extra)
|
||
|
|
|
||
|
|
# Validate and extract
|
||
|
|
is_valid, data = _state_manager.validate_and_extract(state_str)
|
||
|
|
assert is_valid is True
|
||
|
|
assert data == extra
|
||
|
|
|
||
|
|
def test_validate_and_extract_valid_state(self):
|
||
|
|
"""Test validation with a valid state."""
|
||
|
|
extra = {"test": "data"}
|
||
|
|
state_str = _state_manager.generate_state(extra_data=extra)
|
||
|
|
|
||
|
|
is_valid, data = _state_manager.validate_and_extract(state_str)
|
||
|
|
|
||
|
|
assert is_valid is True
|
||
|
|
assert data == extra
|
||
|
|
|
||
|
|
def test_validate_and_extract_none_state(self):
|
||
|
|
"""Test validation with None state."""
|
||
|
|
is_valid, data = _state_manager.validate_and_extract(None)
|
||
|
|
|
||
|
|
assert is_valid is False
|
||
|
|
assert data is None
|
||
|
|
|
||
|
|
def test_validate_and_extract_invalid_state(self):
|
||
|
|
"""Test validation with an invalid state."""
|
||
|
|
is_valid, data = _state_manager.validate_and_extract("invalid.state.here")
|
||
|
|
|
||
|
|
assert is_valid is False
|
||
|
|
assert data is None
|
||
|
|
|
||
|
|
def test_state_cleared_after_validation(self):
|
||
|
|
"""Test that state is cleared after successful validation."""
|
||
|
|
state_str = _state_manager.generate_state()
|
||
|
|
|
||
|
|
# First validation should succeed
|
||
|
|
is_valid1, _ = _state_manager.validate_and_extract(state_str)
|
||
|
|
assert is_valid1 is True
|
||
|
|
|
||
|
|
# Second validation should fail (replay)
|
||
|
|
is_valid2, _ = _state_manager.validate_and_extract(state_str)
|
||
|
|
assert is_valid2 is False
|
||
|
|
|
||
|
|
def test_nonce_tracking_prevents_replay(self):
|
||
|
|
"""Test that nonce tracking prevents replay attacks."""
|
||
|
|
state = SecureOAuthState()
|
||
|
|
serialized = state.serialize()
|
||
|
|
|
||
|
|
# Manually add to used nonces
|
||
|
|
with _state_manager._lock:
|
||
|
|
_state_manager._used_nonces.add(state.nonce)
|
||
|
|
|
||
|
|
# Validation should fail due to nonce replay
|
||
|
|
is_valid, _ = _state_manager.validate_and_extract(serialized)
|
||
|
|
assert is_valid is False
|
||
|
|
|
||
|
|
def test_invalidate_clears_state(self):
|
||
|
|
"""Test that invalidate clears the stored state."""
|
||
|
|
_state_manager.generate_state()
|
||
|
|
assert _state_manager._state is not None
|
||
|
|
|
||
|
|
_state_manager.invalidate()
|
||
|
|
assert _state_manager._state is None
|
||
|
|
|
||
|
|
def test_thread_safety(self):
|
||
|
|
"""Test thread safety of state manager."""
|
||
|
|
results = []
|
||
|
|
|
||
|
|
def generate():
|
||
|
|
state_str = _state_manager.generate_state(extra_data={"thread": threading.current_thread().name})
|
||
|
|
results.append(state_str)
|
||
|
|
|
||
|
|
threads = [threading.Thread(target=generate) for _ in range(10)]
|
||
|
|
for t in threads:
|
||
|
|
t.start()
|
||
|
|
for t in threads:
|
||
|
|
t.join()
|
||
|
|
|
||
|
|
# All states should be unique
|
||
|
|
assert len(set(results)) == 10
|
||
|
|
|
||
|
|
def test_max_nonce_limit(self):
|
||
|
|
"""Test that nonce set is limited to prevent memory growth."""
|
||
|
|
manager = OAuthStateManager()
|
||
|
|
manager._max_used_nonces = 5
|
||
|
|
|
||
|
|
# Generate more nonces than the limit
|
||
|
|
for _ in range(10):
|
||
|
|
state = SecureOAuthState()
|
||
|
|
manager._used_nonces.add(state.nonce)
|
||
|
|
|
||
|
|
# Set should have been cleared at some point
|
||
|
|
# (implementation clears when limit is exceeded)
|
||
|
|
|
||
|
|
|
||
|
|
# =============================================================================
|
||
|
|
# Schema Validation Tests
|
||
|
|
# =============================================================================
|
||
|
|
|
||
|
|
class TestSchemaValidation:
|
||
|
|
"""Tests for JSON schema validation (V-006)."""
|
||
|
|
|
||
|
|
def test_valid_token_schema_accepted(self):
|
||
|
|
"""Test that valid token data passes schema validation."""
|
||
|
|
valid_token = {
|
||
|
|
"access_token": "secret_token_123",
|
||
|
|
"token_type": "Bearer",
|
||
|
|
"refresh_token": "refresh_456",
|
||
|
|
"expires_in": 3600,
|
||
|
|
"expires_at": 1234567890.0,
|
||
|
|
"scope": "read write",
|
||
|
|
"id_token": "id_token_789",
|
||
|
|
}
|
||
|
|
# Should not raise
|
||
|
|
_validate_token_schema(valid_token, _OAUTH_TOKEN_SCHEMA, "token")
|
||
|
|
|
||
|
|
def test_minimal_valid_token_schema(self):
|
||
|
|
"""Test that minimal valid token (only required fields) passes."""
|
||
|
|
minimal_token = {
|
||
|
|
"access_token": "secret",
|
||
|
|
"token_type": "Bearer",
|
||
|
|
}
|
||
|
|
_validate_token_schema(minimal_token, _OAUTH_TOKEN_SCHEMA, "token")
|
||
|
|
|
||
|
|
def test_missing_required_field_rejected(self):
|
||
|
|
"""Test that missing required fields are detected."""
|
||
|
|
invalid_token = {"token_type": "Bearer"} # missing access_token
|
||
|
|
try:
|
||
|
|
_validate_token_schema(invalid_token, _OAUTH_TOKEN_SCHEMA, "token")
|
||
|
|
assert False, "Should have raised OAuthStateError"
|
||
|
|
except OAuthStateError as e:
|
||
|
|
assert "missing required fields" in str(e)
|
||
|
|
assert "access_token" in str(e)
|
||
|
|
|
||
|
|
def test_wrong_type_rejected(self):
|
||
|
|
"""Test that fields with wrong types are rejected."""
|
||
|
|
invalid_token = {
|
||
|
|
"access_token": 12345, # should be string
|
||
|
|
"token_type": "Bearer",
|
||
|
|
}
|
||
|
|
try:
|
||
|
|
_validate_token_schema(invalid_token, _OAUTH_TOKEN_SCHEMA, "token")
|
||
|
|
assert False, "Should have raised OAuthStateError"
|
||
|
|
except OAuthStateError as e:
|
||
|
|
assert "has wrong type" in str(e)
|
||
|
|
|
||
|
|
def test_null_values_accepted(self):
|
||
|
|
"""Test that null values for optional fields are accepted."""
|
||
|
|
token_with_nulls = {
|
||
|
|
"access_token": "secret",
|
||
|
|
"token_type": "Bearer",
|
||
|
|
"refresh_token": None,
|
||
|
|
"expires_in": None,
|
||
|
|
}
|
||
|
|
_validate_token_schema(token_with_nulls, _OAUTH_TOKEN_SCHEMA, "token")
|
||
|
|
|
||
|
|
def test_non_dict_data_rejected(self):
|
||
|
|
"""Test that non-dictionary data is rejected."""
|
||
|
|
try:
|
||
|
|
_validate_token_schema("not a dict", _OAUTH_TOKEN_SCHEMA, "token")
|
||
|
|
assert False, "Should have raised OAuthStateError"
|
||
|
|
except OAuthStateError as e:
|
||
|
|
assert "must be a dictionary" in str(e)
|
||
|
|
|
||
|
|
def test_valid_client_schema(self):
|
||
|
|
"""Test that valid client info passes schema validation."""
|
||
|
|
valid_client = {
|
||
|
|
"client_id": "client_123",
|
||
|
|
"client_secret": "secret_456",
|
||
|
|
"client_name": "Test Client",
|
||
|
|
"redirect_uris": ["http://localhost/callback"],
|
||
|
|
}
|
||
|
|
_validate_token_schema(valid_client, _OAUTH_CLIENT_SCHEMA, "client")
|
||
|
|
|
||
|
|
def test_client_missing_required_rejected(self):
|
||
|
|
"""Test that client info missing client_id is rejected."""
|
||
|
|
invalid_client = {"client_name": "Test"}
|
||
|
|
try:
|
||
|
|
_validate_token_schema(invalid_client, _OAUTH_CLIENT_SCHEMA, "client")
|
||
|
|
assert False, "Should have raised OAuthStateError"
|
||
|
|
except OAuthStateError as e:
|
||
|
|
assert "missing required fields" in str(e)
|
||
|
|
|
||
|
|
|
||
|
|
# =============================================================================
|
||
|
|
# Token Storage Security Tests
|
||
|
|
# =============================================================================
|
||
|
|
|
||
|
|
class TestTokenStorageSecurity:
|
||
|
|
"""Tests for token storage signing and validation (V-006)."""
|
||
|
|
|
||
|
|
def test_sign_and_verify_token_data(self):
|
||
|
|
"""Test that token data can be signed and verified."""
|
||
|
|
data = {"access_token": "test123", "token_type": "Bearer"}
|
||
|
|
sig = _sign_token_data(data)
|
||
|
|
|
||
|
|
assert sig is not None
|
||
|
|
assert len(sig) > 0
|
||
|
|
assert _verify_token_signature(data, sig) is True
|
||
|
|
|
||
|
|
def test_tampered_token_data_rejected(self):
|
||
|
|
"""Test that tampered token data fails verification."""
|
||
|
|
data = {"access_token": "test123", "token_type": "Bearer"}
|
||
|
|
sig = _sign_token_data(data)
|
||
|
|
|
||
|
|
# Modify the data
|
||
|
|
tampered_data = {"access_token": "hacked", "token_type": "Bearer"}
|
||
|
|
assert _verify_token_signature(tampered_data, sig) is False
|
||
|
|
|
||
|
|
def test_empty_signature_rejected(self):
|
||
|
|
"""Test that empty signature is rejected."""
|
||
|
|
data = {"access_token": "test", "token_type": "Bearer"}
|
||
|
|
assert _verify_token_signature(data, "") is False
|
||
|
|
|
||
|
|
def test_invalid_signature_rejected(self):
|
||
|
|
"""Test that invalid signature is rejected."""
|
||
|
|
data = {"access_token": "test", "token_type": "Bearer"}
|
||
|
|
assert _verify_token_signature(data, "invalid") is False
|
||
|
|
|
||
|
|
def test_signature_deterministic(self):
|
||
|
|
"""Test that signing the same data produces the same signature."""
|
||
|
|
data = {"access_token": "test123", "token_type": "Bearer"}
|
||
|
|
sig1 = _sign_token_data(data)
|
||
|
|
sig2 = _sign_token_data(data)
|
||
|
|
assert sig1 == sig2
|
||
|
|
|
||
|
|
def test_different_data_different_signatures(self):
|
||
|
|
"""Test that different data produces different signatures."""
|
||
|
|
data1 = {"access_token": "test1", "token_type": "Bearer"}
|
||
|
|
data2 = {"access_token": "test2", "token_type": "Bearer"}
|
||
|
|
sig1 = _sign_token_data(data1)
|
||
|
|
sig2 = _sign_token_data(data2)
|
||
|
|
assert sig1 != sig2
|
||
|
|
|
||
|
|
|
||
|
|
# =============================================================================
|
||
|
|
# Pickle Security Tests
|
||
|
|
# =============================================================================
|
||
|
|
|
||
|
|
class TestNoPickleUsage:
|
||
|
|
"""Tests to verify pickle is NOT used (V-006 regression tests)."""
|
||
|
|
|
||
|
|
def test_serialization_does_not_use_pickle(self):
|
||
|
|
"""Verify that state serialization uses JSON, not pickle."""
|
||
|
|
state = SecureOAuthState(data={"malicious": "__import__('os').system('rm -rf /')"})
|
||
|
|
serialized = state.serialize()
|
||
|
|
|
||
|
|
# Decode the data part
|
||
|
|
data_part, _ = serialized.split(".")
|
||
|
|
padding = 4 - (len(data_part) % 4) if len(data_part) % 4 else 0
|
||
|
|
decoded = base64.urlsafe_b64decode(data_part + ("=" * padding))
|
||
|
|
|
||
|
|
# Should be valid JSON, not pickle
|
||
|
|
parsed = json.loads(decoded.decode('utf-8'))
|
||
|
|
assert parsed["data"]["malicious"] == "__import__('os').system('rm -rf /')"
|
||
|
|
|
||
|
|
# Should NOT start with pickle protocol markers
|
||
|
|
assert not decoded.startswith(b'\x80') # Pickle protocol marker
|
||
|
|
assert b'cos\n' not in decoded # Pickle module load pattern
|
||
|
|
|
||
|
|
def test_deserialize_rejects_pickle_payload(self):
|
||
|
|
"""Test that pickle payloads are rejected during deserialization."""
|
||
|
|
import pickle
|
||
|
|
|
||
|
|
# Create a pickle payload that would execute code
|
||
|
|
malicious = pickle.dumps({"cmd": "whoami"})
|
||
|
|
|
||
|
|
key = SecureOAuthState._get_secret_key()
|
||
|
|
sig = base64.urlsafe_b64encode(
|
||
|
|
hmac.new(key, malicious, hashlib.sha256).digest()
|
||
|
|
).decode().rstrip("=")
|
||
|
|
data = base64.urlsafe_b64encode(malicious).decode().rstrip("=")
|
||
|
|
|
||
|
|
# Should fail because it's not valid JSON
|
||
|
|
try:
|
||
|
|
SecureOAuthState.deserialize(f"{data}.{sig}")
|
||
|
|
assert False, "Should have raised OAuthStateError"
|
||
|
|
except OAuthStateError as e:
|
||
|
|
assert "Invalid state JSON" in str(e)
|
||
|
|
|
||
|
|
|
||
|
|
# =============================================================================
|
||
|
|
# Key Management Tests
|
||
|
|
# =============================================================================
|
||
|
|
|
||
|
|
class TestSecretKeyManagement:
|
||
|
|
"""Tests for HMAC secret key management."""
|
||
|
|
|
||
|
|
def test_get_secret_key_from_env(self):
|
||
|
|
"""Test that HERMES_OAUTH_SECRET environment variable is used."""
|
||
|
|
with patch.dict(os.environ, {"HERMES_OAUTH_SECRET": "test-secret-key-32bytes-long!!"}):
|
||
|
|
key = SecureOAuthState._get_secret_key()
|
||
|
|
assert key == b"test-secret-key-32bytes-long!!"
|
||
|
|
|
||
|
|
def test_get_token_storage_key_from_env(self):
|
||
|
|
"""Test that HERMES_TOKEN_STORAGE_SECRET environment variable is used."""
|
||
|
|
with patch.dict(os.environ, {"HERMES_TOKEN_STORAGE_SECRET": "storage-secret-key-32bytes!!"}):
|
||
|
|
key = _get_token_storage_key()
|
||
|
|
assert key == b"storage-secret-key-32bytes!!"
|
||
|
|
|
||
|
|
def test_get_secret_key_creates_file(self):
|
||
|
|
"""Test that secret key file is created if it doesn't exist."""
|
||
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||
|
|
home = Path(tmpdir)
|
||
|
|
with patch('pathlib.Path.home', return_value=home):
|
||
|
|
with patch.dict(os.environ, {}, clear=True):
|
||
|
|
key = SecureOAuthState._get_secret_key()
|
||
|
|
assert len(key) == 64
|
||
|
|
|
||
|
|
|
||
|
|
# =============================================================================
|
||
|
|
# Integration Tests
|
||
|
|
# =============================================================================
|
||
|
|
|
||
|
|
class TestOAuthFlowIntegration:
|
||
|
|
"""Integration tests for the OAuth flow with secure state."""
|
||
|
|
|
||
|
|
def setup_method(self):
|
||
|
|
"""Reset state manager before each test."""
|
||
|
|
global _state_manager
|
||
|
|
_state_manager.invalidate()
|
||
|
|
_state_manager._used_nonces.clear()
|
||
|
|
|
||
|
|
def test_full_oauth_state_flow(self):
|
||
|
|
"""Test the full OAuth state generation and validation flow."""
|
||
|
|
# Step 1: Generate state for OAuth request
|
||
|
|
server_name = "test-mcp-server"
|
||
|
|
state = _state_manager.generate_state(extra_data={"server_name": server_name})
|
||
|
|
|
||
|
|
# Step 2: Simulate OAuth callback with state
|
||
|
|
# (In real flow, this comes back from OAuth provider)
|
||
|
|
is_valid, data = _state_manager.validate_and_extract(state)
|
||
|
|
|
||
|
|
# Step 3: Verify validation succeeded
|
||
|
|
assert is_valid is True
|
||
|
|
assert data["server_name"] == server_name
|
||
|
|
|
||
|
|
# Step 4: Verify state cannot be replayed
|
||
|
|
is_valid_replay, _ = _state_manager.validate_and_extract(state)
|
||
|
|
assert is_valid_replay is False
|
||
|
|
|
||
|
|
def test_csrf_attack_prevention(self):
|
||
|
|
"""Test that CSRF attacks using different states are detected."""
|
||
|
|
# Attacker generates their own state
|
||
|
|
attacker_state = _state_manager.generate_state(extra_data={"malicious": True})
|
||
|
|
|
||
|
|
# Victim generates their state
|
||
|
|
victim_manager = OAuthStateManager()
|
||
|
|
victim_state = victim_manager.generate_state(extra_data={"legitimate": True})
|
||
|
|
|
||
|
|
# Attacker tries to use their state with victim's session
|
||
|
|
# This would fail because the tokens don't match
|
||
|
|
is_valid, _ = victim_manager.validate_and_extract(attacker_state)
|
||
|
|
assert is_valid is False
|
||
|
|
|
||
|
|
def test_mitm_attack_detection(self):
|
||
|
|
"""Test that tampered states from MITM attacks are detected."""
|
||
|
|
# Generate legitimate state
|
||
|
|
state = _state_manager.generate_state()
|
||
|
|
|
||
|
|
# Modify the state (simulating MITM tampering)
|
||
|
|
parts = state.split(".")
|
||
|
|
tampered_state = parts[0] + ".tampered-signature-here"
|
||
|
|
|
||
|
|
# Validation should fail
|
||
|
|
is_valid, _ = _state_manager.validate_and_extract(tampered_state)
|
||
|
|
assert is_valid is False
|
||
|
|
|
||
|
|
|
||
|
|
# =============================================================================
|
||
|
|
# Performance Tests
|
||
|
|
# =============================================================================
|
||
|
|
|
||
|
|
class TestPerformance:
|
||
|
|
"""Performance tests for state operations."""
|
||
|
|
|
||
|
|
def test_serialize_performance(self):
|
||
|
|
"""Test that serialization is fast."""
|
||
|
|
state = SecureOAuthState(data={"key": "value" * 100})
|
||
|
|
|
||
|
|
start = time.time()
|
||
|
|
for _ in range(1000):
|
||
|
|
state.serialize()
|
||
|
|
elapsed = time.time() - start
|
||
|
|
|
||
|
|
# Should complete 1000 serializations in under 1 second
|
||
|
|
assert elapsed < 1.0
|
||
|
|
|
||
|
|
def test_deserialize_performance(self):
|
||
|
|
"""Test that deserialization is fast."""
|
||
|
|
state = SecureOAuthState(data={"key": "value" * 100})
|
||
|
|
serialized = state.serialize()
|
||
|
|
|
||
|
|
start = time.time()
|
||
|
|
for _ in range(1000):
|
||
|
|
SecureOAuthState.deserialize(serialized)
|
||
|
|
elapsed = time.time() - start
|
||
|
|
|
||
|
|
# Should complete 1000 deserializations in under 1 second
|
||
|
|
assert elapsed < 1.0
|
||
|
|
|
||
|
|
|
||
|
|
def run_tests():
|
||
|
|
"""Run all tests."""
|
||
|
|
import inspect
|
||
|
|
|
||
|
|
test_classes = [
|
||
|
|
TestSecureOAuthState,
|
||
|
|
TestOAuthStateManager,
|
||
|
|
TestSchemaValidation,
|
||
|
|
TestTokenStorageSecurity,
|
||
|
|
TestNoPickleUsage,
|
||
|
|
TestSecretKeyManagement,
|
||
|
|
TestOAuthFlowIntegration,
|
||
|
|
TestPerformance,
|
||
|
|
]
|
||
|
|
|
||
|
|
total_tests = 0
|
||
|
|
passed_tests = 0
|
||
|
|
failed_tests = []
|
||
|
|
|
||
|
|
for cls in test_classes:
|
||
|
|
print(f"\n{'='*60}")
|
||
|
|
print(f"Running {cls.__name__}")
|
||
|
|
print('='*60)
|
||
|
|
|
||
|
|
instance = cls()
|
||
|
|
|
||
|
|
# Run setup if exists
|
||
|
|
if hasattr(instance, 'setup_method'):
|
||
|
|
instance.setup_method()
|
||
|
|
|
||
|
|
for name, method in inspect.getmembers(cls, predicate=inspect.isfunction):
|
||
|
|
if name.startswith('test_'):
|
||
|
|
total_tests += 1
|
||
|
|
try:
|
||
|
|
method(instance)
|
||
|
|
print(f" ✓ {name}")
|
||
|
|
passed_tests += 1
|
||
|
|
except Exception as e:
|
||
|
|
print(f" ✗ {name}: {e}")
|
||
|
|
failed_tests.append((cls.__name__, name, str(e)))
|
||
|
|
|
||
|
|
print(f"\n{'='*60}")
|
||
|
|
print(f"Results: {passed_tests}/{total_tests} tests passed")
|
||
|
|
print('='*60)
|
||
|
|
|
||
|
|
if failed_tests:
|
||
|
|
print("\nFailed tests:")
|
||
|
|
for cls_name, test_name, error in failed_tests:
|
||
|
|
print(f" - {cls_name}.{test_name}: {error}")
|
||
|
|
return 1
|
||
|
|
else:
|
||
|
|
print("\nAll tests passed!")
|
||
|
|
return 0
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
sys.exit(run_tests())
|