Files
hermes-agent/tests/test_oauth_state_security.py

787 lines
30 KiB
Python
Raw Normal View History

"""
Security tests for OAuth state handling and token storage (V-006 Fix).
Tests verify that:
1. JSON serialization is used instead of pickle
2. HMAC signatures are properly verified for both state and tokens
3. State structure is validated
4. Token schema is validated
5. Tampering is detected
6. Replay attacks are prevented
7. Timing attacks are mitigated via constant-time comparison
"""
import base64
import hashlib
import hmac
import json
import os
import secrets
import sys
import tempfile
import threading
import time
from pathlib import Path
from unittest.mock import patch
# Ensure tools directory is in path
sys.path.insert(0, str(Path(__file__).parent.parent))
from tools.mcp_oauth import (
OAuthStateError,
OAuthStateManager,
SecureOAuthState,
HermesTokenStorage,
_validate_token_schema,
_OAUTH_TOKEN_SCHEMA,
_OAUTH_CLIENT_SCHEMA,
_sign_token_data,
_verify_token_signature,
_get_token_storage_key,
_state_manager,
get_state_manager,
)
# =============================================================================
# SecureOAuthState Tests
# =============================================================================
class TestSecureOAuthState:
"""Tests for the SecureOAuthState class."""
def test_generate_creates_valid_state(self):
"""Test that generated state has all required fields."""
state = SecureOAuthState()
assert state.token is not None
assert len(state.token) >= 16
assert state.timestamp is not None
assert isinstance(state.timestamp, float)
assert state.nonce is not None
assert len(state.nonce) >= 8
assert isinstance(state.data, dict)
def test_generate_unique_tokens(self):
"""Test that generated tokens are unique."""
tokens = {SecureOAuthState._generate_token() for _ in range(100)}
assert len(tokens) == 100
def test_serialization_format(self):
"""Test that serialized state has correct format."""
state = SecureOAuthState(data={"test": "value"})
serialized = state.serialize()
# Should have format: data.signature
parts = serialized.split(".")
assert len(parts) == 2
# Both parts should be URL-safe base64
data_part, sig_part = parts
assert all(c in "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_="
for c in data_part)
assert all(c in "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_="
for c in sig_part)
def test_serialize_deserialize_roundtrip(self):
"""Test that serialize/deserialize preserves state."""
original = SecureOAuthState(data={"server": "test123", "user": "alice"})
serialized = original.serialize()
deserialized = SecureOAuthState.deserialize(serialized)
assert deserialized.token == original.token
assert deserialized.timestamp == original.timestamp
assert deserialized.nonce == original.nonce
assert deserialized.data == original.data
def test_deserialize_empty_raises_error(self):
"""Test that deserializing empty state raises OAuthStateError."""
try:
SecureOAuthState.deserialize("")
assert False, "Should have raised OAuthStateError"
except OAuthStateError as e:
assert "empty or wrong type" in str(e)
try:
SecureOAuthState.deserialize(None)
assert False, "Should have raised OAuthStateError"
except OAuthStateError as e:
assert "empty or wrong type" in str(e)
def test_deserialize_missing_signature_raises_error(self):
"""Test that missing signature is detected."""
data = json.dumps({"test": "data"})
encoded = base64.urlsafe_b64encode(data.encode()).decode()
try:
SecureOAuthState.deserialize(encoded)
assert False, "Should have raised OAuthStateError"
except OAuthStateError as e:
assert "missing signature" in str(e)
def test_deserialize_invalid_base64_raises_error(self):
"""Test that invalid data is rejected (base64 or signature)."""
# Invalid characters may be accepted by Python's base64 decoder
# but signature verification should fail
try:
SecureOAuthState.deserialize("!!!invalid!!!.!!!data!!!")
assert False, "Should have raised OAuthStateError"
except OAuthStateError as e:
# Error could be from encoding or signature verification
assert "Invalid state" in str(e)
def test_deserialize_tampered_signature_detected(self):
"""Test that tampered signature is detected."""
state = SecureOAuthState()
serialized = state.serialize()
# Tamper with the signature
data_part, sig_part = serialized.split(".")
tampered_sig = base64.urlsafe_b64encode(b"tampered").decode().rstrip("=")
tampered = f"{data_part}.{tampered_sig}"
try:
SecureOAuthState.deserialize(tampered)
assert False, "Should have raised OAuthStateError"
except OAuthStateError as e:
assert "tampering detected" in str(e)
def test_deserialize_tampered_data_detected(self):
"""Test that tampered data is detected via HMAC verification."""
state = SecureOAuthState()
serialized = state.serialize()
# Tamper with the data but keep signature
data_part, sig_part = serialized.split(".")
tampered_data = json.dumps({"hacked": True})
tampered_encoded = base64.urlsafe_b64encode(tampered_data.encode()).decode().rstrip("=")
tampered = f"{tampered_encoded}.{sig_part}"
try:
SecureOAuthState.deserialize(tampered)
assert False, "Should have raised OAuthStateError"
except OAuthStateError as e:
assert "tampering detected" in str(e)
def test_deserialize_expired_state_raises_error(self):
"""Test that expired states are rejected."""
# Create a state with old timestamp
old_state = SecureOAuthState()
old_state.timestamp = time.time() - 1000 # 1000 seconds ago
serialized = old_state.serialize()
try:
SecureOAuthState.deserialize(serialized)
assert False, "Should have raised OAuthStateError"
except OAuthStateError as e:
assert "expired" in str(e)
def test_deserialize_invalid_json_raises_error(self):
"""Test that invalid JSON raises OAuthStateError."""
key = SecureOAuthState._get_secret_key()
bad_data = b"not valid json {{{"
sig = base64.urlsafe_b64encode(hmac.new(key, bad_data, hashlib.sha256).digest())
encoded_data = base64.urlsafe_b64encode(bad_data).decode().rstrip("=")
encoded_sig = sig.decode().rstrip("=")
try:
SecureOAuthState.deserialize(f"{encoded_data}.{encoded_sig}")
assert False, "Should have raised OAuthStateError"
except OAuthStateError as e:
assert "Invalid state JSON" in str(e)
def test_deserialize_missing_fields_raises_error(self):
"""Test that missing required fields are detected."""
key = SecureOAuthState._get_secret_key()
bad_data = json.dumps({"token": "test"}).encode() # missing timestamp, nonce
sig = base64.urlsafe_b64encode(hmac.new(key, bad_data, hashlib.sha256).digest())
encoded_data = base64.urlsafe_b64encode(bad_data).decode().rstrip("=")
encoded_sig = sig.decode().rstrip("=")
try:
SecureOAuthState.deserialize(f"{encoded_data}.{encoded_sig}")
assert False, "Should have raised OAuthStateError"
except OAuthStateError as e:
assert "missing fields" in str(e)
def test_deserialize_invalid_token_type_raises_error(self):
"""Test that non-string tokens are rejected."""
key = SecureOAuthState._get_secret_key()
bad_data = json.dumps({
"token": 12345, # should be string
"timestamp": time.time(),
"nonce": "abc123"
}).encode()
sig = base64.urlsafe_b64encode(hmac.new(key, bad_data, hashlib.sha256).digest())
encoded_data = base64.urlsafe_b64encode(bad_data).decode().rstrip("=")
encoded_sig = sig.decode().rstrip("=")
try:
SecureOAuthState.deserialize(f"{encoded_data}.{encoded_sig}")
assert False, "Should have raised OAuthStateError"
except OAuthStateError as e:
assert "token must be a string" in str(e)
def test_deserialize_short_token_raises_error(self):
"""Test that short tokens are rejected."""
key = SecureOAuthState._get_secret_key()
bad_data = json.dumps({
"token": "short", # too short
"timestamp": time.time(),
"nonce": "abc123"
}).encode()
sig = base64.urlsafe_b64encode(hmac.new(key, bad_data, hashlib.sha256).digest())
encoded_data = base64.urlsafe_b64encode(bad_data).decode().rstrip("=")
encoded_sig = sig.decode().rstrip("=")
try:
SecureOAuthState.deserialize(f"{encoded_data}.{encoded_sig}")
assert False, "Should have raised OAuthStateError"
except OAuthStateError as e:
assert "token must be a string" in str(e)
def test_deserialize_invalid_timestamp_raises_error(self):
"""Test that non-numeric timestamps are rejected."""
key = SecureOAuthState._get_secret_key()
bad_data = json.dumps({
"token": "a" * 32,
"timestamp": "not a number",
"nonce": "abc123"
}).encode()
sig = base64.urlsafe_b64encode(hmac.new(key, bad_data, hashlib.sha256).digest())
encoded_data = base64.urlsafe_b64encode(bad_data).decode().rstrip("=")
encoded_sig = sig.decode().rstrip("=")
try:
SecureOAuthState.deserialize(f"{encoded_data}.{encoded_sig}")
assert False, "Should have raised OAuthStateError"
except OAuthStateError as e:
assert "timestamp must be numeric" in str(e)
def test_validate_against_correct_token(self):
"""Test token validation with matching token."""
state = SecureOAuthState()
assert state.validate_against(state.token) is True
def test_validate_against_wrong_token(self):
"""Test token validation with non-matching token."""
state = SecureOAuthState()
assert state.validate_against("wrong-token") is False
def test_validate_against_non_string(self):
"""Test token validation with non-string input."""
state = SecureOAuthState()
assert state.validate_against(None) is False
assert state.validate_against(12345) is False
def test_validate_uses_constant_time_comparison(self):
"""Test that validate_against uses constant-time comparison."""
state = SecureOAuthState(token="test-token-for-comparison")
# This test verifies no early return on mismatch
# In practice, secrets.compare_digest is used
result1 = state.validate_against("wrong-token-for-comparison")
result2 = state.validate_against("another-wrong-token-here")
assert result1 is False
assert result2 is False
def test_to_dict_format(self):
"""Test that to_dict returns correct format."""
state = SecureOAuthState(data={"custom": "data"})
d = state.to_dict()
assert set(d.keys()) == {"token", "timestamp", "nonce", "data"}
assert d["token"] == state.token
assert d["timestamp"] == state.timestamp
assert d["nonce"] == state.nonce
assert d["data"] == {"custom": "data"}
# =============================================================================
# OAuthStateManager Tests
# =============================================================================
class TestOAuthStateManager:
"""Tests for the OAuthStateManager class."""
def setup_method(self):
"""Reset state manager before each test."""
global _state_manager
_state_manager.invalidate()
_state_manager._used_nonces.clear()
def test_generate_state_returns_serialized(self):
"""Test that generate_state returns a serialized state string."""
state_str = _state_manager.generate_state()
# Should be a string with format: data.signature
assert isinstance(state_str, str)
assert "." in state_str
parts = state_str.split(".")
assert len(parts) == 2
def test_generate_state_with_data(self):
"""Test that extra data is included in state."""
extra = {"server_name": "test-server", "user_id": "123"}
state_str = _state_manager.generate_state(extra_data=extra)
# Validate and extract
is_valid, data = _state_manager.validate_and_extract(state_str)
assert is_valid is True
assert data == extra
def test_validate_and_extract_valid_state(self):
"""Test validation with a valid state."""
extra = {"test": "data"}
state_str = _state_manager.generate_state(extra_data=extra)
is_valid, data = _state_manager.validate_and_extract(state_str)
assert is_valid is True
assert data == extra
def test_validate_and_extract_none_state(self):
"""Test validation with None state."""
is_valid, data = _state_manager.validate_and_extract(None)
assert is_valid is False
assert data is None
def test_validate_and_extract_invalid_state(self):
"""Test validation with an invalid state."""
is_valid, data = _state_manager.validate_and_extract("invalid.state.here")
assert is_valid is False
assert data is None
def test_state_cleared_after_validation(self):
"""Test that state is cleared after successful validation."""
state_str = _state_manager.generate_state()
# First validation should succeed
is_valid1, _ = _state_manager.validate_and_extract(state_str)
assert is_valid1 is True
# Second validation should fail (replay)
is_valid2, _ = _state_manager.validate_and_extract(state_str)
assert is_valid2 is False
def test_nonce_tracking_prevents_replay(self):
"""Test that nonce tracking prevents replay attacks."""
state = SecureOAuthState()
serialized = state.serialize()
# Manually add to used nonces
with _state_manager._lock:
_state_manager._used_nonces.add(state.nonce)
# Validation should fail due to nonce replay
is_valid, _ = _state_manager.validate_and_extract(serialized)
assert is_valid is False
def test_invalidate_clears_state(self):
"""Test that invalidate clears the stored state."""
_state_manager.generate_state()
assert _state_manager._state is not None
_state_manager.invalidate()
assert _state_manager._state is None
def test_thread_safety(self):
"""Test thread safety of state manager."""
results = []
def generate():
state_str = _state_manager.generate_state(extra_data={"thread": threading.current_thread().name})
results.append(state_str)
threads = [threading.Thread(target=generate) for _ in range(10)]
for t in threads:
t.start()
for t in threads:
t.join()
# All states should be unique
assert len(set(results)) == 10
def test_max_nonce_limit(self):
"""Test that nonce set is limited to prevent memory growth."""
manager = OAuthStateManager()
manager._max_used_nonces = 5
# Generate more nonces than the limit
for _ in range(10):
state = SecureOAuthState()
manager._used_nonces.add(state.nonce)
# Set should have been cleared at some point
# (implementation clears when limit is exceeded)
# =============================================================================
# Schema Validation Tests
# =============================================================================
class TestSchemaValidation:
"""Tests for JSON schema validation (V-006)."""
def test_valid_token_schema_accepted(self):
"""Test that valid token data passes schema validation."""
valid_token = {
"access_token": "secret_token_123",
"token_type": "Bearer",
"refresh_token": "refresh_456",
"expires_in": 3600,
"expires_at": 1234567890.0,
"scope": "read write",
"id_token": "id_token_789",
}
# Should not raise
_validate_token_schema(valid_token, _OAUTH_TOKEN_SCHEMA, "token")
def test_minimal_valid_token_schema(self):
"""Test that minimal valid token (only required fields) passes."""
minimal_token = {
"access_token": "secret",
"token_type": "Bearer",
}
_validate_token_schema(minimal_token, _OAUTH_TOKEN_SCHEMA, "token")
def test_missing_required_field_rejected(self):
"""Test that missing required fields are detected."""
invalid_token = {"token_type": "Bearer"} # missing access_token
try:
_validate_token_schema(invalid_token, _OAUTH_TOKEN_SCHEMA, "token")
assert False, "Should have raised OAuthStateError"
except OAuthStateError as e:
assert "missing required fields" in str(e)
assert "access_token" in str(e)
def test_wrong_type_rejected(self):
"""Test that fields with wrong types are rejected."""
invalid_token = {
"access_token": 12345, # should be string
"token_type": "Bearer",
}
try:
_validate_token_schema(invalid_token, _OAUTH_TOKEN_SCHEMA, "token")
assert False, "Should have raised OAuthStateError"
except OAuthStateError as e:
assert "has wrong type" in str(e)
def test_null_values_accepted(self):
"""Test that null values for optional fields are accepted."""
token_with_nulls = {
"access_token": "secret",
"token_type": "Bearer",
"refresh_token": None,
"expires_in": None,
}
_validate_token_schema(token_with_nulls, _OAUTH_TOKEN_SCHEMA, "token")
def test_non_dict_data_rejected(self):
"""Test that non-dictionary data is rejected."""
try:
_validate_token_schema("not a dict", _OAUTH_TOKEN_SCHEMA, "token")
assert False, "Should have raised OAuthStateError"
except OAuthStateError as e:
assert "must be a dictionary" in str(e)
def test_valid_client_schema(self):
"""Test that valid client info passes schema validation."""
valid_client = {
"client_id": "client_123",
"client_secret": "secret_456",
"client_name": "Test Client",
"redirect_uris": ["http://localhost/callback"],
}
_validate_token_schema(valid_client, _OAUTH_CLIENT_SCHEMA, "client")
def test_client_missing_required_rejected(self):
"""Test that client info missing client_id is rejected."""
invalid_client = {"client_name": "Test"}
try:
_validate_token_schema(invalid_client, _OAUTH_CLIENT_SCHEMA, "client")
assert False, "Should have raised OAuthStateError"
except OAuthStateError as e:
assert "missing required fields" in str(e)
# =============================================================================
# Token Storage Security Tests
# =============================================================================
class TestTokenStorageSecurity:
"""Tests for token storage signing and validation (V-006)."""
def test_sign_and_verify_token_data(self):
"""Test that token data can be signed and verified."""
data = {"access_token": "test123", "token_type": "Bearer"}
sig = _sign_token_data(data)
assert sig is not None
assert len(sig) > 0
assert _verify_token_signature(data, sig) is True
def test_tampered_token_data_rejected(self):
"""Test that tampered token data fails verification."""
data = {"access_token": "test123", "token_type": "Bearer"}
sig = _sign_token_data(data)
# Modify the data
tampered_data = {"access_token": "hacked", "token_type": "Bearer"}
assert _verify_token_signature(tampered_data, sig) is False
def test_empty_signature_rejected(self):
"""Test that empty signature is rejected."""
data = {"access_token": "test", "token_type": "Bearer"}
assert _verify_token_signature(data, "") is False
def test_invalid_signature_rejected(self):
"""Test that invalid signature is rejected."""
data = {"access_token": "test", "token_type": "Bearer"}
assert _verify_token_signature(data, "invalid") is False
def test_signature_deterministic(self):
"""Test that signing the same data produces the same signature."""
data = {"access_token": "test123", "token_type": "Bearer"}
sig1 = _sign_token_data(data)
sig2 = _sign_token_data(data)
assert sig1 == sig2
def test_different_data_different_signatures(self):
"""Test that different data produces different signatures."""
data1 = {"access_token": "test1", "token_type": "Bearer"}
data2 = {"access_token": "test2", "token_type": "Bearer"}
sig1 = _sign_token_data(data1)
sig2 = _sign_token_data(data2)
assert sig1 != sig2
# =============================================================================
# Pickle Security Tests
# =============================================================================
class TestNoPickleUsage:
"""Tests to verify pickle is NOT used (V-006 regression tests)."""
def test_serialization_does_not_use_pickle(self):
"""Verify that state serialization uses JSON, not pickle."""
state = SecureOAuthState(data={"malicious": "__import__('os').system('rm -rf /')"})
serialized = state.serialize()
# Decode the data part
data_part, _ = serialized.split(".")
padding = 4 - (len(data_part) % 4) if len(data_part) % 4 else 0
decoded = base64.urlsafe_b64decode(data_part + ("=" * padding))
# Should be valid JSON, not pickle
parsed = json.loads(decoded.decode('utf-8'))
assert parsed["data"]["malicious"] == "__import__('os').system('rm -rf /')"
# Should NOT start with pickle protocol markers
assert not decoded.startswith(b'\x80') # Pickle protocol marker
assert b'cos\n' not in decoded # Pickle module load pattern
def test_deserialize_rejects_pickle_payload(self):
"""Test that pickle payloads are rejected during deserialization."""
import pickle
# Create a pickle payload that would execute code
malicious = pickle.dumps({"cmd": "whoami"})
key = SecureOAuthState._get_secret_key()
sig = base64.urlsafe_b64encode(
hmac.new(key, malicious, hashlib.sha256).digest()
).decode().rstrip("=")
data = base64.urlsafe_b64encode(malicious).decode().rstrip("=")
# Should fail because it's not valid JSON
try:
SecureOAuthState.deserialize(f"{data}.{sig}")
assert False, "Should have raised OAuthStateError"
except OAuthStateError as e:
assert "Invalid state JSON" in str(e)
# =============================================================================
# Key Management Tests
# =============================================================================
class TestSecretKeyManagement:
"""Tests for HMAC secret key management."""
def test_get_secret_key_from_env(self):
"""Test that HERMES_OAUTH_SECRET environment variable is used."""
with patch.dict(os.environ, {"HERMES_OAUTH_SECRET": "test-secret-key-32bytes-long!!"}):
key = SecureOAuthState._get_secret_key()
assert key == b"test-secret-key-32bytes-long!!"
def test_get_token_storage_key_from_env(self):
"""Test that HERMES_TOKEN_STORAGE_SECRET environment variable is used."""
with patch.dict(os.environ, {"HERMES_TOKEN_STORAGE_SECRET": "storage-secret-key-32bytes!!"}):
key = _get_token_storage_key()
assert key == b"storage-secret-key-32bytes!!"
def test_get_secret_key_creates_file(self):
"""Test that secret key file is created if it doesn't exist."""
with tempfile.TemporaryDirectory() as tmpdir:
home = Path(tmpdir)
with patch('pathlib.Path.home', return_value=home):
with patch.dict(os.environ, {}, clear=True):
key = SecureOAuthState._get_secret_key()
assert len(key) == 64
# =============================================================================
# Integration Tests
# =============================================================================
class TestOAuthFlowIntegration:
"""Integration tests for the OAuth flow with secure state."""
def setup_method(self):
"""Reset state manager before each test."""
global _state_manager
_state_manager.invalidate()
_state_manager._used_nonces.clear()
def test_full_oauth_state_flow(self):
"""Test the full OAuth state generation and validation flow."""
# Step 1: Generate state for OAuth request
server_name = "test-mcp-server"
state = _state_manager.generate_state(extra_data={"server_name": server_name})
# Step 2: Simulate OAuth callback with state
# (In real flow, this comes back from OAuth provider)
is_valid, data = _state_manager.validate_and_extract(state)
# Step 3: Verify validation succeeded
assert is_valid is True
assert data["server_name"] == server_name
# Step 4: Verify state cannot be replayed
is_valid_replay, _ = _state_manager.validate_and_extract(state)
assert is_valid_replay is False
def test_csrf_attack_prevention(self):
"""Test that CSRF attacks using different states are detected."""
# Attacker generates their own state
attacker_state = _state_manager.generate_state(extra_data={"malicious": True})
# Victim generates their state
victim_manager = OAuthStateManager()
victim_state = victim_manager.generate_state(extra_data={"legitimate": True})
# Attacker tries to use their state with victim's session
# This would fail because the tokens don't match
is_valid, _ = victim_manager.validate_and_extract(attacker_state)
assert is_valid is False
def test_mitm_attack_detection(self):
"""Test that tampered states from MITM attacks are detected."""
# Generate legitimate state
state = _state_manager.generate_state()
# Modify the state (simulating MITM tampering)
parts = state.split(".")
tampered_state = parts[0] + ".tampered-signature-here"
# Validation should fail
is_valid, _ = _state_manager.validate_and_extract(tampered_state)
assert is_valid is False
# =============================================================================
# Performance Tests
# =============================================================================
class TestPerformance:
"""Performance tests for state operations."""
def test_serialize_performance(self):
"""Test that serialization is fast."""
state = SecureOAuthState(data={"key": "value" * 100})
start = time.time()
for _ in range(1000):
state.serialize()
elapsed = time.time() - start
# Should complete 1000 serializations in under 1 second
assert elapsed < 1.0
def test_deserialize_performance(self):
"""Test that deserialization is fast."""
state = SecureOAuthState(data={"key": "value" * 100})
serialized = state.serialize()
start = time.time()
for _ in range(1000):
SecureOAuthState.deserialize(serialized)
elapsed = time.time() - start
# Should complete 1000 deserializations in under 1 second
assert elapsed < 1.0
def run_tests():
"""Run all tests."""
import inspect
test_classes = [
TestSecureOAuthState,
TestOAuthStateManager,
TestSchemaValidation,
TestTokenStorageSecurity,
TestNoPickleUsage,
TestSecretKeyManagement,
TestOAuthFlowIntegration,
TestPerformance,
]
total_tests = 0
passed_tests = 0
failed_tests = []
for cls in test_classes:
print(f"\n{'='*60}")
print(f"Running {cls.__name__}")
print('='*60)
instance = cls()
# Run setup if exists
if hasattr(instance, 'setup_method'):
instance.setup_method()
for name, method in inspect.getmembers(cls, predicate=inspect.isfunction):
if name.startswith('test_'):
total_tests += 1
try:
method(instance)
print(f"{name}")
passed_tests += 1
except Exception as e:
print(f"{name}: {e}")
failed_tests.append((cls.__name__, name, str(e)))
print(f"\n{'='*60}")
print(f"Results: {passed_tests}/{total_tests} tests passed")
print('='*60)
if failed_tests:
print("\nFailed tests:")
for cls_name, test_name, error in failed_tests:
print(f" - {cls_name}.{test_name}: {error}")
return 1
else:
print("\nAll tests passed!")
return 0
if __name__ == "__main__":
sys.exit(run_tests())