Compare commits
5 Commits
fix/806
...
burn/804-1
| Author | SHA1 | Date | |
|---|---|---|---|
| 84d247d320 | |||
| 5a3f53c8b5 | |||
| fec32b5659 | |||
| c2821055d9 | |||
| 5a03eba791 |
22
a2a/__init__.py
Normal file
22
a2a/__init__.py
Normal file
@@ -0,0 +1,22 @@
|
||||
"""A2A — Agent2Agent Protocol v1.0 for Hermes task delegation.
|
||||
|
||||
Usage:
|
||||
from a2a import A2AClient, A2AServer, AgentCard, Task, TextPart
|
||||
"""
|
||||
|
||||
from a2a.types import (
|
||||
AgentCard, AgentSkill, Artifact, DataPart, FilePart,
|
||||
JSONRPCError, JSONRPCRequest, JSONRPCResponse,
|
||||
Message, Part, Task, TaskState, TaskStatus, TextPart, A2AError,
|
||||
part_from_dict,
|
||||
)
|
||||
from a2a.client import A2AClient, A2AClientConfig
|
||||
from a2a.server import A2AServer, TaskHandler
|
||||
|
||||
__all__ = [
|
||||
"A2AClient", "A2AClientConfig", "A2AServer", "TaskHandler",
|
||||
"AgentCard", "AgentSkill", "Artifact",
|
||||
"DataPart", "FilePart", "TextPart", "Part", "part_from_dict",
|
||||
"JSONRPCError", "JSONRPCRequest", "JSONRPCResponse",
|
||||
"Message", "Task", "TaskState", "TaskStatus", "A2AError",
|
||||
]
|
||||
98
a2a/client.py
Normal file
98
a2a/client.py
Normal file
@@ -0,0 +1,98 @@
|
||||
"""A2A Client - send tasks to remote agents via JSON-RPC 2.0."""
|
||||
|
||||
from __future__ import annotations
|
||||
import asyncio, json, logging, time
|
||||
from typing import Any, Dict, List, Optional
|
||||
import aiohttp
|
||||
from a2a.types import (AgentCard, Artifact, JSONRPCError, JSONRPCRequest, JSONRPCResponse, Message, Part, Task, TaskState, TaskStatus, TextPart, A2AError)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class A2AClientConfig:
|
||||
def __init__(self, timeout=30.0, max_retries=3, retry_delay=2.0, auth_token=None, auth_scheme="Bearer"):
|
||||
self.timeout = timeout
|
||||
self.max_retries = max_retries
|
||||
self.retry_delay = retry_delay
|
||||
self.auth_token = auth_token
|
||||
self.auth_scheme = auth_scheme
|
||||
|
||||
class A2AClient:
|
||||
def __init__(self, base_url, config=None):
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.config = config or A2AClientConfig()
|
||||
self._audit_log = []
|
||||
|
||||
def _headers(self):
|
||||
h = {"Content-Type": "application/json"}
|
||||
if self.config.auth_token:
|
||||
h["Authorization"] = f"{self.config.auth_scheme} {self.config.auth_token}"
|
||||
return h
|
||||
|
||||
async def _rpc_call(self, method, params=None):
|
||||
req = JSONRPCRequest(method=method, params=params)
|
||||
payload = json.dumps(req.to_dict())
|
||||
last_error = None
|
||||
for attempt in range(1, self.config.max_retries + 1):
|
||||
t0 = time.monotonic()
|
||||
try:
|
||||
timeout = aiohttp.ClientTimeout(total=self.config.timeout)
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
async with session.post(f"{self.base_url}/a2a/v1/rpc", data=payload, headers=self._headers()) as resp:
|
||||
body = await resp.text()
|
||||
elapsed = time.monotonic() - t0
|
||||
self._audit_log.append({"method": method, "params": params, "status": resp.status, "elapsed_ms": round(elapsed*1000), "attempt": attempt})
|
||||
if resp.status != 200:
|
||||
logger.warning("A2A %s -> HTTP %d (attempt %d/%d)", method, resp.status, attempt, self.config.max_retries)
|
||||
if attempt < self.config.max_retries:
|
||||
await asyncio.sleep(self.config.retry_delay * attempt)
|
||||
continue
|
||||
return JSONRPCResponse(id=req.id, error=A2AError.internal_error(f"HTTP {resp.status}: {body[:200]}"))
|
||||
return JSONRPCResponse.from_dict(json.loads(body))
|
||||
except Exception as exc:
|
||||
elapsed = time.monotonic() - t0
|
||||
last_error = exc
|
||||
self._audit_log.append({"method": method, "error": str(exc), "elapsed_ms": round(elapsed*1000), "attempt": attempt})
|
||||
if attempt < self.config.max_retries:
|
||||
await asyncio.sleep(self.config.retry_delay * attempt)
|
||||
continue
|
||||
return JSONRPCResponse(id=req.id, error=A2AError.internal_error(str(last_error)))
|
||||
|
||||
async def get_agent_card(self):
|
||||
resp = await self._rpc_call("GetAgentCard")
|
||||
if resp.error: raise RuntimeError(f"Failed to get agent card: {resp.error.message}")
|
||||
return AgentCard.from_dict(resp.result)
|
||||
|
||||
async def send_message(self, text, context_id=None, skill_id=None, metadata=None):
|
||||
msg = Message(role="user", parts=[TextPart(text=text)], context_id=context_id)
|
||||
params = {"message": msg.to_dict()}
|
||||
if skill_id: params["skillId"] = skill_id
|
||||
if metadata: params["metadata"] = metadata
|
||||
resp = await self._rpc_call("SendMessage", params)
|
||||
if resp.error: raise RuntimeError(f"SendMessage failed: {resp.error.message}")
|
||||
return Task.from_dict(resp.result)
|
||||
|
||||
async def get_task(self, task_id):
|
||||
resp = await self._rpc_call("GetTask", {"taskId": task_id})
|
||||
if resp.error: raise RuntimeError(f"GetTask failed: {resp.error.message}")
|
||||
return Task.from_dict(resp.result)
|
||||
|
||||
async def cancel_task(self, task_id):
|
||||
resp = await self._rpc_call("CancelTask", {"taskId": task_id})
|
||||
if resp.error: raise RuntimeError(f"CancelTask failed: {resp.error.message}")
|
||||
return Task.from_dict(resp.result)
|
||||
|
||||
async def wait_for_completion(self, task_id, poll_interval=2.0, timeout=300.0):
|
||||
t0 = time.monotonic()
|
||||
while True:
|
||||
task = await self.get_task(task_id)
|
||||
if task.status.state.terminal: return task
|
||||
if time.monotonic() - t0 > timeout: raise TimeoutError(f"Task {task_id} did not complete within {timeout}s")
|
||||
await asyncio.sleep(poll_interval)
|
||||
|
||||
async def delegate(self, text, skill_id=None, wait=True, timeout=300.0):
|
||||
task = await self.send_message(text, skill_id=skill_id)
|
||||
if wait: return await self.wait_for_completion(task.id, timeout=timeout)
|
||||
return task
|
||||
|
||||
@property
|
||||
def audit_log(self): return list(self._audit_log)
|
||||
60
a2a/server.py
Normal file
60
a2a/server.py
Normal file
@@ -0,0 +1,60 @@
|
||||
"""A2A Server - receive and execute tasks via JSON-RPC 2.0."""
|
||||
from __future__ import annotations
|
||||
import json,logging,uuid
|
||||
from datetime import datetime,timezone
|
||||
from typing import Any,Callable,Dict,List,Optional,Awaitable
|
||||
from a2a.types import AgentCard,Artifact,JSONRPCError,JSONRPCRequest,JSONRPCResponse,Message,Task,TaskState,TaskStatus,TextPart,A2AError,part_from_dict
|
||||
logger=logging.getLogger(__name__)
|
||||
TaskHandler=Callable[[Task,AgentCard],Awaitable[Task]]
|
||||
class A2AServer:
|
||||
def __init__(s,card):s.card=card;s._tasks={};s._handlers={};s._default_handler=None;s._audit_log=[]
|
||||
def register_handler(s,skill_id,handler):s._handlers[skill_id]=handler
|
||||
def set_default_handler(s,handler):s._default_handler=handler
|
||||
async def handle_rpc(s,raw):
|
||||
try:data=json.loads(raw)
|
||||
except (json.JSONDecodeError,TypeError):return json.dumps(JSONRPCResponse(id="",error=A2AError.parse_error()).to_dict())
|
||||
req_id=data.get("id","");method=data.get("method","");params=data.get("params")
|
||||
s._audit_log.append({"method":method,"id":req_id,"ts":datetime.now(timezone.utc).isoformat()})
|
||||
try:
|
||||
if method=="SendMessage":result=await s._handle_send_message(params)
|
||||
elif method=="GetTask":result=s._handle_get_task(params)
|
||||
elif method=="CancelTask":result=s._handle_cancel_task(params)
|
||||
elif method=="GetAgentCard":result=s.card.to_dict()
|
||||
elif method=="ListTasks":result=s._handle_list_tasks(params)
|
||||
else:return json.dumps(JSONRPCResponse(id=req_id,error=A2AError.method_not_found()).to_dict())
|
||||
return json.dumps(JSONRPCResponse(id=req_id,result=result).to_dict())
|
||||
except Exception as exc:
|
||||
logger.exception("A2A handler error for %s",method)
|
||||
return json.dumps(JSONRPCResponse(id=req_id,error=A2AError.internal_error(str(exc))).to_dict())
|
||||
async def _handle_send_message(s,params):
|
||||
if not params or "message" not in params:raise ValueError("SendMessage requires message param")
|
||||
msg=Message.from_dict(params["message"])
|
||||
task=Task(id=str(uuid.uuid4()),context_id=msg.context_id,status=TaskStatus(state=TaskState.SUBMITTED),history=[msg])
|
||||
s._tasks[task.id]=task
|
||||
skill_id=params.get("skillId");handler=s._handlers.get(skill_id) if skill_id else None
|
||||
if handler is None:handler=s._default_handler
|
||||
if handler is None:
|
||||
text=msg.parts[0].text if msg.parts and hasattr(msg.parts[0],"text") else ""
|
||||
task.status=TaskStatus(state=TaskState.COMPLETED)
|
||||
task.artifacts=[Artifact(parts=[TextPart(text=f"Received: {text}")])]
|
||||
else:
|
||||
task.status=TaskStatus(state=TaskState.WORKING);s._tasks[task.id]=task
|
||||
task=await handler(task,s.card);s._tasks[task.id]=task
|
||||
return task.to_dict()
|
||||
def _handle_get_task(s,params):
|
||||
task_id=(params or {}).get("taskId")
|
||||
if not task_id or task_id not in s._tasks:raise KeyError(f"Task not found: {task_id}")
|
||||
return s._tasks[task_id].to_dict()
|
||||
def _handle_cancel_task(s,params):
|
||||
task_id=(params or {}).get("taskId")
|
||||
if not task_id or task_id not in s._tasks:raise KeyError(f"Task not found: {task_id}")
|
||||
task=s._tasks[task_id]
|
||||
if task.status.state.terminal:raise ValueError(f"Task already {task.status.state.value}")
|
||||
task.status=TaskStatus(state=TaskState.CANCELED);s._tasks[task_id]=task
|
||||
return task.to_dict()
|
||||
def _handle_list_tasks(s,params):
|
||||
context_id=(params or {}).get("contextId")
|
||||
return {"tasks":[t.to_dict() for t in s._tasks.values() if not context_id or t.context_id==context_id]}
|
||||
def add_task(s,task):s._tasks[task.id]=task
|
||||
@property
|
||||
def audit_log(s):return list(s._audit_log)
|
||||
230
a2a/types.py
Normal file
230
a2a/types.py
Normal file
@@ -0,0 +1,230 @@
|
||||
"""A2A Protocol Types - Agent2Agent v1.0 data structures."""
|
||||
|
||||
from __future__ import annotations
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
class TaskState(str, Enum):
|
||||
SUBMITTED = "TASK_STATE_SUBMITTED"
|
||||
WORKING = "TASK_STATE_WORKING"
|
||||
INPUT_REQUIRED = "TASK_STATE_INPUT_REQUIRED"
|
||||
COMPLETED = "TASK_STATE_COMPLETED"
|
||||
FAILED = "TASK_STATE_FAILED"
|
||||
CANCELED = "TASK_STATE_CANCELED"
|
||||
REJECTED = "TASK_STATE_REJECTED"
|
||||
@property
|
||||
def terminal(self) -> bool:
|
||||
return self in {TaskState.COMPLETED, TaskState.FAILED, TaskState.CANCELED, TaskState.REJECTED}
|
||||
|
||||
@dataclass
|
||||
class TextPart:
|
||||
text: str
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
def to_dict(self) -> dict:
|
||||
d = {"text": self.text}
|
||||
if self.metadata: d["metadata"] = self.metadata
|
||||
return d
|
||||
@classmethod
|
||||
def from_dict(cls, d): return cls(text=d["text"], metadata=d.get("metadata"))
|
||||
|
||||
@dataclass
|
||||
class FilePart:
|
||||
media_type: str = "application/octet-stream"
|
||||
raw: Optional[str] = None
|
||||
url: Optional[str] = None
|
||||
filename: Optional[str] = None
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
def to_dict(self) -> dict:
|
||||
d = {"mediaType": self.media_type}
|
||||
if self.raw is not None: d["raw"] = self.raw
|
||||
if self.url is not None: d["url"] = self.url
|
||||
if self.filename: d["filename"] = self.filename
|
||||
if self.metadata: d["metadata"] = self.metadata
|
||||
return d
|
||||
@classmethod
|
||||
def from_dict(cls, d):
|
||||
return cls(media_type=d.get("mediaType","application/octet-stream"), raw=d.get("raw"), url=d.get("url"), filename=d.get("filename"), metadata=d.get("metadata"))
|
||||
|
||||
@dataclass
|
||||
class DataPart:
|
||||
data: Dict[str, Any]
|
||||
media_type: str = "application/json"
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
def to_dict(self) -> dict:
|
||||
d = {"data": self.data, "mediaType": self.media_type}
|
||||
if self.metadata: d["metadata"] = self.metadata
|
||||
return d
|
||||
@classmethod
|
||||
def from_dict(cls, d): return cls(data=d["data"], media_type=d.get("mediaType","application/json"), metadata=d.get("metadata"))
|
||||
|
||||
Part = Union[TextPart, FilePart, DataPart]
|
||||
|
||||
def part_from_dict(d):
|
||||
if "text" in d: return TextPart.from_dict(d)
|
||||
if "raw" in d or "url" in d: return FilePart.from_dict(d)
|
||||
if "data" in d: return DataPart.from_dict(d)
|
||||
raise ValueError(f"Cannot discriminate Part type from keys: {list(d.keys())}")
|
||||
|
||||
@dataclass
|
||||
class Message:
|
||||
role: str
|
||||
parts: List[Part]
|
||||
message_id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||
context_id: Optional[str] = None
|
||||
task_id: Optional[str] = None
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
def to_dict(self):
|
||||
d = {"role": self.role, "messageId": self.message_id, "parts": [p.to_dict() for p in self.parts]}
|
||||
if self.context_id: d["contextId"] = self.context_id
|
||||
if self.task_id: d["taskId"] = self.task_id
|
||||
if self.metadata: d["metadata"] = self.metadata
|
||||
return d
|
||||
@classmethod
|
||||
def from_dict(cls, d):
|
||||
return cls(role=d["role"], parts=[part_from_dict(p) for p in d["parts"]], message_id=d.get("messageId",str(uuid.uuid4())), context_id=d.get("contextId"), task_id=d.get("taskId"), metadata=d.get("metadata"))
|
||||
|
||||
@dataclass
|
||||
class Artifact:
|
||||
artifact_id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||
parts: List[Part] = field(default_factory=list)
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
def to_dict(self):
|
||||
d = {"artifactId": self.artifact_id, "parts": [p.to_dict() for p in self.parts]}
|
||||
if self.name: d["name"] = self.name
|
||||
if self.description: d["description"] = self.description
|
||||
if self.metadata: d["metadata"] = self.metadata
|
||||
return d
|
||||
@classmethod
|
||||
def from_dict(cls, d):
|
||||
return cls(artifact_id=d.get("artifactId",str(uuid.uuid4())), parts=[part_from_dict(p) for p in d.get("parts",[])], name=d.get("name"), description=d.get("description"), metadata=d.get("metadata"))
|
||||
|
||||
@dataclass
|
||||
class TaskStatus:
|
||||
state: TaskState
|
||||
message: Optional[Message] = None
|
||||
timestamp: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
|
||||
def to_dict(self):
|
||||
d = {"state": self.state.value, "timestamp": self.timestamp}
|
||||
if self.message: d["message"] = self.message.to_dict()
|
||||
return d
|
||||
@classmethod
|
||||
def from_dict(cls, d):
|
||||
msg = d.get("message")
|
||||
return cls(state=TaskState(d["state"]), message=Message.from_dict(msg) if msg else None, timestamp=d.get("timestamp",datetime.now(timezone.utc).isoformat()))
|
||||
|
||||
@dataclass
|
||||
class Task:
|
||||
id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||
context_id: Optional[str] = None
|
||||
status: TaskStatus = field(default_factory=lambda: TaskStatus(state=TaskState.SUBMITTED))
|
||||
artifacts: List[Artifact] = field(default_factory=list)
|
||||
history: List[Message] = field(default_factory=list)
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
def to_dict(self):
|
||||
d = {"id": self.id, "status": self.status.to_dict()}
|
||||
if self.context_id: d["contextId"] = self.context_id
|
||||
if self.artifacts: d["artifacts"] = [a.to_dict() for a in self.artifacts]
|
||||
if self.history: d["history"] = [m.to_dict() for m in self.history]
|
||||
if self.metadata: d["metadata"] = self.metadata
|
||||
return d
|
||||
@classmethod
|
||||
def from_dict(cls, d):
|
||||
return cls(id=d.get("id",str(uuid.uuid4())), context_id=d.get("contextId"), status=TaskStatus.from_dict(d["status"]) if "status" in d else TaskStatus(TaskState.SUBMITTED), artifacts=[Artifact.from_dict(a) for a in d.get("artifacts",[])], history=[Message.from_dict(m) for m in d.get("history",[])], metadata=d.get("metadata"))
|
||||
|
||||
@dataclass
|
||||
class AgentSkill:
|
||||
id: str
|
||||
name: str
|
||||
description: str = ""
|
||||
tags: List[str] = field(default_factory=list)
|
||||
examples: List[str] = field(default_factory=list)
|
||||
input_modes: List[str] = field(default_factory=lambda: ["text"])
|
||||
output_modes: List[str] = field(default_factory=lambda: ["text"])
|
||||
def to_dict(self):
|
||||
d = {"id": self.id, "name": self.name, "description": self.description}
|
||||
if self.tags: d["tags"] = self.tags
|
||||
if self.examples: d["examples"] = self.examples
|
||||
if self.input_modes != ["text"]: d["inputModes"] = self.input_modes
|
||||
if self.output_modes != ["text"]: d["outputModes"] = self.output_modes
|
||||
return d
|
||||
@classmethod
|
||||
def from_dict(cls, d):
|
||||
return cls(id=d["id"], name=d.get("name",d["id"]), description=d.get("description",""), tags=d.get("tags",[]), examples=d.get("examples",[]), input_modes=d.get("inputModes",["text"]), output_modes=d.get("outputModes",["text"]))
|
||||
|
||||
@dataclass
|
||||
class AgentCard:
|
||||
name: str
|
||||
description: str = ""
|
||||
version: str = "1.0.0"
|
||||
url: str = ""
|
||||
skills: List[AgentSkill] = field(default_factory=list)
|
||||
capabilities: Dict[str, bool] = field(default_factory=dict)
|
||||
provider: Optional[Dict[str, str]] = None
|
||||
def to_dict(self):
|
||||
d = {"name": self.name, "description": self.description, "version": self.version, "url": self.url, "skills": [s.to_dict() for s in self.skills]}
|
||||
if self.capabilities: d["capabilities"] = self.capabilities
|
||||
if self.provider: d["provider"] = self.provider
|
||||
return d
|
||||
@classmethod
|
||||
def from_dict(cls, d):
|
||||
return cls(name=d["name"], description=d.get("description",""), version=d.get("version","1.0.0"), url=d.get("url",""), skills=[AgentSkill.from_dict(s) for s in d.get("skills",[])], capabilities=d.get("capabilities",{}), provider=d.get("provider"))
|
||||
|
||||
@dataclass
|
||||
class JSONRPCError:
|
||||
code: int
|
||||
message: str
|
||||
data: Optional[Any] = None
|
||||
def to_dict(self):
|
||||
d = {"code": self.code, "message": self.message}
|
||||
if self.data is not None: d["data"] = self.data
|
||||
return d
|
||||
|
||||
@dataclass
|
||||
class JSONRPCRequest:
|
||||
method: str
|
||||
params: Optional[Dict[str, Any]] = None
|
||||
id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||
jsonrpc: str = "2.0"
|
||||
def to_dict(self):
|
||||
d = {"jsonrpc": self.jsonrpc, "method": self.method, "id": self.id}
|
||||
if self.params is not None: d["params"] = self.params
|
||||
return d
|
||||
|
||||
@dataclass
|
||||
class JSONRPCResponse:
|
||||
id: str
|
||||
result: Optional[Any] = None
|
||||
error: Optional[JSONRPCError] = None
|
||||
jsonrpc: str = "2.0"
|
||||
def to_dict(self):
|
||||
d = {"jsonrpc": self.jsonrpc, "id": self.id}
|
||||
if self.error: d["error"] = self.error.to_dict()
|
||||
else: d["result"] = self.result
|
||||
return d
|
||||
@classmethod
|
||||
def from_dict(cls, d):
|
||||
err = d.get("error")
|
||||
return cls(id=d["id"], result=d.get("result"), error=JSONRPCError(err["code"],err["message"],err.get("data")) if err else None)
|
||||
|
||||
class A2AError:
|
||||
@staticmethod
|
||||
def parse_error(data=None): return JSONRPCError(-32700, "Parse error", data)
|
||||
@staticmethod
|
||||
def invalid_request(data=None): return JSONRPCError(-32600, "Invalid Request", data)
|
||||
@staticmethod
|
||||
def method_not_found(data=None): return JSONRPCError(-32601, "Method not found", data)
|
||||
@staticmethod
|
||||
def invalid_params(data=None): return JSONRPCError(-32602, "Invalid params", data)
|
||||
@staticmethod
|
||||
def internal_error(data=None): return JSONRPCError(-32603, "Internal error", data)
|
||||
@staticmethod
|
||||
def task_not_found(task_id): return JSONRPCError(-32001, f"Task not found: {task_id}")
|
||||
@staticmethod
|
||||
def task_not_cancelable(task_id): return JSONRPCError(-32002, f"Task not cancelable: {task_id}")
|
||||
@staticmethod
|
||||
def agent_not_found(name): return JSONRPCError(-32009, f"Agent not found: {name}")
|
||||
116
tests/test_a2a.py
Normal file
116
tests/test_a2a.py
Normal file
@@ -0,0 +1,116 @@
|
||||
"""Tests for A2A protocol implementation."""
|
||||
import asyncio,json,pytest
|
||||
from a2a.types import AgentCard,AgentSkill,Artifact,DataPart,FilePart,JSONRPCError,JSONRPCRequest,JSONRPCResponse,Message,Task,TaskState,TaskStatus,TextPart,A2AError,part_from_dict
|
||||
from a2a.client import A2AClient,A2AClientConfig
|
||||
from a2a.server import A2AServer
|
||||
|
||||
class TestTextPart:
|
||||
def test_roundtrip(self):
|
||||
p=TextPart(text="hello",metadata={"k":"v"});d=p.to_dict();assert d=={"text":"hello","metadata":{"k":"v"}};p2=TextPart.from_dict(d);assert p2.text=="hello"
|
||||
def test_no_metadata(self):
|
||||
p=TextPart(text="hi");d=p.to_dict();assert "metadata" not in d
|
||||
|
||||
class TestFilePart:
|
||||
def test_inline(self):
|
||||
p=FilePart(media_type="text/plain",raw="SGVsbG8=",filename="hello.txt");d=p.to_dict();assert d["raw"]=="SGVsbG8=";p2=FilePart.from_dict(d);assert p2.filename=="hello.txt"
|
||||
def test_url(self):
|
||||
p=FilePart(url="https://x.com/f");d=p.to_dict();assert d["url"]=="https://x.com/f"
|
||||
|
||||
class TestDataPart:
|
||||
def test_roundtrip(self):
|
||||
p=DataPart(data={"key":42});d=p.to_dict();assert d["data"]=={"key":42}
|
||||
|
||||
class TestPartDiscrimination:
|
||||
def test_text(self):assert isinstance(part_from_dict({"text":"hi"}),TextPart)
|
||||
def test_file_raw(self):assert isinstance(part_from_dict({"raw":"d","mediaType":"t"}),FilePart)
|
||||
def test_file_url(self):assert isinstance(part_from_dict({"url":"https://x.com"}),FilePart)
|
||||
def test_data(self):assert isinstance(part_from_dict({"data":{"a":1}}),DataPart)
|
||||
def test_unknown(self):
|
||||
with pytest.raises(ValueError):part_from_dict({"unknown":True})
|
||||
|
||||
class TestMessage:
|
||||
def test_roundtrip(self):
|
||||
m=Message(role="user",parts=[TextPart(text="hi")],context_id="c1");d=m.to_dict();assert d["role"]=="user";m2=Message.from_dict(d);assert m2.parts[0].text=="hi"
|
||||
|
||||
class TestArtifact:
|
||||
def test_roundtrip(self):
|
||||
a=Artifact(name="r",parts=[TextPart(text="d")]);d=a.to_dict();assert d["name"]=="r"
|
||||
|
||||
class TestTaskStatus:
|
||||
def test_roundtrip(self):
|
||||
s=TaskStatus(state=TaskState.COMPLETED);d=s.to_dict();assert d["state"]=="TASK_STATE_COMPLETED";s2=TaskStatus.from_dict(d);assert s2.state==TaskState.COMPLETED
|
||||
def test_terminal(self):
|
||||
assert TaskState.COMPLETED.terminal;assert TaskState.FAILED.terminal;assert not TaskState.SUBMITTED.terminal
|
||||
|
||||
class TestTask:
|
||||
def test_roundtrip(self):
|
||||
t=Task(id="t1",status=TaskStatus(state=TaskState.WORKING));d=t.to_dict();assert d["id"]=="t1";t2=Task.from_dict(d);assert t2.status.state==TaskState.WORKING
|
||||
|
||||
class TestAgentCard:
|
||||
def test_roundtrip(self):
|
||||
c=AgentCard(name="a",skills=[AgentSkill(id="s1",name="S")]);d=c.to_dict();assert d["name"]=="a";c2=AgentCard.from_dict(d);assert c2.skills[0].id=="s1"
|
||||
|
||||
class TestJSONRPC:
|
||||
def test_request(self):
|
||||
r=JSONRPCRequest(method="GetTask",params={"taskId":"1"});d=r.to_dict();assert d["method"]=="GetTask"
|
||||
def test_response_success(self):
|
||||
r=JSONRPCResponse(id="1",result={"ok":True});d=r.to_dict();assert d["result"]["ok"]==True
|
||||
def test_response_error(self):
|
||||
r=JSONRPCResponse(id="1",error=A2AError.parse_error());d=r.to_dict();assert d["error"]["code"]==-32700
|
||||
def test_response_from_dict(self):
|
||||
r=JSONRPCResponse.from_dict({"jsonrpc":"2.0","id":"1","result":{"ok":True}});assert r.result["ok"]==True
|
||||
|
||||
class TestA2AError:
|
||||
def test_codes(self):
|
||||
assert A2AError.parse_error().code==-32700;assert A2AError.invalid_request().code==-32600
|
||||
assert A2AError.method_not_found().code==-32601;assert A2AError.task_not_found("x").code==-32001
|
||||
|
||||
def _make_card():
|
||||
return AgentCard(name="test",description="Test",url="http://localhost:9999/a2a/v1",skills=[AgentSkill(id="echo",name="Echo",tags=["test"])])
|
||||
|
||||
def _run(coro):return asyncio.get_event_loop().run_until_complete(coro)
|
||||
|
||||
class TestServer:
|
||||
def test_echo(self):
|
||||
s=A2AServer(_make_card());msg=Message(role="user",parts=[TextPart(text="hello")])
|
||||
raw=json.dumps(JSONRPCRequest(method="SendMessage",params={"message":msg.to_dict()}).to_dict())
|
||||
resp=json.loads(_run(s.handle_rpc(raw)));assert "error" not in resp;assert resp["result"]["status"]["state"]=="TASK_STATE_COMPLETED"
|
||||
def test_get_task(self):
|
||||
s=A2AServer(_make_card());msg=Message(role="user",parts=[TextPart(text="hi")])
|
||||
raw=json.dumps(JSONRPCRequest(method="SendMessage",params={"message":msg.to_dict()}).to_dict())
|
||||
task_id=json.loads(_run(s.handle_rpc(raw)))["result"]["id"]
|
||||
raw2=json.dumps(JSONRPCRequest(method="GetTask",params={"taskId":task_id}).to_dict())
|
||||
resp=json.loads(_run(s.handle_rpc(raw2)));assert resp["result"]["id"]==task_id
|
||||
def test_cancel(self):
|
||||
s=A2AServer(_make_card());task=Task(id="c1",status=TaskStatus(state=TaskState.WORKING));s.add_task(task)
|
||||
raw=json.dumps(JSONRPCRequest(method="CancelTask",params={"taskId":"c1"}).to_dict())
|
||||
resp=json.loads(_run(s.handle_rpc(raw)));assert resp["result"]["status"]["state"]=="TASK_STATE_CANCELED"
|
||||
def test_cancel_terminal(self):
|
||||
s=A2AServer(_make_card());task=Task(id="d1",status=TaskStatus(state=TaskState.COMPLETED));s.add_task(task)
|
||||
raw=json.dumps(JSONRPCRequest(method="CancelTask",params={"taskId":"d1"}).to_dict())
|
||||
resp=json.loads(_run(s.handle_rpc(raw)));assert "error" in resp
|
||||
def test_card(self):
|
||||
s=A2AServer(_make_card());raw=json.dumps(JSONRPCRequest(method="GetAgentCard").to_dict())
|
||||
resp=json.loads(_run(s.handle_rpc(raw)));assert resp["result"]["name"]=="test"
|
||||
def test_list(self):
|
||||
s=A2AServer(_make_card());msg=Message(role="user",parts=[TextPart(text="hi")])
|
||||
raw=json.dumps(JSONRPCRequest(method="SendMessage",params={"message":msg.to_dict()}).to_dict())
|
||||
_run(s.handle_rpc(raw));raw2=json.dumps(JSONRPCRequest(method="ListTasks").to_dict())
|
||||
resp=json.loads(_run(s.handle_rpc(raw2)));assert len(resp["result"]["tasks"])>=1
|
||||
def test_unknown(self):
|
||||
s=A2AServer(_make_card());raw=json.dumps(JSONRPCRequest(method="Nope").to_dict())
|
||||
resp=json.loads(_run(s.handle_rpc(raw)));assert resp["error"]["code"]==-32601
|
||||
def test_invalid_json(self):
|
||||
s=A2AServer(_make_card());resp=json.loads(_run(s.handle_rpc("bad")));assert resp["error"]["code"]==-32700
|
||||
def test_custom_handler(self):
|
||||
s=A2AServer(_make_card())
|
||||
async def h(task,card):task.status=TaskStatus(state=TaskState.COMPLETED);task.artifacts=[Artifact(parts=[TextPart(text="custom")])];return task
|
||||
s.register_handler("echo",h)
|
||||
msg=Message(role="user",parts=[TextPart(text="t")])
|
||||
raw=json.dumps(JSONRPCRequest(method="SendMessage",params={"message":msg.to_dict(),"skillId":"echo"}).to_dict())
|
||||
resp=json.loads(_run(s.handle_rpc(raw)));assert resp["result"]["artifacts"][0]["parts"][0]["text"]=="custom"
|
||||
def test_audit(self):
|
||||
s=A2AServer(_make_card());raw=json.dumps(JSONRPCRequest(method="GetAgentCard").to_dict())
|
||||
_run(s.handle_rpc(raw));assert len(s.audit_log)==1;assert s.audit_log[0]["method"]=="GetAgentCard"
|
||||
|
||||
if __name__=="__main__":pytest.main([__file__,"-v"])
|
||||
Reference in New Issue
Block a user