feat: mTLS server for A2A (#806)
This commit is contained in:
40
tools/mtls_server.py
Normal file
40
tools/mtls_server.py
Normal file
@@ -0,0 +1,40 @@
|
||||
"""mTLS server for A2A auth (#806)."""
|
||||
import asyncio, logging, ssl
|
||||
from pathlib import Path
|
||||
from typing import Callable, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def create_ssl_context(ca_cert_path, server_cert_path, server_key_path):
|
||||
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
||||
ctx.minimum_version = ssl.TLSVersion.TLSv1_2
|
||||
ctx.load_cert_chain(certfile=server_cert_path, keyfile=server_key_path)
|
||||
ctx.verify_mode = ssl.CERT_REQUIRED
|
||||
ctx.load_verify_locations(cafile=ca_cert_path)
|
||||
return ctx
|
||||
|
||||
def get_client_identity(ssl_obj):
|
||||
try:
|
||||
cert = ssl_obj.getpeercert()
|
||||
if cert:
|
||||
for rdn in cert.get("subject", ()):
|
||||
for attr in rdn:
|
||||
if attr[0] == "commonName": return attr[1]
|
||||
except Exception: pass
|
||||
return None
|
||||
|
||||
async def create_mtls_server(handler, host="127.0.0.1", port=8766, ca_cert="", server_cert="", server_key=""):
|
||||
ca_cert = str(Path(ca_cert).expanduser())
|
||||
server_cert = str(Path(server_cert).expanduser())
|
||||
server_key = str(Path(server_key).expanduser())
|
||||
ssl_ctx = create_ssl_context(ca_cert, server_cert, server_key)
|
||||
async def _wrapper(reader, writer):
|
||||
ssl_obj = writer.transport.get_extra_info("ssl_object")
|
||||
agent = get_client_identity(ssl_obj) or "unknown"
|
||||
logger.info("mTLS connection from: %s", agent)
|
||||
try: await handler(ssl_obj, reader, writer)
|
||||
except Exception as e: logger.error("Handler error: %s", e)
|
||||
finally: writer.close()
|
||||
server = await asyncio.start_server(_wrapper, host, port, ssl=ssl_ctx)
|
||||
logger.info("mTLS server on %s:%d", host, port)
|
||||
return server
|
||||
Reference in New Issue
Block a user