Compare commits

...

5 Commits

Author SHA1 Message Date
84d247d320 feat: add A2A tests (#804)
Some checks failed
Docker Build and Publish / build-and-push (pull_request) Has been skipped
Contributor Attribution Check / check-attribution (pull_request) Failing after 37s
Supply Chain Audit / Scan PR for supply chain risks (pull_request) Successful in 39s
Tests / e2e (pull_request) Successful in 2m57s
Tests / test (pull_request) Failing after 58m20s
2026-04-15 23:03:16 +00:00
5a3f53c8b5 feat: add A2A server (#804) 2026-04-15 23:02:27 +00:00
fec32b5659 feat: add A2A client (#804) 2026-04-15 22:58:03 +00:00
c2821055d9 feat: add A2A types (#804) 2026-04-15 22:57:00 +00:00
5a03eba791 feat: add A2A protocol implementation (#804) 2026-04-15 22:55:21 +00:00
5 changed files with 526 additions and 0 deletions

22
a2a/__init__.py Normal file
View File

@@ -0,0 +1,22 @@
"""A2A — Agent2Agent Protocol v1.0 for Hermes task delegation.
Usage:
from a2a import A2AClient, A2AServer, AgentCard, Task, TextPart
"""
from a2a.types import (
AgentCard, AgentSkill, Artifact, DataPart, FilePart,
JSONRPCError, JSONRPCRequest, JSONRPCResponse,
Message, Part, Task, TaskState, TaskStatus, TextPart, A2AError,
part_from_dict,
)
from a2a.client import A2AClient, A2AClientConfig
from a2a.server import A2AServer, TaskHandler
__all__ = [
"A2AClient", "A2AClientConfig", "A2AServer", "TaskHandler",
"AgentCard", "AgentSkill", "Artifact",
"DataPart", "FilePart", "TextPart", "Part", "part_from_dict",
"JSONRPCError", "JSONRPCRequest", "JSONRPCResponse",
"Message", "Task", "TaskState", "TaskStatus", "A2AError",
]

98
a2a/client.py Normal file
View File

@@ -0,0 +1,98 @@
"""A2A Client - send tasks to remote agents via JSON-RPC 2.0."""
from __future__ import annotations
import asyncio, json, logging, time
from typing import Any, Dict, List, Optional
import aiohttp
from a2a.types import (AgentCard, Artifact, JSONRPCError, JSONRPCRequest, JSONRPCResponse, Message, Part, Task, TaskState, TaskStatus, TextPart, A2AError)
logger = logging.getLogger(__name__)
class A2AClientConfig:
def __init__(self, timeout=30.0, max_retries=3, retry_delay=2.0, auth_token=None, auth_scheme="Bearer"):
self.timeout = timeout
self.max_retries = max_retries
self.retry_delay = retry_delay
self.auth_token = auth_token
self.auth_scheme = auth_scheme
class A2AClient:
def __init__(self, base_url, config=None):
self.base_url = base_url.rstrip("/")
self.config = config or A2AClientConfig()
self._audit_log = []
def _headers(self):
h = {"Content-Type": "application/json"}
if self.config.auth_token:
h["Authorization"] = f"{self.config.auth_scheme} {self.config.auth_token}"
return h
async def _rpc_call(self, method, params=None):
req = JSONRPCRequest(method=method, params=params)
payload = json.dumps(req.to_dict())
last_error = None
for attempt in range(1, self.config.max_retries + 1):
t0 = time.monotonic()
try:
timeout = aiohttp.ClientTimeout(total=self.config.timeout)
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.post(f"{self.base_url}/a2a/v1/rpc", data=payload, headers=self._headers()) as resp:
body = await resp.text()
elapsed = time.monotonic() - t0
self._audit_log.append({"method": method, "params": params, "status": resp.status, "elapsed_ms": round(elapsed*1000), "attempt": attempt})
if resp.status != 200:
logger.warning("A2A %s -> HTTP %d (attempt %d/%d)", method, resp.status, attempt, self.config.max_retries)
if attempt < self.config.max_retries:
await asyncio.sleep(self.config.retry_delay * attempt)
continue
return JSONRPCResponse(id=req.id, error=A2AError.internal_error(f"HTTP {resp.status}: {body[:200]}"))
return JSONRPCResponse.from_dict(json.loads(body))
except Exception as exc:
elapsed = time.monotonic() - t0
last_error = exc
self._audit_log.append({"method": method, "error": str(exc), "elapsed_ms": round(elapsed*1000), "attempt": attempt})
if attempt < self.config.max_retries:
await asyncio.sleep(self.config.retry_delay * attempt)
continue
return JSONRPCResponse(id=req.id, error=A2AError.internal_error(str(last_error)))
async def get_agent_card(self):
resp = await self._rpc_call("GetAgentCard")
if resp.error: raise RuntimeError(f"Failed to get agent card: {resp.error.message}")
return AgentCard.from_dict(resp.result)
async def send_message(self, text, context_id=None, skill_id=None, metadata=None):
msg = Message(role="user", parts=[TextPart(text=text)], context_id=context_id)
params = {"message": msg.to_dict()}
if skill_id: params["skillId"] = skill_id
if metadata: params["metadata"] = metadata
resp = await self._rpc_call("SendMessage", params)
if resp.error: raise RuntimeError(f"SendMessage failed: {resp.error.message}")
return Task.from_dict(resp.result)
async def get_task(self, task_id):
resp = await self._rpc_call("GetTask", {"taskId": task_id})
if resp.error: raise RuntimeError(f"GetTask failed: {resp.error.message}")
return Task.from_dict(resp.result)
async def cancel_task(self, task_id):
resp = await self._rpc_call("CancelTask", {"taskId": task_id})
if resp.error: raise RuntimeError(f"CancelTask failed: {resp.error.message}")
return Task.from_dict(resp.result)
async def wait_for_completion(self, task_id, poll_interval=2.0, timeout=300.0):
t0 = time.monotonic()
while True:
task = await self.get_task(task_id)
if task.status.state.terminal: return task
if time.monotonic() - t0 > timeout: raise TimeoutError(f"Task {task_id} did not complete within {timeout}s")
await asyncio.sleep(poll_interval)
async def delegate(self, text, skill_id=None, wait=True, timeout=300.0):
task = await self.send_message(text, skill_id=skill_id)
if wait: return await self.wait_for_completion(task.id, timeout=timeout)
return task
@property
def audit_log(self): return list(self._audit_log)

60
a2a/server.py Normal file
View File

@@ -0,0 +1,60 @@
"""A2A Server - receive and execute tasks via JSON-RPC 2.0."""
from __future__ import annotations
import json,logging,uuid
from datetime import datetime,timezone
from typing import Any,Callable,Dict,List,Optional,Awaitable
from a2a.types import AgentCard,Artifact,JSONRPCError,JSONRPCRequest,JSONRPCResponse,Message,Task,TaskState,TaskStatus,TextPart,A2AError,part_from_dict
logger=logging.getLogger(__name__)
TaskHandler=Callable[[Task,AgentCard],Awaitable[Task]]
class A2AServer:
def __init__(s,card):s.card=card;s._tasks={};s._handlers={};s._default_handler=None;s._audit_log=[]
def register_handler(s,skill_id,handler):s._handlers[skill_id]=handler
def set_default_handler(s,handler):s._default_handler=handler
async def handle_rpc(s,raw):
try:data=json.loads(raw)
except (json.JSONDecodeError,TypeError):return json.dumps(JSONRPCResponse(id="",error=A2AError.parse_error()).to_dict())
req_id=data.get("id","");method=data.get("method","");params=data.get("params")
s._audit_log.append({"method":method,"id":req_id,"ts":datetime.now(timezone.utc).isoformat()})
try:
if method=="SendMessage":result=await s._handle_send_message(params)
elif method=="GetTask":result=s._handle_get_task(params)
elif method=="CancelTask":result=s._handle_cancel_task(params)
elif method=="GetAgentCard":result=s.card.to_dict()
elif method=="ListTasks":result=s._handle_list_tasks(params)
else:return json.dumps(JSONRPCResponse(id=req_id,error=A2AError.method_not_found()).to_dict())
return json.dumps(JSONRPCResponse(id=req_id,result=result).to_dict())
except Exception as exc:
logger.exception("A2A handler error for %s",method)
return json.dumps(JSONRPCResponse(id=req_id,error=A2AError.internal_error(str(exc))).to_dict())
async def _handle_send_message(s,params):
if not params or "message" not in params:raise ValueError("SendMessage requires message param")
msg=Message.from_dict(params["message"])
task=Task(id=str(uuid.uuid4()),context_id=msg.context_id,status=TaskStatus(state=TaskState.SUBMITTED),history=[msg])
s._tasks[task.id]=task
skill_id=params.get("skillId");handler=s._handlers.get(skill_id) if skill_id else None
if handler is None:handler=s._default_handler
if handler is None:
text=msg.parts[0].text if msg.parts and hasattr(msg.parts[0],"text") else ""
task.status=TaskStatus(state=TaskState.COMPLETED)
task.artifacts=[Artifact(parts=[TextPart(text=f"Received: {text}")])]
else:
task.status=TaskStatus(state=TaskState.WORKING);s._tasks[task.id]=task
task=await handler(task,s.card);s._tasks[task.id]=task
return task.to_dict()
def _handle_get_task(s,params):
task_id=(params or {}).get("taskId")
if not task_id or task_id not in s._tasks:raise KeyError(f"Task not found: {task_id}")
return s._tasks[task_id].to_dict()
def _handle_cancel_task(s,params):
task_id=(params or {}).get("taskId")
if not task_id or task_id not in s._tasks:raise KeyError(f"Task not found: {task_id}")
task=s._tasks[task_id]
if task.status.state.terminal:raise ValueError(f"Task already {task.status.state.value}")
task.status=TaskStatus(state=TaskState.CANCELED);s._tasks[task_id]=task
return task.to_dict()
def _handle_list_tasks(s,params):
context_id=(params or {}).get("contextId")
return {"tasks":[t.to_dict() for t in s._tasks.values() if not context_id or t.context_id==context_id]}
def add_task(s,task):s._tasks[task.id]=task
@property
def audit_log(s):return list(s._audit_log)

230
a2a/types.py Normal file
View File

@@ -0,0 +1,230 @@
"""A2A Protocol Types - Agent2Agent v1.0 data structures."""
from __future__ import annotations
import uuid
from dataclasses import dataclass, field
from datetime import datetime, timezone
from enum import Enum
from typing import Any, Dict, List, Optional, Union
class TaskState(str, Enum):
SUBMITTED = "TASK_STATE_SUBMITTED"
WORKING = "TASK_STATE_WORKING"
INPUT_REQUIRED = "TASK_STATE_INPUT_REQUIRED"
COMPLETED = "TASK_STATE_COMPLETED"
FAILED = "TASK_STATE_FAILED"
CANCELED = "TASK_STATE_CANCELED"
REJECTED = "TASK_STATE_REJECTED"
@property
def terminal(self) -> bool:
return self in {TaskState.COMPLETED, TaskState.FAILED, TaskState.CANCELED, TaskState.REJECTED}
@dataclass
class TextPart:
text: str
metadata: Optional[Dict[str, Any]] = None
def to_dict(self) -> dict:
d = {"text": self.text}
if self.metadata: d["metadata"] = self.metadata
return d
@classmethod
def from_dict(cls, d): return cls(text=d["text"], metadata=d.get("metadata"))
@dataclass
class FilePart:
media_type: str = "application/octet-stream"
raw: Optional[str] = None
url: Optional[str] = None
filename: Optional[str] = None
metadata: Optional[Dict[str, Any]] = None
def to_dict(self) -> dict:
d = {"mediaType": self.media_type}
if self.raw is not None: d["raw"] = self.raw
if self.url is not None: d["url"] = self.url
if self.filename: d["filename"] = self.filename
if self.metadata: d["metadata"] = self.metadata
return d
@classmethod
def from_dict(cls, d):
return cls(media_type=d.get("mediaType","application/octet-stream"), raw=d.get("raw"), url=d.get("url"), filename=d.get("filename"), metadata=d.get("metadata"))
@dataclass
class DataPart:
data: Dict[str, Any]
media_type: str = "application/json"
metadata: Optional[Dict[str, Any]] = None
def to_dict(self) -> dict:
d = {"data": self.data, "mediaType": self.media_type}
if self.metadata: d["metadata"] = self.metadata
return d
@classmethod
def from_dict(cls, d): return cls(data=d["data"], media_type=d.get("mediaType","application/json"), metadata=d.get("metadata"))
Part = Union[TextPart, FilePart, DataPart]
def part_from_dict(d):
if "text" in d: return TextPart.from_dict(d)
if "raw" in d or "url" in d: return FilePart.from_dict(d)
if "data" in d: return DataPart.from_dict(d)
raise ValueError(f"Cannot discriminate Part type from keys: {list(d.keys())}")
@dataclass
class Message:
role: str
parts: List[Part]
message_id: str = field(default_factory=lambda: str(uuid.uuid4()))
context_id: Optional[str] = None
task_id: Optional[str] = None
metadata: Optional[Dict[str, Any]] = None
def to_dict(self):
d = {"role": self.role, "messageId": self.message_id, "parts": [p.to_dict() for p in self.parts]}
if self.context_id: d["contextId"] = self.context_id
if self.task_id: d["taskId"] = self.task_id
if self.metadata: d["metadata"] = self.metadata
return d
@classmethod
def from_dict(cls, d):
return cls(role=d["role"], parts=[part_from_dict(p) for p in d["parts"]], message_id=d.get("messageId",str(uuid.uuid4())), context_id=d.get("contextId"), task_id=d.get("taskId"), metadata=d.get("metadata"))
@dataclass
class Artifact:
artifact_id: str = field(default_factory=lambda: str(uuid.uuid4()))
parts: List[Part] = field(default_factory=list)
name: Optional[str] = None
description: Optional[str] = None
metadata: Optional[Dict[str, Any]] = None
def to_dict(self):
d = {"artifactId": self.artifact_id, "parts": [p.to_dict() for p in self.parts]}
if self.name: d["name"] = self.name
if self.description: d["description"] = self.description
if self.metadata: d["metadata"] = self.metadata
return d
@classmethod
def from_dict(cls, d):
return cls(artifact_id=d.get("artifactId",str(uuid.uuid4())), parts=[part_from_dict(p) for p in d.get("parts",[])], name=d.get("name"), description=d.get("description"), metadata=d.get("metadata"))
@dataclass
class TaskStatus:
state: TaskState
message: Optional[Message] = None
timestamp: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
def to_dict(self):
d = {"state": self.state.value, "timestamp": self.timestamp}
if self.message: d["message"] = self.message.to_dict()
return d
@classmethod
def from_dict(cls, d):
msg = d.get("message")
return cls(state=TaskState(d["state"]), message=Message.from_dict(msg) if msg else None, timestamp=d.get("timestamp",datetime.now(timezone.utc).isoformat()))
@dataclass
class Task:
id: str = field(default_factory=lambda: str(uuid.uuid4()))
context_id: Optional[str] = None
status: TaskStatus = field(default_factory=lambda: TaskStatus(state=TaskState.SUBMITTED))
artifacts: List[Artifact] = field(default_factory=list)
history: List[Message] = field(default_factory=list)
metadata: Optional[Dict[str, Any]] = None
def to_dict(self):
d = {"id": self.id, "status": self.status.to_dict()}
if self.context_id: d["contextId"] = self.context_id
if self.artifacts: d["artifacts"] = [a.to_dict() for a in self.artifacts]
if self.history: d["history"] = [m.to_dict() for m in self.history]
if self.metadata: d["metadata"] = self.metadata
return d
@classmethod
def from_dict(cls, d):
return cls(id=d.get("id",str(uuid.uuid4())), context_id=d.get("contextId"), status=TaskStatus.from_dict(d["status"]) if "status" in d else TaskStatus(TaskState.SUBMITTED), artifacts=[Artifact.from_dict(a) for a in d.get("artifacts",[])], history=[Message.from_dict(m) for m in d.get("history",[])], metadata=d.get("metadata"))
@dataclass
class AgentSkill:
id: str
name: str
description: str = ""
tags: List[str] = field(default_factory=list)
examples: List[str] = field(default_factory=list)
input_modes: List[str] = field(default_factory=lambda: ["text"])
output_modes: List[str] = field(default_factory=lambda: ["text"])
def to_dict(self):
d = {"id": self.id, "name": self.name, "description": self.description}
if self.tags: d["tags"] = self.tags
if self.examples: d["examples"] = self.examples
if self.input_modes != ["text"]: d["inputModes"] = self.input_modes
if self.output_modes != ["text"]: d["outputModes"] = self.output_modes
return d
@classmethod
def from_dict(cls, d):
return cls(id=d["id"], name=d.get("name",d["id"]), description=d.get("description",""), tags=d.get("tags",[]), examples=d.get("examples",[]), input_modes=d.get("inputModes",["text"]), output_modes=d.get("outputModes",["text"]))
@dataclass
class AgentCard:
name: str
description: str = ""
version: str = "1.0.0"
url: str = ""
skills: List[AgentSkill] = field(default_factory=list)
capabilities: Dict[str, bool] = field(default_factory=dict)
provider: Optional[Dict[str, str]] = None
def to_dict(self):
d = {"name": self.name, "description": self.description, "version": self.version, "url": self.url, "skills": [s.to_dict() for s in self.skills]}
if self.capabilities: d["capabilities"] = self.capabilities
if self.provider: d["provider"] = self.provider
return d
@classmethod
def from_dict(cls, d):
return cls(name=d["name"], description=d.get("description",""), version=d.get("version","1.0.0"), url=d.get("url",""), skills=[AgentSkill.from_dict(s) for s in d.get("skills",[])], capabilities=d.get("capabilities",{}), provider=d.get("provider"))
@dataclass
class JSONRPCError:
code: int
message: str
data: Optional[Any] = None
def to_dict(self):
d = {"code": self.code, "message": self.message}
if self.data is not None: d["data"] = self.data
return d
@dataclass
class JSONRPCRequest:
method: str
params: Optional[Dict[str, Any]] = None
id: str = field(default_factory=lambda: str(uuid.uuid4()))
jsonrpc: str = "2.0"
def to_dict(self):
d = {"jsonrpc": self.jsonrpc, "method": self.method, "id": self.id}
if self.params is not None: d["params"] = self.params
return d
@dataclass
class JSONRPCResponse:
id: str
result: Optional[Any] = None
error: Optional[JSONRPCError] = None
jsonrpc: str = "2.0"
def to_dict(self):
d = {"jsonrpc": self.jsonrpc, "id": self.id}
if self.error: d["error"] = self.error.to_dict()
else: d["result"] = self.result
return d
@classmethod
def from_dict(cls, d):
err = d.get("error")
return cls(id=d["id"], result=d.get("result"), error=JSONRPCError(err["code"],err["message"],err.get("data")) if err else None)
class A2AError:
@staticmethod
def parse_error(data=None): return JSONRPCError(-32700, "Parse error", data)
@staticmethod
def invalid_request(data=None): return JSONRPCError(-32600, "Invalid Request", data)
@staticmethod
def method_not_found(data=None): return JSONRPCError(-32601, "Method not found", data)
@staticmethod
def invalid_params(data=None): return JSONRPCError(-32602, "Invalid params", data)
@staticmethod
def internal_error(data=None): return JSONRPCError(-32603, "Internal error", data)
@staticmethod
def task_not_found(task_id): return JSONRPCError(-32001, f"Task not found: {task_id}")
@staticmethod
def task_not_cancelable(task_id): return JSONRPCError(-32002, f"Task not cancelable: {task_id}")
@staticmethod
def agent_not_found(name): return JSONRPCError(-32009, f"Agent not found: {name}")

116
tests/test_a2a.py Normal file
View File

@@ -0,0 +1,116 @@
"""Tests for A2A protocol implementation."""
import asyncio,json,pytest
from a2a.types import AgentCard,AgentSkill,Artifact,DataPart,FilePart,JSONRPCError,JSONRPCRequest,JSONRPCResponse,Message,Task,TaskState,TaskStatus,TextPart,A2AError,part_from_dict
from a2a.client import A2AClient,A2AClientConfig
from a2a.server import A2AServer
class TestTextPart:
def test_roundtrip(self):
p=TextPart(text="hello",metadata={"k":"v"});d=p.to_dict();assert d=={"text":"hello","metadata":{"k":"v"}};p2=TextPart.from_dict(d);assert p2.text=="hello"
def test_no_metadata(self):
p=TextPart(text="hi");d=p.to_dict();assert "metadata" not in d
class TestFilePart:
def test_inline(self):
p=FilePart(media_type="text/plain",raw="SGVsbG8=",filename="hello.txt");d=p.to_dict();assert d["raw"]=="SGVsbG8=";p2=FilePart.from_dict(d);assert p2.filename=="hello.txt"
def test_url(self):
p=FilePart(url="https://x.com/f");d=p.to_dict();assert d["url"]=="https://x.com/f"
class TestDataPart:
def test_roundtrip(self):
p=DataPart(data={"key":42});d=p.to_dict();assert d["data"]=={"key":42}
class TestPartDiscrimination:
def test_text(self):assert isinstance(part_from_dict({"text":"hi"}),TextPart)
def test_file_raw(self):assert isinstance(part_from_dict({"raw":"d","mediaType":"t"}),FilePart)
def test_file_url(self):assert isinstance(part_from_dict({"url":"https://x.com"}),FilePart)
def test_data(self):assert isinstance(part_from_dict({"data":{"a":1}}),DataPart)
def test_unknown(self):
with pytest.raises(ValueError):part_from_dict({"unknown":True})
class TestMessage:
def test_roundtrip(self):
m=Message(role="user",parts=[TextPart(text="hi")],context_id="c1");d=m.to_dict();assert d["role"]=="user";m2=Message.from_dict(d);assert m2.parts[0].text=="hi"
class TestArtifact:
def test_roundtrip(self):
a=Artifact(name="r",parts=[TextPart(text="d")]);d=a.to_dict();assert d["name"]=="r"
class TestTaskStatus:
def test_roundtrip(self):
s=TaskStatus(state=TaskState.COMPLETED);d=s.to_dict();assert d["state"]=="TASK_STATE_COMPLETED";s2=TaskStatus.from_dict(d);assert s2.state==TaskState.COMPLETED
def test_terminal(self):
assert TaskState.COMPLETED.terminal;assert TaskState.FAILED.terminal;assert not TaskState.SUBMITTED.terminal
class TestTask:
def test_roundtrip(self):
t=Task(id="t1",status=TaskStatus(state=TaskState.WORKING));d=t.to_dict();assert d["id"]=="t1";t2=Task.from_dict(d);assert t2.status.state==TaskState.WORKING
class TestAgentCard:
def test_roundtrip(self):
c=AgentCard(name="a",skills=[AgentSkill(id="s1",name="S")]);d=c.to_dict();assert d["name"]=="a";c2=AgentCard.from_dict(d);assert c2.skills[0].id=="s1"
class TestJSONRPC:
def test_request(self):
r=JSONRPCRequest(method="GetTask",params={"taskId":"1"});d=r.to_dict();assert d["method"]=="GetTask"
def test_response_success(self):
r=JSONRPCResponse(id="1",result={"ok":True});d=r.to_dict();assert d["result"]["ok"]==True
def test_response_error(self):
r=JSONRPCResponse(id="1",error=A2AError.parse_error());d=r.to_dict();assert d["error"]["code"]==-32700
def test_response_from_dict(self):
r=JSONRPCResponse.from_dict({"jsonrpc":"2.0","id":"1","result":{"ok":True}});assert r.result["ok"]==True
class TestA2AError:
def test_codes(self):
assert A2AError.parse_error().code==-32700;assert A2AError.invalid_request().code==-32600
assert A2AError.method_not_found().code==-32601;assert A2AError.task_not_found("x").code==-32001
def _make_card():
return AgentCard(name="test",description="Test",url="http://localhost:9999/a2a/v1",skills=[AgentSkill(id="echo",name="Echo",tags=["test"])])
def _run(coro):return asyncio.get_event_loop().run_until_complete(coro)
class TestServer:
def test_echo(self):
s=A2AServer(_make_card());msg=Message(role="user",parts=[TextPart(text="hello")])
raw=json.dumps(JSONRPCRequest(method="SendMessage",params={"message":msg.to_dict()}).to_dict())
resp=json.loads(_run(s.handle_rpc(raw)));assert "error" not in resp;assert resp["result"]["status"]["state"]=="TASK_STATE_COMPLETED"
def test_get_task(self):
s=A2AServer(_make_card());msg=Message(role="user",parts=[TextPart(text="hi")])
raw=json.dumps(JSONRPCRequest(method="SendMessage",params={"message":msg.to_dict()}).to_dict())
task_id=json.loads(_run(s.handle_rpc(raw)))["result"]["id"]
raw2=json.dumps(JSONRPCRequest(method="GetTask",params={"taskId":task_id}).to_dict())
resp=json.loads(_run(s.handle_rpc(raw2)));assert resp["result"]["id"]==task_id
def test_cancel(self):
s=A2AServer(_make_card());task=Task(id="c1",status=TaskStatus(state=TaskState.WORKING));s.add_task(task)
raw=json.dumps(JSONRPCRequest(method="CancelTask",params={"taskId":"c1"}).to_dict())
resp=json.loads(_run(s.handle_rpc(raw)));assert resp["result"]["status"]["state"]=="TASK_STATE_CANCELED"
def test_cancel_terminal(self):
s=A2AServer(_make_card());task=Task(id="d1",status=TaskStatus(state=TaskState.COMPLETED));s.add_task(task)
raw=json.dumps(JSONRPCRequest(method="CancelTask",params={"taskId":"d1"}).to_dict())
resp=json.loads(_run(s.handle_rpc(raw)));assert "error" in resp
def test_card(self):
s=A2AServer(_make_card());raw=json.dumps(JSONRPCRequest(method="GetAgentCard").to_dict())
resp=json.loads(_run(s.handle_rpc(raw)));assert resp["result"]["name"]=="test"
def test_list(self):
s=A2AServer(_make_card());msg=Message(role="user",parts=[TextPart(text="hi")])
raw=json.dumps(JSONRPCRequest(method="SendMessage",params={"message":msg.to_dict()}).to_dict())
_run(s.handle_rpc(raw));raw2=json.dumps(JSONRPCRequest(method="ListTasks").to_dict())
resp=json.loads(_run(s.handle_rpc(raw2)));assert len(resp["result"]["tasks"])>=1
def test_unknown(self):
s=A2AServer(_make_card());raw=json.dumps(JSONRPCRequest(method="Nope").to_dict())
resp=json.loads(_run(s.handle_rpc(raw)));assert resp["error"]["code"]==-32601
def test_invalid_json(self):
s=A2AServer(_make_card());resp=json.loads(_run(s.handle_rpc("bad")));assert resp["error"]["code"]==-32700
def test_custom_handler(self):
s=A2AServer(_make_card())
async def h(task,card):task.status=TaskStatus(state=TaskState.COMPLETED);task.artifacts=[Artifact(parts=[TextPart(text="custom")])];return task
s.register_handler("echo",h)
msg=Message(role="user",parts=[TextPart(text="t")])
raw=json.dumps(JSONRPCRequest(method="SendMessage",params={"message":msg.to_dict(),"skillId":"echo"}).to_dict())
resp=json.loads(_run(s.handle_rpc(raw)));assert resp["result"]["artifacts"][0]["parts"][0]["text"]=="custom"
def test_audit(self):
s=A2AServer(_make_card());raw=json.dumps(JSONRPCRequest(method="GetAgentCard").to_dict())
_run(s.handle_rpc(raw));assert len(s.audit_log)==1;assert s.audit_log[0]["method"]=="GetAgentCard"
if __name__=="__main__":pytest.main([__file__,"-v"])