Compare commits

...

8 Commits

Author SHA1 Message Date
3659c2c57d feat: mTLS server for A2A (#806) 2026-04-15 16:28:36 +00:00
7331846f87 feat: Fleet CA for mTLS (#806) 2026-04-15 16:28:33 +00:00
6a460857bf test: A2A mTLS tests
Some checks failed
Docker Build and Publish / build-and-push (pull_request) Has been skipped
Contributor Attribution Check / check-attribution (pull_request) Failing after 26s
Supply Chain Audit / Scan PR for supply chain risks (pull_request) Successful in 27s
Tests / e2e (pull_request) Successful in 2m31s
Tests / test (pull_request) Failing after 41m50s
Part of #806
2026-04-15 16:18:14 +00:00
9446db5ee7 feat: A2A mutual TLS verification module
Part of #806
2026-04-15 16:18:12 +00:00
301b8c296b feat: Fleet CA and agent cert generator
Part of #806
2026-04-15 16:18:08 +00:00
d86359cbb2 Merge pull request 'feat: robust tool orchestration and circuit breaking' (#811) from feat/robust-tool-orchestration-1776268138150 into main 2026-04-15 16:03:07 +00:00
f264b55b29 refactor: use ToolOrchestrator for robust tool execution
Some checks failed
Contributor Attribution Check / check-attribution (pull_request) Successful in 36s
Docker Build and Publish / build-and-push (pull_request) Has been skipped
Supply Chain Audit / Scan PR for supply chain risks (pull_request) Successful in 38s
Tests / e2e (pull_request) Successful in 2m37s
Tests / test (pull_request) Failing after 40m19s
2026-04-15 15:49:02 +00:00
dfe23f66b1 feat: add ToolOrchestrator with circuit breaker 2026-04-15 15:49:00 +00:00
7 changed files with 722 additions and 2 deletions

165
agent/a2a_mtls.py Normal file
View File

@@ -0,0 +1,165 @@
"""
A2A Mutual TLS Verification — #806
Provides TLS context factories for mTLS-secured agent-to-agent communication.
Each agent presents its cert, the server verifies against the Fleet CA.
Usage:
from agent.a2a_mtls import get_server_ssl_context, get_client_ssl_context
# Server side (A2A server)
ssl_ctx = get_server_ssl_context(
cert_file="/path/to/agent.crt",
key_file="/path/to/agent.key",
ca_file="/path/to/fleet-ca.crt",
)
# Client side (A2A client)
ssl_ctx = get_client_ssl_context(
cert_file="/path/to/agent.crt",
key_file="/path/to/agent.key",
ca_file="/path/to/fleet-ca.crt",
)
"""
import os
import ssl
from pathlib import Path
from typing import Optional
# Default paths
DEFAULT_CERTS_DIR = Path(os.getenv("FLEET_CERTS_DIR", str(Path.home() / ".hermes" / "fleet-certs")))
def get_server_ssl_context(
cert_file: Optional[str] = None,
key_file: Optional[str] = None,
ca_file: Optional[str] = None,
agent_name: Optional[str] = None,
) -> ssl.SSLContext:
"""
Create SSL context for mTLS server.
Requires client certificate verification.
"""
if agent_name and not cert_file:
cert_file = str(DEFAULT_CERTS_DIR / agent_name / f"{agent_name}.crt")
key_file = str(DEFAULT_CERTS_DIR / agent_name / f"{agent_name}.key")
ca_file = str(DEFAULT_CERTS_DIR / agent_name / "fleet-ca.crt")
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
# Load server cert and key
ctx.load_cert_chain(certfile=cert_file, keyfile=key_file)
# Require client certificate
ctx.verify_mode = ssl.CERT_REQUIRED
# Load CA for verifying client certs
ctx.load_verify_locations(cafile=ca_file)
# Security settings
ctx.minimum_version = ssl.TLSVersion.TLSv1_2
ctx.check_hostname = False # Internal fleet, not public DNS
return ctx
def get_client_ssl_context(
cert_file: Optional[str] = None,
key_file: Optional[str] = None,
ca_file: Optional[str] = None,
agent_name: Optional[str] = None,
) -> ssl.SSLContext:
"""
Create SSL context for mTLS client.
Presents client certificate for server verification.
"""
if agent_name and not cert_file:
cert_file = str(DEFAULT_CERTS_DIR / agent_name / f"{agent_name}.crt")
key_file = str(DEFAULT_CERTS_DIR / agent_name / f"{agent_name}.key")
ca_file = str(DEFAULT_CERTS_DIR / agent_name / "fleet-ca.crt")
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
# Load client cert and key
ctx.load_cert_chain(certfile=cert_file, keyfile=key_file)
# Load CA for verifying server cert
ctx.load_verify_locations(cafile=ca_file)
# Security settings
ctx.minimum_version = ssl.TLSVersion.TLSv1_2
ctx.check_hostname = False # Internal fleet
return ctx
def verify_agent_cert(cert_pem: str, ca_file: Optional[str] = None) -> tuple[bool, str]:
"""
Verify an agent certificate against the Fleet CA.
Returns (valid, subject_cn).
"""
if ca_file is None:
ca_file = str(DEFAULT_CERTS_DIR / "fleet-ca.crt")
try:
from cryptography import x509
from cryptography.x509.verification import PolicyBuilder, Store
cert = x509.load_pem_x509_certificate(cert_pem.encode() if isinstance(cert_pem, str) else cert_pem)
with open(ca_file, "rb") as f:
ca_cert = x509.load_pem_x509_certificate(f.read())
store = Store([ca_cert])
builder = PolicyBuilder().store(store)
verifier = builder.build_server_verifier(x509.DNSName("fleet.local"))
try:
verifier.verify(cert, [cert])
cn = cert.subject.get_attributes_for_oid(x509.oid.NameOID.COMMON_NAME)[0].value
return True, cn
except Exception as e:
return False, str(e)
except ImportError:
# Fallback: basic validation
try:
from cryptography import x509
cert = x509.load_pem_x509_certificate(cert_pem.encode() if isinstance(cert_pem, str) else cert_pem)
with open(ca_file, "rb") as f:
ca_cert = x509.load_pem_x509_certificate(f.read())
# Check issuer matches CA subject
if cert.issuer == ca_cert.subject:
cn = cert.subject.get_attributes_for_oid(x509.oid.NameOID.COMMON_NAME)[0].value
return True, cn
return False, "Issuer mismatch"
except Exception as e:
return False, str(e)
def get_agent_cn_from_context(ssl_context: ssl.SSLContext) -> Optional[str]:
"""
Extract agent Common Name from an SSL context's peer certificate.
Used by the server to identify which agent is connecting.
"""
try:
peer_cert = ssl_context.getpeercert(binary_form=True)
if peer_cert:
from cryptography import x509
cert = x509.load_der_x509_certificate(peer_cert)
cn_attrs = cert.subject.get_attributes_for_oid(x509.oid.NameOID.COMMON_NAME)
if cn_attrs:
cn = cn_attrs[0].value
# Strip "agent-" prefix if present
if cn.startswith("agent-"):
return cn[6:]
return cn
except Exception:
pass
return None

177
agent/tool_orchestrator.py Normal file
View File

@@ -0,0 +1,177 @@
"""Tool Orchestrator — Robust execution and circuit breaking for agent tools.
Provides a unified execution service that wraps the tool registry.
Implements the Circuit Breaker pattern to prevent the agent from getting
stuck in failure loops when a specific tool or its underlying service
is flapping or down.
Architecture:
Discovery (tools/registry.py) -> Orchestration (agent/tool_orchestrator.py) -> Dispatch
"""
import json
import time
import logging
import threading
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple
from tools.registry import registry
logger = logging.getLogger(__name__)
class CircuitState:
"""States for the tool circuit breaker."""
CLOSED = "closed" # Normal operation
OPEN = "open" # Failing, execution blocked
HALF_OPEN = "half_open" # Testing if service recovered
@dataclass
class ToolStats:
"""Execution statistics for a tool."""
name: str
state: str = CircuitState.CLOSED
failures: int = 0
successes: int = 0
last_failure_time: float = 0
total_execution_time: float = 0
call_count: int = 0
class ToolOrchestrator:
"""Orchestrates tool execution with robustness patterns."""
def __init__(
self,
failure_threshold: int = 3,
reset_timeout: int = 300,
):
"""
Args:
failure_threshold: Number of failures before opening the circuit.
reset_timeout: Seconds to wait before transitioning from OPEN to HALF_OPEN.
"""
self.failure_threshold = failure_threshold
self.reset_timeout = reset_timeout
self._stats: Dict[str, ToolStats] = {}
self._lock = threading.Lock()
def _get_stats(self, name: str) -> ToolStats:
"""Get or initialize stats for a tool with thread-safe state transition."""
with self._lock:
if name not in self._stats:
self._stats[name] = ToolStats(name=name)
stats = self._stats[name]
# Transition from OPEN to HALF_OPEN if timeout expired
if stats.state == CircuitState.OPEN:
if time.time() - stats.last_failure_time > self.reset_timeout:
stats.state = CircuitState.HALF_OPEN
logger.info("Circuit breaker HALF_OPEN for tool: %s", name)
return stats
def _record_success(self, name: str, execution_time: float):
"""Record a successful tool execution and close the circuit."""
with self._lock:
stats = self._stats[name]
stats.successes += 1
stats.call_count += 1
stats.total_execution_time += execution_time
if stats.state != CircuitState.CLOSED:
logger.info("Circuit breaker CLOSED for tool: %s (recovered)", name)
stats.state = CircuitState.CLOSED
stats.failures = 0
def _record_failure(self, name: str, execution_time: float):
"""Record a failed tool execution and potentially open the circuit."""
with self._lock:
stats = self._stats[name]
stats.failures += 1
stats.call_count += 1
stats.total_execution_time += execution_time
stats.last_failure_time = time.time()
if stats.state == CircuitState.HALF_OPEN or stats.failures >= self.failure_threshold:
stats.state = CircuitState.OPEN
logger.warning(
"Circuit breaker OPEN for tool: %s (failures: %d)",
name, stats.failures
)
def dispatch(self, name: str, args: dict, **kwargs) -> str:
"""Execute a tool via the registry with circuit breaker protection."""
stats = self._get_stats(name)
if stats.state == CircuitState.OPEN:
return json.dumps({
"error": (
f"Tool '{name}' is temporarily unavailable due to repeated failures. "
f"Circuit breaker is OPEN. Please try again in a few minutes or use an alternative tool."
),
"circuit_breaker": True,
"tool_name": name
})
start_time = time.time()
try:
# Dispatch to the underlying registry
result_str = registry.dispatch(name, args, **kwargs)
execution_time = time.time() - start_time
# Inspect result for errors. registry.dispatch catches internal
# exceptions and returns a JSON error string.
is_error = False
try:
# Lightweight check for error key in JSON
if '"error":' in result_str:
res_json = json.loads(result_str)
if isinstance(res_json, dict) and "error" in res_json:
is_error = True
except (json.JSONDecodeError, TypeError):
# If it's not valid JSON, it's a malformed result (error)
is_error = True
if is_error:
self._record_failure(name, execution_time)
else:
self._record_success(name, execution_time)
return result_str
except Exception as e:
# This should rarely be hit as registry.dispatch catches most things,
# but we guard against orchestrator-level or registry-level bugs.
execution_time = time.time() - start_time
self._record_failure(name, execution_time)
error_msg = f"Tool orchestrator error during {name}: {type(e).__name__}: {e}"
logger.exception(error_msg)
return json.dumps({
"error": error_msg,
"tool_name": name,
"execution_time": execution_time
})
def get_fleet_stats(self) -> Dict[str, Any]:
"""Return execution statistics for all tools."""
with self._lock:
return {
name: {
"state": s.state,
"failures": s.failures,
"successes": s.successes,
"avg_time": s.total_execution_time / s.call_count if s.call_count > 0 else 0,
"calls": s.call_count
}
for name, s in self._stats.items()
}
# Global orchestrator instance
orchestrator = ToolOrchestrator()

View File

@@ -28,6 +28,7 @@ from typing import Dict, Any, List, Optional, Tuple
from tools.registry import discover_builtin_tools, registry
from toolsets import resolve_toolset, validate_toolset
from agent.tool_orchestrator import orchestrator
logger = logging.getLogger(__name__)
@@ -499,13 +500,13 @@ def handle_function_call(
# Prefer the caller-provided list so subagents can't overwrite
# the parent's tool set via the process-global.
sandbox_enabled = enabled_tools if enabled_tools is not None else _last_resolved_tool_names
result = registry.dispatch(
result = orchestrator.dispatch(
function_name, function_args,
task_id=task_id,
enabled_tools=sandbox_enabled,
)
else:
result = registry.dispatch(
result = orchestrator.dispatch(
function_name, function_args,
task_id=task_id,
user_task=user_task,

View File

@@ -0,0 +1,206 @@
#!/usr/bin/env python3
"""
Fleet CA and Agent Certificate Generator — #806
Generates a Fleet CA and per-agent TLS certificates for mutual TLS
authentication between fleet agents.
Usage:
# Generate Fleet CA
python scripts/generate_fleet_ca.py --ca-dir ./fleet-ca
# Generate agent cert
python scripts/generate_fleet_ca.py --ca-dir ./fleet-ca --agent timmy
python scripts/generate_fleet_ca.py --ca-dir ./fleet-ca --agent allegro
python scripts/generate_fleet_ca.py --ca-dir ./fleet-ca --agent ezra
# Generate all fleet certs
python scripts/generate_fleet_ca.py --ca-dir ./fleet-ca --all
"""
import argparse
import datetime
import os
import sys
from pathlib import Path
try:
from cryptography import x509
from cryptography.x509.oid import NameOID
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import rsa
HAS_CRYPTO = True
except ImportError:
HAS_CRYPTO = False
FLEET_AGENTS = ["timmy", "allegro", "ezra", "bezalel"]
CA_VALIDITY_DAYS = 3650 # 10 years
CERT_VALIDITY_DAYS = 365 # 1 year
KEY_SIZE = 2048
def generate_ca(ca_dir: Path) -> tuple:
"""Generate Fleet CA key and certificate."""
ca_dir.mkdir(parents=True, exist_ok=True)
# Generate CA key
ca_key = rsa.generate_private_key(public_exponent=65537, key_size=KEY_SIZE)
# Generate CA cert
subject = issuer = x509.Name([
x509.NameAttribute(NameOID.ORGANIZATION_NAME, "Timmy Foundation"),
x509.NameAttribute(NameOID.COMMON_NAME, "Fleet CA"),
])
ca_cert = (
x509.CertificateBuilder()
.subject_name(subject)
.issuer_name(issuer)
.public_key(ca_key.public_key())
.serial_number(x509.random_serial_number())
.not_valid_before(datetime.datetime.utcnow())
.not_valid_after(datetime.datetime.utcnow() + datetime.timedelta(days=CA_VALIDITY_DAYS))
.add_extension(x509.BasicConstraints(ca=True, path_length=None), critical=True)
.add_extension(
x509.KeyUsage(
digital_signature=True, key_cert_sign=True, crl_sign=True,
content_commitment=False, key_encipherment=False,
data_encipherment=False, key_agreement=False,
encipher_only=False, decipher_only=False,
),
critical=True,
)
.sign(ca_key, hashes.SHA256())
)
# Save
ca_key_path = ca_dir / "fleet-ca.key"
ca_cert_path = ca_dir / "fleet-ca.crt"
ca_key_path.write_bytes(ca_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.NoEncryption(),
))
ca_cert_path.write_bytes(ca_cert.public_bytes(serialization.Encoding.PEM))
# Secure permissions
os.chmod(ca_key_path, 0o600)
os.chmod(ca_cert_path, 0o644)
print(f"CA key: {ca_key_path}")
print(f"CA cert: {ca_cert_path}")
return ca_key, ca_cert
def generate_agent_cert(ca_dir: Path, agent_name: str, ca_key=None, ca_cert=None) -> tuple:
"""Generate TLS certificate for an agent."""
agent_dir = ca_dir / agent_name
agent_dir.mkdir(parents=True, exist_ok=True)
# Load CA if not provided
if ca_key is None or ca_cert is None:
ca_key_path = ca_dir / "fleet-ca.key"
ca_cert_path = ca_dir / "fleet-ca.crt"
if not ca_key_path.exists() or not ca_cert_path.exists():
print(f"Error: CA not found in {ca_dir}. Run --ca first.")
return None, None
with open(ca_key_path, "rb") as f:
ca_key = serialization.load_pem_private_key(f.read(), password=None)
with open(ca_cert_path, "rb") as f:
ca_cert = x509.load_pem_x509_certificate(f.read())
# Generate agent key
agent_key = rsa.generate_private_key(public_exponent=65537, key_size=KEY_SIZE)
# Generate agent cert
subject = x509.Name([
x509.NameAttribute(NameOID.ORGANIZATION_NAME, "Timmy Foundation"),
x509.NameAttribute(NameOID.COMMON_NAME, f"agent-{agent_name}"),
])
agent_cert = (
x509.CertificateBuilder()
.subject_name(subject)
.issuer_name(ca_cert.subject)
.public_key(agent_key.public_key())
.serial_number(x509.random_serial_number())
.not_valid_before(datetime.datetime.utcnow())
.not_valid_after(datetime.datetime.utcnow() + datetime.timedelta(days=CERT_VALIDITY_DAYS))
.add_extension(x509.BasicConstraints(ca=False, path_length=None), critical=True)
.add_extension(
x509.SubjectAlternativeName([
x509.DNSName(f"{agent_name}.fleet.local"),
x509.DNSName(f"{agent_name}"),
x509.DNSName("localhost"),
]),
critical=False,
)
.sign(ca_key, hashes.SHA256())
)
# Save
key_path = agent_dir / f"{agent_name}.key"
cert_path = agent_dir / f"{agent_name}.crt"
key_path.write_bytes(agent_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.NoEncryption(),
))
cert_path.write_bytes(agent_cert.public_bytes(serialization.Encoding.PEM))
# Copy CA cert to agent dir
ca_copy = agent_dir / "fleet-ca.crt"
ca_copy.write_bytes(ca_cert.public_bytes(serialization.Encoding.PEM))
# Secure permissions
os.chmod(key_path, 0o600)
os.chmod(cert_path, 0o644)
print(f"Agent {agent_name}:")
print(f" Key: {key_path}")
print(f" Cert: {cert_path}")
print(f" CA: {ca_copy}")
return agent_key, agent_cert
def main():
parser = argparse.ArgumentParser(description="Fleet CA and Agent Certificate Generator")
parser.add_argument("--ca-dir", type=Path, default=Path("./fleet-ca"), help="CA directory")
parser.add_argument("--ca", action="store_true", help="Generate Fleet CA")
parser.add_argument("--agent", type=str, help="Generate cert for agent")
parser.add_argument("--all", action="store_true", help="Generate certs for all fleet agents")
args = parser.parse_args()
if not HAS_CRYPTO:
print("Error: cryptography package required. pip install cryptography")
sys.exit(1)
if args.ca:
generate_ca(args.ca_dir)
if args.agent:
generate_agent_cert(args.ca_dir, args.agent)
if args.all:
# Generate CA first if not exists
ca_key_path = args.ca_dir / "fleet-ca.key"
if not ca_key_path.exists():
ca_key, ca_cert = generate_ca(args.ca_dir)
else:
ca_key, ca_cert = None, None
for agent in FLEET_AGENTS:
generate_agent_cert(args.ca_dir, agent, ca_key, ca_cert)
if not args.ca and not args.agent and not args.all:
parser.print_help()
if __name__ == "__main__":
main()

60
tests/test_a2a_mtls.py Normal file
View File

@@ -0,0 +1,60 @@
"""Tests for A2A mutual TLS (#806)."""
import sys
import tempfile
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
def test_import():
"""Module imports cleanly."""
from agent.a2a_mtls import get_server_ssl_context, get_client_ssl_context, verify_agent_cert
assert callable(get_server_ssl_context)
assert callable(get_client_ssl_context)
assert callable(verify_agent_cert)
def test_default_paths():
"""Default cert paths resolve correctly."""
from agent.a2a_mtls import DEFAULT_CERTS_DIR
assert DEFAULT_CERTS_DIR is not None
assert "fleet-certs" in str(DEFAULT_CERTS_DIR)
def test_server_context_creation():
"""Server SSL context can be created with agent name."""
# This will fail if certs don't exist, which is expected
from agent.a2a_mtls import get_server_ssl_context
try:
ctx = get_server_ssl_context(agent_name="timmy")
assert ctx is not None
except FileNotFoundError:
pass # Expected when certs don't exist
def test_client_context_creation():
"""Client SSL context can be created with agent name."""
from agent.a2a_mtls import get_client_ssl_context
try:
ctx = get_client_ssl_context(agent_name="timmy")
assert ctx is not None
except FileNotFoundError:
pass # Expected when certs don't exist
def test_verify_agent_cert_invalid():
"""Invalid cert returns False."""
from agent.a2a_mtls import verify_agent_cert
valid, msg = verify_agent_cert("not a cert")
assert not valid
if __name__ == "__main__":
tests = [test_import, test_default_paths, test_server_context_creation,
test_client_context_creation, test_verify_agent_cert_invalid]
for t in tests:
print(f"Running {t.__name__}...")
t()
print(" PASS")
print("\nAll tests passed.")

71
tools/fleet_ca.py Normal file
View File

@@ -0,0 +1,71 @@
"""Fleet CA for agent-to-agent mTLS (#806)."""
import argparse, datetime, os, sys
from pathlib import Path
def init_ca(output_dir, ca_name="Timmy Fleet CA", days=3650):
from cryptography import x509
from cryptography.x509.oid import NameOID
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import rsa
out = Path(output_dir); out.mkdir(parents=True, exist_ok=True)
ca_key = rsa.generate_private_key(65537, 4096)
ca_cert = (x509.CertificateBuilder()
.subject_name(x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, ca_name)]))
.issuer_name(x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, ca_name)]))
.public_key(ca_key.public_key()).serial_number(x509.random_serial_number())
.not_valid_before(datetime.datetime.utcnow())
.not_valid_after(datetime.datetime.utcnow() + datetime.timedelta(days=days))
.add_extension(x509.BasicConstraints(ca=True, path_length=None), critical=True)
.sign(ca_key, hashes.SHA256()))
with open(out/"ca.key","wb") as f: f.write(ca_key.private_bytes(serialization.Encoding.PEM, serialization.PrivateFormat.TraditionalOpenSSL, serialization.NoEncryption()))
os.chmod(out/"ca.key", 0o600)
with open(out/"ca.crt","wb") as f: f.write(ca_cert.public_bytes(serialization.Encoding.PEM))
print(f"CA created: {out}/ca.crt")
def issue_cert(agent_name, output_dir, days=365):
from cryptography import x509
from cryptography.x509.oid import NameOID
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import rsa
out = Path(output_dir)
with open(out/"ca.key","rb") as f: ca_key = serialization.load_pem_private_key(f.read(), None)
with open(out/"ca.crt","rb") as f: ca_cert = x509.load_pem_x509_certificate(f.read())
key = rsa.generate_private_key(65537, 2048)
cert = (x509.CertificateBuilder()
.subject_name(x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, agent_name)]))
.issuer_name(ca_cert.subject).public_key(key.public_key())
.serial_number(x509.random_serial_number())
.not_valid_before(datetime.datetime.utcnow())
.not_valid_after(datetime.datetime.utcnow() + datetime.timedelta(days=days))
.add_extension(x509.BasicConstraints(ca=False, path_length=None), critical=True)
.add_extension(x509.SubjectAlternativeName([x509.DNSName(agent_name), x509.DNSName(f"{agent_name}.local")]), critical=False)
.sign(ca_key, hashes.SHA256()))
with open(out/f"{agent_name}.key","wb") as f: f.write(key.private_bytes(serialization.Encoding.PEM, serialization.PrivateFormat.TraditionalOpenSSL, serialization.NoEncryption()))
os.chmod(out/f"{agent_name}.key", 0o600)
with open(out/f"{agent_name}.crt","wb") as f: f.write(cert.public_bytes(serialization.Encoding.PEM))
print(f"Cert issued: {out}/{agent_name}.crt")
def verify_cert(cert_path, ca_path):
from cryptography import x509
from cryptography.hazmat.primitives.asymmetric import padding
with open(ca_path,"rb") as f: ca = x509.load_pem_x509_certificate(f.read())
with open(cert_path,"rb") as f: cert = x509.load_pem_x509_certificate(f.read())
try:
ca.public_key().verify(cert.signature, cert.tbs_certificate_bytes, padding.PKCS1v15(), cert.signature_hash_algorithm)
print(f"OK: {cert.subject} signed by {ca.subject}"); return True
except Exception as e: print(f"FAIL: {e}"); return False
def main():
p = argparse.ArgumentParser(description="Fleet CA for mTLS")
sub = p.add_subparsers(dest="cmd")
pi = sub.add_parser("init"); pi.add_argument("--output-dir", default=os.path.expanduser("~/.hermes/certs"))
pi.add_argument("--ca-name", default="Timmy Fleet CA"); pi.add_argument("--days", type=int, default=3650)
pu = sub.add_parser("issue"); pu.add_argument("--agent", required=True); pu.add_argument("--output-dir", default=os.path.expanduser("~/.hermes/certs"))
pv = sub.add_parser("verify"); pv.add_argument("--cert", required=True); pv.add_argument("--ca", required=True)
args = p.parse_args()
if args.cmd == "init": init_ca(args.output_dir, args.ca_name, args.days)
elif args.cmd == "issue": issue_cert(args.agent, args.output_dir)
elif args.cmd == "verify": sys.exit(0 if verify_cert(args.cert, args.ca) else 1)
else: p.print_help()
if __name__ == "__main__": main()

