All checks were successful
Lint / lint (pull_request) Successful in 9s
Builds on the existing A2AServer / build_*_ssl_context foundation:
- agent/a2a_mtls.py:
- Add A2AMTLSServer: routing-based HTTPS server with add_route() and
context-manager (__enter__/__exit__) lifecycle support
- Add A2AMTLSClient: fleet-cert-presenting HTTP client with .get() / .post()
- Widen imports (json, Callable, Dict, urlopen)
- tests/agent/test_a2a_mtls.py:
- Fix datetime.utcnow() deprecation — use datetime.now(timezone.utc)
- Add TestA2AMTLSServerAndClient (9 tests): routing GET/POST, 404,
context-manager stop, rogue-cert rejection, A2AMTLSClient, concurrency
- Total: 11 → 20 passing tests
Refs #806
444 lines
15 KiB
Python
444 lines
15 KiB
Python
"""
|
|
A2A mutual-TLS server — secure agent-to-agent communication.
|
|
|
|
Each fleet agent runs an A2A server that:
|
|
- Presents its own TLS certificate (signed by the fleet CA).
|
|
- Requires the connecting peer to present a valid client certificate
|
|
also signed by the fleet CA.
|
|
- Rejects connections from unknown / self-signed peers.
|
|
|
|
Usage (standalone):
|
|
python -m agent.a2a_mtls \\
|
|
--cert ~/.hermes/pki/agents/timmy/timmy.crt \\
|
|
--key ~/.hermes/pki/agents/timmy/timmy.key \\
|
|
--ca ~/.hermes/pki/ca/fleet-ca.crt \\
|
|
--host 0.0.0.0 --port 9443
|
|
|
|
Environment variables (alternative to CLI flags):
|
|
HERMES_A2A_CERT path to agent certificate
|
|
HERMES_A2A_KEY path to agent private key
|
|
HERMES_A2A_CA path to fleet CA certificate
|
|
|
|
Refs #806
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
import os
|
|
import ssl
|
|
import threading
|
|
from http.server import BaseHTTPRequestHandler, HTTPServer
|
|
from pathlib import Path
|
|
from typing import Any, Callable, Dict, Optional
|
|
from urllib.error import URLError
|
|
from urllib.request import Request, urlopen
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# mTLS SSL context helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def build_server_ssl_context(
|
|
cert: str | Path,
|
|
key: str | Path,
|
|
ca: str | Path,
|
|
) -> ssl.SSLContext:
|
|
"""Return an SSLContext that presents *cert/key* and requires a valid
|
|
client certificate signed by *ca*.
|
|
|
|
Raises ``FileNotFoundError`` if any path is missing.
|
|
Raises ``ssl.SSLError`` if the files are malformed.
|
|
"""
|
|
cert, key, ca = Path(cert), Path(key), Path(ca)
|
|
for p in (cert, key, ca):
|
|
if not p.exists():
|
|
raise FileNotFoundError(f"mTLS: file not found: {p}")
|
|
|
|
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
|
ctx.minimum_version = ssl.TLSVersion.TLSv1_2
|
|
ctx.load_cert_chain(certfile=str(cert), keyfile=str(key))
|
|
ctx.load_verify_locations(cafile=str(ca))
|
|
# CERT_REQUIRED — reject peers that don't present a cert signed by *ca*.
|
|
ctx.verify_mode = ssl.CERT_REQUIRED
|
|
return ctx
|
|
|
|
|
|
def build_client_ssl_context(
|
|
cert: str | Path,
|
|
key: str | Path,
|
|
ca: str | Path,
|
|
) -> ssl.SSLContext:
|
|
"""Return an SSLContext for an outgoing mTLS connection.
|
|
|
|
Presents *cert/key* as the client identity and verifies the server
|
|
certificate against *ca*.
|
|
"""
|
|
cert, key, ca = Path(cert), Path(key), Path(ca)
|
|
for p in (cert, key, ca):
|
|
if not p.exists():
|
|
raise FileNotFoundError(f"mTLS client: file not found: {p}")
|
|
|
|
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
|
ctx.minimum_version = ssl.TLSVersion.TLSv1_2
|
|
ctx.load_cert_chain(certfile=str(cert), keyfile=str(key))
|
|
ctx.load_verify_locations(cafile=str(ca))
|
|
ctx.verify_mode = ssl.CERT_REQUIRED
|
|
ctx.check_hostname = True
|
|
return ctx
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Minimal A2A HTTP request handler
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class A2AHandler(BaseHTTPRequestHandler):
|
|
"""Handles A2A requests over a mutually-authenticated TLS connection.
|
|
|
|
GET /.well-known/agent-card.json — returns the local agent card.
|
|
POST /a2a/task — dispatches an A2A task (stub).
|
|
"""
|
|
|
|
log_message = logger.debug # route access log to Python logger
|
|
|
|
def do_GET(self) -> None: # noqa: N802
|
|
if self.path in ("/.well-known/agent-card.json", "/agent-card.json"):
|
|
self._serve_agent_card()
|
|
else:
|
|
self._send_json(404, {"error": "not found"})
|
|
|
|
def do_POST(self) -> None: # noqa: N802
|
|
if self.path == "/a2a/task":
|
|
self._handle_task()
|
|
else:
|
|
self._send_json(404, {"error": "not found"})
|
|
|
|
# ------------------------------------------------------------------
|
|
def _serve_agent_card(self) -> None:
|
|
try:
|
|
from agent.agent_card import get_agent_card_json
|
|
body = get_agent_card_json().encode()
|
|
except Exception as exc:
|
|
logger.warning("agent-card unavailable: %s", exc)
|
|
body = b'{"error": "agent card unavailable"}'
|
|
self._send_raw(200, "application/json", body)
|
|
|
|
def _handle_task(self) -> None:
|
|
length = int(self.headers.get("Content-Length", 0))
|
|
_body = self.rfile.read(length) if length else b""
|
|
# Stub: echo back a 202 Accepted with the peer CN so callers can
|
|
# confirm which agent processed the request.
|
|
peer_cn = _peer_cn(self.connection)
|
|
self._send_json(202, {"status": "accepted", "handled_by": peer_cn})
|
|
|
|
# ------------------------------------------------------------------
|
|
def _send_json(self, code: int, data: dict) -> None:
|
|
import json
|
|
body = json.dumps(data).encode()
|
|
self._send_raw(code, "application/json", body)
|
|
|
|
def _send_raw(self, code: int, content_type: str, body: bytes) -> None:
|
|
self.send_response(code)
|
|
self.send_header("Content-Type", content_type)
|
|
self.send_header("Content-Length", str(len(body)))
|
|
self.end_headers()
|
|
self.wfile.write(body)
|
|
|
|
def log_message(self, fmt: str, *args: object) -> None: # type: ignore[override]
|
|
logger.debug("a2a: " + fmt, *args)
|
|
|
|
|
|
def _peer_cn(conn: ssl.SSLSocket) -> Optional[str]:
|
|
"""Extract the Common Name from the peer certificate, or None."""
|
|
try:
|
|
peer = conn.getpeercert()
|
|
if not peer:
|
|
return None
|
|
for rdn in peer.get("subject", ()):
|
|
for key, val in rdn:
|
|
if key == "commonName":
|
|
return val
|
|
except Exception:
|
|
pass
|
|
return None
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Server lifecycle
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class A2AServer:
|
|
"""Mutual-TLS A2A server.
|
|
|
|
Example::
|
|
|
|
server = A2AServer(
|
|
cert="~/.hermes/pki/agents/timmy/timmy.crt",
|
|
key="~/.hermes/pki/agents/timmy/timmy.key",
|
|
ca="~/.hermes/pki/ca/fleet-ca.crt",
|
|
)
|
|
server.start() # non-blocking (daemon thread)
|
|
...
|
|
server.stop()
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
cert: str | Path,
|
|
key: str | Path,
|
|
ca: str | Path,
|
|
host: str = "0.0.0.0",
|
|
port: int = 9443,
|
|
) -> None:
|
|
self.cert = Path(cert).expanduser()
|
|
self.key = Path(key).expanduser()
|
|
self.ca = Path(ca).expanduser()
|
|
self.host = host
|
|
self.port = port
|
|
self._httpd: Optional[HTTPServer] = None
|
|
self._thread: Optional[threading.Thread] = None
|
|
|
|
def start(self, daemon: bool = True) -> None:
|
|
"""Start the server in a background thread (default: daemon)."""
|
|
ssl_ctx = build_server_ssl_context(self.cert, self.key, self.ca)
|
|
self._httpd = HTTPServer((self.host, self.port), A2AHandler)
|
|
self._httpd.socket = ssl_ctx.wrap_socket(
|
|
self._httpd.socket, server_side=True
|
|
)
|
|
self._thread = threading.Thread(
|
|
target=self._httpd.serve_forever, daemon=daemon
|
|
)
|
|
self._thread.start()
|
|
logger.info(
|
|
"A2A mTLS server listening on %s:%s (cert=%s)",
|
|
self.host, self.port, self.cert.name,
|
|
)
|
|
|
|
def stop(self) -> None:
|
|
if self._httpd:
|
|
self._httpd.shutdown()
|
|
self._httpd = None
|
|
if self._thread:
|
|
self._thread.join(timeout=5)
|
|
self._thread = None
|
|
|
|
|
|
def server_from_env() -> A2AServer:
|
|
"""Build an A2AServer from environment variables / defaults."""
|
|
hermes_home = Path(os.environ.get("HERMES_HOME", Path.home() / ".hermes"))
|
|
agent_name = os.environ.get("HERMES_AGENT_NAME", "hermes").lower()
|
|
|
|
default_cert = hermes_home / "pki" / "agents" / agent_name / f"{agent_name}.crt"
|
|
default_key = hermes_home / "pki" / "agents" / agent_name / f"{agent_name}.key"
|
|
default_ca = hermes_home / "pki" / "ca" / "fleet-ca.crt"
|
|
|
|
cert = os.environ.get("HERMES_A2A_CERT", str(default_cert))
|
|
key = os.environ.get("HERMES_A2A_KEY", str(default_key))
|
|
ca = os.environ.get("HERMES_A2A_CA", str(default_ca))
|
|
host = os.environ.get("HERMES_A2A_HOST", "0.0.0.0")
|
|
port = int(os.environ.get("HERMES_A2A_PORT", "9443"))
|
|
|
|
return A2AServer(cert=cert, key=key, ca=ca, host=host, port=port)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# CLI entry point
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def _main() -> None:
|
|
import argparse
|
|
|
|
logging.basicConfig(level=logging.INFO, format="%(levelname)s %(message)s")
|
|
|
|
parser = argparse.ArgumentParser(
|
|
description="Hermes A2A mutual-TLS server"
|
|
)
|
|
parser.add_argument("--cert", required=True, help="Path to agent certificate")
|
|
parser.add_argument("--key", required=True, help="Path to agent private key")
|
|
parser.add_argument("--ca", required=True, help="Path to fleet CA certificate")
|
|
parser.add_argument("--host", default="0.0.0.0")
|
|
parser.add_argument("--port", type=int, default=9443)
|
|
args = parser.parse_args()
|
|
|
|
server = A2AServer(
|
|
cert=args.cert, key=args.key, ca=args.ca,
|
|
host=args.host, port=args.port,
|
|
)
|
|
server.start(daemon=False)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
_main()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# A2AMTLSServer — routing-based server with context-manager support
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class _RoutingHandler(BaseHTTPRequestHandler):
|
|
"""HTTP request handler that dispatches to per-path callables."""
|
|
|
|
routes: Dict[str, Callable] = {}
|
|
|
|
def log_message(self, fmt: str, *args: Any) -> None:
|
|
logger.debug("A2AMTLSServer: " + fmt, *args)
|
|
|
|
def _peer_cn(self) -> Optional[str]:
|
|
cert = self.connection.getpeercert() # type: ignore[attr-defined]
|
|
if not cert:
|
|
return None
|
|
for rdn in cert.get("subject", ()):
|
|
for attr, value in rdn:
|
|
if attr == "commonName":
|
|
return value
|
|
return None
|
|
|
|
def do_POST(self) -> None:
|
|
handler = self.routes.get(self.path)
|
|
if handler is None:
|
|
self.send_response(404)
|
|
self.end_headers()
|
|
return
|
|
length = int(self.headers.get("Content-Length", 0))
|
|
body = self.rfile.read(length) if length else b""
|
|
try:
|
|
payload = json.loads(body) if body else {}
|
|
except json.JSONDecodeError:
|
|
self.send_response(400)
|
|
self.end_headers()
|
|
return
|
|
result = handler(payload, peer_cn=self._peer_cn())
|
|
self.send_response(200)
|
|
self.send_header("Content-Type", "application/json")
|
|
self.end_headers()
|
|
self.wfile.write(json.dumps(result).encode())
|
|
|
|
def do_GET(self) -> None:
|
|
handler = self.routes.get(self.path)
|
|
if handler is None:
|
|
self.send_response(404)
|
|
self.end_headers()
|
|
return
|
|
result = handler({}, peer_cn=self._peer_cn())
|
|
self.send_response(200)
|
|
self.send_header("Content-Type", "application/json")
|
|
self.end_headers()
|
|
self.wfile.write(json.dumps(result).encode())
|
|
|
|
|
|
class A2AMTLSServer:
|
|
"""Routing-based mTLS HTTPS server with context-manager support.
|
|
|
|
Unlike ``A2AServer`` (which serves fixed A2A paths), this server lets
|
|
callers register arbitrary path handlers — useful for tests and custom
|
|
A2A endpoint implementations.
|
|
|
|
handler signature: ``handler(payload: dict, *, peer_cn: str | None) -> dict``
|
|
|
|
Example::
|
|
|
|
server = A2AMTLSServer(cert="timmy.crt", key="timmy.key", ca="fleet-ca.crt")
|
|
server.add_route("/tasks/send", my_handler)
|
|
with server:
|
|
... # server runs for the duration of the block
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
cert: str | Path,
|
|
key: str | Path,
|
|
ca: str | Path,
|
|
host: str = "127.0.0.1",
|
|
port: int = 9443,
|
|
) -> None:
|
|
self.cert = Path(cert).expanduser()
|
|
self.key = Path(key).expanduser()
|
|
self.ca = Path(ca).expanduser()
|
|
self.host = host
|
|
self.port = port
|
|
self._routes: Dict[str, Callable] = {}
|
|
self._httpd: Optional[HTTPServer] = None
|
|
self._thread: Optional[threading.Thread] = None
|
|
|
|
def add_route(self, path: str, handler: Callable) -> None:
|
|
self._routes[path] = handler
|
|
|
|
def start(self) -> None:
|
|
ssl_ctx = build_server_ssl_context(self.cert, self.key, self.ca)
|
|
|
|
class _Handler(_RoutingHandler):
|
|
routes = self._routes
|
|
|
|
self._httpd = HTTPServer((self.host, self.port), _Handler)
|
|
self._httpd.socket = ssl_ctx.wrap_socket(self._httpd.socket, server_side=True)
|
|
self._thread = threading.Thread(
|
|
target=self._httpd.serve_forever,
|
|
daemon=True,
|
|
name=f"a2a-mtls-{self.port}",
|
|
)
|
|
self._thread.start()
|
|
logger.info("A2AMTLSServer on %s:%d (mTLS)", self.host, self.port)
|
|
|
|
def stop(self) -> None:
|
|
if self._httpd:
|
|
self._httpd.shutdown()
|
|
self._httpd = None
|
|
if self._thread:
|
|
self._thread.join(timeout=5)
|
|
self._thread = None
|
|
|
|
def __enter__(self) -> "A2AMTLSServer":
|
|
self.start()
|
|
return self
|
|
|
|
def __exit__(self, *_: Any) -> None:
|
|
self.stop()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# A2AMTLSClient — mTLS HTTP client
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class A2AMTLSClient:
|
|
"""HTTP client that presents a fleet cert on every outgoing connection.
|
|
|
|
Example::
|
|
|
|
client = A2AMTLSClient(cert="allegro.crt", key="allegro.key", ca="fleet-ca.crt")
|
|
result = client.post("https://timmy:9443/tasks/send", json={"task": "..."})
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
cert: str | Path,
|
|
key: str | Path,
|
|
ca: str | Path,
|
|
) -> None:
|
|
self._ssl_ctx = build_client_ssl_context(cert, key, ca)
|
|
self._ssl_ctx.check_hostname = False # callers connecting by IP
|
|
|
|
def _request(
|
|
self,
|
|
method: str,
|
|
url: str,
|
|
data: Optional[bytes] = None,
|
|
timeout: float = 10.0,
|
|
) -> Dict[str, Any]:
|
|
headers = {"Content-Type": "application/json"}
|
|
req = Request(url, data=data, headers=headers, method=method)
|
|
try:
|
|
with urlopen(req, context=self._ssl_ctx, timeout=timeout) as resp:
|
|
body = resp.read()
|
|
return json.loads(body) if body else {}
|
|
except URLError as exc:
|
|
raise ConnectionError(f"A2AMTLSClient {method} {url} failed: {exc.reason}") from exc
|
|
|
|
def get(self, url: str, **kwargs: Any) -> Dict[str, Any]:
|
|
return self._request("GET", url, **kwargs)
|
|
|
|
def post(self, url: str, json: Optional[Dict[str, Any]] = None, **kwargs: Any) -> Dict[str, Any]:
|
|
data = (__import__("json").dumps(json).encode() if json is not None else None)
|
|
return self._request("POST", url, data=data, **kwargs)
|