Files
hermes-agent/a2a/client.py

99 lines
4.9 KiB
Python

"""A2A Client - send tasks to remote agents via JSON-RPC 2.0."""
from __future__ import annotations
import asyncio, json, logging, time
from typing import Any, Dict, List, Optional
import aiohttp
from a2a.types import (AgentCard, Artifact, JSONRPCError, JSONRPCRequest, JSONRPCResponse, Message, Part, Task, TaskState, TaskStatus, TextPart, A2AError)
logger = logging.getLogger(__name__)
class A2AClientConfig:
def __init__(self, timeout=30.0, max_retries=3, retry_delay=2.0, auth_token=None, auth_scheme="Bearer"):
self.timeout = timeout
self.max_retries = max_retries
self.retry_delay = retry_delay
self.auth_token = auth_token
self.auth_scheme = auth_scheme
class A2AClient:
def __init__(self, base_url, config=None):
self.base_url = base_url.rstrip("/")
self.config = config or A2AClientConfig()
self._audit_log = []
def _headers(self):
h = {"Content-Type": "application/json"}
if self.config.auth_token:
h["Authorization"] = f"{self.config.auth_scheme} {self.config.auth_token}"
return h
async def _rpc_call(self, method, params=None):
req = JSONRPCRequest(method=method, params=params)
payload = json.dumps(req.to_dict())
last_error = None
for attempt in range(1, self.config.max_retries + 1):
t0 = time.monotonic()
try:
timeout = aiohttp.ClientTimeout(total=self.config.timeout)
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.post(f"{self.base_url}/a2a/v1/rpc", data=payload, headers=self._headers()) as resp:
body = await resp.text()
elapsed = time.monotonic() - t0
self._audit_log.append({"method": method, "params": params, "status": resp.status, "elapsed_ms": round(elapsed*1000), "attempt": attempt})
if resp.status != 200:
logger.warning("A2A %s -> HTTP %d (attempt %d/%d)", method, resp.status, attempt, self.config.max_retries)
if attempt < self.config.max_retries:
await asyncio.sleep(self.config.retry_delay * attempt)
continue
return JSONRPCResponse(id=req.id, error=A2AError.internal_error(f"HTTP {resp.status}: {body[:200]}"))
return JSONRPCResponse.from_dict(json.loads(body))
except Exception as exc:
elapsed = time.monotonic() - t0
last_error = exc
self._audit_log.append({"method": method, "error": str(exc), "elapsed_ms": round(elapsed*1000), "attempt": attempt})
if attempt < self.config.max_retries:
await asyncio.sleep(self.config.retry_delay * attempt)
continue
return JSONRPCResponse(id=req.id, error=A2AError.internal_error(str(last_error)))
async def get_agent_card(self):
resp = await self._rpc_call("GetAgentCard")
if resp.error: raise RuntimeError(f"Failed to get agent card: {resp.error.message}")
return AgentCard.from_dict(resp.result)
async def send_message(self, text, context_id=None, skill_id=None, metadata=None):
msg = Message(role="user", parts=[TextPart(text=text)], context_id=context_id)
params = {"message": msg.to_dict()}
if skill_id: params["skillId"] = skill_id
if metadata: params["metadata"] = metadata
resp = await self._rpc_call("SendMessage", params)
if resp.error: raise RuntimeError(f"SendMessage failed: {resp.error.message}")
return Task.from_dict(resp.result)
async def get_task(self, task_id):
resp = await self._rpc_call("GetTask", {"taskId": task_id})
if resp.error: raise RuntimeError(f"GetTask failed: {resp.error.message}")
return Task.from_dict(resp.result)
async def cancel_task(self, task_id):
resp = await self._rpc_call("CancelTask", {"taskId": task_id})
if resp.error: raise RuntimeError(f"CancelTask failed: {resp.error.message}")
return Task.from_dict(resp.result)
async def wait_for_completion(self, task_id, poll_interval=2.0, timeout=300.0):
t0 = time.monotonic()
while True:
task = await self.get_task(task_id)
if task.status.state.terminal: return task
if time.monotonic() - t0 > timeout: raise TimeoutError(f"Task {task_id} did not complete within {timeout}s")
await asyncio.sleep(poll_interval)
async def delegate(self, text, skill_id=None, wait=True, timeout=300.0):
task = await self.send_message(text, skill_id=skill_id)
if wait: return await self.wait_for_completion(task.id, timeout=timeout)
return task
@property
def audit_log(self): return list(self._audit_log)