diff --git a/tests/test_a2a.py b/tests/test_a2a.py new file mode 100644 index 000000000..3afd7858d --- /dev/null +++ b/tests/test_a2a.py @@ -0,0 +1,116 @@ +"""Tests for A2A protocol implementation.""" +import asyncio,json,pytest +from a2a.types import AgentCard,AgentSkill,Artifact,DataPart,FilePart,JSONRPCError,JSONRPCRequest,JSONRPCResponse,Message,Task,TaskState,TaskStatus,TextPart,A2AError,part_from_dict +from a2a.client import A2AClient,A2AClientConfig +from a2a.server import A2AServer + +class TestTextPart: + def test_roundtrip(self): + p=TextPart(text="hello",metadata={"k":"v"});d=p.to_dict();assert d=={"text":"hello","metadata":{"k":"v"}};p2=TextPart.from_dict(d);assert p2.text=="hello" + def test_no_metadata(self): + p=TextPart(text="hi");d=p.to_dict();assert "metadata" not in d + +class TestFilePart: + def test_inline(self): + p=FilePart(media_type="text/plain",raw="SGVsbG8=",filename="hello.txt");d=p.to_dict();assert d["raw"]=="SGVsbG8=";p2=FilePart.from_dict(d);assert p2.filename=="hello.txt" + def test_url(self): + p=FilePart(url="https://x.com/f");d=p.to_dict();assert d["url"]=="https://x.com/f" + +class TestDataPart: + def test_roundtrip(self): + p=DataPart(data={"key":42});d=p.to_dict();assert d["data"]=={"key":42} + +class TestPartDiscrimination: + def test_text(self):assert isinstance(part_from_dict({"text":"hi"}),TextPart) + def test_file_raw(self):assert isinstance(part_from_dict({"raw":"d","mediaType":"t"}),FilePart) + def test_file_url(self):assert isinstance(part_from_dict({"url":"https://x.com"}),FilePart) + def test_data(self):assert isinstance(part_from_dict({"data":{"a":1}}),DataPart) + def test_unknown(self): + with pytest.raises(ValueError):part_from_dict({"unknown":True}) + +class TestMessage: + def test_roundtrip(self): + m=Message(role="user",parts=[TextPart(text="hi")],context_id="c1");d=m.to_dict();assert d["role"]=="user";m2=Message.from_dict(d);assert m2.parts[0].text=="hi" + +class TestArtifact: + def test_roundtrip(self): + a=Artifact(name="r",parts=[TextPart(text="d")]);d=a.to_dict();assert d["name"]=="r" + +class TestTaskStatus: + def test_roundtrip(self): + s=TaskStatus(state=TaskState.COMPLETED);d=s.to_dict();assert d["state"]=="TASK_STATE_COMPLETED";s2=TaskStatus.from_dict(d);assert s2.state==TaskState.COMPLETED + def test_terminal(self): + assert TaskState.COMPLETED.terminal;assert TaskState.FAILED.terminal;assert not TaskState.SUBMITTED.terminal + +class TestTask: + def test_roundtrip(self): + t=Task(id="t1",status=TaskStatus(state=TaskState.WORKING));d=t.to_dict();assert d["id"]=="t1";t2=Task.from_dict(d);assert t2.status.state==TaskState.WORKING + +class TestAgentCard: + def test_roundtrip(self): + c=AgentCard(name="a",skills=[AgentSkill(id="s1",name="S")]);d=c.to_dict();assert d["name"]=="a";c2=AgentCard.from_dict(d);assert c2.skills[0].id=="s1" + +class TestJSONRPC: + def test_request(self): + r=JSONRPCRequest(method="GetTask",params={"taskId":"1"});d=r.to_dict();assert d["method"]=="GetTask" + def test_response_success(self): + r=JSONRPCResponse(id="1",result={"ok":True});d=r.to_dict();assert d["result"]["ok"]==True + def test_response_error(self): + r=JSONRPCResponse(id="1",error=A2AError.parse_error());d=r.to_dict();assert d["error"]["code"]==-32700 + def test_response_from_dict(self): + r=JSONRPCResponse.from_dict({"jsonrpc":"2.0","id":"1","result":{"ok":True}});assert r.result["ok"]==True + +class TestA2AError: + def test_codes(self): + assert A2AError.parse_error().code==-32700;assert A2AError.invalid_request().code==-32600 + assert A2AError.method_not_found().code==-32601;assert A2AError.task_not_found("x").code==-32001 + +def _make_card(): + return AgentCard(name="test",description="Test",url="http://localhost:9999/a2a/v1",skills=[AgentSkill(id="echo",name="Echo",tags=["test"])]) + +def _run(coro):return asyncio.get_event_loop().run_until_complete(coro) + +class TestServer: + def test_echo(self): + s=A2AServer(_make_card());msg=Message(role="user",parts=[TextPart(text="hello")]) + raw=json.dumps(JSONRPCRequest(method="SendMessage",params={"message":msg.to_dict()}).to_dict()) + resp=json.loads(_run(s.handle_rpc(raw)));assert "error" not in resp;assert resp["result"]["status"]["state"]=="TASK_STATE_COMPLETED" + def test_get_task(self): + s=A2AServer(_make_card());msg=Message(role="user",parts=[TextPart(text="hi")]) + raw=json.dumps(JSONRPCRequest(method="SendMessage",params={"message":msg.to_dict()}).to_dict()) + task_id=json.loads(_run(s.handle_rpc(raw)))["result"]["id"] + raw2=json.dumps(JSONRPCRequest(method="GetTask",params={"taskId":task_id}).to_dict()) + resp=json.loads(_run(s.handle_rpc(raw2)));assert resp["result"]["id"]==task_id + def test_cancel(self): + s=A2AServer(_make_card());task=Task(id="c1",status=TaskStatus(state=TaskState.WORKING));s.add_task(task) + raw=json.dumps(JSONRPCRequest(method="CancelTask",params={"taskId":"c1"}).to_dict()) + resp=json.loads(_run(s.handle_rpc(raw)));assert resp["result"]["status"]["state"]=="TASK_STATE_CANCELED" + def test_cancel_terminal(self): + s=A2AServer(_make_card());task=Task(id="d1",status=TaskStatus(state=TaskState.COMPLETED));s.add_task(task) + raw=json.dumps(JSONRPCRequest(method="CancelTask",params={"taskId":"d1"}).to_dict()) + resp=json.loads(_run(s.handle_rpc(raw)));assert "error" in resp + def test_card(self): + s=A2AServer(_make_card());raw=json.dumps(JSONRPCRequest(method="GetAgentCard").to_dict()) + resp=json.loads(_run(s.handle_rpc(raw)));assert resp["result"]["name"]=="test" + def test_list(self): + s=A2AServer(_make_card());msg=Message(role="user",parts=[TextPart(text="hi")]) + raw=json.dumps(JSONRPCRequest(method="SendMessage",params={"message":msg.to_dict()}).to_dict()) + _run(s.handle_rpc(raw));raw2=json.dumps(JSONRPCRequest(method="ListTasks").to_dict()) + resp=json.loads(_run(s.handle_rpc(raw2)));assert len(resp["result"]["tasks"])>=1 + def test_unknown(self): + s=A2AServer(_make_card());raw=json.dumps(JSONRPCRequest(method="Nope").to_dict()) + resp=json.loads(_run(s.handle_rpc(raw)));assert resp["error"]["code"]==-32601 + def test_invalid_json(self): + s=A2AServer(_make_card());resp=json.loads(_run(s.handle_rpc("bad")));assert resp["error"]["code"]==-32700 + def test_custom_handler(self): + s=A2AServer(_make_card()) + async def h(task,card):task.status=TaskStatus(state=TaskState.COMPLETED);task.artifacts=[Artifact(parts=[TextPart(text="custom")])];return task + s.register_handler("echo",h) + msg=Message(role="user",parts=[TextPart(text="t")]) + raw=json.dumps(JSONRPCRequest(method="SendMessage",params={"message":msg.to_dict(),"skillId":"echo"}).to_dict()) + resp=json.loads(_run(s.handle_rpc(raw)));assert resp["result"]["artifacts"][0]["parts"][0]["text"]=="custom" + def test_audit(self): + s=A2AServer(_make_card());raw=json.dumps(JSONRPCRequest(method="GetAgentCard").to_dict()) + _run(s.handle_rpc(raw));assert len(s.audit_log)==1;assert s.audit_log[0]["method"]=="GetAgentCard" + +if __name__=="__main__":pytest.main([__file__,"-v"])