40
tools/mtls_server.py Normal file
View File

@@ -0,0 +1,40 @@
"""mTLS server for A2A auth (#806)."""
import asyncio, logging, ssl
from pathlib import Path
from typing import Callable, Optional
logger = logging.getLogger(__name__)
def create_ssl_context(ca_cert_path, server_cert_path, server_key_path):
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
ctx.minimum_version = ssl.TLSVersion.TLSv1_2
ctx.load_cert_chain(certfile=server_cert_path, keyfile=server_key_path)
ctx.verify_mode = ssl.CERT_REQUIRED
ctx.load_verify_locations(cafile=ca_cert_path)
return ctx
def get_client_identity(ssl_obj):
try:
cert = ssl_obj.getpeercert()
if cert:
for rdn in cert.get("subject", ()):
for attr in rdn:
if attr[0] == "commonName": return attr[1]
except Exception: pass
return None
async def create_mtls_server(handler, host="127.0.0.1", port=8766, ca_cert="", server_cert="", server_key=""):
ca_cert = str(Path(ca_cert).expanduser())
server_cert = str(Path(server_cert).expanduser())
server_key = str(Path(server_key).expanduser())
ssl_ctx = create_ssl_context(ca_cert, server_cert, server_key)
async def _wrapper(reader, writer):
ssl_obj = writer.transport.get_extra_info("ssl_object")
agent = get_client_identity(ssl_obj) or "unknown"
logger.info("mTLS connection from: %s", agent)
try: await handler(ssl_obj, reader, writer)
except Exception as e: logger.error("Handler error: %s", e)
finally: writer.close()
server = await asyncio.start_server(_wrapper, host, port, ssl=ssl_ctx)
logger.info("mTLS server on %s:%d", host, port)
return server