Compare commits
5 Commits
fix/format
...
burn/804-1
| Author | SHA1 | Date | |
|---|---|---|---|
| 84d247d320 | |||
| 5a3f53c8b5 | |||
| fec32b5659 | |||
| c2821055d9 | |||
| 5a03eba791 |
22
a2a/__init__.py
Normal file
22
a2a/__init__.py
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
"""A2A — Agent2Agent Protocol v1.0 for Hermes task delegation.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
from a2a import A2AClient, A2AServer, AgentCard, Task, TextPart
|
||||||
|
"""
|
||||||
|
|
||||||
|
from a2a.types import (
|
||||||
|
AgentCard, AgentSkill, Artifact, DataPart, FilePart,
|
||||||
|
JSONRPCError, JSONRPCRequest, JSONRPCResponse,
|
||||||
|
Message, Part, Task, TaskState, TaskStatus, TextPart, A2AError,
|
||||||
|
part_from_dict,
|
||||||
|
)
|
||||||
|
from a2a.client import A2AClient, A2AClientConfig
|
||||||
|
from a2a.server import A2AServer, TaskHandler
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"A2AClient", "A2AClientConfig", "A2AServer", "TaskHandler",
|
||||||
|
"AgentCard", "AgentSkill", "Artifact",
|
||||||
|
"DataPart", "FilePart", "TextPart", "Part", "part_from_dict",
|
||||||
|
"JSONRPCError", "JSONRPCRequest", "JSONRPCResponse",
|
||||||
|
"Message", "Task", "TaskState", "TaskStatus", "A2AError",
|
||||||
|
]
|
||||||
98
a2a/client.py
Normal file
98
a2a/client.py
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
"""A2A Client - send tasks to remote agents via JSON-RPC 2.0."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
import asyncio, json, logging, time
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
import aiohttp
|
||||||
|
from a2a.types import (AgentCard, Artifact, JSONRPCError, JSONRPCRequest, JSONRPCResponse, Message, Part, Task, TaskState, TaskStatus, TextPart, A2AError)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class A2AClientConfig:
|
||||||
|
def __init__(self, timeout=30.0, max_retries=3, retry_delay=2.0, auth_token=None, auth_scheme="Bearer"):
|
||||||
|
self.timeout = timeout
|
||||||
|
self.max_retries = max_retries
|
||||||
|
self.retry_delay = retry_delay
|
||||||
|
self.auth_token = auth_token
|
||||||
|
self.auth_scheme = auth_scheme
|
||||||
|
|
||||||
|
class A2AClient:
|
||||||
|
def __init__(self, base_url, config=None):
|
||||||
|
self.base_url = base_url.rstrip("/")
|
||||||
|
self.config = config or A2AClientConfig()
|
||||||
|
self._audit_log = []
|
||||||
|
|
||||||
|
def _headers(self):
|
||||||
|
h = {"Content-Type": "application/json"}
|
||||||
|
if self.config.auth_token:
|
||||||
|
h["Authorization"] = f"{self.config.auth_scheme} {self.config.auth_token}"
|
||||||
|
return h
|
||||||
|
|
||||||
|
async def _rpc_call(self, method, params=None):
|
||||||
|
req = JSONRPCRequest(method=method, params=params)
|
||||||
|
payload = json.dumps(req.to_dict())
|
||||||
|
last_error = None
|
||||||
|
for attempt in range(1, self.config.max_retries + 1):
|
||||||
|
t0 = time.monotonic()
|
||||||
|
try:
|
||||||
|
timeout = aiohttp.ClientTimeout(total=self.config.timeout)
|
||||||
|
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||||
|
async with session.post(f"{self.base_url}/a2a/v1/rpc", data=payload, headers=self._headers()) as resp:
|
||||||
|
body = await resp.text()
|
||||||
|
elapsed = time.monotonic() - t0
|
||||||
|
self._audit_log.append({"method": method, "params": params, "status": resp.status, "elapsed_ms": round(elapsed*1000), "attempt": attempt})
|
||||||
|
if resp.status != 200:
|
||||||
|
logger.warning("A2A %s -> HTTP %d (attempt %d/%d)", method, resp.status, attempt, self.config.max_retries)
|
||||||
|
if attempt < self.config.max_retries:
|
||||||
|
await asyncio.sleep(self.config.retry_delay * attempt)
|
||||||
|
continue
|
||||||
|
return JSONRPCResponse(id=req.id, error=A2AError.internal_error(f"HTTP {resp.status}: {body[:200]}"))
|
||||||
|
return JSONRPCResponse.from_dict(json.loads(body))
|
||||||
|
except Exception as exc:
|
||||||
|
elapsed = time.monotonic() - t0
|
||||||
|
last_error = exc
|
||||||
|
self._audit_log.append({"method": method, "error": str(exc), "elapsed_ms": round(elapsed*1000), "attempt": attempt})
|
||||||
|
if attempt < self.config.max_retries:
|
||||||
|
await asyncio.sleep(self.config.retry_delay * attempt)
|
||||||
|
continue
|
||||||
|
return JSONRPCResponse(id=req.id, error=A2AError.internal_error(str(last_error)))
|
||||||
|
|
||||||
|
async def get_agent_card(self):
|
||||||
|
resp = await self._rpc_call("GetAgentCard")
|
||||||
|
if resp.error: raise RuntimeError(f"Failed to get agent card: {resp.error.message}")
|
||||||
|
return AgentCard.from_dict(resp.result)
|
||||||
|
|
||||||
|
async def send_message(self, text, context_id=None, skill_id=None, metadata=None):
|
||||||
|
msg = Message(role="user", parts=[TextPart(text=text)], context_id=context_id)
|
||||||
|
params = {"message": msg.to_dict()}
|
||||||
|
if skill_id: params["skillId"] = skill_id
|
||||||
|
if metadata: params["metadata"] = metadata
|
||||||
|
resp = await self._rpc_call("SendMessage", params)
|
||||||
|
if resp.error: raise RuntimeError(f"SendMessage failed: {resp.error.message}")
|
||||||
|
return Task.from_dict(resp.result)
|
||||||
|
|
||||||
|
async def get_task(self, task_id):
|
||||||
|
resp = await self._rpc_call("GetTask", {"taskId": task_id})
|
||||||
|
if resp.error: raise RuntimeError(f"GetTask failed: {resp.error.message}")
|
||||||
|
return Task.from_dict(resp.result)
|
||||||
|
|
||||||
|
async def cancel_task(self, task_id):
|
||||||
|
resp = await self._rpc_call("CancelTask", {"taskId": task_id})
|
||||||
|
if resp.error: raise RuntimeError(f"CancelTask failed: {resp.error.message}")
|
||||||
|
return Task.from_dict(resp.result)
|
||||||
|
|
||||||
|
async def wait_for_completion(self, task_id, poll_interval=2.0, timeout=300.0):
|
||||||
|
t0 = time.monotonic()
|
||||||
|
while True:
|
||||||
|
task = await self.get_task(task_id)
|
||||||
|
if task.status.state.terminal: return task
|
||||||
|
if time.monotonic() - t0 > timeout: raise TimeoutError(f"Task {task_id} did not complete within {timeout}s")
|
||||||
|
await asyncio.sleep(poll_interval)
|
||||||
|
|
||||||
|
async def delegate(self, text, skill_id=None, wait=True, timeout=300.0):
|
||||||
|
task = await self.send_message(text, skill_id=skill_id)
|
||||||
|
if wait: return await self.wait_for_completion(task.id, timeout=timeout)
|
||||||
|
return task
|
||||||
|
|
||||||
|
@property
|
||||||
|
def audit_log(self): return list(self._audit_log)
|
||||||
60
a2a/server.py
Normal file
60
a2a/server.py
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
"""A2A Server - receive and execute tasks via JSON-RPC 2.0."""
|
||||||
|
from __future__ import annotations
|
||||||
|
import json,logging,uuid
|
||||||
|
from datetime import datetime,timezone
|
||||||
|
from typing import Any,Callable,Dict,List,Optional,Awaitable
|
||||||
|
from a2a.types import AgentCard,Artifact,JSONRPCError,JSONRPCRequest,JSONRPCResponse,Message,Task,TaskState,TaskStatus,TextPart,A2AError,part_from_dict
|
||||||
|
logger=logging.getLogger(__name__)
|
||||||
|
TaskHandler=Callable[[Task,AgentCard],Awaitable[Task]]
|
||||||
|
class A2AServer:
|
||||||
|
def __init__(s,card):s.card=card;s._tasks={};s._handlers={};s._default_handler=None;s._audit_log=[]
|
||||||
|
def register_handler(s,skill_id,handler):s._handlers[skill_id]=handler
|
||||||
|
def set_default_handler(s,handler):s._default_handler=handler
|
||||||
|
async def handle_rpc(s,raw):
|
||||||
|
try:data=json.loads(raw)
|
||||||
|
except (json.JSONDecodeError,TypeError):return json.dumps(JSONRPCResponse(id="",error=A2AError.parse_error()).to_dict())
|
||||||
|
req_id=data.get("id","");method=data.get("method","");params=data.get("params")
|
||||||
|
s._audit_log.append({"method":method,"id":req_id,"ts":datetime.now(timezone.utc).isoformat()})
|
||||||
|
try:
|
||||||
|
if method=="SendMessage":result=await s._handle_send_message(params)
|
||||||
|
elif method=="GetTask":result=s._handle_get_task(params)
|
||||||
|
elif method=="CancelTask":result=s._handle_cancel_task(params)
|
||||||
|
elif method=="GetAgentCard":result=s.card.to_dict()
|
||||||
|
elif method=="ListTasks":result=s._handle_list_tasks(params)
|
||||||
|
else:return json.dumps(JSONRPCResponse(id=req_id,error=A2AError.method_not_found()).to_dict())
|
||||||
|
return json.dumps(JSONRPCResponse(id=req_id,result=result).to_dict())
|
||||||
|
except Exception as exc:
|
||||||
|
logger.exception("A2A handler error for %s",method)
|
||||||
|
return json.dumps(JSONRPCResponse(id=req_id,error=A2AError.internal_error(str(exc))).to_dict())
|
||||||
|
async def _handle_send_message(s,params):
|
||||||
|
if not params or "message" not in params:raise ValueError("SendMessage requires message param")
|
||||||
|
msg=Message.from_dict(params["message"])
|
||||||
|
task=Task(id=str(uuid.uuid4()),context_id=msg.context_id,status=TaskStatus(state=TaskState.SUBMITTED),history=[msg])
|
||||||
|
s._tasks[task.id]=task
|
||||||
|
skill_id=params.get("skillId");handler=s._handlers.get(skill_id) if skill_id else None
|
||||||
|
if handler is None:handler=s._default_handler
|
||||||
|
if handler is None:
|
||||||
|
text=msg.parts[0].text if msg.parts and hasattr(msg.parts[0],"text") else ""
|
||||||
|
task.status=TaskStatus(state=TaskState.COMPLETED)
|
||||||
|
task.artifacts=[Artifact(parts=[TextPart(text=f"Received: {text}")])]
|
||||||
|
else:
|
||||||
|
task.status=TaskStatus(state=TaskState.WORKING);s._tasks[task.id]=task
|
||||||
|
task=await handler(task,s.card);s._tasks[task.id]=task
|
||||||
|
return task.to_dict()
|
||||||
|
def _handle_get_task(s,params):
|
||||||
|
task_id=(params or {}).get("taskId")
|
||||||
|
if not task_id or task_id not in s._tasks:raise KeyError(f"Task not found: {task_id}")
|
||||||
|
return s._tasks[task_id].to_dict()
|
||||||
|
def _handle_cancel_task(s,params):
|
||||||
|
task_id=(params or {}).get("taskId")
|
||||||
|
if not task_id or task_id not in s._tasks:raise KeyError(f"Task not found: {task_id}")
|
||||||
|
task=s._tasks[task_id]
|
||||||
|
if task.status.state.terminal:raise ValueError(f"Task already {task.status.state.value}")
|
||||||
|
task.status=TaskStatus(state=TaskState.CANCELED);s._tasks[task_id]=task
|
||||||
|
return task.to_dict()
|
||||||
|
def _handle_list_tasks(s,params):
|
||||||
|
context_id=(params or {}).get("contextId")
|
||||||
|
return {"tasks":[t.to_dict() for t in s._tasks.values() if not context_id or t.context_id==context_id]}
|
||||||
|
def add_task(s,task):s._tasks[task.id]=task
|
||||||
|
@property
|
||||||
|
def audit_log(s):return list(s._audit_log)
|
||||||
230
a2a/types.py
Normal file
230
a2a/types.py
Normal file
@@ -0,0 +1,230 @@
|
|||||||
|
"""A2A Protocol Types - Agent2Agent v1.0 data structures."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
import uuid
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
class TaskState(str, Enum):
|
||||||
|
SUBMITTED = "TASK_STATE_SUBMITTED"
|
||||||
|
WORKING = "TASK_STATE_WORKING"
|
||||||
|
INPUT_REQUIRED = "TASK_STATE_INPUT_REQUIRED"
|
||||||
|
COMPLETED = "TASK_STATE_COMPLETED"
|
||||||
|
FAILED = "TASK_STATE_FAILED"
|
||||||
|
CANCELED = "TASK_STATE_CANCELED"
|
||||||
|
REJECTED = "TASK_STATE_REJECTED"
|
||||||
|
@property
|
||||||
|
def terminal(self) -> bool:
|
||||||
|
return self in {TaskState.COMPLETED, TaskState.FAILED, TaskState.CANCELED, TaskState.REJECTED}
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TextPart:
|
||||||
|
text: str
|
||||||
|
metadata: Optional[Dict[str, Any]] = None
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
d = {"text": self.text}
|
||||||
|
if self.metadata: d["metadata"] = self.metadata
|
||||||
|
return d
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, d): return cls(text=d["text"], metadata=d.get("metadata"))
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FilePart:
|
||||||
|
media_type: str = "application/octet-stream"
|
||||||
|
raw: Optional[str] = None
|
||||||
|
url: Optional[str] = None
|
||||||
|
filename: Optional[str] = None
|
||||||
|
metadata: Optional[Dict[str, Any]] = None
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
d = {"mediaType": self.media_type}
|
||||||
|
if self.raw is not None: d["raw"] = self.raw
|
||||||
|
if self.url is not None: d["url"] = self.url
|
||||||
|
if self.filename: d["filename"] = self.filename
|
||||||
|
if self.metadata: d["metadata"] = self.metadata
|
||||||
|
return d
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, d):
|
||||||
|
return cls(media_type=d.get("mediaType","application/octet-stream"), raw=d.get("raw"), url=d.get("url"), filename=d.get("filename"), metadata=d.get("metadata"))
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DataPart:
|
||||||
|
data: Dict[str, Any]
|
||||||
|
media_type: str = "application/json"
|
||||||
|
metadata: Optional[Dict[str, Any]] = None
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
d = {"data": self.data, "mediaType": self.media_type}
|
||||||
|
if self.metadata: d["metadata"] = self.metadata
|
||||||
|
return d
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, d): return cls(data=d["data"], media_type=d.get("mediaType","application/json"), metadata=d.get("metadata"))
|
||||||
|
|
||||||
|
Part = Union[TextPart, FilePart, DataPart]
|
||||||
|
|
||||||
|
def part_from_dict(d):
|
||||||
|
if "text" in d: return TextPart.from_dict(d)
|
||||||
|
if "raw" in d or "url" in d: return FilePart.from_dict(d)
|
||||||
|
if "data" in d: return DataPart.from_dict(d)
|
||||||
|
raise ValueError(f"Cannot discriminate Part type from keys: {list(d.keys())}")
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Message:
|
||||||
|
role: str
|
||||||
|
parts: List[Part]
|
||||||
|
message_id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||||
|
context_id: Optional[str] = None
|
||||||
|
task_id: Optional[str] = None
|
||||||
|
metadata: Optional[Dict[str, Any]] = None
|
||||||
|
def to_dict(self):
|
||||||
|
d = {"role": self.role, "messageId": self.message_id, "parts": [p.to_dict() for p in self.parts]}
|
||||||
|
if self.context_id: d["contextId"] = self.context_id
|
||||||
|
if self.task_id: d["taskId"] = self.task_id
|
||||||
|
if self.metadata: d["metadata"] = self.metadata
|
||||||
|
return d
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, d):
|
||||||
|
return cls(role=d["role"], parts=[part_from_dict(p) for p in d["parts"]], message_id=d.get("messageId",str(uuid.uuid4())), context_id=d.get("contextId"), task_id=d.get("taskId"), metadata=d.get("metadata"))
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Artifact:
|
||||||
|
artifact_id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||||
|
parts: List[Part] = field(default_factory=list)
|
||||||
|
name: Optional[str] = None
|
||||||
|
description: Optional[str] = None
|
||||||
|
metadata: Optional[Dict[str, Any]] = None
|
||||||
|
def to_dict(self):
|
||||||
|
d = {"artifactId": self.artifact_id, "parts": [p.to_dict() for p in self.parts]}
|
||||||
|
if self.name: d["name"] = self.name
|
||||||
|
if self.description: d["description"] = self.description
|
||||||
|
if self.metadata: d["metadata"] = self.metadata
|
||||||
|
return d
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, d):
|
||||||
|
return cls(artifact_id=d.get("artifactId",str(uuid.uuid4())), parts=[part_from_dict(p) for p in d.get("parts",[])], name=d.get("name"), description=d.get("description"), metadata=d.get("metadata"))
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TaskStatus:
|
||||||
|
state: TaskState
|
||||||
|
message: Optional[Message] = None
|
||||||
|
timestamp: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
|
||||||
|
def to_dict(self):
|
||||||
|
d = {"state": self.state.value, "timestamp": self.timestamp}
|
||||||
|
if self.message: d["message"] = self.message.to_dict()
|
||||||
|
return d
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, d):
|
||||||
|
msg = d.get("message")
|
||||||
|
return cls(state=TaskState(d["state"]), message=Message.from_dict(msg) if msg else None, timestamp=d.get("timestamp",datetime.now(timezone.utc).isoformat()))
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Task:
|
||||||
|
id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||||
|
context_id: Optional[str] = None
|
||||||
|
status: TaskStatus = field(default_factory=lambda: TaskStatus(state=TaskState.SUBMITTED))
|
||||||
|
artifacts: List[Artifact] = field(default_factory=list)
|
||||||
|
history: List[Message] = field(default_factory=list)
|
||||||
|
metadata: Optional[Dict[str, Any]] = None
|
||||||
|
def to_dict(self):
|
||||||
|
d = {"id": self.id, "status": self.status.to_dict()}
|
||||||
|
if self.context_id: d["contextId"] = self.context_id
|
||||||
|
if self.artifacts: d["artifacts"] = [a.to_dict() for a in self.artifacts]
|
||||||
|
if self.history: d["history"] = [m.to_dict() for m in self.history]
|
||||||
|
if self.metadata: d["metadata"] = self.metadata
|
||||||
|
return d
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, d):
|
||||||
|
return cls(id=d.get("id",str(uuid.uuid4())), context_id=d.get("contextId"), status=TaskStatus.from_dict(d["status"]) if "status" in d else TaskStatus(TaskState.SUBMITTED), artifacts=[Artifact.from_dict(a) for a in d.get("artifacts",[])], history=[Message.from_dict(m) for m in d.get("history",[])], metadata=d.get("metadata"))
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AgentSkill:
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
description: str = ""
|
||||||
|
tags: List[str] = field(default_factory=list)
|
||||||
|
examples: List[str] = field(default_factory=list)
|
||||||
|
input_modes: List[str] = field(default_factory=lambda: ["text"])
|
||||||
|
output_modes: List[str] = field(default_factory=lambda: ["text"])
|
||||||
|
def to_dict(self):
|
||||||
|
d = {"id": self.id, "name": self.name, "description": self.description}
|
||||||
|
if self.tags: d["tags"] = self.tags
|
||||||
|
if self.examples: d["examples"] = self.examples
|
||||||
|
if self.input_modes != ["text"]: d["inputModes"] = self.input_modes
|
||||||
|
if self.output_modes != ["text"]: d["outputModes"] = self.output_modes
|
||||||
|
return d
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, d):
|
||||||
|
return cls(id=d["id"], name=d.get("name",d["id"]), description=d.get("description",""), tags=d.get("tags",[]), examples=d.get("examples",[]), input_modes=d.get("inputModes",["text"]), output_modes=d.get("outputModes",["text"]))
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AgentCard:
|
||||||
|
name: str
|
||||||
|
description: str = ""
|
||||||
|
version: str = "1.0.0"
|
||||||
|
url: str = ""
|
||||||
|
skills: List[AgentSkill] = field(default_factory=list)
|
||||||
|
capabilities: Dict[str, bool] = field(default_factory=dict)
|
||||||
|
provider: Optional[Dict[str, str]] = None
|
||||||
|
def to_dict(self):
|
||||||
|
d = {"name": self.name, "description": self.description, "version": self.version, "url": self.url, "skills": [s.to_dict() for s in self.skills]}
|
||||||
|
if self.capabilities: d["capabilities"] = self.capabilities
|
||||||
|
if self.provider: d["provider"] = self.provider
|
||||||
|
return d
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, d):
|
||||||
|
return cls(name=d["name"], description=d.get("description",""), version=d.get("version","1.0.0"), url=d.get("url",""), skills=[AgentSkill.from_dict(s) for s in d.get("skills",[])], capabilities=d.get("capabilities",{}), provider=d.get("provider"))
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class JSONRPCError:
|
||||||
|
code: int
|
||||||
|
message: str
|
||||||
|
data: Optional[Any] = None
|
||||||
|
def to_dict(self):
|
||||||
|
d = {"code": self.code, "message": self.message}
|
||||||
|
if self.data is not None: d["data"] = self.data
|
||||||
|
return d
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class JSONRPCRequest:
|
||||||
|
method: str
|
||||||
|
params: Optional[Dict[str, Any]] = None
|
||||||
|
id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||||
|
jsonrpc: str = "2.0"
|
||||||
|
def to_dict(self):
|
||||||
|
d = {"jsonrpc": self.jsonrpc, "method": self.method, "id": self.id}
|
||||||
|
if self.params is not None: d["params"] = self.params
|
||||||
|
return d
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class JSONRPCResponse:
|
||||||
|
id: str
|
||||||
|
result: Optional[Any] = None
|
||||||
|
error: Optional[JSONRPCError] = None
|
||||||
|
jsonrpc: str = "2.0"
|
||||||
|
def to_dict(self):
|
||||||
|
d = {"jsonrpc": self.jsonrpc, "id": self.id}
|
||||||
|
if self.error: d["error"] = self.error.to_dict()
|
||||||
|
else: d["result"] = self.result
|
||||||
|
return d
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, d):
|
||||||
|
err = d.get("error")
|
||||||
|
return cls(id=d["id"], result=d.get("result"), error=JSONRPCError(err["code"],err["message"],err.get("data")) if err else None)
|
||||||
|
|
||||||
|
class A2AError:
|
||||||
|
@staticmethod
|
||||||
|
def parse_error(data=None): return JSONRPCError(-32700, "Parse error", data)
|
||||||
|
@staticmethod
|
||||||
|
def invalid_request(data=None): return JSONRPCError(-32600, "Invalid Request", data)
|
||||||
|
@staticmethod
|
||||||
|
def method_not_found(data=None): return JSONRPCError(-32601, "Method not found", data)
|
||||||
|
@staticmethod
|
||||||
|
def invalid_params(data=None): return JSONRPCError(-32602, "Invalid params", data)
|
||||||
|
@staticmethod
|
||||||
|
def internal_error(data=None): return JSONRPCError(-32603, "Internal error", data)
|
||||||
|
@staticmethod
|
||||||
|
def task_not_found(task_id): return JSONRPCError(-32001, f"Task not found: {task_id}")
|
||||||
|
@staticmethod
|
||||||
|
def task_not_cancelable(task_id): return JSONRPCError(-32002, f"Task not cancelable: {task_id}")
|
||||||
|
@staticmethod
|
||||||
|
def agent_not_found(name): return JSONRPCError(-32009, f"Agent not found: {name}")
|
||||||
116
tests/test_a2a.py
Normal file
116
tests/test_a2a.py
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
"""Tests for A2A protocol implementation."""
|
||||||
|
import asyncio,json,pytest
|
||||||
|
from a2a.types import AgentCard,AgentSkill,Artifact,DataPart,FilePart,JSONRPCError,JSONRPCRequest,JSONRPCResponse,Message,Task,TaskState,TaskStatus,TextPart,A2AError,part_from_dict
|
||||||
|
from a2a.client import A2AClient,A2AClientConfig
|
||||||
|
from a2a.server import A2AServer
|
||||||
|
|
||||||
|
class TestTextPart:
|
||||||
|
def test_roundtrip(self):
|
||||||
|
p=TextPart(text="hello",metadata={"k":"v"});d=p.to_dict();assert d=={"text":"hello","metadata":{"k":"v"}};p2=TextPart.from_dict(d);assert p2.text=="hello"
|
||||||
|
def test_no_metadata(self):
|
||||||
|
p=TextPart(text="hi");d=p.to_dict();assert "metadata" not in d
|
||||||
|
|
||||||
|
class TestFilePart:
|
||||||
|
def test_inline(self):
|
||||||
|
p=FilePart(media_type="text/plain",raw="SGVsbG8=",filename="hello.txt");d=p.to_dict();assert d["raw"]=="SGVsbG8=";p2=FilePart.from_dict(d);assert p2.filename=="hello.txt"
|
||||||
|
def test_url(self):
|
||||||
|
p=FilePart(url="https://x.com/f");d=p.to_dict();assert d["url"]=="https://x.com/f"
|
||||||
|
|
||||||
|
class TestDataPart:
|
||||||
|
def test_roundtrip(self):
|
||||||
|
p=DataPart(data={"key":42});d=p.to_dict();assert d["data"]=={"key":42}
|
||||||
|
|
||||||
|
class TestPartDiscrimination:
|
||||||
|
def test_text(self):assert isinstance(part_from_dict({"text":"hi"}),TextPart)
|
||||||
|
def test_file_raw(self):assert isinstance(part_from_dict({"raw":"d","mediaType":"t"}),FilePart)
|
||||||
|
def test_file_url(self):assert isinstance(part_from_dict({"url":"https://x.com"}),FilePart)
|
||||||
|
def test_data(self):assert isinstance(part_from_dict({"data":{"a":1}}),DataPart)
|
||||||
|
def test_unknown(self):
|
||||||
|
with pytest.raises(ValueError):part_from_dict({"unknown":True})
|
||||||
|
|
||||||
|
class TestMessage:
|
||||||
|
def test_roundtrip(self):
|
||||||
|
m=Message(role="user",parts=[TextPart(text="hi")],context_id="c1");d=m.to_dict();assert d["role"]=="user";m2=Message.from_dict(d);assert m2.parts[0].text=="hi"
|
||||||
|
|
||||||
|
class TestArtifact:
|
||||||
|
def test_roundtrip(self):
|
||||||
|
a=Artifact(name="r",parts=[TextPart(text="d")]);d=a.to_dict();assert d["name"]=="r"
|
||||||
|
|
||||||
|
class TestTaskStatus:
|
||||||
|
def test_roundtrip(self):
|
||||||
|
s=TaskStatus(state=TaskState.COMPLETED);d=s.to_dict();assert d["state"]=="TASK_STATE_COMPLETED";s2=TaskStatus.from_dict(d);assert s2.state==TaskState.COMPLETED
|
||||||
|
def test_terminal(self):
|
||||||
|
assert TaskState.COMPLETED.terminal;assert TaskState.FAILED.terminal;assert not TaskState.SUBMITTED.terminal
|
||||||
|
|
||||||
|
class TestTask:
|
||||||
|
def test_roundtrip(self):
|
||||||
|
t=Task(id="t1",status=TaskStatus(state=TaskState.WORKING));d=t.to_dict();assert d["id"]=="t1";t2=Task.from_dict(d);assert t2.status.state==TaskState.WORKING
|
||||||
|
|
||||||
|
class TestAgentCard:
|
||||||
|
def test_roundtrip(self):
|
||||||
|
c=AgentCard(name="a",skills=[AgentSkill(id="s1",name="S")]);d=c.to_dict();assert d["name"]=="a";c2=AgentCard.from_dict(d);assert c2.skills[0].id=="s1"
|
||||||
|
|
||||||
|
class TestJSONRPC:
|
||||||
|
def test_request(self):
|
||||||
|
r=JSONRPCRequest(method="GetTask",params={"taskId":"1"});d=r.to_dict();assert d["method"]=="GetTask"
|
||||||
|
def test_response_success(self):
|
||||||
|
r=JSONRPCResponse(id="1",result={"ok":True});d=r.to_dict();assert d["result"]["ok"]==True
|
||||||
|
def test_response_error(self):
|
||||||
|
r=JSONRPCResponse(id="1",error=A2AError.parse_error());d=r.to_dict();assert d["error"]["code"]==-32700
|
||||||
|
def test_response_from_dict(self):
|
||||||
|
r=JSONRPCResponse.from_dict({"jsonrpc":"2.0","id":"1","result":{"ok":True}});assert r.result["ok"]==True
|
||||||
|
|
||||||
|
class TestA2AError:
|
||||||
|
def test_codes(self):
|
||||||
|
assert A2AError.parse_error().code==-32700;assert A2AError.invalid_request().code==-32600
|
||||||
|
assert A2AError.method_not_found().code==-32601;assert A2AError.task_not_found("x").code==-32001
|
||||||
|
|
||||||
|
def _make_card():
|
||||||
|
return AgentCard(name="test",description="Test",url="http://localhost:9999/a2a/v1",skills=[AgentSkill(id="echo",name="Echo",tags=["test"])])
|
||||||
|
|
||||||
|
def _run(coro):return asyncio.get_event_loop().run_until_complete(coro)
|
||||||
|
|
||||||
|
class TestServer:
|
||||||
|
def test_echo(self):
|
||||||
|
s=A2AServer(_make_card());msg=Message(role="user",parts=[TextPart(text="hello")])
|
||||||
|
raw=json.dumps(JSONRPCRequest(method="SendMessage",params={"message":msg.to_dict()}).to_dict())
|
||||||
|
resp=json.loads(_run(s.handle_rpc(raw)));assert "error" not in resp;assert resp["result"]["status"]["state"]=="TASK_STATE_COMPLETED"
|
||||||
|
def test_get_task(self):
|
||||||
|
s=A2AServer(_make_card());msg=Message(role="user",parts=[TextPart(text="hi")])
|
||||||
|
raw=json.dumps(JSONRPCRequest(method="SendMessage",params={"message":msg.to_dict()}).to_dict())
|
||||||
|
task_id=json.loads(_run(s.handle_rpc(raw)))["result"]["id"]
|
||||||
|
raw2=json.dumps(JSONRPCRequest(method="GetTask",params={"taskId":task_id}).to_dict())
|
||||||
|
resp=json.loads(_run(s.handle_rpc(raw2)));assert resp["result"]["id"]==task_id
|
||||||
|
def test_cancel(self):
|
||||||
|
s=A2AServer(_make_card());task=Task(id="c1",status=TaskStatus(state=TaskState.WORKING));s.add_task(task)
|
||||||
|
raw=json.dumps(JSONRPCRequest(method="CancelTask",params={"taskId":"c1"}).to_dict())
|
||||||
|
resp=json.loads(_run(s.handle_rpc(raw)));assert resp["result"]["status"]["state"]=="TASK_STATE_CANCELED"
|
||||||
|
def test_cancel_terminal(self):
|
||||||
|
s=A2AServer(_make_card());task=Task(id="d1",status=TaskStatus(state=TaskState.COMPLETED));s.add_task(task)
|
||||||
|
raw=json.dumps(JSONRPCRequest(method="CancelTask",params={"taskId":"d1"}).to_dict())
|
||||||
|
resp=json.loads(_run(s.handle_rpc(raw)));assert "error" in resp
|
||||||
|
def test_card(self):
|
||||||
|
s=A2AServer(_make_card());raw=json.dumps(JSONRPCRequest(method="GetAgentCard").to_dict())
|
||||||
|
resp=json.loads(_run(s.handle_rpc(raw)));assert resp["result"]["name"]=="test"
|
||||||
|
def test_list(self):
|
||||||
|
s=A2AServer(_make_card());msg=Message(role="user",parts=[TextPart(text="hi")])
|
||||||
|
raw=json.dumps(JSONRPCRequest(method="SendMessage",params={"message":msg.to_dict()}).to_dict())
|
||||||
|
_run(s.handle_rpc(raw));raw2=json.dumps(JSONRPCRequest(method="ListTasks").to_dict())
|
||||||
|
resp=json.loads(_run(s.handle_rpc(raw2)));assert len(resp["result"]["tasks"])>=1
|
||||||
|
def test_unknown(self):
|
||||||
|
s=A2AServer(_make_card());raw=json.dumps(JSONRPCRequest(method="Nope").to_dict())
|
||||||
|
resp=json.loads(_run(s.handle_rpc(raw)));assert resp["error"]["code"]==-32601
|
||||||
|
def test_invalid_json(self):
|
||||||
|
s=A2AServer(_make_card());resp=json.loads(_run(s.handle_rpc("bad")));assert resp["error"]["code"]==-32700
|
||||||
|
def test_custom_handler(self):
|
||||||
|
s=A2AServer(_make_card())
|
||||||
|
async def h(task,card):task.status=TaskStatus(state=TaskState.COMPLETED);task.artifacts=[Artifact(parts=[TextPart(text="custom")])];return task
|
||||||
|
s.register_handler("echo",h)
|
||||||
|
msg=Message(role="user",parts=[TextPart(text="t")])
|
||||||
|
raw=json.dumps(JSONRPCRequest(method="SendMessage",params={"message":msg.to_dict(),"skillId":"echo"}).to_dict())
|
||||||
|
resp=json.loads(_run(s.handle_rpc(raw)));assert resp["result"]["artifacts"][0]["parts"][0]["text"]=="custom"
|
||||||
|
def test_audit(self):
|
||||||
|
s=A2AServer(_make_card());raw=json.dumps(JSONRPCRequest(method="GetAgentCard").to_dict())
|
||||||
|
_run(s.handle_rpc(raw));assert len(s.audit_log)==1;assert s.audit_log[0]["method"]=="GetAgentCard"
|
||||||
|
|
||||||
|
if __name__=="__main__":pytest.main([__file__,"-v"])
|
||||||
Reference in New Issue
Block a user