390 lines
14 KiB
Python
390 lines
14 KiB
Python
|
|
"""
|
||
|
|
Tests for agent/mtls.py — mutual TLS between fleet agents.
|
||
|
|
|
||
|
|
Covers:
|
||
|
|
- is_mtls_configured() with various env combinations
|
||
|
|
- build_server_ssl_context() / build_client_ssl_context() with real certs
|
||
|
|
- MTLSMiddleware: authorized agent accepted, unauthorized agent rejected
|
||
|
|
"""
|
||
|
|
|
||
|
|
import ssl
|
||
|
|
import datetime
|
||
|
|
import ipaddress
|
||
|
|
import os
|
||
|
|
import pytest
|
||
|
|
from pathlib import Path
|
||
|
|
from unittest.mock import patch
|
||
|
|
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
# Helpers: generate real in-memory certs using the `cryptography` library
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
|
||
|
|
try:
|
||
|
|
from cryptography import x509
|
||
|
|
from cryptography.x509.oid import NameOID, ExtendedKeyUsageOID
|
||
|
|
from cryptography.hazmat.primitives import hashes, serialization
|
||
|
|
from cryptography.hazmat.primitives.asymmetric import rsa
|
||
|
|
_CRYPTO_AVAILABLE = True
|
||
|
|
except ImportError:
|
||
|
|
_CRYPTO_AVAILABLE = False
|
||
|
|
|
||
|
|
pytestmark = pytest.mark.skipif(
|
||
|
|
not _CRYPTO_AVAILABLE,
|
||
|
|
reason="cryptography package required for mTLS tests",
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def _make_key():
|
||
|
|
return rsa.generate_private_key(public_exponent=65537, key_size=2048)
|
||
|
|
|
||
|
|
|
||
|
|
def _write_pem(path: Path, data: bytes) -> None:
|
||
|
|
path.write_bytes(data)
|
||
|
|
path.chmod(0o600)
|
||
|
|
|
||
|
|
|
||
|
|
def make_fleet_pki(tmp_path: Path):
|
||
|
|
"""
|
||
|
|
Create a minimal Fleet PKI in tmp_path:
|
||
|
|
- fleet-ca.key / fleet-ca.crt (self-signed CA)
|
||
|
|
- agent.key / agent.crt (signed by fleet CA, CN=test-agent)
|
||
|
|
- rogue.key / rogue.crt (self-signed, NOT signed by fleet CA)
|
||
|
|
|
||
|
|
Returns a dict of Path objects.
|
||
|
|
"""
|
||
|
|
now = datetime.datetime.now(datetime.timezone.utc)
|
||
|
|
|
||
|
|
# --- Fleet CA ---
|
||
|
|
ca_key = _make_key()
|
||
|
|
ca_name = x509.Name([
|
||
|
|
x509.NameAttribute(NameOID.COMMON_NAME, "Hermes Fleet CA"),
|
||
|
|
x509.NameAttribute(NameOID.ORGANIZATION_NAME, "Hermes Fleet"),
|
||
|
|
])
|
||
|
|
ca_cert = (
|
||
|
|
x509.CertificateBuilder()
|
||
|
|
.subject_name(ca_name)
|
||
|
|
.issuer_name(ca_name)
|
||
|
|
.public_key(ca_key.public_key())
|
||
|
|
.serial_number(x509.random_serial_number())
|
||
|
|
.not_valid_before(now)
|
||
|
|
.not_valid_after(now + datetime.timedelta(days=3650))
|
||
|
|
.add_extension(x509.BasicConstraints(ca=True, path_length=None), critical=True)
|
||
|
|
.add_extension(
|
||
|
|
x509.KeyUsage(
|
||
|
|
digital_signature=False, content_commitment=False,
|
||
|
|
key_encipherment=False, data_encipherment=False,
|
||
|
|
key_agreement=False, key_cert_sign=True, crl_sign=True,
|
||
|
|
encipher_only=False, decipher_only=False,
|
||
|
|
),
|
||
|
|
critical=True,
|
||
|
|
)
|
||
|
|
.sign(ca_key, hashes.SHA256())
|
||
|
|
)
|
||
|
|
|
||
|
|
# --- Fleet agent cert ---
|
||
|
|
agent_key = _make_key()
|
||
|
|
agent_name = x509.Name([
|
||
|
|
x509.NameAttribute(NameOID.COMMON_NAME, "test-agent"),
|
||
|
|
x509.NameAttribute(NameOID.ORGANIZATION_NAME, "Hermes Fleet"),
|
||
|
|
])
|
||
|
|
agent_cert = (
|
||
|
|
x509.CertificateBuilder()
|
||
|
|
.subject_name(agent_name)
|
||
|
|
.issuer_name(ca_name)
|
||
|
|
.public_key(agent_key.public_key())
|
||
|
|
.serial_number(x509.random_serial_number())
|
||
|
|
.not_valid_before(now)
|
||
|
|
.not_valid_after(now + datetime.timedelta(days=730))
|
||
|
|
.add_extension(x509.BasicConstraints(ca=False, path_length=None), critical=True)
|
||
|
|
.add_extension(
|
||
|
|
x509.SubjectAlternativeName([
|
||
|
|
x509.DNSName("test-agent"),
|
||
|
|
x509.DNSName("localhost"),
|
||
|
|
x509.IPAddress(ipaddress.IPv4Address("127.0.0.1")),
|
||
|
|
]),
|
||
|
|
critical=False,
|
||
|
|
)
|
||
|
|
.add_extension(
|
||
|
|
x509.ExtendedKeyUsage([
|
||
|
|
ExtendedKeyUsageOID.CLIENT_AUTH,
|
||
|
|
ExtendedKeyUsageOID.SERVER_AUTH,
|
||
|
|
]),
|
||
|
|
critical=False,
|
||
|
|
)
|
||
|
|
.sign(ca_key, hashes.SHA256())
|
||
|
|
)
|
||
|
|
|
||
|
|
# --- Rogue cert (self-signed, not from fleet CA) ---
|
||
|
|
rogue_key = _make_key()
|
||
|
|
rogue_name = x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, "rogue-agent")])
|
||
|
|
rogue_cert = (
|
||
|
|
x509.CertificateBuilder()
|
||
|
|
.subject_name(rogue_name)
|
||
|
|
.issuer_name(rogue_name)
|
||
|
|
.public_key(rogue_key.public_key())
|
||
|
|
.serial_number(x509.random_serial_number())
|
||
|
|
.not_valid_before(now)
|
||
|
|
.not_valid_after(now + datetime.timedelta(days=365))
|
||
|
|
.add_extension(x509.BasicConstraints(ca=False, path_length=None), critical=True)
|
||
|
|
.sign(rogue_key, hashes.SHA256())
|
||
|
|
)
|
||
|
|
|
||
|
|
# Write to tmp_path
|
||
|
|
pem = serialization.Encoding.PEM
|
||
|
|
private_fmt = serialization.PrivateFormat.TraditionalOpenSSL
|
||
|
|
no_enc = serialization.NoEncryption()
|
||
|
|
|
||
|
|
paths = {}
|
||
|
|
|
||
|
|
paths["ca_key"] = tmp_path / "fleet-ca.key"
|
||
|
|
_write_pem(paths["ca_key"], ca_key.private_bytes(pem, private_fmt, no_enc))
|
||
|
|
|
||
|
|
paths["ca_cert"] = tmp_path / "fleet-ca.crt"
|
||
|
|
_write_pem(paths["ca_cert"], ca_cert.public_bytes(pem))
|
||
|
|
|
||
|
|
paths["agent_key"] = tmp_path / "agent.key"
|
||
|
|
_write_pem(paths["agent_key"], agent_key.private_bytes(pem, private_fmt, no_enc))
|
||
|
|
|
||
|
|
paths["agent_cert"] = tmp_path / "agent.crt"
|
||
|
|
_write_pem(paths["agent_cert"], agent_cert.public_bytes(pem))
|
||
|
|
|
||
|
|
paths["rogue_key"] = tmp_path / "rogue.key"
|
||
|
|
_write_pem(paths["rogue_key"], rogue_key.private_bytes(pem, private_fmt, no_enc))
|
||
|
|
|
||
|
|
paths["rogue_cert"] = tmp_path / "rogue.crt"
|
||
|
|
_write_pem(paths["rogue_cert"], rogue_cert.public_bytes(pem))
|
||
|
|
|
||
|
|
return paths
|
||
|
|
|
||
|
|
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
# Tests: is_mtls_configured
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
|
||
|
|
class TestIsMtlsConfigured:
|
||
|
|
def test_all_vars_missing(self):
|
||
|
|
from agent.mtls import is_mtls_configured
|
||
|
|
env = {k: "" for k in ("HERMES_MTLS_CERT", "HERMES_MTLS_KEY", "HERMES_MTLS_CA")}
|
||
|
|
with patch.dict(os.environ, env, clear=False):
|
||
|
|
assert not is_mtls_configured()
|
||
|
|
|
||
|
|
def test_partial_vars(self, tmp_path):
|
||
|
|
from agent.mtls import is_mtls_configured
|
||
|
|
f = tmp_path / "cert.pem"
|
||
|
|
f.write_text("x")
|
||
|
|
env = {"HERMES_MTLS_CERT": str(f), "HERMES_MTLS_KEY": "", "HERMES_MTLS_CA": ""}
|
||
|
|
with patch.dict(os.environ, env, clear=False):
|
||
|
|
assert not is_mtls_configured()
|
||
|
|
|
||
|
|
def test_all_vars_set_but_file_missing(self, tmp_path):
|
||
|
|
from agent.mtls import is_mtls_configured
|
||
|
|
env = {
|
||
|
|
"HERMES_MTLS_CERT": str(tmp_path / "no.crt"),
|
||
|
|
"HERMES_MTLS_KEY": str(tmp_path / "no.key"),
|
||
|
|
"HERMES_MTLS_CA": str(tmp_path / "no-ca.crt"),
|
||
|
|
}
|
||
|
|
with patch.dict(os.environ, env, clear=False):
|
||
|
|
assert not is_mtls_configured()
|
||
|
|
|
||
|
|
def test_all_vars_set_and_files_exist(self, tmp_path):
|
||
|
|
from agent.mtls import is_mtls_configured
|
||
|
|
for name in ("cert.pem", "key.pem", "ca.pem"):
|
||
|
|
(tmp_path / name).write_text("x")
|
||
|
|
env = {
|
||
|
|
"HERMES_MTLS_CERT": str(tmp_path / "cert.pem"),
|
||
|
|
"HERMES_MTLS_KEY": str(tmp_path / "key.pem"),
|
||
|
|
"HERMES_MTLS_CA": str(tmp_path / "ca.pem"),
|
||
|
|
}
|
||
|
|
with patch.dict(os.environ, env, clear=False):
|
||
|
|
assert is_mtls_configured()
|
||
|
|
|
||
|
|
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
# Tests: build_server_ssl_context / build_client_ssl_context
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
|
||
|
|
class TestBuildSslContexts:
|
||
|
|
def test_raises_when_not_configured(self):
|
||
|
|
from agent.mtls import build_server_ssl_context, build_client_ssl_context
|
||
|
|
env = {"HERMES_MTLS_CERT": "", "HERMES_MTLS_KEY": "", "HERMES_MTLS_CA": ""}
|
||
|
|
with patch.dict(os.environ, env, clear=False):
|
||
|
|
with pytest.raises(RuntimeError, match="not configured"):
|
||
|
|
build_server_ssl_context()
|
||
|
|
with pytest.raises(RuntimeError, match="not configured"):
|
||
|
|
build_client_ssl_context()
|
||
|
|
|
||
|
|
def test_server_context_requires_client_cert(self, tmp_path):
|
||
|
|
from agent.mtls import build_server_ssl_context
|
||
|
|
pki = make_fleet_pki(tmp_path)
|
||
|
|
env = {
|
||
|
|
"HERMES_MTLS_CERT": str(pki["agent_cert"]),
|
||
|
|
"HERMES_MTLS_KEY": str(pki["agent_key"]),
|
||
|
|
"HERMES_MTLS_CA": str(pki["ca_cert"]),
|
||
|
|
}
|
||
|
|
with patch.dict(os.environ, env, clear=False):
|
||
|
|
ctx = build_server_ssl_context()
|
||
|
|
assert isinstance(ctx, ssl.SSLContext)
|
||
|
|
assert ctx.verify_mode == ssl.CERT_REQUIRED
|
||
|
|
|
||
|
|
def test_client_context_has_cert_required(self, tmp_path):
|
||
|
|
from agent.mtls import build_client_ssl_context
|
||
|
|
pki = make_fleet_pki(tmp_path)
|
||
|
|
env = {
|
||
|
|
"HERMES_MTLS_CERT": str(pki["agent_cert"]),
|
||
|
|
"HERMES_MTLS_KEY": str(pki["agent_key"]),
|
||
|
|
"HERMES_MTLS_CA": str(pki["ca_cert"]),
|
||
|
|
}
|
||
|
|
with patch.dict(os.environ, env, clear=False):
|
||
|
|
ctx = build_client_ssl_context()
|
||
|
|
assert isinstance(ctx, ssl.SSLContext)
|
||
|
|
assert ctx.verify_mode == ssl.CERT_REQUIRED
|
||
|
|
|
||
|
|
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
# Tests: MTLSMiddleware
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
|
||
|
|
def _make_scope(path: str, peer_cert=None) -> dict:
|
||
|
|
"""Build a minimal ASGI HTTP scope, optionally with a fake TLS peer_cert."""
|
||
|
|
scope = {
|
||
|
|
"type": "http",
|
||
|
|
"path": path,
|
||
|
|
"extensions": {},
|
||
|
|
}
|
||
|
|
if peer_cert is not None:
|
||
|
|
scope["extensions"]["tls"] = {"peer_cert": peer_cert}
|
||
|
|
return scope
|
||
|
|
|
||
|
|
|
||
|
|
async def _collect_response(middleware, scope):
|
||
|
|
"""Drive the middleware and capture (status, body)."""
|
||
|
|
status = None
|
||
|
|
body = b""
|
||
|
|
|
||
|
|
async def receive():
|
||
|
|
return {"type": "http.request", "body": b""}
|
||
|
|
|
||
|
|
async def send(event):
|
||
|
|
nonlocal status, body
|
||
|
|
if event["type"] == "http.response.start":
|
||
|
|
status = event["status"]
|
||
|
|
elif event["type"] == "http.response.body":
|
||
|
|
body += event.get("body", b"")
|
||
|
|
|
||
|
|
await middleware(scope, receive, send)
|
||
|
|
return status, body
|
||
|
|
|
||
|
|
|
||
|
|
class TestMTLSMiddleware:
|
||
|
|
"""
|
||
|
|
Unit-test the MTLSMiddleware without spinning up a real server.
|
||
|
|
We inject mTLS configuration through env-var patching so the middleware
|
||
|
|
believes it is enabled, and use the ASGI scope's tls extension to simulate
|
||
|
|
whether a client cert was presented.
|
||
|
|
"""
|
||
|
|
|
||
|
|
def _make_middleware(self, tmp_path, app=None):
|
||
|
|
"""Return a configured MTLSMiddleware backed by real-looking cert files."""
|
||
|
|
from agent.mtls import MTLSMiddleware
|
||
|
|
|
||
|
|
for name in ("cert.pem", "key.pem", "ca.pem"):
|
||
|
|
(tmp_path / name).write_text("x")
|
||
|
|
|
||
|
|
env = {
|
||
|
|
"HERMES_MTLS_CERT": str(tmp_path / "cert.pem"),
|
||
|
|
"HERMES_MTLS_KEY": str(tmp_path / "key.pem"),
|
||
|
|
"HERMES_MTLS_CA": str(tmp_path / "ca.pem"),
|
||
|
|
}
|
||
|
|
|
||
|
|
async def passthrough(scope, receive, send):
|
||
|
|
await send({"type": "http.response.start", "status": 200, "headers": []})
|
||
|
|
await send({"type": "http.response.body", "body": b"ok"})
|
||
|
|
|
||
|
|
with patch.dict(os.environ, env, clear=False):
|
||
|
|
mw = MTLSMiddleware(app or passthrough)
|
||
|
|
return mw
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_authorized_agent_accepted(self, tmp_path):
|
||
|
|
"""An A2A route with a valid client cert passes through (200)."""
|
||
|
|
mw = self._make_middleware(tmp_path)
|
||
|
|
scope = _make_scope("/.well-known/agent-card.json", peer_cert={"subject": (("commonName", "timmy"),)})
|
||
|
|
status, body = await _collect_response(mw, scope)
|
||
|
|
assert status == 200
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_unauthorized_agent_rejected(self, tmp_path):
|
||
|
|
"""An A2A route with NO client cert is rejected (403)."""
|
||
|
|
mw = self._make_middleware(tmp_path)
|
||
|
|
scope = _make_scope("/.well-known/agent-card.json", peer_cert=None)
|
||
|
|
status, body = await _collect_response(mw, scope)
|
||
|
|
assert status == 403
|
||
|
|
assert b"certificate" in body.lower()
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_non_a2a_route_not_gated(self, tmp_path):
|
||
|
|
"""Non-A2A routes (like /api/status) pass through even without a cert."""
|
||
|
|
mw = self._make_middleware(tmp_path)
|
||
|
|
scope = _make_scope("/api/status", peer_cert=None)
|
||
|
|
status, body = await _collect_response(mw, scope)
|
||
|
|
assert status == 200
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_agent_card_api_route_gated(self, tmp_path):
|
||
|
|
"""The /api/agent-card route also requires a client cert."""
|
||
|
|
mw = self._make_middleware(tmp_path)
|
||
|
|
scope = _make_scope("/api/agent-card", peer_cert=None)
|
||
|
|
status, _ = await _collect_response(mw, scope)
|
||
|
|
assert status == 403
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_middleware_disabled_when_not_configured(self):
|
||
|
|
"""When mTLS env vars are absent, the middleware is a no-op."""
|
||
|
|
from agent.mtls import MTLSMiddleware
|
||
|
|
|
||
|
|
async def passthrough(scope, receive, send):
|
||
|
|
await send({"type": "http.response.start", "status": 200, "headers": []})
|
||
|
|
await send({"type": "http.response.body", "body": b"ok"})
|
||
|
|
|
||
|
|
env = {"HERMES_MTLS_CERT": "", "HERMES_MTLS_KEY": "", "HERMES_MTLS_CA": ""}
|
||
|
|
with patch.dict(os.environ, env, clear=False):
|
||
|
|
mw = MTLSMiddleware(passthrough)
|
||
|
|
|
||
|
|
# Even an A2A route with no cert should pass through
|
||
|
|
scope = _make_scope("/.well-known/agent-card.json", peer_cert=None)
|
||
|
|
status, _ = await _collect_response(mw, scope)
|
||
|
|
assert status == 200
|
||
|
|
|
||
|
|
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
# Tests: get_peer_cn
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
|
||
|
|
class TestGetPeerCn:
|
||
|
|
def test_returns_cn_from_subject(self):
|
||
|
|
from agent.mtls import get_peer_cn
|
||
|
|
|
||
|
|
class FakeSSL:
|
||
|
|
def getpeercert(self):
|
||
|
|
return {"subject": ((("commonName", "timmy"),),)}
|
||
|
|
|
||
|
|
assert get_peer_cn(FakeSSL()) == "timmy"
|
||
|
|
|
||
|
|
def test_returns_none_when_no_cert(self):
|
||
|
|
from agent.mtls import get_peer_cn
|
||
|
|
|
||
|
|
class FakeSSL:
|
||
|
|
def getpeercert(self):
|
||
|
|
return None
|
||
|
|
|
||
|
|
assert get_peer_cn(FakeSSL()) is None
|
||
|
|
|
||
|
|
def test_returns_none_on_exception(self):
|
||
|
|
from agent.mtls import get_peer_cn
|
||
|
|
|
||
|
|
class BrokenSSL:
|
||
|
|
def getpeercert(self):
|
||
|
|
raise RuntimeError("no ssl")
|
||
|
|
|
||
|
|
assert get_peer_cn(BrokenSSL()) is None
|