Files
hermes-agent/agent/a2a_mtls.py

166 lines
5.3 KiB
Python

"""
A2A Mutual TLS Verification — #806
Provides TLS context factories for mTLS-secured agent-to-agent communication.
Each agent presents its cert, the server verifies against the Fleet CA.
Usage:
from agent.a2a_mtls import get_server_ssl_context, get_client_ssl_context
# Server side (A2A server)
ssl_ctx = get_server_ssl_context(
cert_file="/path/to/agent.crt",
key_file="/path/to/agent.key",
ca_file="/path/to/fleet-ca.crt",
)
# Client side (A2A client)
ssl_ctx = get_client_ssl_context(
cert_file="/path/to/agent.crt",
key_file="/path/to/agent.key",
ca_file="/path/to/fleet-ca.crt",
)
"""
import os
import ssl
from pathlib import Path
from typing import Optional
# Default paths
DEFAULT_CERTS_DIR = Path(os.getenv("FLEET_CERTS_DIR", str(Path.home() / ".hermes" / "fleet-certs")))
def get_server_ssl_context(
cert_file: Optional[str] = None,
key_file: Optional[str] = None,
ca_file: Optional[str] = None,
agent_name: Optional[str] = None,
) -> ssl.SSLContext:
"""
Create SSL context for mTLS server.
Requires client certificate verification.
"""
if agent_name and not cert_file:
cert_file = str(DEFAULT_CERTS_DIR / agent_name / f"{agent_name}.crt")
key_file = str(DEFAULT_CERTS_DIR / agent_name / f"{agent_name}.key")
ca_file = str(DEFAULT_CERTS_DIR / agent_name / "fleet-ca.crt")
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
# Load server cert and key
ctx.load_cert_chain(certfile=cert_file, keyfile=key_file)
# Require client certificate
ctx.verify_mode = ssl.CERT_REQUIRED
# Load CA for verifying client certs
ctx.load_verify_locations(cafile=ca_file)
# Security settings
ctx.minimum_version = ssl.TLSVersion.TLSv1_2
ctx.check_hostname = False # Internal fleet, not public DNS
return ctx
def get_client_ssl_context(
cert_file: Optional[str] = None,
key_file: Optional[str] = None,
ca_file: Optional[str] = None,
agent_name: Optional[str] = None,
) -> ssl.SSLContext:
"""
Create SSL context for mTLS client.
Presents client certificate for server verification.
"""
if agent_name and not cert_file:
cert_file = str(DEFAULT_CERTS_DIR / agent_name / f"{agent_name}.crt")
key_file = str(DEFAULT_CERTS_DIR / agent_name / f"{agent_name}.key")
ca_file = str(DEFAULT_CERTS_DIR / agent_name / "fleet-ca.crt")
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
# Load client cert and key
ctx.load_cert_chain(certfile=cert_file, keyfile=key_file)
# Load CA for verifying server cert
ctx.load_verify_locations(cafile=ca_file)
# Security settings
ctx.minimum_version = ssl.TLSVersion.TLSv1_2
ctx.check_hostname = False # Internal fleet
return ctx
def verify_agent_cert(cert_pem: str, ca_file: Optional[str] = None) -> tuple[bool, str]:
"""
Verify an agent certificate against the Fleet CA.
Returns (valid, subject_cn).
"""
if ca_file is None:
ca_file = str(DEFAULT_CERTS_DIR / "fleet-ca.crt")
try:
from cryptography import x509
from cryptography.x509.verification import PolicyBuilder, Store
cert = x509.load_pem_x509_certificate(cert_pem.encode() if isinstance(cert_pem, str) else cert_pem)
with open(ca_file, "rb") as f:
ca_cert = x509.load_pem_x509_certificate(f.read())
store = Store([ca_cert])
builder = PolicyBuilder().store(store)
verifier = builder.build_server_verifier(x509.DNSName("fleet.local"))
try:
verifier.verify(cert, [cert])
cn = cert.subject.get_attributes_for_oid(x509.oid.NameOID.COMMON_NAME)[0].value
return True, cn
except Exception as e:
return False, str(e)
except ImportError:
# Fallback: basic validation
try:
from cryptography import x509
cert = x509.load_pem_x509_certificate(cert_pem.encode() if isinstance(cert_pem, str) else cert_pem)
with open(ca_file, "rb") as f:
ca_cert = x509.load_pem_x509_certificate(f.read())
# Check issuer matches CA subject
if cert.issuer == ca_cert.subject:
cn = cert.subject.get_attributes_for_oid(x509.oid.NameOID.COMMON_NAME)[0].value
return True, cn
return False, "Issuer mismatch"
except Exception as e:
return False, str(e)
def get_agent_cn_from_context(ssl_context: ssl.SSLContext) -> Optional[str]:
"""
Extract agent Common Name from an SSL context's peer certificate.
Used by the server to identify which agent is connecting.
"""
try:
peer_cert = ssl_context.getpeercert(binary_form=True)
if peer_cert:
from cryptography import x509
cert = x509.load_der_x509_certificate(peer_cert)
cn_attrs = cert.subject.get_attributes_for_oid(x509.oid.NameOID.COMMON_NAME)
if cn_attrs:
cn = cn_attrs[0].value
# Strip "agent-" prefix if present
if cn.startswith("agent-"):
return cn[6:]
return cn
except Exception:
pass
return None