"""mTLS server for A2A auth (#806).""" import asyncio, logging, ssl from pathlib import Path from typing import Callable, Optional logger = logging.getLogger(__name__) def create_ssl_context(ca_cert_path, server_cert_path, server_key_path): ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) ctx.minimum_version = ssl.TLSVersion.TLSv1_2 ctx.load_cert_chain(certfile=server_cert_path, keyfile=server_key_path) ctx.verify_mode = ssl.CERT_REQUIRED ctx.load_verify_locations(cafile=ca_cert_path) return ctx def get_client_identity(ssl_obj): try: cert = ssl_obj.getpeercert() if cert: for rdn in cert.get("subject", ()): for attr in rdn: if attr[0] == "commonName": return attr[1] except Exception: pass return None async def create_mtls_server(handler, host="127.0.0.1", port=8766, ca_cert="", server_cert="", server_key=""): ca_cert = str(Path(ca_cert).expanduser()) server_cert = str(Path(server_cert).expanduser()) server_key = str(Path(server_key).expanduser()) ssl_ctx = create_ssl_context(ca_cert, server_cert, server_key) async def _wrapper(reader, writer): ssl_obj = writer.transport.get_extra_info("ssl_object") agent = get_client_identity(ssl_obj) or "unknown" logger.info("mTLS connection from: %s", agent) try: await handler(ssl_obj, reader, writer) except Exception as e: logger.error("Handler error: %s", e) finally: writer.close() server = await asyncio.start_server(_wrapper, host, port, ssl=ssl_ctx) logger.info("mTLS server on %s:%d", host, port) return server