41 lines
1.6 KiB
Python
41 lines
1.6 KiB
Python
"""mTLS server for A2A auth (#806)."""
|
|
import asyncio, logging, ssl
|
|
from pathlib import Path
|
|
from typing import Callable, Optional
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
def create_ssl_context(ca_cert_path, server_cert_path, server_key_path):
|
|
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
|
ctx.minimum_version = ssl.TLSVersion.TLSv1_2
|
|
ctx.load_cert_chain(certfile=server_cert_path, keyfile=server_key_path)
|
|
ctx.verify_mode = ssl.CERT_REQUIRED
|
|
ctx.load_verify_locations(cafile=ca_cert_path)
|
|
return ctx
|
|
|
|
def get_client_identity(ssl_obj):
|
|
try:
|
|
cert = ssl_obj.getpeercert()
|
|
if cert:
|
|
for rdn in cert.get("subject", ()):
|
|
for attr in rdn:
|
|
if attr[0] == "commonName": return attr[1]
|
|
except Exception: pass
|
|
return None
|
|
|
|
async def create_mtls_server(handler, host="127.0.0.1", port=8766, ca_cert="", server_cert="", server_key=""):
|
|
ca_cert = str(Path(ca_cert).expanduser())
|
|
server_cert = str(Path(server_cert).expanduser())
|
|
server_key = str(Path(server_key).expanduser())
|
|
ssl_ctx = create_ssl_context(ca_cert, server_cert, server_key)
|
|
async def _wrapper(reader, writer):
|
|
ssl_obj = writer.transport.get_extra_info("ssl_object")
|
|
agent = get_client_identity(ssl_obj) or "unknown"
|
|
logger.info("mTLS connection from: %s", agent)
|
|
try: await handler(ssl_obj, reader, writer)
|
|
except Exception as e: logger.error("Handler error: %s", e)
|
|
finally: writer.close()
|
|
server = await asyncio.start_server(_wrapper, host, port, ssl=ssl_ctx)
|
|
logger.info("mTLS server on %s:%d", host, port)
|
|
return server
|