""" Tests for A2A mutual-TLS authentication. Scenarios covered: - authorized agent (valid fleet-CA-signed cert) is accepted - unauthorized agent (self-signed cert) is rejected with SSLError - missing client cert is rejected - build_server_ssl_context raises FileNotFoundError for missing paths - build_client_ssl_context raises FileNotFoundError for missing paths - A2AServer.start() / stop() lifecycle (no network I/O) All TLS I/O is done in-process against a loopback server so no ports need to be opened on a CI runner. Refs #806 """ from __future__ import annotations import datetime import ipaddress import ssl import threading import time import urllib.request import urllib.error from pathlib import Path from typing import Tuple import pytest # --------------------------------------------------------------------------- # Helpers — generate self-signed certs in-memory with Python's ``cryptography`` # library (dev extra). If cryptography is unavailable we skip the network # tests gracefully. # --------------------------------------------------------------------------- try: from cryptography import x509 from cryptography.hazmat.primitives import hashes, serialization from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.x509.oid import NameOID import cryptography.hazmat.backends as _backends _CRYPTO_AVAILABLE = True except ImportError: _CRYPTO_AVAILABLE = False _requires_crypto = pytest.mark.skipif( not _CRYPTO_AVAILABLE, reason="cryptography package not installed", ) # --------------------------------------------------------------------------- # Fixture helpers # --------------------------------------------------------------------------- def _make_ca_keypair(tmp_path: Path) -> Tuple[Path, Path]: """Generate a self-signed CA cert+key and write to *tmp_path*.""" key = rsa.generate_private_key(public_exponent=65537, key_size=2048) name = x509.Name([ x509.NameAttribute(NameOID.COMMON_NAME, "Test Fleet CA"), x509.NameAttribute(NameOID.ORGANIZATION_NAME, "TestOrg"), ]) now = datetime.datetime.now(datetime.timezone.utc) cert = ( x509.CertificateBuilder() .subject_name(name) .issuer_name(name) .public_key(key.public_key()) .serial_number(x509.random_serial_number()) .not_valid_before(now) .not_valid_after(now + datetime.timedelta(days=3650)) .add_extension(x509.BasicConstraints(ca=True, path_length=0), critical=True) .add_extension( x509.KeyUsage( digital_signature=False, key_cert_sign=True, crl_sign=True, content_commitment=False, key_encipherment=False, data_encipherment=False, key_agreement=False, encipher_only=False, decipher_only=False, ), critical=True, ) .sign(key, hashes.SHA256()) ) key_path = tmp_path / "ca.key" cert_path = tmp_path / "ca.crt" key_path.write_bytes(key.private_bytes( serialization.Encoding.PEM, serialization.PrivateFormat.TraditionalOpenSSL, serialization.NoEncryption(), )) cert_path.write_bytes(cert.public_bytes(serialization.Encoding.PEM)) return cert_path, key_path def _make_agent_keypair( tmp_path: Path, name: str, ca_cert_path: Path, ca_key_path: Path, ) -> Tuple[Path, Path]: """Generate an agent cert signed by the test CA.""" ca_cert = x509.load_pem_x509_certificate(ca_cert_path.read_bytes()) ca_key = serialization.load_pem_private_key( ca_key_path.read_bytes(), password=None ) key = rsa.generate_private_key(public_exponent=65537, key_size=2048) subject = x509.Name([ x509.NameAttribute(NameOID.COMMON_NAME, f"{name}.fleet.hermes"), x509.NameAttribute(NameOID.ORGANIZATION_NAME, "TestOrg"), ]) now = datetime.datetime.now(datetime.timezone.utc) cert = ( x509.CertificateBuilder() .subject_name(subject) .issuer_name(ca_cert.subject) .public_key(key.public_key()) .serial_number(x509.random_serial_number()) .not_valid_before(now) .not_valid_after(now + datetime.timedelta(days=365)) .add_extension(x509.BasicConstraints(ca=False, path_length=None), critical=True) .add_extension( x509.SubjectAlternativeName([ x509.DNSName(f"{name}.fleet.hermes"), x509.DNSName(name), x509.IPAddress(ipaddress.IPv4Address("127.0.0.1")), ]), critical=False, ) .add_extension( x509.ExtendedKeyUsage([ x509.ExtendedKeyUsageOID.CLIENT_AUTH, x509.ExtendedKeyUsageOID.SERVER_AUTH, ]), critical=False, ) .sign(ca_key, hashes.SHA256()) ) key_path = tmp_path / f"{name}.key" cert_path = tmp_path / f"{name}.crt" key_path.write_bytes(key.private_bytes( serialization.Encoding.PEM, serialization.PrivateFormat.TraditionalOpenSSL, serialization.NoEncryption(), )) cert_path.write_bytes(cert.public_bytes(serialization.Encoding.PEM)) return cert_path, key_path def _make_self_signed_keypair(tmp_path: Path, name: str) -> Tuple[Path, Path]: """Generate a self-signed cert NOT signed by the test CA (unauthorized).""" key = rsa.generate_private_key(public_exponent=65537, key_size=2048) subject = x509.Name([ x509.NameAttribute(NameOID.COMMON_NAME, f"{name}.rogue"), ]) now = datetime.datetime.now(datetime.timezone.utc) cert = ( x509.CertificateBuilder() .subject_name(subject) .issuer_name(subject) .public_key(key.public_key()) .serial_number(x509.random_serial_number()) .not_valid_before(now) .not_valid_after(now + datetime.timedelta(days=365)) .add_extension(x509.BasicConstraints(ca=False, path_length=None), critical=True) .add_extension( x509.SubjectAlternativeName([x509.IPAddress(ipaddress.IPv4Address("127.0.0.1"))]), critical=False, ) .sign(key, hashes.SHA256()) ) key_path = tmp_path / f"{name}_rogue.key" cert_path = tmp_path / f"{name}_rogue.crt" key_path.write_bytes(key.private_bytes( serialization.Encoding.PEM, serialization.PrivateFormat.TraditionalOpenSSL, serialization.NoEncryption(), )) cert_path.write_bytes(cert.public_bytes(serialization.Encoding.PEM)) return cert_path, key_path # --------------------------------------------------------------------------- # Unit tests — no network I/O # --------------------------------------------------------------------------- class TestBuildSslContextErrors: def test_server_context_missing_cert(self, tmp_path): from agent.a2a_mtls import build_server_ssl_context with pytest.raises(FileNotFoundError, match="mTLS"): build_server_ssl_context( cert=tmp_path / "nope.crt", key=tmp_path / "nope.key", ca=tmp_path / "nope.crt", ) def test_client_context_missing_cert(self, tmp_path): from agent.a2a_mtls import build_client_ssl_context with pytest.raises(FileNotFoundError, match="mTLS client"): build_client_ssl_context( cert=tmp_path / "nope.crt", key=tmp_path / "nope.key", ca=tmp_path / "nope.crt", ) @_requires_crypto def test_server_context_builds_with_valid_certs(self, tmp_path): from agent.a2a_mtls import build_server_ssl_context ca_dir = tmp_path / "ca" ca_dir.mkdir() ca_crt, ca_key = _make_ca_keypair(ca_dir) srv_crt, srv_key = _make_agent_keypair( tmp_path, "srv", ca_crt, ca_key ) ctx = build_server_ssl_context(cert=srv_crt, key=srv_key, ca=ca_crt) assert isinstance(ctx, ssl.SSLContext) assert ctx.verify_mode == ssl.CERT_REQUIRED @_requires_crypto def test_client_context_builds_with_valid_certs(self, tmp_path): from agent.a2a_mtls import build_client_ssl_context ca_dir = tmp_path / "ca" ca_dir.mkdir() ca_crt, ca_key = _make_ca_keypair(ca_dir) cli_crt, cli_key = _make_agent_keypair( tmp_path, "cli", ca_crt, ca_key ) ctx = build_client_ssl_context(cert=cli_crt, key=cli_key, ca=ca_crt) assert isinstance(ctx, ssl.SSLContext) assert ctx.verify_mode == ssl.CERT_REQUIRED # --------------------------------------------------------------------------- # Integration tests — loopback mTLS server # --------------------------------------------------------------------------- def _find_free_port() -> int: import socket with socket.socket() as s: s.bind(("127.0.0.1", 0)) return s.getsockname()[1] def _https_get(url: str, ssl_ctx: ssl.SSLContext) -> int: """Return the HTTP status code for a GET request, or raise SSLError.""" req = urllib.request.urlopen(url, context=ssl_ctx, timeout=5) return req.status @_requires_crypto class TestMutualTLSAuth: """End-to-end mTLS auth over a loopback connection.""" @pytest.fixture(autouse=True) def _pki(self, tmp_path): """Set up a fleet CA and agent certs for timmy (server) and allegro (authorized client).""" ca_dir = tmp_path / "ca" ca_dir.mkdir() self.ca_crt, self.ca_key = _make_ca_keypair(ca_dir) agent_dir = tmp_path / "agents" agent_dir.mkdir() # Server agent: timmy self.srv_crt, self.srv_key = _make_agent_keypair( agent_dir, "timmy", self.ca_crt, self.ca_key ) # Authorized client agent: allegro self.cli_crt, self.cli_key = _make_agent_keypair( agent_dir, "allegro", self.ca_crt, self.ca_key ) # Unauthorized (self-signed) client: rogue self.rogue_crt, self.rogue_key = _make_self_signed_keypair(agent_dir, "rogue") @pytest.fixture() def running_server(self): """Start an A2AServer on a free loopback port, yield the URL, stop after test.""" from agent.a2a_mtls import A2AServer port = _find_free_port() server = A2AServer( cert=self.srv_crt, key=self.srv_key, ca=self.ca_crt, host="127.0.0.1", port=port, ) server.start(daemon=True) time.sleep(0.15) # let the thread bind yield f"https://127.0.0.1:{port}" server.stop() def _authorized_ctx(self) -> ssl.SSLContext: from agent.a2a_mtls import build_client_ssl_context ctx = build_client_ssl_context( cert=self.cli_crt, key=self.cli_key, ca=self.ca_crt ) ctx.check_hostname = False # loopback IP doesn't match DNS SAN return ctx def _unauthorized_ctx(self) -> ssl.SSLContext: """Client context with a self-signed cert not trusted by the server CA.""" ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) ctx.minimum_version = ssl.TLSVersion.TLSv1_2 ctx.load_cert_chain(certfile=str(self.rogue_crt), keyfile=str(self.rogue_key)) # Load the real fleet CA so server cert is accepted — but our client # cert is self-signed and will be rejected by the server. ctx.load_verify_locations(cafile=str(self.ca_crt)) ctx.check_hostname = False return ctx def _no_client_cert_ctx(self) -> ssl.SSLContext: """Client context with no client certificate at all.""" ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) ctx.minimum_version = ssl.TLSVersion.TLSv1_2 ctx.load_verify_locations(cafile=str(self.ca_crt)) ctx.check_hostname = False return ctx # ------------------------------------------------------------------ # Authorized agent accepted # ------------------------------------------------------------------ def test_authorized_agent_accepted(self, running_server): """An agent with a fleet-CA-signed cert gets a 200-range response.""" status = _https_get( running_server + "/.well-known/agent-card.json", self._authorized_ctx(), ) assert status == 200 def test_authorized_agent_task_endpoint(self, running_server): """POST /a2a/task returns 202 for an authorized agent.""" import urllib.request req = urllib.request.Request( running_server + "/a2a/task", data=b'{"hello":"world"}', method="POST", ) req.add_header("Content-Type", "application/json") resp = urllib.request.urlopen(req, context=self._authorized_ctx(), timeout=5) assert resp.status == 202 # ------------------------------------------------------------------ # Unauthorized agent rejected # ------------------------------------------------------------------ def test_unauthorized_agent_rejected(self, running_server): """A self-signed cert not signed by the fleet CA is rejected at TLS handshake.""" with pytest.raises((ssl.SSLError, OSError)): _https_get(running_server + "/", self._unauthorized_ctx()) def test_no_client_cert_rejected(self, running_server): """A client with no cert at all is rejected at TLS handshake.""" with pytest.raises((ssl.SSLError, OSError)): _https_get(running_server + "/", self._no_client_cert_ctx()) # ------------------------------------------------------------------ # Server lifecycle # ------------------------------------------------------------------ def test_server_stop_is_idempotent(self): """Calling stop() twice does not raise.""" from agent.a2a_mtls import A2AServer port = _find_free_port() server = A2AServer( cert=self.srv_crt, key=self.srv_key, ca=self.ca_crt, host="127.0.0.1", port=port, ) server.start(daemon=True) time.sleep(0.1) server.stop() server.stop() # second call must not raise # --------------------------------------------------------------------------- # server_from_env() — environment variable wiring # --------------------------------------------------------------------------- class TestServerFromEnv: def test_reads_env_vars(self, tmp_path, monkeypatch): # Create dummy files so FileNotFoundError isn't triggered cert = tmp_path / "a.crt" key = tmp_path / "a.key" ca = tmp_path / "ca.crt" for f in (cert, key, ca): f.write_text("PLACEHOLDER") monkeypatch.setenv("HERMES_A2A_CERT", str(cert)) monkeypatch.setenv("HERMES_A2A_KEY", str(key)) monkeypatch.setenv("HERMES_A2A_CA", str(ca)) monkeypatch.setenv("HERMES_A2A_HOST", "127.0.0.2") monkeypatch.setenv("HERMES_A2A_PORT", "19443") from agent.a2a_mtls import server_from_env srv = server_from_env() assert srv.cert == cert assert srv.key == key assert srv.ca == ca assert srv.host == "127.0.0.2" assert srv.port == 19443 def test_uses_agent_name_for_defaults(self, tmp_path, monkeypatch): monkeypatch.setenv("HERMES_HOME", str(tmp_path)) monkeypatch.setenv("HERMES_AGENT_NAME", "ezra") # Unset explicit cert overrides monkeypatch.delenv("HERMES_A2A_CERT", raising=False) monkeypatch.delenv("HERMES_A2A_KEY", raising=False) monkeypatch.delenv("HERMES_A2A_CA", raising=False) from agent.a2a_mtls import server_from_env srv = server_from_env() assert "ezra" in str(srv.cert) assert "ezra" in str(srv.key) assert "fleet-ca" in str(srv.ca) # --------------------------------------------------------------------------- # A2AMTLSServer and A2AMTLSClient — routing server + client helper # --------------------------------------------------------------------------- @_requires_crypto class TestA2AMTLSServerAndClient: """Tests for the routing-based A2AMTLSServer and A2AMTLSClient.""" @pytest.fixture(autouse=True) def _pki(self, tmp_path): ca_dir = tmp_path / "ca" ca_dir.mkdir() self.ca_crt, self.ca_key = _make_ca_keypair(ca_dir) agent_dir = tmp_path / "agents" agent_dir.mkdir() self.srv_crt, self.srv_key = _make_agent_keypair( agent_dir, "timmy", self.ca_crt, self.ca_key ) self.cli_crt, self.cli_key = _make_agent_keypair( agent_dir, "allegro", self.ca_crt, self.ca_key ) self.rogue_crt, self.rogue_key = _make_self_signed_keypair(agent_dir, "rogue") @pytest.fixture() def routing_server(self): from agent.a2a_mtls import A2AMTLSServer port = _find_free_port() server = A2AMTLSServer( cert=self.srv_crt, key=self.srv_key, ca=self.ca_crt, host="127.0.0.1", port=port, ) server.add_route("/echo", lambda p, *, peer_cn=None: {"echo": p, "peer": peer_cn}) server.add_route("/tasks/send", lambda p, *, peer_cn=None: {"status": "ok", "echo": p}) with server: time.sleep(0.1) yield server, port def _authorized_ctx(self) -> ssl.SSLContext: from agent.a2a_mtls import build_client_ssl_context ctx = build_client_ssl_context( cert=self.cli_crt, key=self.cli_key, ca=self.ca_crt ) ctx.check_hostname = False return ctx def test_routing_server_get(self, routing_server): server, port = routing_server ctx = self._authorized_ctx() req = urllib.request.Request(f"https://127.0.0.1:{port}/echo") with urllib.request.urlopen(req, context=ctx, timeout=5) as resp: import json data = json.loads(resp.read()) assert data["peer"] is not None # CN present def test_routing_server_post_payload(self, routing_server): server, port = routing_server ctx = self._authorized_ctx() import json payload = {"task_id": "abc", "action": "delegate"} req = urllib.request.Request( f"https://127.0.0.1:{port}/tasks/send", data=json.dumps(payload).encode(), headers={"Content-Type": "application/json"}, method="POST", ) with urllib.request.urlopen(req, context=ctx, timeout=5) as resp: data = json.loads(resp.read()) assert data["status"] == "ok" assert data["echo"]["task_id"] == "abc" def test_routing_server_unknown_route_404(self, routing_server): server, port = routing_server ctx = self._authorized_ctx() req = urllib.request.Request(f"https://127.0.0.1:{port}/nonexistent") with pytest.raises(urllib.error.URLError) as exc_info: urllib.request.urlopen(req, context=ctx, timeout=5) assert "404" in str(exc_info.value) def test_routing_server_context_manager_stops(self): from agent.a2a_mtls import A2AMTLSServer port = _find_free_port() server = A2AMTLSServer( cert=self.srv_crt, key=self.srv_key, ca=self.ca_crt, host="127.0.0.1", port=port, ) server.add_route("/ping", lambda p, *, peer_cn=None: {"pong": True}) with server: time.sleep(0.05) assert server._httpd is not None assert server._httpd is None # stopped after __exit__ def test_routing_server_rogue_client_rejected(self, routing_server): server, port = routing_server ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) ctx.load_verify_locations(cafile=str(self.ca_crt)) ctx.load_cert_chain(certfile=str(self.rogue_crt), keyfile=str(self.rogue_key)) ctx.check_hostname = False req = urllib.request.Request(f"https://127.0.0.1:{port}/echo") with pytest.raises((ssl.SSLError, OSError, urllib.error.URLError)): urllib.request.urlopen(req, context=ctx, timeout=5) def test_a2a_mtls_client_get(self, routing_server): from agent.a2a_mtls import A2AMTLSClient server, port = routing_server client = A2AMTLSClient( cert=self.cli_crt, key=self.cli_key, ca=self.ca_crt ) result = client.get(f"https://127.0.0.1:{port}/echo") assert result["peer"] is not None def test_a2a_mtls_client_post(self, routing_server): from agent.a2a_mtls import A2AMTLSClient server, port = routing_server client = A2AMTLSClient( cert=self.cli_crt, key=self.cli_key, ca=self.ca_crt ) result = client.post(f"https://127.0.0.1:{port}/tasks/send", json={"x": 1}) assert result["status"] == "ok" assert result["echo"]["x"] == 1 def test_a2a_mtls_client_rogue_cert_raises(self, routing_server): from agent.a2a_mtls import A2AMTLSClient server, port = routing_server client = A2AMTLSClient( cert=self.rogue_crt, key=self.rogue_key, ca=self.ca_crt ) with pytest.raises((ConnectionError, ssl.SSLError, OSError)): client.get(f"https://127.0.0.1:{port}/echo") def test_concurrent_fleet_agents(self, routing_server): """timmy (server) accepts concurrent connections from multiple authorized clients.""" from agent.a2a_mtls import build_client_ssl_context server, port = routing_server results: dict = {} errors: dict = {} def connect(name: str) -> None: try: ctx = build_client_ssl_context( cert=self.cli_crt, key=self.cli_key, ca=self.ca_crt ) ctx.check_hostname = False req = urllib.request.Request(f"https://127.0.0.1:{port}/echo") with urllib.request.urlopen(req, context=ctx, timeout=5) as resp: import json results[name] = json.loads(resp.read()) except Exception as exc: errors[name] = exc threads = [threading.Thread(target=connect, args=(n,)) for n in ("t1", "t2", "t3")] for t in threads: t.start() for t in threads: t.join(timeout=10) assert not errors, f"Concurrent connection errors: {errors}" assert len(results) == 3