166 lines
5.3 KiB
Python
166 lines
5.3 KiB
Python
"""
|
|
A2A Mutual TLS Verification — #806
|
|
|
|
Provides TLS context factories for mTLS-secured agent-to-agent communication.
|
|
Each agent presents its cert, the server verifies against the Fleet CA.
|
|
|
|
Usage:
|
|
from agent.a2a_mtls import get_server_ssl_context, get_client_ssl_context
|
|
|
|
# Server side (A2A server)
|
|
ssl_ctx = get_server_ssl_context(
|
|
cert_file="/path/to/agent.crt",
|
|
key_file="/path/to/agent.key",
|
|
ca_file="/path/to/fleet-ca.crt",
|
|
)
|
|
|
|
# Client side (A2A client)
|
|
ssl_ctx = get_client_ssl_context(
|
|
cert_file="/path/to/agent.crt",
|
|
key_file="/path/to/agent.key",
|
|
ca_file="/path/to/fleet-ca.crt",
|
|
)
|
|
"""
|
|
|
|
import os
|
|
import ssl
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
|
|
# Default paths
|
|
DEFAULT_CERTS_DIR = Path(os.getenv("FLEET_CERTS_DIR", str(Path.home() / ".hermes" / "fleet-certs")))
|
|
|
|
|
|
def get_server_ssl_context(
|
|
cert_file: Optional[str] = None,
|
|
key_file: Optional[str] = None,
|
|
ca_file: Optional[str] = None,
|
|
agent_name: Optional[str] = None,
|
|
) -> ssl.SSLContext:
|
|
"""
|
|
Create SSL context for mTLS server.
|
|
|
|
Requires client certificate verification.
|
|
"""
|
|
if agent_name and not cert_file:
|
|
cert_file = str(DEFAULT_CERTS_DIR / agent_name / f"{agent_name}.crt")
|
|
key_file = str(DEFAULT_CERTS_DIR / agent_name / f"{agent_name}.key")
|
|
ca_file = str(DEFAULT_CERTS_DIR / agent_name / "fleet-ca.crt")
|
|
|
|
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
|
|
|
# Load server cert and key
|
|
ctx.load_cert_chain(certfile=cert_file, keyfile=key_file)
|
|
|
|
# Require client certificate
|
|
ctx.verify_mode = ssl.CERT_REQUIRED
|
|
|
|
# Load CA for verifying client certs
|
|
ctx.load_verify_locations(cafile=ca_file)
|
|
|
|
# Security settings
|
|
ctx.minimum_version = ssl.TLSVersion.TLSv1_2
|
|
ctx.check_hostname = False # Internal fleet, not public DNS
|
|
|
|
return ctx
|
|
|
|
|
|
def get_client_ssl_context(
|
|
cert_file: Optional[str] = None,
|
|
key_file: Optional[str] = None,
|
|
ca_file: Optional[str] = None,
|
|
agent_name: Optional[str] = None,
|
|
) -> ssl.SSLContext:
|
|
"""
|
|
Create SSL context for mTLS client.
|
|
|
|
Presents client certificate for server verification.
|
|
"""
|
|
if agent_name and not cert_file:
|
|
cert_file = str(DEFAULT_CERTS_DIR / agent_name / f"{agent_name}.crt")
|
|
key_file = str(DEFAULT_CERTS_DIR / agent_name / f"{agent_name}.key")
|
|
ca_file = str(DEFAULT_CERTS_DIR / agent_name / "fleet-ca.crt")
|
|
|
|
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
|
|
|
# Load client cert and key
|
|
ctx.load_cert_chain(certfile=cert_file, keyfile=key_file)
|
|
|
|
# Load CA for verifying server cert
|
|
ctx.load_verify_locations(cafile=ca_file)
|
|
|
|
# Security settings
|
|
ctx.minimum_version = ssl.TLSVersion.TLSv1_2
|
|
ctx.check_hostname = False # Internal fleet
|
|
|
|
return ctx
|
|
|
|
|
|
def verify_agent_cert(cert_pem: str, ca_file: Optional[str] = None) -> tuple[bool, str]:
|
|
"""
|
|
Verify an agent certificate against the Fleet CA.
|
|
|
|
Returns (valid, subject_cn).
|
|
"""
|
|
if ca_file is None:
|
|
ca_file = str(DEFAULT_CERTS_DIR / "fleet-ca.crt")
|
|
|
|
try:
|
|
from cryptography import x509
|
|
from cryptography.x509.verification import PolicyBuilder, Store
|
|
|
|
cert = x509.load_pem_x509_certificate(cert_pem.encode() if isinstance(cert_pem, str) else cert_pem)
|
|
|
|
with open(ca_file, "rb") as f:
|
|
ca_cert = x509.load_pem_x509_certificate(f.read())
|
|
|
|
store = Store([ca_cert])
|
|
builder = PolicyBuilder().store(store)
|
|
verifier = builder.build_server_verifier(x509.DNSName("fleet.local"))
|
|
|
|
try:
|
|
verifier.verify(cert, [cert])
|
|
cn = cert.subject.get_attributes_for_oid(x509.oid.NameOID.COMMON_NAME)[0].value
|
|
return True, cn
|
|
except Exception as e:
|
|
return False, str(e)
|
|
except ImportError:
|
|
# Fallback: basic validation
|
|
try:
|
|
from cryptography import x509
|
|
cert = x509.load_pem_x509_certificate(cert_pem.encode() if isinstance(cert_pem, str) else cert_pem)
|
|
|
|
with open(ca_file, "rb") as f:
|
|
ca_cert = x509.load_pem_x509_certificate(f.read())
|
|
|
|
# Check issuer matches CA subject
|
|
if cert.issuer == ca_cert.subject:
|
|
cn = cert.subject.get_attributes_for_oid(x509.oid.NameOID.COMMON_NAME)[0].value
|
|
return True, cn
|
|
return False, "Issuer mismatch"
|
|
except Exception as e:
|
|
return False, str(e)
|
|
|
|
|
|
def get_agent_cn_from_context(ssl_context: ssl.SSLContext) -> Optional[str]:
|
|
"""
|
|
Extract agent Common Name from an SSL context's peer certificate.
|
|
|
|
Used by the server to identify which agent is connecting.
|
|
"""
|
|
try:
|
|
peer_cert = ssl_context.getpeercert(binary_form=True)
|
|
if peer_cert:
|
|
from cryptography import x509
|
|
cert = x509.load_der_x509_certificate(peer_cert)
|
|
cn_attrs = cert.subject.get_attributes_for_oid(x509.oid.NameOID.COMMON_NAME)
|
|
if cn_attrs:
|
|
cn = cn_attrs[0].value
|
|
# Strip "agent-" prefix if present
|
|
if cn.startswith("agent-"):
|
|
return cn[6:]
|
|
return cn
|
|
except Exception:
|
|
pass
|
|
return None
|