99 lines
4.9 KiB
Python
99 lines
4.9 KiB
Python
"""A2A Client - send tasks to remote agents via JSON-RPC 2.0."""
|
|
|
|
from __future__ import annotations
|
|
import asyncio, json, logging, time
|
|
from typing import Any, Dict, List, Optional
|
|
import aiohttp
|
|
from a2a.types import (AgentCard, Artifact, JSONRPCError, JSONRPCRequest, JSONRPCResponse, Message, Part, Task, TaskState, TaskStatus, TextPart, A2AError)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class A2AClientConfig:
|
|
def __init__(self, timeout=30.0, max_retries=3, retry_delay=2.0, auth_token=None, auth_scheme="Bearer"):
|
|
self.timeout = timeout
|
|
self.max_retries = max_retries
|
|
self.retry_delay = retry_delay
|
|
self.auth_token = auth_token
|
|
self.auth_scheme = auth_scheme
|
|
|
|
class A2AClient:
|
|
def __init__(self, base_url, config=None):
|
|
self.base_url = base_url.rstrip("/")
|
|
self.config = config or A2AClientConfig()
|
|
self._audit_log = []
|
|
|
|
def _headers(self):
|
|
h = {"Content-Type": "application/json"}
|
|
if self.config.auth_token:
|
|
h["Authorization"] = f"{self.config.auth_scheme} {self.config.auth_token}"
|
|
return h
|
|
|
|
async def _rpc_call(self, method, params=None):
|
|
req = JSONRPCRequest(method=method, params=params)
|
|
payload = json.dumps(req.to_dict())
|
|
last_error = None
|
|
for attempt in range(1, self.config.max_retries + 1):
|
|
t0 = time.monotonic()
|
|
try:
|
|
timeout = aiohttp.ClientTimeout(total=self.config.timeout)
|
|
async with aiohttp.ClientSession(timeout=timeout) as session:
|
|
async with session.post(f"{self.base_url}/a2a/v1/rpc", data=payload, headers=self._headers()) as resp:
|
|
body = await resp.text()
|
|
elapsed = time.monotonic() - t0
|
|
self._audit_log.append({"method": method, "params": params, "status": resp.status, "elapsed_ms": round(elapsed*1000), "attempt": attempt})
|
|
if resp.status != 200:
|
|
logger.warning("A2A %s -> HTTP %d (attempt %d/%d)", method, resp.status, attempt, self.config.max_retries)
|
|
if attempt < self.config.max_retries:
|
|
await asyncio.sleep(self.config.retry_delay * attempt)
|
|
continue
|
|
return JSONRPCResponse(id=req.id, error=A2AError.internal_error(f"HTTP {resp.status}: {body[:200]}"))
|
|
return JSONRPCResponse.from_dict(json.loads(body))
|
|
except Exception as exc:
|
|
elapsed = time.monotonic() - t0
|
|
last_error = exc
|
|
self._audit_log.append({"method": method, "error": str(exc), "elapsed_ms": round(elapsed*1000), "attempt": attempt})
|
|
if attempt < self.config.max_retries:
|
|
await asyncio.sleep(self.config.retry_delay * attempt)
|
|
continue
|
|
return JSONRPCResponse(id=req.id, error=A2AError.internal_error(str(last_error)))
|
|
|
|
async def get_agent_card(self):
|
|
resp = await self._rpc_call("GetAgentCard")
|
|
if resp.error: raise RuntimeError(f"Failed to get agent card: {resp.error.message}")
|
|
return AgentCard.from_dict(resp.result)
|
|
|
|
async def send_message(self, text, context_id=None, skill_id=None, metadata=None):
|
|
msg = Message(role="user", parts=[TextPart(text=text)], context_id=context_id)
|
|
params = {"message": msg.to_dict()}
|
|
if skill_id: params["skillId"] = skill_id
|
|
if metadata: params["metadata"] = metadata
|
|
resp = await self._rpc_call("SendMessage", params)
|
|
if resp.error: raise RuntimeError(f"SendMessage failed: {resp.error.message}")
|
|
return Task.from_dict(resp.result)
|
|
|
|
async def get_task(self, task_id):
|
|
resp = await self._rpc_call("GetTask", {"taskId": task_id})
|
|
if resp.error: raise RuntimeError(f"GetTask failed: {resp.error.message}")
|
|
return Task.from_dict(resp.result)
|
|
|
|
async def cancel_task(self, task_id):
|
|
resp = await self._rpc_call("CancelTask", {"taskId": task_id})
|
|
if resp.error: raise RuntimeError(f"CancelTask failed: {resp.error.message}")
|
|
return Task.from_dict(resp.result)
|
|
|
|
async def wait_for_completion(self, task_id, poll_interval=2.0, timeout=300.0):
|
|
t0 = time.monotonic()
|
|
while True:
|
|
task = await self.get_task(task_id)
|
|
if task.status.state.terminal: return task
|
|
if time.monotonic() - t0 > timeout: raise TimeoutError(f"Task {task_id} did not complete within {timeout}s")
|
|
await asyncio.sleep(poll_interval)
|
|
|
|
async def delegate(self, text, skill_id=None, wait=True, timeout=300.0):
|
|
task = await self.send_message(text, skill_id=skill_id)
|
|
if wait: return await self.wait_for_completion(task.id, timeout=timeout)
|
|
return task
|
|
|
|
@property
|
|
def audit_log(self): return list(self._audit_log)
|