61 lines
3.4 KiB
Python
61 lines
3.4 KiB
Python
"""A2A Server - receive and execute tasks via JSON-RPC 2.0."""
|
|
from __future__ import annotations
|
|
import json,logging,uuid
|
|
from datetime import datetime,timezone
|
|
from typing import Any,Callable,Dict,List,Optional,Awaitable
|
|
from a2a.types import AgentCard,Artifact,JSONRPCError,JSONRPCRequest,JSONRPCResponse,Message,Task,TaskState,TaskStatus,TextPart,A2AError,part_from_dict
|
|
logger=logging.getLogger(__name__)
|
|
TaskHandler=Callable[[Task,AgentCard],Awaitable[Task]]
|
|
class A2AServer:
|
|
def __init__(s,card):s.card=card;s._tasks={};s._handlers={};s._default_handler=None;s._audit_log=[]
|
|
def register_handler(s,skill_id,handler):s._handlers[skill_id]=handler
|
|
def set_default_handler(s,handler):s._default_handler=handler
|
|
async def handle_rpc(s,raw):
|
|
try:data=json.loads(raw)
|
|
except (json.JSONDecodeError,TypeError):return json.dumps(JSONRPCResponse(id="",error=A2AError.parse_error()).to_dict())
|
|
req_id=data.get("id","");method=data.get("method","");params=data.get("params")
|
|
s._audit_log.append({"method":method,"id":req_id,"ts":datetime.now(timezone.utc).isoformat()})
|
|
try:
|
|
if method=="SendMessage":result=await s._handle_send_message(params)
|
|
elif method=="GetTask":result=s._handle_get_task(params)
|
|
elif method=="CancelTask":result=s._handle_cancel_task(params)
|
|
elif method=="GetAgentCard":result=s.card.to_dict()
|
|
elif method=="ListTasks":result=s._handle_list_tasks(params)
|
|
else:return json.dumps(JSONRPCResponse(id=req_id,error=A2AError.method_not_found()).to_dict())
|
|
return json.dumps(JSONRPCResponse(id=req_id,result=result).to_dict())
|
|
except Exception as exc:
|
|
logger.exception("A2A handler error for %s",method)
|
|
return json.dumps(JSONRPCResponse(id=req_id,error=A2AError.internal_error(str(exc))).to_dict())
|
|
async def _handle_send_message(s,params):
|
|
if not params or "message" not in params:raise ValueError("SendMessage requires message param")
|
|
msg=Message.from_dict(params["message"])
|
|
task=Task(id=str(uuid.uuid4()),context_id=msg.context_id,status=TaskStatus(state=TaskState.SUBMITTED),history=[msg])
|
|
s._tasks[task.id]=task
|
|
skill_id=params.get("skillId");handler=s._handlers.get(skill_id) if skill_id else None
|
|
if handler is None:handler=s._default_handler
|
|
if handler is None:
|
|
text=msg.parts[0].text if msg.parts and hasattr(msg.parts[0],"text") else ""
|
|
task.status=TaskStatus(state=TaskState.COMPLETED)
|
|
task.artifacts=[Artifact(parts=[TextPart(text=f"Received: {text}")])]
|
|
else:
|
|
task.status=TaskStatus(state=TaskState.WORKING);s._tasks[task.id]=task
|
|
task=await handler(task,s.card);s._tasks[task.id]=task
|
|
return task.to_dict()
|
|
def _handle_get_task(s,params):
|
|
task_id=(params or {}).get("taskId")
|
|
if not task_id or task_id not in s._tasks:raise KeyError(f"Task not found: {task_id}")
|
|
return s._tasks[task_id].to_dict()
|
|
def _handle_cancel_task(s,params):
|
|
task_id=(params or {}).get("taskId")
|
|
if not task_id or task_id not in s._tasks:raise KeyError(f"Task not found: {task_id}")
|
|
task=s._tasks[task_id]
|
|
if task.status.state.terminal:raise ValueError(f"Task already {task.status.state.value}")
|
|
task.status=TaskStatus(state=TaskState.CANCELED);s._tasks[task_id]=task
|
|
return task.to_dict()
|
|
def _handle_list_tasks(s,params):
|
|
context_id=(params or {}).get("contextId")
|
|
return {"tasks":[t.to_dict() for t in s._tasks.values() if not context_id or t.context_id==context_id]}
|
|
def add_task(s,task):s._tasks[task.id]=task
|
|
@property
|
|
def audit_log(s):return list(s._audit_log)
|