Compare commits
7 Commits
burn/804-1
...
fix/796
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4bb12e05ef | ||
| db72e908f7 | |||
| b82b760d5d | |||
| d8d7846897 | |||
| 6840d05554 | |||
| 8abe59ed95 | |||
| 435d790201 |
@@ -1,22 +0,0 @@
|
||||
"""A2A — Agent2Agent Protocol v1.0 for Hermes task delegation.
|
||||
|
||||
Usage:
|
||||
from a2a import A2AClient, A2AServer, AgentCard, Task, TextPart
|
||||
"""
|
||||
|
||||
from a2a.types import (
|
||||
AgentCard, AgentSkill, Artifact, DataPart, FilePart,
|
||||
JSONRPCError, JSONRPCRequest, JSONRPCResponse,
|
||||
Message, Part, Task, TaskState, TaskStatus, TextPart, A2AError,
|
||||
part_from_dict,
|
||||
)
|
||||
from a2a.client import A2AClient, A2AClientConfig
|
||||
from a2a.server import A2AServer, TaskHandler
|
||||
|
||||
__all__ = [
|
||||
"A2AClient", "A2AClientConfig", "A2AServer", "TaskHandler",
|
||||
"AgentCard", "AgentSkill", "Artifact",
|
||||
"DataPart", "FilePart", "TextPart", "Part", "part_from_dict",
|
||||
"JSONRPCError", "JSONRPCRequest", "JSONRPCResponse",
|
||||
"Message", "Task", "TaskState", "TaskStatus", "A2AError",
|
||||
]
|
||||
@@ -1,98 +0,0 @@
|
||||
"""A2A Client - send tasks to remote agents via JSON-RPC 2.0."""
|
||||
|
||||
from __future__ import annotations
|
||||
import asyncio, json, logging, time
|
||||
from typing import Any, Dict, List, Optional
|
||||
import aiohttp
|
||||
from a2a.types import (AgentCard, Artifact, JSONRPCError, JSONRPCRequest, JSONRPCResponse, Message, Part, Task, TaskState, TaskStatus, TextPart, A2AError)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class A2AClientConfig:
|
||||
def __init__(self, timeout=30.0, max_retries=3, retry_delay=2.0, auth_token=None, auth_scheme="Bearer"):
|
||||
self.timeout = timeout
|
||||
self.max_retries = max_retries
|
||||
self.retry_delay = retry_delay
|
||||
self.auth_token = auth_token
|
||||
self.auth_scheme = auth_scheme
|
||||
|
||||
class A2AClient:
|
||||
def __init__(self, base_url, config=None):
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.config = config or A2AClientConfig()
|
||||
self._audit_log = []
|
||||
|
||||
def _headers(self):
|
||||
h = {"Content-Type": "application/json"}
|
||||
if self.config.auth_token:
|
||||
h["Authorization"] = f"{self.config.auth_scheme} {self.config.auth_token}"
|
||||
return h
|
||||
|
||||
async def _rpc_call(self, method, params=None):
|
||||
req = JSONRPCRequest(method=method, params=params)
|
||||
payload = json.dumps(req.to_dict())
|
||||
last_error = None
|
||||
for attempt in range(1, self.config.max_retries + 1):
|
||||
t0 = time.monotonic()
|
||||
try:
|
||||
timeout = aiohttp.ClientTimeout(total=self.config.timeout)
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
async with session.post(f"{self.base_url}/a2a/v1/rpc", data=payload, headers=self._headers()) as resp:
|
||||
body = await resp.text()
|
||||
elapsed = time.monotonic() - t0
|
||||
self._audit_log.append({"method": method, "params": params, "status": resp.status, "elapsed_ms": round(elapsed*1000), "attempt": attempt})
|
||||
if resp.status != 200:
|
||||
logger.warning("A2A %s -> HTTP %d (attempt %d/%d)", method, resp.status, attempt, self.config.max_retries)
|
||||
if attempt < self.config.max_retries:
|
||||
await asyncio.sleep(self.config.retry_delay * attempt)
|
||||
continue
|
||||
return JSONRPCResponse(id=req.id, error=A2AError.internal_error(f"HTTP {resp.status}: {body[:200]}"))
|
||||
return JSONRPCResponse.from_dict(json.loads(body))
|
||||
except Exception as exc:
|
||||
elapsed = time.monotonic() - t0
|
||||
last_error = exc
|
||||
self._audit_log.append({"method": method, "error": str(exc), "elapsed_ms": round(elapsed*1000), "attempt": attempt})
|
||||
if attempt < self.config.max_retries:
|
||||
await asyncio.sleep(self.config.retry_delay * attempt)
|
||||
continue
|
||||
return JSONRPCResponse(id=req.id, error=A2AError.internal_error(str(last_error)))
|
||||
|
||||
async def get_agent_card(self):
|
||||
resp = await self._rpc_call("GetAgentCard")
|
||||
if resp.error: raise RuntimeError(f"Failed to get agent card: {resp.error.message}")
|
||||
return AgentCard.from_dict(resp.result)
|
||||
|
||||
async def send_message(self, text, context_id=None, skill_id=None, metadata=None):
|
||||
msg = Message(role="user", parts=[TextPart(text=text)], context_id=context_id)
|
||||
params = {"message": msg.to_dict()}
|
||||
if skill_id: params["skillId"] = skill_id
|
||||
if metadata: params["metadata"] = metadata
|
||||
resp = await self._rpc_call("SendMessage", params)
|
||||
if resp.error: raise RuntimeError(f"SendMessage failed: {resp.error.message}")
|
||||
return Task.from_dict(resp.result)
|
||||
|
||||
async def get_task(self, task_id):
|
||||
resp = await self._rpc_call("GetTask", {"taskId": task_id})
|
||||
if resp.error: raise RuntimeError(f"GetTask failed: {resp.error.message}")
|
||||
return Task.from_dict(resp.result)
|
||||
|
||||
async def cancel_task(self, task_id):
|
||||
resp = await self._rpc_call("CancelTask", {"taskId": task_id})
|
||||
if resp.error: raise RuntimeError(f"CancelTask failed: {resp.error.message}")
|
||||
return Task.from_dict(resp.result)
|
||||
|
||||
async def wait_for_completion(self, task_id, poll_interval=2.0, timeout=300.0):
|
||||
t0 = time.monotonic()
|
||||
while True:
|
||||
task = await self.get_task(task_id)
|
||||
if task.status.state.terminal: return task
|
||||
if time.monotonic() - t0 > timeout: raise TimeoutError(f"Task {task_id} did not complete within {timeout}s")
|
||||
await asyncio.sleep(poll_interval)
|
||||
|
||||
async def delegate(self, text, skill_id=None, wait=True, timeout=300.0):
|
||||
task = await self.send_message(text, skill_id=skill_id)
|
||||
if wait: return await self.wait_for_completion(task.id, timeout=timeout)
|
||||
return task
|
||||
|
||||
@property
|
||||
def audit_log(self): return list(self._audit_log)
|
||||
@@ -1,60 +0,0 @@
|
||||
"""A2A Server - receive and execute tasks via JSON-RPC 2.0."""
|
||||
from __future__ import annotations
|
||||
import json,logging,uuid
|
||||
from datetime import datetime,timezone
|
||||
from typing import Any,Callable,Dict,List,Optional,Awaitable
|
||||
from a2a.types import AgentCard,Artifact,JSONRPCError,JSONRPCRequest,JSONRPCResponse,Message,Task,TaskState,TaskStatus,TextPart,A2AError,part_from_dict
|
||||
logger=logging.getLogger(__name__)
|
||||
TaskHandler=Callable[[Task,AgentCard],Awaitable[Task]]
|
||||
class A2AServer:
|
||||
def __init__(s,card):s.card=card;s._tasks={};s._handlers={};s._default_handler=None;s._audit_log=[]
|
||||
def register_handler(s,skill_id,handler):s._handlers[skill_id]=handler
|
||||
def set_default_handler(s,handler):s._default_handler=handler
|
||||
async def handle_rpc(s,raw):
|
||||
try:data=json.loads(raw)
|
||||
except (json.JSONDecodeError,TypeError):return json.dumps(JSONRPCResponse(id="",error=A2AError.parse_error()).to_dict())
|
||||
req_id=data.get("id","");method=data.get("method","");params=data.get("params")
|
||||
s._audit_log.append({"method":method,"id":req_id,"ts":datetime.now(timezone.utc).isoformat()})
|
||||
try:
|
||||
if method=="SendMessage":result=await s._handle_send_message(params)
|
||||
elif method=="GetTask":result=s._handle_get_task(params)
|
||||
elif method=="CancelTask":result=s._handle_cancel_task(params)
|
||||
elif method=="GetAgentCard":result=s.card.to_dict()
|
||||
elif method=="ListTasks":result=s._handle_list_tasks(params)
|
||||
else:return json.dumps(JSONRPCResponse(id=req_id,error=A2AError.method_not_found()).to_dict())
|
||||
return json.dumps(JSONRPCResponse(id=req_id,result=result).to_dict())
|
||||
except Exception as exc:
|
||||
logger.exception("A2A handler error for %s",method)
|
||||
return json.dumps(JSONRPCResponse(id=req_id,error=A2AError.internal_error(str(exc))).to_dict())
|
||||
async def _handle_send_message(s,params):
|
||||
if not params or "message" not in params:raise ValueError("SendMessage requires message param")
|
||||
msg=Message.from_dict(params["message"])
|
||||
task=Task(id=str(uuid.uuid4()),context_id=msg.context_id,status=TaskStatus(state=TaskState.SUBMITTED),history=[msg])
|
||||
s._tasks[task.id]=task
|
||||
skill_id=params.get("skillId");handler=s._handlers.get(skill_id) if skill_id else None
|
||||
if handler is None:handler=s._default_handler
|
||||
if handler is None:
|
||||
text=msg.parts[0].text if msg.parts and hasattr(msg.parts[0],"text") else ""
|
||||
task.status=TaskStatus(state=TaskState.COMPLETED)
|
||||
task.artifacts=[Artifact(parts=[TextPart(text=f"Received: {text}")])]
|
||||
else:
|
||||
task.status=TaskStatus(state=TaskState.WORKING);s._tasks[task.id]=task
|
||||
task=await handler(task,s.card);s._tasks[task.id]=task
|
||||
return task.to_dict()
|
||||
def _handle_get_task(s,params):
|
||||
task_id=(params or {}).get("taskId")
|
||||
if not task_id or task_id not in s._tasks:raise KeyError(f"Task not found: {task_id}")
|
||||
return s._tasks[task_id].to_dict()
|
||||
def _handle_cancel_task(s,params):
|
||||
task_id=(params or {}).get("taskId")
|
||||
if not task_id or task_id not in s._tasks:raise KeyError(f"Task not found: {task_id}")
|
||||
task=s._tasks[task_id]
|
||||
if task.status.state.terminal:raise ValueError(f"Task already {task.status.state.value}")
|
||||
task.status=TaskStatus(state=TaskState.CANCELED);s._tasks[task_id]=task
|
||||
return task.to_dict()
|
||||
def _handle_list_tasks(s,params):
|
||||
context_id=(params or {}).get("contextId")
|
||||
return {"tasks":[t.to_dict() for t in s._tasks.values() if not context_id or t.context_id==context_id]}
|
||||
def add_task(s,task):s._tasks[task.id]=task
|
||||
@property
|
||||
def audit_log(s):return list(s._audit_log)
|
||||
230
a2a/types.py
230
a2a/types.py
@@ -1,230 +0,0 @@
|
||||
"""A2A Protocol Types - Agent2Agent v1.0 data structures."""
|
||||
|
||||
from __future__ import annotations
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
class TaskState(str, Enum):
|
||||
SUBMITTED = "TASK_STATE_SUBMITTED"
|
||||
WORKING = "TASK_STATE_WORKING"
|
||||
INPUT_REQUIRED = "TASK_STATE_INPUT_REQUIRED"
|
||||
COMPLETED = "TASK_STATE_COMPLETED"
|
||||
FAILED = "TASK_STATE_FAILED"
|
||||
CANCELED = "TASK_STATE_CANCELED"
|
||||
REJECTED = "TASK_STATE_REJECTED"
|
||||
@property
|
||||
def terminal(self) -> bool:
|
||||
return self in {TaskState.COMPLETED, TaskState.FAILED, TaskState.CANCELED, TaskState.REJECTED}
|
||||
|
||||
@dataclass
|
||||
class TextPart:
|
||||
text: str
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
def to_dict(self) -> dict:
|
||||
d = {"text": self.text}
|
||||
if self.metadata: d["metadata"] = self.metadata
|
||||
return d
|
||||
@classmethod
|
||||
def from_dict(cls, d): return cls(text=d["text"], metadata=d.get("metadata"))
|
||||
|
||||
@dataclass
|
||||
class FilePart:
|
||||
media_type: str = "application/octet-stream"
|
||||
raw: Optional[str] = None
|
||||
url: Optional[str] = None
|
||||
filename: Optional[str] = None
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
def to_dict(self) -> dict:
|
||||
d = {"mediaType": self.media_type}
|
||||
if self.raw is not None: d["raw"] = self.raw
|
||||
if self.url is not None: d["url"] = self.url
|
||||
if self.filename: d["filename"] = self.filename
|
||||
if self.metadata: d["metadata"] = self.metadata
|
||||
return d
|
||||
@classmethod
|
||||
def from_dict(cls, d):
|
||||
return cls(media_type=d.get("mediaType","application/octet-stream"), raw=d.get("raw"), url=d.get("url"), filename=d.get("filename"), metadata=d.get("metadata"))
|
||||
|
||||
@dataclass
|
||||
class DataPart:
|
||||
data: Dict[str, Any]
|
||||
media_type: str = "application/json"
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
def to_dict(self) -> dict:
|
||||
d = {"data": self.data, "mediaType": self.media_type}
|
||||
if self.metadata: d["metadata"] = self.metadata
|
||||
return d
|
||||
@classmethod
|
||||
def from_dict(cls, d): return cls(data=d["data"], media_type=d.get("mediaType","application/json"), metadata=d.get("metadata"))
|
||||
|
||||
Part = Union[TextPart, FilePart, DataPart]
|
||||
|
||||
def part_from_dict(d):
|
||||
if "text" in d: return TextPart.from_dict(d)
|
||||
if "raw" in d or "url" in d: return FilePart.from_dict(d)
|
||||
if "data" in d: return DataPart.from_dict(d)
|
||||
raise ValueError(f"Cannot discriminate Part type from keys: {list(d.keys())}")
|
||||
|
||||
@dataclass
|
||||
class Message:
|
||||
role: str
|
||||
parts: List[Part]
|
||||
message_id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||
context_id: Optional[str] = None
|
||||
task_id: Optional[str] = None
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
def to_dict(self):
|
||||
d = {"role": self.role, "messageId": self.message_id, "parts": [p.to_dict() for p in self.parts]}
|
||||
if self.context_id: d["contextId"] = self.context_id
|
||||
if self.task_id: d["taskId"] = self.task_id
|
||||
if self.metadata: d["metadata"] = self.metadata
|
||||
return d
|
||||
@classmethod
|
||||
def from_dict(cls, d):
|
||||
return cls(role=d["role"], parts=[part_from_dict(p) for p in d["parts"]], message_id=d.get("messageId",str(uuid.uuid4())), context_id=d.get("contextId"), task_id=d.get("taskId"), metadata=d.get("metadata"))
|
||||
|
||||
@dataclass
|
||||
class Artifact:
|
||||
artifact_id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||
parts: List[Part] = field(default_factory=list)
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
def to_dict(self):
|
||||
d = {"artifactId": self.artifact_id, "parts": [p.to_dict() for p in self.parts]}
|
||||
if self.name: d["name"] = self.name
|
||||
if self.description: d["description"] = self.description
|
||||
if self.metadata: d["metadata"] = self.metadata
|
||||
return d
|
||||
@classmethod
|
||||
def from_dict(cls, d):
|
||||
return cls(artifact_id=d.get("artifactId",str(uuid.uuid4())), parts=[part_from_dict(p) for p in d.get("parts",[])], name=d.get("name"), description=d.get("description"), metadata=d.get("metadata"))
|
||||
|
||||
@dataclass
|
||||
class TaskStatus:
|
||||
state: TaskState
|
||||
message: Optional[Message] = None
|
||||
timestamp: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
|
||||
def to_dict(self):
|
||||
d = {"state": self.state.value, "timestamp": self.timestamp}
|
||||
if self.message: d["message"] = self.message.to_dict()
|
||||
return d
|
||||
@classmethod
|
||||
def from_dict(cls, d):
|
||||
msg = d.get("message")
|
||||
return cls(state=TaskState(d["state"]), message=Message.from_dict(msg) if msg else None, timestamp=d.get("timestamp",datetime.now(timezone.utc).isoformat()))
|
||||
|
||||
@dataclass
|
||||
class Task:
|
||||
id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||
context_id: Optional[str] = None
|
||||
status: TaskStatus = field(default_factory=lambda: TaskStatus(state=TaskState.SUBMITTED))
|
||||
artifacts: List[Artifact] = field(default_factory=list)
|
||||
history: List[Message] = field(default_factory=list)
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
def to_dict(self):
|
||||
d = {"id": self.id, "status": self.status.to_dict()}
|
||||
if self.context_id: d["contextId"] = self.context_id
|
||||
if self.artifacts: d["artifacts"] = [a.to_dict() for a in self.artifacts]
|
||||
if self.history: d["history"] = [m.to_dict() for m in self.history]
|
||||
if self.metadata: d["metadata"] = self.metadata
|
||||
return d
|
||||
@classmethod
|
||||
def from_dict(cls, d):
|
||||
return cls(id=d.get("id",str(uuid.uuid4())), context_id=d.get("contextId"), status=TaskStatus.from_dict(d["status"]) if "status" in d else TaskStatus(TaskState.SUBMITTED), artifacts=[Artifact.from_dict(a) for a in d.get("artifacts",[])], history=[Message.from_dict(m) for m in d.get("history",[])], metadata=d.get("metadata"))
|
||||
|
||||
@dataclass
|
||||
class AgentSkill:
|
||||
id: str
|
||||
name: str
|
||||
description: str = ""
|
||||
tags: List[str] = field(default_factory=list)
|
||||
examples: List[str] = field(default_factory=list)
|
||||
input_modes: List[str] = field(default_factory=lambda: ["text"])
|
||||
output_modes: List[str] = field(default_factory=lambda: ["text"])
|
||||
def to_dict(self):
|
||||
d = {"id": self.id, "name": self.name, "description": self.description}
|
||||
if self.tags: d["tags"] = self.tags
|
||||
if self.examples: d["examples"] = self.examples
|
||||
if self.input_modes != ["text"]: d["inputModes"] = self.input_modes
|
||||
if self.output_modes != ["text"]: d["outputModes"] = self.output_modes
|
||||
return d
|
||||
@classmethod
|
||||
def from_dict(cls, d):
|
||||
return cls(id=d["id"], name=d.get("name",d["id"]), description=d.get("description",""), tags=d.get("tags",[]), examples=d.get("examples",[]), input_modes=d.get("inputModes",["text"]), output_modes=d.get("outputModes",["text"]))
|
||||
|
||||
@dataclass
|
||||
class AgentCard:
|
||||
name: str
|
||||
description: str = ""
|
||||
version: str = "1.0.0"
|
||||
url: str = ""
|
||||
skills: List[AgentSkill] = field(default_factory=list)
|
||||
capabilities: Dict[str, bool] = field(default_factory=dict)
|
||||
provider: Optional[Dict[str, str]] = None
|
||||
def to_dict(self):
|
||||
d = {"name": self.name, "description": self.description, "version": self.version, "url": self.url, "skills": [s.to_dict() for s in self.skills]}
|
||||
if self.capabilities: d["capabilities"] = self.capabilities
|
||||
if self.provider: d["provider"] = self.provider
|
||||
return d
|
||||
@classmethod
|
||||
def from_dict(cls, d):
|
||||
return cls(name=d["name"], description=d.get("description",""), version=d.get("version","1.0.0"), url=d.get("url",""), skills=[AgentSkill.from_dict(s) for s in d.get("skills",[])], capabilities=d.get("capabilities",{}), provider=d.get("provider"))
|
||||
|
||||
@dataclass
|
||||
class JSONRPCError:
|
||||
code: int
|
||||
message: str
|
||||
data: Optional[Any] = None
|
||||
def to_dict(self):
|
||||
d = {"code": self.code, "message": self.message}
|
||||
if self.data is not None: d["data"] = self.data
|
||||
return d
|
||||
|
||||
@dataclass
|
||||
class JSONRPCRequest:
|
||||
method: str
|
||||
params: Optional[Dict[str, Any]] = None
|
||||
id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||
jsonrpc: str = "2.0"
|
||||
def to_dict(self):
|
||||
d = {"jsonrpc": self.jsonrpc, "method": self.method, "id": self.id}
|
||||
if self.params is not None: d["params"] = self.params
|
||||
return d
|
||||
|
||||
@dataclass
|
||||
class JSONRPCResponse:
|
||||
id: str
|
||||
result: Optional[Any] = None
|
||||
error: Optional[JSONRPCError] = None
|
||||
jsonrpc: str = "2.0"
|
||||
def to_dict(self):
|
||||
d = {"jsonrpc": self.jsonrpc, "id": self.id}
|
||||
if self.error: d["error"] = self.error.to_dict()
|
||||
else: d["result"] = self.result
|
||||
return d
|
||||
@classmethod
|
||||
def from_dict(cls, d):
|
||||
err = d.get("error")
|
||||
return cls(id=d["id"], result=d.get("result"), error=JSONRPCError(err["code"],err["message"],err.get("data")) if err else None)
|
||||
|
||||
class A2AError:
|
||||
@staticmethod
|
||||
def parse_error(data=None): return JSONRPCError(-32700, "Parse error", data)
|
||||
@staticmethod
|
||||
def invalid_request(data=None): return JSONRPCError(-32600, "Invalid Request", data)
|
||||
@staticmethod
|
||||
def method_not_found(data=None): return JSONRPCError(-32601, "Method not found", data)
|
||||
@staticmethod
|
||||
def invalid_params(data=None): return JSONRPCError(-32602, "Invalid params", data)
|
||||
@staticmethod
|
||||
def internal_error(data=None): return JSONRPCError(-32603, "Internal error", data)
|
||||
@staticmethod
|
||||
def task_not_found(task_id): return JSONRPCError(-32001, f"Task not found: {task_id}")
|
||||
@staticmethod
|
||||
def task_not_cancelable(task_id): return JSONRPCError(-32002, f"Task not cancelable: {task_id}")
|
||||
@staticmethod
|
||||
def agent_not_found(name): return JSONRPCError(-32009, f"Agent not found: {name}")
|
||||
353
agent/privacy_filter.py
Normal file
353
agent/privacy_filter.py
Normal file
@@ -0,0 +1,353 @@
|
||||
"""Privacy Filter — strip PII from context before remote API calls.
|
||||
|
||||
Implements Vitalik's Pattern 2: "A local model can strip out private data
|
||||
before passing the query along to a remote LLM."
|
||||
|
||||
When Hermes routes a request to a cloud provider (Anthropic, OpenRouter, etc.),
|
||||
this module sanitizes the message context to remove personally identifiable
|
||||
information before it leaves the user's machine.
|
||||
|
||||
Threat model (from Vitalik's secure LLM architecture):
|
||||
- Privacy (other): Non-LLM data leakage via search queries, API calls
|
||||
- LLM accidents: LLM accidentally leaking private data in prompts
|
||||
- LLM jailbreaks: Remote content extracting private context
|
||||
|
||||
Usage:
|
||||
from agent.privacy_filter import PrivacyFilter, sanitize_messages
|
||||
|
||||
pf = PrivacyFilter()
|
||||
safe_messages = pf.sanitize_messages(messages)
|
||||
# safe_messages has PII replaced with [REDACTED] tokens
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum, auto
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Sensitivity(Enum):
|
||||
"""Classification of content sensitivity."""
|
||||
PUBLIC = auto() # No PII detected
|
||||
LOW = auto() # Generic references (e.g., city names)
|
||||
MEDIUM = auto() # Personal identifiers (name, email, phone)
|
||||
HIGH = auto() # Secrets, keys, financial data, medical info
|
||||
CRITICAL = auto() # Crypto keys, passwords, SSN patterns
|
||||
|
||||
|
||||
@dataclass
|
||||
class RedactionReport:
|
||||
"""Summary of what was redacted from a message batch."""
|
||||
total_messages: int = 0
|
||||
redacted_messages: int = 0
|
||||
redactions: List[Dict[str, Any]] = field(default_factory=list)
|
||||
max_sensitivity: Sensitivity = Sensitivity.PUBLIC
|
||||
|
||||
@property
|
||||
def had_redactions(self) -> bool:
|
||||
return self.redacted_messages > 0
|
||||
|
||||
def summary(self) -> str:
|
||||
if not self.had_redactions:
|
||||
return "No PII detected — context is clean for remote query."
|
||||
parts = [f"Redacted {self.redacted_messages}/{self.total_messages} messages:"]
|
||||
for r in self.redactions[:10]:
|
||||
parts.append(f" - {r['type']}: {r['count']} occurrence(s)")
|
||||
if len(self.redactions) > 10:
|
||||
parts.append(f" ... and {len(self.redactions) - 10} more types")
|
||||
return "\n".join(parts)
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# PII pattern definitions
|
||||
# =========================================================================
|
||||
|
||||
# Each pattern is (compiled_regex, redaction_type, sensitivity_level, replacement)
|
||||
_PII_PATTERNS: List[Tuple[re.Pattern, str, Sensitivity, str]] = []
|
||||
|
||||
|
||||
def _compile_patterns() -> None:
|
||||
"""Compile PII detection patterns. Called once at module init."""
|
||||
global _PII_PATTERNS
|
||||
if _PII_PATTERNS:
|
||||
return
|
||||
|
||||
raw_patterns = [
|
||||
# --- CRITICAL: secrets and credentials ---
|
||||
(
|
||||
r'(?:api[_-]?key|apikey|secret[_-]?key|access[_-]?token)\s*[:=]\s*["\']?([A-Za-z0-9_\-\.]{20,})["\']?',
|
||||
"api_key_or_token",
|
||||
Sensitivity.CRITICAL,
|
||||
"[REDACTED-API-KEY]",
|
||||
),
|
||||
(
|
||||
r'\b(?:sk-|sk_|pk_|rk_|ak_)[A-Za-z0-9]{20,}\b',
|
||||
"prefixed_secret",
|
||||
Sensitivity.CRITICAL,
|
||||
"[REDACTED-SECRET]",
|
||||
),
|
||||
(
|
||||
r'\b(?:ghp_|gho_|ghu_|ghs_|ghr_)[A-Za-z0-9]{36,}\b',
|
||||
"github_token",
|
||||
Sensitivity.CRITICAL,
|
||||
"[REDACTED-GITHUB-TOKEN]",
|
||||
),
|
||||
(
|
||||
r'\b(?:xox[bposa]-[A-Za-z0-9\-]+)\b',
|
||||
"slack_token",
|
||||
Sensitivity.CRITICAL,
|
||||
"[REDACTED-SLACK-TOKEN]",
|
||||
),
|
||||
(
|
||||
r'(?:password|passwd|pwd)\s*[:=]\s*["\']?([^\s"\']{4,})["\']?',
|
||||
"password",
|
||||
Sensitivity.CRITICAL,
|
||||
"[REDACTED-PASSWORD]",
|
||||
),
|
||||
(
|
||||
r'(?:-----BEGIN (?:RSA |EC |OPENSSH )?PRIVATE KEY-----)',
|
||||
"private_key_block",
|
||||
Sensitivity.CRITICAL,
|
||||
"[REDACTED-PRIVATE-KEY]",
|
||||
),
|
||||
# Ethereum / crypto addresses (42-char hex starting with 0x)
|
||||
(
|
||||
r'\b0x[a-fA-F0-9]{40}\b',
|
||||
"ethereum_address",
|
||||
Sensitivity.HIGH,
|
||||
"[REDACTED-ETH-ADDR]",
|
||||
),
|
||||
# Bitcoin addresses (base58, 25-34 chars starting with 1/3/bc1)
|
||||
(
|
||||
r'\b[13][a-km-zA-HJ-NP-Z1-9]{25,34}\b',
|
||||
"bitcoin_address",
|
||||
Sensitivity.HIGH,
|
||||
"[REDACTED-BTC-ADDR]",
|
||||
),
|
||||
(
|
||||
r'\bbc1[a-zA-HJ-NP-Z0-9]{39,59}\b',
|
||||
"bech32_address",
|
||||
Sensitivity.HIGH,
|
||||
"[REDACTED-BTC-ADDR]",
|
||||
),
|
||||
# --- HIGH: financial ---
|
||||
(
|
||||
r'\b(?:\d{4}[-\s]?){3}\d{4}\b',
|
||||
"credit_card_number",
|
||||
Sensitivity.HIGH,
|
||||
"[REDACTED-CC]",
|
||||
),
|
||||
(
|
||||
r'\b\d{3}-\d{2}-\d{4}\b',
|
||||
"us_ssn",
|
||||
Sensitivity.HIGH,
|
||||
"[REDACTED-SSN]",
|
||||
),
|
||||
# --- MEDIUM: personal identifiers ---
|
||||
# Email addresses
|
||||
(
|
||||
r'\b[A-Za-z0-9._%+\-]+@[A-Za-z0-9.\-]+\.[A-Za-z]{2,}\b',
|
||||
"email_address",
|
||||
Sensitivity.MEDIUM,
|
||||
"[REDACTED-EMAIL]",
|
||||
),
|
||||
# Phone numbers (US/international patterns)
|
||||
(
|
||||
r'\b(?:\+?1[-.\s]?)?\(?\d{3}\)?[-.\s]?\d{3}[-.\s]?\d{4}\b',
|
||||
"phone_number_us",
|
||||
Sensitivity.MEDIUM,
|
||||
"[REDACTED-PHONE]",
|
||||
),
|
||||
(
|
||||
r'\b\+\d{1,3}[-.\s]?\d{4,14}\b',
|
||||
"phone_number_intl",
|
||||
Sensitivity.MEDIUM,
|
||||
"[REDACTED-PHONE]",
|
||||
),
|
||||
# Filesystem paths that reveal user identity
|
||||
(
|
||||
r'(?:/Users/|/home/|C:\\Users\\)([A-Za-z0-9_\-]+)',
|
||||
"user_home_path",
|
||||
Sensitivity.MEDIUM,
|
||||
r"/Users/[REDACTED-USER]",
|
||||
),
|
||||
# --- LOW: environment / system info ---
|
||||
# Internal IPs
|
||||
(
|
||||
r'\b(?:10\.\d{1,3}\.\d{1,3}\.\d{1,3}|172\.(?:1[6-9]|2\d|3[01])\.\d{1,3}\.\d{1,3}|192\.168\.\d{1,3}\.\d{1,3})\b',
|
||||
"internal_ip",
|
||||
Sensitivity.LOW,
|
||||
"[REDACTED-IP]",
|
||||
),
|
||||
]
|
||||
|
||||
_PII_PATTERNS = [
|
||||
(re.compile(pattern, re.IGNORECASE), rtype, sensitivity, replacement)
|
||||
for pattern, rtype, sensitivity, replacement in raw_patterns
|
||||
]
|
||||
|
||||
|
||||
_compile_patterns()
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Sensitive file path patterns (context-aware)
|
||||
# =========================================================================
|
||||
|
||||
_SENSITIVE_PATH_PATTERNS = [
|
||||
re.compile(r'\.(?:env|pem|key|p12|pfx|jks|keystore)\b', re.IGNORECASE),
|
||||
re.compile(r'(?:\.ssh/|\.gnupg/|\.aws/|\.config/gcloud/)', re.IGNORECASE),
|
||||
re.compile(r'(?:wallet|keystore|seed|mnemonic)', re.IGNORECASE),
|
||||
re.compile(r'(?:\.hermes/\.env)', re.IGNORECASE),
|
||||
]
|
||||
|
||||
|
||||
def _classify_path_sensitivity(path: str) -> Sensitivity:
|
||||
"""Check if a file path references sensitive material."""
|
||||
for pat in _SENSITIVE_PATH_PATTERNS:
|
||||
if pat.search(path):
|
||||
return Sensitivity.HIGH
|
||||
return Sensitivity.PUBLIC
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Core filtering
|
||||
# =========================================================================
|
||||
|
||||
class PrivacyFilter:
|
||||
"""Strip PII from message context before remote API calls.
|
||||
|
||||
Integrates with the agent's message pipeline. Call sanitize_messages()
|
||||
before sending context to any cloud LLM provider.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
min_sensitivity: Sensitivity = Sensitivity.MEDIUM,
|
||||
aggressive_mode: bool = False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
min_sensitivity: Only redact PII at or above this level.
|
||||
Default MEDIUM — redacts emails, phones, paths but not IPs.
|
||||
aggressive_mode: If True, also redact file paths and internal IPs.
|
||||
"""
|
||||
self.min_sensitivity = (
|
||||
Sensitivity.LOW if aggressive_mode else min_sensitivity
|
||||
)
|
||||
self.aggressive_mode = aggressive_mode
|
||||
|
||||
def sanitize_text(self, text: str) -> Tuple[str, List[Dict[str, Any]]]:
|
||||
"""Sanitize a single text string. Returns (cleaned_text, redaction_list)."""
|
||||
redactions = []
|
||||
cleaned = text
|
||||
|
||||
for pattern, rtype, sensitivity, replacement in _PII_PATTERNS:
|
||||
if sensitivity.value < self.min_sensitivity.value:
|
||||
continue
|
||||
|
||||
matches = pattern.findall(cleaned)
|
||||
if matches:
|
||||
count = len(matches) if isinstance(matches[0], str) else sum(
|
||||
1 for m in matches if m
|
||||
)
|
||||
if count > 0:
|
||||
cleaned = pattern.sub(replacement, cleaned)
|
||||
redactions.append({
|
||||
"type": rtype,
|
||||
"sensitivity": sensitivity.name,
|
||||
"count": count,
|
||||
})
|
||||
|
||||
return cleaned, redactions
|
||||
|
||||
def sanitize_messages(
|
||||
self, messages: List[Dict[str, Any]]
|
||||
) -> Tuple[List[Dict[str, Any]], RedactionReport]:
|
||||
"""Sanitize a list of OpenAI-format messages.
|
||||
|
||||
Returns (safe_messages, report). System messages are NOT sanitized
|
||||
(they're typically static prompts). Only user and assistant messages
|
||||
with string content are processed.
|
||||
|
||||
Args:
|
||||
messages: List of {"role": ..., "content": ...} dicts.
|
||||
|
||||
Returns:
|
||||
Tuple of (sanitized_messages, redaction_report).
|
||||
"""
|
||||
report = RedactionReport(total_messages=len(messages))
|
||||
safe_messages = []
|
||||
|
||||
for msg in messages:
|
||||
role = msg.get("role", "")
|
||||
content = msg.get("content", "")
|
||||
|
||||
# Only sanitize user/assistant string content
|
||||
if role in ("user", "assistant") and isinstance(content, str) and content:
|
||||
cleaned, redactions = self.sanitize_text(content)
|
||||
if redactions:
|
||||
report.redacted_messages += 1
|
||||
report.redactions.extend(redactions)
|
||||
# Track max sensitivity
|
||||
for r in redactions:
|
||||
s = Sensitivity[r["sensitivity"]]
|
||||
if s.value > report.max_sensitivity.value:
|
||||
report.max_sensitivity = s
|
||||
safe_msg = {**msg, "content": cleaned}
|
||||
safe_messages.append(safe_msg)
|
||||
logger.info(
|
||||
"Privacy filter: redacted %d PII type(s) from %s message",
|
||||
len(redactions), role,
|
||||
)
|
||||
else:
|
||||
safe_messages.append(msg)
|
||||
else:
|
||||
safe_messages.append(msg)
|
||||
|
||||
return safe_messages, report
|
||||
|
||||
def should_use_local_only(self, text: str) -> Tuple[bool, str]:
|
||||
"""Determine if content is too sensitive for any remote call.
|
||||
|
||||
Returns (should_block, reason). If True, the content should only
|
||||
be processed by a local model.
|
||||
"""
|
||||
_, redactions = self.sanitize_text(text)
|
||||
|
||||
critical_count = sum(
|
||||
1 for r in redactions
|
||||
if Sensitivity[r["sensitivity"]] == Sensitivity.CRITICAL
|
||||
)
|
||||
high_count = sum(
|
||||
1 for r in redactions
|
||||
if Sensitivity[r["sensitivity"]] == Sensitivity.HIGH
|
||||
)
|
||||
|
||||
if critical_count > 0:
|
||||
return True, f"Contains {critical_count} critical-secret pattern(s) — local-only"
|
||||
if high_count >= 3:
|
||||
return True, f"Contains {high_count} high-sensitivity pattern(s) — local-only"
|
||||
return False, ""
|
||||
|
||||
|
||||
def sanitize_messages(
|
||||
messages: List[Dict[str, Any]],
|
||||
min_sensitivity: Sensitivity = Sensitivity.MEDIUM,
|
||||
aggressive: bool = False,
|
||||
) -> Tuple[List[Dict[str, Any]], RedactionReport]:
|
||||
"""Convenience function: sanitize messages with default settings."""
|
||||
pf = PrivacyFilter(min_sensitivity=min_sensitivity, aggressive_mode=aggressive)
|
||||
return pf.sanitize_messages(messages)
|
||||
|
||||
|
||||
def quick_sanitize(text: str) -> str:
|
||||
"""Quick sanitize a single string — returns cleaned text only."""
|
||||
pf = PrivacyFilter()
|
||||
cleaned, _ = pf.sanitize_text(text)
|
||||
return cleaned
|
||||
461
benchmarks/tool_call_benchmark.py
Executable file
461
benchmarks/tool_call_benchmark.py
Executable file
@@ -0,0 +1,461 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
tool_call_benchmark.py — Benchmark Gemma 4 tool calling vs mimo-v2-pro.
|
||||
|
||||
Runs 100 diverse tool calling prompts through each model and compares:
|
||||
- Schema parse success rate
|
||||
- Tool execution success rate
|
||||
- Parallel tool call success rate
|
||||
- Average latency
|
||||
- Token cost per call
|
||||
|
||||
Usage:
|
||||
python3 benchmarks/tool_call_benchmark.py --model1 gemma3:27b --model2 xiaomi/mimo-v2-pro
|
||||
python3 benchmarks/tool_call_benchmark.py --model1 gemma3:27b --limit 10 # quick test
|
||||
python3 benchmarks/tool_call_benchmark.py --output benchmarks/results.json
|
||||
|
||||
Requires:
|
||||
- Ollama running locally (or --endpoint for remote)
|
||||
- Models pulled: ollama pull gemma3:27b, etc.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import urllib.request
|
||||
import urllib.error
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from dataclasses import dataclass, field, asdict
|
||||
from typing import List, Dict, Optional
|
||||
|
||||
ENDPOINT = os.environ.get("OPENAI_BASE_URL", "http://localhost:11434/v1")
|
||||
API_KEY = os.environ.get("OPENAI_API_KEY", "ollama")
|
||||
|
||||
# ── Tool schemas (subset for benchmarking) ──────────────────────────────
|
||||
|
||||
TOOL_SCHEMAS = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "read_file",
|
||||
"description": "Read a text file",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string", "description": "File path"},
|
||||
"offset": {"type": "integer", "description": "Start line"},
|
||||
"limit": {"type": "integer", "description": "Max lines"}
|
||||
},
|
||||
"required": ["path"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "terminal",
|
||||
"description": "Execute a shell command",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"command": {"type": "string", "description": "Shell command"}
|
||||
},
|
||||
"required": ["command"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "write_file",
|
||||
"description": "Write content to a file",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string"},
|
||||
"content": {"type": "string"}
|
||||
},
|
||||
"required": ["path", "content"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "search_files",
|
||||
"description": "Search for content in files",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"pattern": {"type": "string"},
|
||||
"path": {"type": "string"}
|
||||
},
|
||||
"required": ["pattern"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "web_search",
|
||||
"description": "Search the web",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string"}
|
||||
},
|
||||
"required": ["query"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "execute_code",
|
||||
"description": "Execute Python code",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"code": {"type": "string"}
|
||||
},
|
||||
"required": ["code"]
|
||||
}
|
||||
}
|
||||
},
|
||||
]
|
||||
|
||||
SYSTEM_PROMPT = "You are a helpful assistant with access to tools. Use tools when needed."
|
||||
|
||||
# ── Test prompts (100 diverse tool calling scenarios) ────────────────────
|
||||
|
||||
TEST_PROMPTS = [
|
||||
# File operations (20)
|
||||
("Read the README.md file", "read_file", "file_ops"),
|
||||
("Show me the contents of config.yaml", "read_file", "file_ops"),
|
||||
("Read lines 10-20 of main.py", "read_file", "file_ops"),
|
||||
("Open the package.json", "read_file", "file_ops"),
|
||||
("Read the .gitignore file", "read_file", "file_ops"),
|
||||
("Save this to notes.txt: meeting at 3pm", "write_file", "file_ops"),
|
||||
("Create a new file hello.py with print hello", "write_file", "file_ops"),
|
||||
("Write the config to settings.json", "write_file", "file_ops"),
|
||||
("Save the output to results.txt", "write_file", "file_ops"),
|
||||
("Create TODO.md with my tasks", "write_file", "file_ops"),
|
||||
("Search for 'import os' in the codebase", "search_files", "file_ops"),
|
||||
("Find all Python files mentioning 'error'", "search_files", "file_ops"),
|
||||
("Search for TODO comments", "search_files", "file_ops"),
|
||||
("Find where 'authenticate' is defined", "search_files", "file_ops"),
|
||||
("Look for any hardcoded API keys", "search_files", "file_ops"),
|
||||
("Read the Makefile", "read_file", "file_ops"),
|
||||
("Show me the Dockerfile", "read_file", "file_ops"),
|
||||
("Read the docker-compose.yml", "read_file", "file_ops"),
|
||||
("Save the function to utils.py", "write_file", "file_ops"),
|
||||
("Create a backup of config.yaml", "write_file", "file_ops"),
|
||||
|
||||
# Terminal commands (20)
|
||||
("List all files in the current directory", "terminal", "terminal"),
|
||||
("Show disk usage", "terminal", "terminal"),
|
||||
("Check what processes are running", "terminal", "terminal"),
|
||||
("Show the git log", "terminal", "terminal"),
|
||||
("Check the Python version", "terminal", "terminal"),
|
||||
("Run ls -la in the home directory", "terminal", "terminal"),
|
||||
("Show the current date and time", "terminal", "terminal"),
|
||||
("Check network connectivity with ping", "terminal", "terminal"),
|
||||
("Show environment variables", "terminal", "terminal"),
|
||||
("List running docker containers", "terminal", "terminal"),
|
||||
("Check system memory usage", "terminal", "terminal"),
|
||||
("Show the crontab", "terminal", "terminal"),
|
||||
("Check the firewall status", "terminal", "terminal"),
|
||||
("Show recent log entries", "terminal", "terminal"),
|
||||
("Check disk free space", "terminal", "terminal"),
|
||||
("Run a system update check", "terminal", "terminal"),
|
||||
("Show open network connections", "terminal", "terminal"),
|
||||
("Check the timezone", "terminal", "terminal"),
|
||||
("List tmux sessions", "terminal", "terminal"),
|
||||
("Check systemd service status", "terminal", "terminal"),
|
||||
|
||||
# Web search (15)
|
||||
("Search for Python asyncio documentation", "web_search", "web"),
|
||||
("Look up the latest GPT-4 pricing", "web_search", "web"),
|
||||
("Find information about Gemma 4 benchmarks", "web_search", "web"),
|
||||
("Search for Rust vs Go performance comparison", "web_search", "web"),
|
||||
("Look up Docker best practices", "web_search", "web"),
|
||||
("Search for Kubernetes deployment tutorials", "web_search", "web"),
|
||||
("Find the latest AI safety research papers", "web_search", "web"),
|
||||
("Search for SQLite vs PostgreSQL comparison", "web_search", "web"),
|
||||
("Look up Linux kernel tuning parameters", "web_search", "web"),
|
||||
("Search for WebSocket protocol specification", "web_search", "web"),
|
||||
("Find information about Matrix protocol federation", "web_search", "web"),
|
||||
("Search for MCP protocol documentation", "web_search", "web"),
|
||||
("Look up A2A agent protocol spec", "web_search", "web"),
|
||||
("Search for quantization methods for LLMs", "web_search", "web"),
|
||||
("Find information about GRPO training", "web_search", "web"),
|
||||
|
||||
# Code execution (15)
|
||||
("Calculate the factorial of 20", "execute_code", "code"),
|
||||
("Parse this JSON and extract keys", "execute_code", "code"),
|
||||
("Sort a list of numbers", "execute_code", "code"),
|
||||
("Calculate the fibonacci sequence", "execute_code", "code"),
|
||||
("Convert a CSV to JSON", "execute_code", "code"),
|
||||
("Parse an email address", "execute_code", "code"),
|
||||
("Calculate elapsed time between dates", "execute_code", "code"),
|
||||
("Generate a random password", "execute_code", "code"),
|
||||
("Hash a string with SHA256", "execute_code", "code"),
|
||||
("Parse a URL into components", "execute_code", "code"),
|
||||
("Calculate statistics on a dataset", "execute_code", "code"),
|
||||
("Convert epoch timestamp to human readable", "execute_code", "code"),
|
||||
("Validate an IPv4 address", "execute_code", "code"),
|
||||
("Calculate the distance between coordinates", "execute_code", "code"),
|
||||
("Generate a UUID", "execute_code", "code"),
|
||||
|
||||
# Parallel tool calls (10)
|
||||
("Read config.yaml and show git status at the same time", "read_file|terminal", "parallel"),
|
||||
("Check disk usage and memory usage simultaneously", "terminal|terminal", "parallel"),
|
||||
("Read two files at once: README and CHANGELOG", "read_file|read_file", "parallel"),
|
||||
("Search for imports in both Python and JS files", "search_files|search_files", "parallel"),
|
||||
("Check git log and disk space in parallel", "terminal|terminal", "parallel"),
|
||||
("Read the Makefile and Dockerfile together", "read_file|read_file", "parallel"),
|
||||
("Search for TODO and FIXME at the same time", "search_files|search_files", "parallel"),
|
||||
("List files and check Python version simultaneously", "terminal|terminal", "parallel"),
|
||||
("Read package.json and requirements.txt together", "read_file|read_file", "parallel"),
|
||||
("Check system time and uptime in parallel", "terminal|terminal", "parallel"),
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class BenchmarkResult:
|
||||
model: str
|
||||
prompt: str
|
||||
expected_tool: str
|
||||
category: str
|
||||
success: bool = False
|
||||
tool_called: str = ""
|
||||
args_valid: bool = False
|
||||
latency_ms: float = 0.0
|
||||
prompt_tokens: int = 0
|
||||
completion_tokens: int = 0
|
||||
error: str = ""
|
||||
|
||||
|
||||
def call_model(model: str, prompt: str) -> dict:
|
||||
"""Call a model with tool schemas and return the response."""
|
||||
url = f"{ENDPOINT}/chat/completions"
|
||||
data = {
|
||||
"model": model,
|
||||
"messages": [
|
||||
{"role": "system", "content": SYSTEM_PROMPT},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
"tools": TOOL_SCHEMAS,
|
||||
"max_tokens": 512,
|
||||
"temperature": 0.0,
|
||||
}
|
||||
body = json.dumps(data).encode()
|
||||
req = urllib.request.Request(url, data=body, headers={
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {API_KEY}",
|
||||
}, method="POST")
|
||||
|
||||
start = time.time()
|
||||
try:
|
||||
with urllib.request.urlopen(req, timeout=60) as resp:
|
||||
result = json.loads(resp.read())
|
||||
elapsed = time.time() - start
|
||||
return {"response": result, "elapsed": elapsed, "error": None}
|
||||
except Exception as e:
|
||||
elapsed = time.time() - start
|
||||
return {"response": None, "elapsed": elapsed, "error": str(e)}
|
||||
|
||||
|
||||
def evaluate_response(result: dict, expected_tool: str) -> BenchmarkResult:
|
||||
"""Evaluate a model response against expectations."""
|
||||
resp = result.get("response")
|
||||
error = result.get("error", "")
|
||||
elapsed = result.get("elapsed", 0)
|
||||
|
||||
br = BenchmarkResult(
|
||||
model="",
|
||||
prompt="",
|
||||
expected_tool=expected_tool,
|
||||
category="",
|
||||
latency_ms=round(elapsed * 1000, 1),
|
||||
error=error or "",
|
||||
)
|
||||
|
||||
if not resp:
|
||||
br.success = False
|
||||
return br
|
||||
|
||||
usage = resp.get("usage", {})
|
||||
br.prompt_tokens = usage.get("prompt_tokens", 0)
|
||||
br.completion_tokens = usage.get("completion_tokens", 0)
|
||||
|
||||
choice = resp.get("choices", [{}])[0]
|
||||
message = choice.get("message", {})
|
||||
tool_calls = message.get("tool_calls", [])
|
||||
|
||||
if not tool_calls:
|
||||
br.success = False
|
||||
br.error = "no_tool_calls"
|
||||
return br
|
||||
|
||||
# Check first tool call
|
||||
tc = tool_calls[0]
|
||||
fn = tc.get("function", {})
|
||||
br.tool_called = fn.get("name", "")
|
||||
|
||||
# Parse args
|
||||
args_str = fn.get("arguments", "{}")
|
||||
try:
|
||||
json.loads(args_str)
|
||||
br.args_valid = True
|
||||
except json.JSONDecodeError:
|
||||
# Try normalization
|
||||
try:
|
||||
import re
|
||||
fixed = re.sub(r',\s*([}\]])', r'\1', args_str.strip())
|
||||
json.loads(fixed)
|
||||
br.args_valid = True
|
||||
except:
|
||||
br.args_valid = False
|
||||
|
||||
# Success = tool called matches expected (or contains it for parallel)
|
||||
expected = expected_tool.split("|")[0] # primary expected tool
|
||||
br.success = br.tool_called == expected and br.args_valid
|
||||
|
||||
return br
|
||||
|
||||
|
||||
def run_benchmark(model: str, prompts: list, limit: int = None) -> List[BenchmarkResult]:
|
||||
"""Run benchmark against a model."""
|
||||
if limit:
|
||||
prompts = prompts[:limit]
|
||||
|
||||
results = []
|
||||
for i, (prompt, expected_tool, category) in enumerate(prompts):
|
||||
print(f" [{i+1}/{len(prompts)}] {model}: {prompt[:50]}...", end=" ", flush=True)
|
||||
|
||||
raw = call_model(model, prompt)
|
||||
br = evaluate_response(raw, expected_tool)
|
||||
br.model = model
|
||||
br.prompt = prompt
|
||||
br.category = category
|
||||
|
||||
status = "OK" if br.success else f"FAIL({br.error or br.tool_called})"
|
||||
print(f"{status} {br.latency_ms}ms")
|
||||
results.append(br)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def generate_report(results: List[BenchmarkResult]) -> str:
|
||||
"""Generate markdown benchmark report."""
|
||||
by_model = {}
|
||||
for r in results:
|
||||
if r.model not in by_model:
|
||||
by_model[r.model] = []
|
||||
by_model[r.model].append(r)
|
||||
|
||||
lines = [
|
||||
"# Gemma 4 Tool Calling Benchmark",
|
||||
f"",
|
||||
f"**Date:** {datetime.now().strftime('%Y-%m-%d %H:%M')}",
|
||||
f"**Prompts:** {len(results) // len(by_model)} per model",
|
||||
f"",
|
||||
]
|
||||
|
||||
# Summary table
|
||||
lines.append("| Metric | " + " | ".join(by_model.keys()) + " |")
|
||||
lines.append("|--------|" + "|".join(["--------"] * len(by_model)) + "|")
|
||||
|
||||
metrics = ["schema_parse", "tool_execution", "avg_latency_ms", "total_prompt_tokens"]
|
||||
for metric in ["success_rate", "args_valid_rate", "avg_latency_ms", "total_prompt_tokens"]:
|
||||
vals = []
|
||||
for model, rs in by_model.items():
|
||||
if metric == "success_rate":
|
||||
v = sum(1 for r in rs if r.success) / len(rs) * 100
|
||||
vals.append(f"{v:.1f}%")
|
||||
elif metric == "args_valid_rate":
|
||||
v = sum(1 for r in rs if r.args_valid) / len(rs) * 100
|
||||
vals.append(f"{v:.1f}%")
|
||||
elif metric == "avg_latency_ms":
|
||||
v = sum(r.latency_ms for r in rs) / len(rs)
|
||||
vals.append(f"{v:.0f}ms")
|
||||
elif metric == "total_prompt_tokens":
|
||||
v = sum(r.prompt_tokens for r in rs)
|
||||
vals.append(f"{v:,}")
|
||||
label = metric.replace("_", " ").title()
|
||||
lines.append(f"| {label} | " + " | ".join(vals) + " |")
|
||||
|
||||
lines.append("")
|
||||
|
||||
# By category
|
||||
lines.append("## By Category")
|
||||
lines.append("")
|
||||
lines.append("| Category | " + " | ".join(f"{m} success" for m in by_model.keys()) + " |")
|
||||
lines.append("|----------|" + "|".join(["--------"] * len(by_model)) + "|")
|
||||
|
||||
categories = sorted(set(r.category for r in results))
|
||||
for cat in categories:
|
||||
vals = []
|
||||
for model, rs in by_model.items():
|
||||
cat_results = [r for r in rs if r.category == cat]
|
||||
if cat_results:
|
||||
v = sum(1 for r in cat_results if r.success) / len(cat_results) * 100
|
||||
vals.append(f"{v:.0f}%")
|
||||
else:
|
||||
vals.append("N/A")
|
||||
lines.append(f"| {cat} | " + " | ".join(vals) + " |")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def main():
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser(description="Tool calling benchmark")
|
||||
parser.add_argument("--model1", default="gemma3:27b")
|
||||
parser.add_argument("--model2", default="xiaomi/mimo-v2-pro")
|
||||
parser.add_argument("--endpoint", default=ENDPOINT)
|
||||
parser.add_argument("--limit", type=int, default=None)
|
||||
parser.add_argument("--output", default=None)
|
||||
parser.add_argument("--markdown", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
global ENDPOINT
|
||||
ENDPOINT = args.endpoint
|
||||
|
||||
prompts = TEST_PROMPTS
|
||||
if args.limit:
|
||||
prompts = prompts[:args.limit]
|
||||
|
||||
print(f"Benchmark: {args.model1} vs {args.model2}")
|
||||
print(f"Prompts: {len(prompts)}")
|
||||
print()
|
||||
|
||||
print(f"--- {args.model1} ---")
|
||||
results1 = run_benchmark(args.model1, prompts)
|
||||
|
||||
print(f"\n--- {args.model2} ---")
|
||||
results2 = run_benchmark(args.model2, prompts)
|
||||
|
||||
all_results = results1 + results2
|
||||
|
||||
report = generate_report(all_results)
|
||||
print(f"\n{report}")
|
||||
|
||||
if args.output:
|
||||
with open(args.output, "w") as f:
|
||||
json.dump([r.__dict__ for r in all_results], f, indent=2, default=str)
|
||||
print(f"\nResults saved to {args.output}")
|
||||
|
||||
# Save markdown report
|
||||
report_path = f"benchmarks/gemma4-tool-calling-{datetime.now().strftime('%Y-%m-%d')}.md"
|
||||
Path("benchmarks").mkdir(exist_ok=True)
|
||||
with open(report_path, "w") as f:
|
||||
f.write(report)
|
||||
print(f"Report saved to {report_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
202
tests/agent/test_privacy_filter.py
Normal file
202
tests/agent/test_privacy_filter.py
Normal file
@@ -0,0 +1,202 @@
|
||||
"""Tests for agent.privacy_filter — PII stripping before remote API calls."""
|
||||
|
||||
import pytest
|
||||
from agent.privacy_filter import (
|
||||
PrivacyFilter,
|
||||
RedactionReport,
|
||||
Sensitivity,
|
||||
sanitize_messages,
|
||||
quick_sanitize,
|
||||
)
|
||||
|
||||
|
||||
class TestPrivacyFilterSanitizeText:
|
||||
"""Test single-text sanitization."""
|
||||
|
||||
def test_no_pii_returns_clean(self):
|
||||
pf = PrivacyFilter()
|
||||
text = "The weather in Paris is nice today."
|
||||
cleaned, redactions = pf.sanitize_text(text)
|
||||
assert cleaned == text
|
||||
assert redactions == []
|
||||
|
||||
def test_email_redacted(self):
|
||||
pf = PrivacyFilter()
|
||||
text = "Send report to alice@example.com by Friday."
|
||||
cleaned, redactions = pf.sanitize_text(text)
|
||||
assert "alice@example.com" not in cleaned
|
||||
assert "[REDACTED-EMAIL]" in cleaned
|
||||
assert any(r["type"] == "email_address" for r in redactions)
|
||||
|
||||
def test_phone_redacted(self):
|
||||
pf = PrivacyFilter()
|
||||
text = "Call me at 555-123-4567 when ready."
|
||||
cleaned, redactions = pf.sanitize_text(text)
|
||||
assert "555-123-4567" not in cleaned
|
||||
assert "[REDACTED-PHONE]" in cleaned
|
||||
|
||||
def test_api_key_redacted(self):
|
||||
pf = PrivacyFilter()
|
||||
text = 'api_key = "sk-proj-abcdefghij1234567890abcdefghij1234567890"'
|
||||
cleaned, redactions = pf.sanitize_text(text)
|
||||
assert "sk-proj-" not in cleaned
|
||||
assert any(r["sensitivity"] == "CRITICAL" for r in redactions)
|
||||
|
||||
def test_github_token_redacted(self):
|
||||
pf = PrivacyFilter()
|
||||
text = "Use ghp_1234567890abcdefghijklmnopqrstuvwxyz1234 for auth"
|
||||
cleaned, redactions = pf.sanitize_text(text)
|
||||
assert "ghp_" not in cleaned
|
||||
assert any(r["type"] == "github_token" for r in redactions)
|
||||
|
||||
def test_ethereum_address_redacted(self):
|
||||
pf = PrivacyFilter()
|
||||
text = "Send to 0x742d35Cc6634C0532925a3b844Bc9e7595f2bD18 please"
|
||||
cleaned, redactions = pf.sanitize_text(text)
|
||||
assert "0x742d" not in cleaned
|
||||
assert any(r["type"] == "ethereum_address" for r in redactions)
|
||||
|
||||
def test_user_home_path_redacted(self):
|
||||
pf = PrivacyFilter()
|
||||
text = "Read file at /Users/alice/Documents/secret.txt"
|
||||
cleaned, redactions = pf.sanitize_text(text)
|
||||
assert "alice" not in cleaned
|
||||
assert "[REDACTED-USER]" in cleaned
|
||||
|
||||
def test_multiple_pii_types(self):
|
||||
pf = PrivacyFilter()
|
||||
text = (
|
||||
"Contact john@test.com or call 555-999-1234. "
|
||||
"The API key is sk-abcdefghijklmnopqrstuvwxyz1234567890."
|
||||
)
|
||||
cleaned, redactions = pf.sanitize_text(text)
|
||||
assert "john@test.com" not in cleaned
|
||||
assert "555-999-1234" not in cleaned
|
||||
assert "sk-abcd" not in cleaned
|
||||
assert len(redactions) >= 3
|
||||
|
||||
|
||||
class TestPrivacyFilterSanitizeMessages:
|
||||
"""Test message-list sanitization."""
|
||||
|
||||
def test_sanitize_user_message(self):
|
||||
pf = PrivacyFilter()
|
||||
messages = [
|
||||
{"role": "system", "content": "You are helpful."},
|
||||
{"role": "user", "content": "Email me at bob@test.com with results."},
|
||||
]
|
||||
safe, report = pf.sanitize_messages(messages)
|
||||
assert report.redacted_messages == 1
|
||||
assert "bob@test.com" not in safe[1]["content"]
|
||||
assert "[REDACTED-EMAIL]" in safe[1]["content"]
|
||||
# System message unchanged
|
||||
assert safe[0]["content"] == "You are helpful."
|
||||
|
||||
def test_no_redaction_needed(self):
|
||||
pf = PrivacyFilter()
|
||||
messages = [
|
||||
{"role": "user", "content": "What is 2+2?"},
|
||||
{"role": "assistant", "content": "4"},
|
||||
]
|
||||
safe, report = pf.sanitize_messages(messages)
|
||||
assert report.redacted_messages == 0
|
||||
assert not report.had_redactions
|
||||
|
||||
def test_assistant_messages_also_sanitized(self):
|
||||
pf = PrivacyFilter()
|
||||
messages = [
|
||||
{"role": "assistant", "content": "Your email admin@corp.com was found."},
|
||||
]
|
||||
safe, report = pf.sanitize_messages(messages)
|
||||
assert report.redacted_messages == 1
|
||||
assert "admin@corp.com" not in safe[0]["content"]
|
||||
|
||||
def test_tool_messages_not_sanitized(self):
|
||||
pf = PrivacyFilter()
|
||||
messages = [
|
||||
{"role": "tool", "content": "Result: user@test.com found"},
|
||||
]
|
||||
safe, report = pf.sanitize_messages(messages)
|
||||
assert report.redacted_messages == 0
|
||||
assert safe[0]["content"] == "Result: user@test.com found"
|
||||
|
||||
|
||||
class TestShouldUseLocalOnly:
|
||||
"""Test the local-only routing decision."""
|
||||
|
||||
def test_normal_text_allows_remote(self):
|
||||
pf = PrivacyFilter()
|
||||
block, reason = pf.should_use_local_only("Summarize this article about Python.")
|
||||
assert not block
|
||||
|
||||
def test_critical_secret_blocks_remote(self):
|
||||
pf = PrivacyFilter()
|
||||
text = "Here is the API key: sk-abcdefghijklmnopqrstuvwxyz1234567890"
|
||||
block, reason = pf.should_use_local_only(text)
|
||||
assert block
|
||||
assert "critical" in reason.lower()
|
||||
|
||||
def test_multiple_high_sensitivity_blocks(self):
|
||||
pf = PrivacyFilter()
|
||||
# 3+ high-sensitivity patterns
|
||||
text = (
|
||||
"Card: 4111-1111-1111-1111, "
|
||||
"SSN: 123-45-6789, "
|
||||
"BTC: 1A1zP1eP5QGefi2DMPTfTL5SLmv7DivfNa, "
|
||||
"ETH: 0x742d35Cc6634C0532925a3b844Bc9e7595f2bD18"
|
||||
)
|
||||
block, reason = pf.should_use_local_only(text)
|
||||
assert block
|
||||
|
||||
|
||||
class TestAggressiveMode:
|
||||
"""Test aggressive filtering mode."""
|
||||
|
||||
def test_aggressive_redacts_internal_ips(self):
|
||||
pf = PrivacyFilter(aggressive_mode=True)
|
||||
text = "Server at 192.168.1.100 is responding."
|
||||
cleaned, redactions = pf.sanitize_text(text)
|
||||
assert "192.168.1.100" not in cleaned
|
||||
assert any(r["type"] == "internal_ip" for r in redactions)
|
||||
|
||||
def test_normal_does_not_redact_ips(self):
|
||||
pf = PrivacyFilter(aggressive_mode=False)
|
||||
text = "Server at 192.168.1.100 is responding."
|
||||
cleaned, redactions = pf.sanitize_text(text)
|
||||
assert "192.168.1.100" in cleaned # IP preserved in normal mode
|
||||
|
||||
|
||||
class TestConvenienceFunctions:
|
||||
"""Test module-level convenience functions."""
|
||||
|
||||
def test_quick_sanitize(self):
|
||||
text = "Contact alice@example.com for details"
|
||||
result = quick_sanitize(text)
|
||||
assert "alice@example.com" not in result
|
||||
assert "[REDACTED-EMAIL]" in result
|
||||
|
||||
def test_sanitize_messages_convenience(self):
|
||||
messages = [{"role": "user", "content": "Call 555-000-1234"}]
|
||||
safe, report = sanitize_messages(messages)
|
||||
assert report.redacted_messages == 1
|
||||
|
||||
|
||||
class TestRedactionReport:
|
||||
"""Test the reporting structure."""
|
||||
|
||||
def test_summary_no_redactions(self):
|
||||
report = RedactionReport(total_messages=3, redacted_messages=0)
|
||||
assert "No PII" in report.summary()
|
||||
|
||||
def test_summary_with_redactions(self):
|
||||
report = RedactionReport(
|
||||
total_messages=2,
|
||||
redacted_messages=1,
|
||||
redactions=[
|
||||
{"type": "email_address", "sensitivity": "MEDIUM", "count": 2},
|
||||
{"type": "phone_number_us", "sensitivity": "MEDIUM", "count": 1},
|
||||
],
|
||||
)
|
||||
summary = report.summary()
|
||||
assert "1/2" in summary
|
||||
assert "email_address" in summary
|
||||
@@ -1,116 +0,0 @@
|
||||
"""Tests for A2A protocol implementation."""
|
||||
import asyncio,json,pytest
|
||||
from a2a.types import AgentCard,AgentSkill,Artifact,DataPart,FilePart,JSONRPCError,JSONRPCRequest,JSONRPCResponse,Message,Task,TaskState,TaskStatus,TextPart,A2AError,part_from_dict
|
||||
from a2a.client import A2AClient,A2AClientConfig
|
||||
from a2a.server import A2AServer
|
||||
|
||||
class TestTextPart:
|
||||
def test_roundtrip(self):
|
||||
p=TextPart(text="hello",metadata={"k":"v"});d=p.to_dict();assert d=={"text":"hello","metadata":{"k":"v"}};p2=TextPart.from_dict(d);assert p2.text=="hello"
|
||||
def test_no_metadata(self):
|
||||
p=TextPart(text="hi");d=p.to_dict();assert "metadata" not in d
|
||||
|
||||
class TestFilePart:
|
||||
def test_inline(self):
|
||||
p=FilePart(media_type="text/plain",raw="SGVsbG8=",filename="hello.txt");d=p.to_dict();assert d["raw"]=="SGVsbG8=";p2=FilePart.from_dict(d);assert p2.filename=="hello.txt"
|
||||
def test_url(self):
|
||||
p=FilePart(url="https://x.com/f");d=p.to_dict();assert d["url"]=="https://x.com/f"
|
||||
|
||||
class TestDataPart:
|
||||
def test_roundtrip(self):
|
||||
p=DataPart(data={"key":42});d=p.to_dict();assert d["data"]=={"key":42}
|
||||
|
||||
class TestPartDiscrimination:
|
||||
def test_text(self):assert isinstance(part_from_dict({"text":"hi"}),TextPart)
|
||||
def test_file_raw(self):assert isinstance(part_from_dict({"raw":"d","mediaType":"t"}),FilePart)
|
||||
def test_file_url(self):assert isinstance(part_from_dict({"url":"https://x.com"}),FilePart)
|
||||
def test_data(self):assert isinstance(part_from_dict({"data":{"a":1}}),DataPart)
|
||||
def test_unknown(self):
|
||||
with pytest.raises(ValueError):part_from_dict({"unknown":True})
|
||||
|
||||
class TestMessage:
|
||||
def test_roundtrip(self):
|
||||
m=Message(role="user",parts=[TextPart(text="hi")],context_id="c1");d=m.to_dict();assert d["role"]=="user";m2=Message.from_dict(d);assert m2.parts[0].text=="hi"
|
||||
|
||||
class TestArtifact:
|
||||
def test_roundtrip(self):
|
||||
a=Artifact(name="r",parts=[TextPart(text="d")]);d=a.to_dict();assert d["name"]=="r"
|
||||
|
||||
class TestTaskStatus:
|
||||
def test_roundtrip(self):
|
||||
s=TaskStatus(state=TaskState.COMPLETED);d=s.to_dict();assert d["state"]=="TASK_STATE_COMPLETED";s2=TaskStatus.from_dict(d);assert s2.state==TaskState.COMPLETED
|
||||
def test_terminal(self):
|
||||
assert TaskState.COMPLETED.terminal;assert TaskState.FAILED.terminal;assert not TaskState.SUBMITTED.terminal
|
||||
|
||||
class TestTask:
|
||||
def test_roundtrip(self):
|
||||
t=Task(id="t1",status=TaskStatus(state=TaskState.WORKING));d=t.to_dict();assert d["id"]=="t1";t2=Task.from_dict(d);assert t2.status.state==TaskState.WORKING
|
||||
|
||||
class TestAgentCard:
|
||||
def test_roundtrip(self):
|
||||
c=AgentCard(name="a",skills=[AgentSkill(id="s1",name="S")]);d=c.to_dict();assert d["name"]=="a";c2=AgentCard.from_dict(d);assert c2.skills[0].id=="s1"
|
||||
|
||||
class TestJSONRPC:
|
||||
def test_request(self):
|
||||
r=JSONRPCRequest(method="GetTask",params={"taskId":"1"});d=r.to_dict();assert d["method"]=="GetTask"
|
||||
def test_response_success(self):
|
||||
r=JSONRPCResponse(id="1",result={"ok":True});d=r.to_dict();assert d["result"]["ok"]==True
|
||||
def test_response_error(self):
|
||||
r=JSONRPCResponse(id="1",error=A2AError.parse_error());d=r.to_dict();assert d["error"]["code"]==-32700
|
||||
def test_response_from_dict(self):
|
||||
r=JSONRPCResponse.from_dict({"jsonrpc":"2.0","id":"1","result":{"ok":True}});assert r.result["ok"]==True
|
||||
|
||||
class TestA2AError:
|
||||
def test_codes(self):
|
||||
assert A2AError.parse_error().code==-32700;assert A2AError.invalid_request().code==-32600
|
||||
assert A2AError.method_not_found().code==-32601;assert A2AError.task_not_found("x").code==-32001
|
||||
|
||||
def _make_card():
|
||||
return AgentCard(name="test",description="Test",url="http://localhost:9999/a2a/v1",skills=[AgentSkill(id="echo",name="Echo",tags=["test"])])
|
||||
|
||||
def _run(coro):return asyncio.get_event_loop().run_until_complete(coro)
|
||||
|
||||
class TestServer:
|
||||
def test_echo(self):
|
||||
s=A2AServer(_make_card());msg=Message(role="user",parts=[TextPart(text="hello")])
|
||||
raw=json.dumps(JSONRPCRequest(method="SendMessage",params={"message":msg.to_dict()}).to_dict())
|
||||
resp=json.loads(_run(s.handle_rpc(raw)));assert "error" not in resp;assert resp["result"]["status"]["state"]=="TASK_STATE_COMPLETED"
|
||||
def test_get_task(self):
|
||||
s=A2AServer(_make_card());msg=Message(role="user",parts=[TextPart(text="hi")])
|
||||
raw=json.dumps(JSONRPCRequest(method="SendMessage",params={"message":msg.to_dict()}).to_dict())
|
||||
task_id=json.loads(_run(s.handle_rpc(raw)))["result"]["id"]
|
||||
raw2=json.dumps(JSONRPCRequest(method="GetTask",params={"taskId":task_id}).to_dict())
|
||||
resp=json.loads(_run(s.handle_rpc(raw2)));assert resp["result"]["id"]==task_id
|
||||
def test_cancel(self):
|
||||
s=A2AServer(_make_card());task=Task(id="c1",status=TaskStatus(state=TaskState.WORKING));s.add_task(task)
|
||||
raw=json.dumps(JSONRPCRequest(method="CancelTask",params={"taskId":"c1"}).to_dict())
|
||||
resp=json.loads(_run(s.handle_rpc(raw)));assert resp["result"]["status"]["state"]=="TASK_STATE_CANCELED"
|
||||
def test_cancel_terminal(self):
|
||||
s=A2AServer(_make_card());task=Task(id="d1",status=TaskStatus(state=TaskState.COMPLETED));s.add_task(task)
|
||||
raw=json.dumps(JSONRPCRequest(method="CancelTask",params={"taskId":"d1"}).to_dict())
|
||||
resp=json.loads(_run(s.handle_rpc(raw)));assert "error" in resp
|
||||
def test_card(self):
|
||||
s=A2AServer(_make_card());raw=json.dumps(JSONRPCRequest(method="GetAgentCard").to_dict())
|
||||
resp=json.loads(_run(s.handle_rpc(raw)));assert resp["result"]["name"]=="test"
|
||||
def test_list(self):
|
||||
s=A2AServer(_make_card());msg=Message(role="user",parts=[TextPart(text="hi")])
|
||||
raw=json.dumps(JSONRPCRequest(method="SendMessage",params={"message":msg.to_dict()}).to_dict())
|
||||
_run(s.handle_rpc(raw));raw2=json.dumps(JSONRPCRequest(method="ListTasks").to_dict())
|
||||
resp=json.loads(_run(s.handle_rpc(raw2)));assert len(resp["result"]["tasks"])>=1
|
||||
def test_unknown(self):
|
||||
s=A2AServer(_make_card());raw=json.dumps(JSONRPCRequest(method="Nope").to_dict())
|
||||
resp=json.loads(_run(s.handle_rpc(raw)));assert resp["error"]["code"]==-32601
|
||||
def test_invalid_json(self):
|
||||
s=A2AServer(_make_card());resp=json.loads(_run(s.handle_rpc("bad")));assert resp["error"]["code"]==-32700
|
||||
def test_custom_handler(self):
|
||||
s=A2AServer(_make_card())
|
||||
async def h(task,card):task.status=TaskStatus(state=TaskState.COMPLETED);task.artifacts=[Artifact(parts=[TextPart(text="custom")])];return task
|
||||
s.register_handler("echo",h)
|
||||
msg=Message(role="user",parts=[TextPart(text="t")])
|
||||
raw=json.dumps(JSONRPCRequest(method="SendMessage",params={"message":msg.to_dict(),"skillId":"echo"}).to_dict())
|
||||
resp=json.loads(_run(s.handle_rpc(raw)));assert resp["result"]["artifacts"][0]["parts"][0]["text"]=="custom"
|
||||
def test_audit(self):
|
||||
s=A2AServer(_make_card());raw=json.dumps(JSONRPCRequest(method="GetAgentCard").to_dict())
|
||||
_run(s.handle_rpc(raw));assert len(s.audit_log)==1;assert s.audit_log[0]["method"]=="GetAgentCard"
|
||||
|
||||
if __name__=="__main__":pytest.main([__file__,"-v"])
|
||||
190
tests/tools/test_confirmation_daemon.py
Normal file
190
tests/tools/test_confirmation_daemon.py
Normal file
@@ -0,0 +1,190 @@
|
||||
"""Tests for tools.confirmation_daemon — Human Confirmation Firewall."""
|
||||
|
||||
import pytest
|
||||
import time
|
||||
from tools.confirmation_daemon import (
|
||||
ConfirmationDaemon,
|
||||
ConfirmationRequest,
|
||||
ConfirmationStatus,
|
||||
RiskLevel,
|
||||
classify_action,
|
||||
_is_whitelisted,
|
||||
_DEFAULT_WHITELIST,
|
||||
)
|
||||
|
||||
|
||||
class TestClassifyAction:
|
||||
"""Test action risk classification."""
|
||||
|
||||
def test_crypto_tx_is_critical(self):
|
||||
assert classify_action("crypto_tx") == RiskLevel.CRITICAL
|
||||
|
||||
def test_sign_transaction_is_critical(self):
|
||||
assert classify_action("sign_transaction") == RiskLevel.CRITICAL
|
||||
|
||||
def test_send_email_is_high(self):
|
||||
assert classify_action("send_email") == RiskLevel.HIGH
|
||||
|
||||
def test_send_message_is_medium(self):
|
||||
assert classify_action("send_message") == RiskLevel.MEDIUM
|
||||
|
||||
def test_access_calendar_is_low(self):
|
||||
assert classify_action("access_calendar") == RiskLevel.LOW
|
||||
|
||||
def test_unknown_action_is_medium(self):
|
||||
assert classify_action("unknown_action_xyz") == RiskLevel.MEDIUM
|
||||
|
||||
|
||||
class TestWhitelist:
|
||||
"""Test whitelist auto-approval."""
|
||||
|
||||
def test_self_email_is_whitelisted(self):
|
||||
whitelist = dict(_DEFAULT_WHITELIST)
|
||||
payload = {"from": "me@test.com", "to": "me@test.com"}
|
||||
assert _is_whitelisted("send_email", payload, whitelist) is True
|
||||
|
||||
def test_non_whitelisted_recipient_not_approved(self):
|
||||
whitelist = dict(_DEFAULT_WHITELIST)
|
||||
payload = {"to": "random@stranger.com"}
|
||||
assert _is_whitelisted("send_email", payload, whitelist) is False
|
||||
|
||||
def test_whitelisted_contact_approved(self):
|
||||
whitelist = {
|
||||
"send_message": {"targets": ["alice", "bob"]},
|
||||
}
|
||||
assert _is_whitelisted("send_message", {"to": "alice"}, whitelist) is True
|
||||
assert _is_whitelisted("send_message", {"to": "charlie"}, whitelist) is False
|
||||
|
||||
def test_no_whitelist_entry_means_not_whitelisted(self):
|
||||
whitelist = {}
|
||||
assert _is_whitelisted("crypto_tx", {"amount": 1.0}, whitelist) is False
|
||||
|
||||
|
||||
class TestConfirmationRequest:
|
||||
"""Test the request data model."""
|
||||
|
||||
def test_defaults(self):
|
||||
req = ConfirmationRequest(
|
||||
request_id="test-1",
|
||||
action="send_email",
|
||||
description="Test email",
|
||||
risk_level="high",
|
||||
payload={},
|
||||
)
|
||||
assert req.status == ConfirmationStatus.PENDING.value
|
||||
assert req.created_at > 0
|
||||
assert req.expires_at > req.created_at
|
||||
|
||||
def test_is_pending(self):
|
||||
req = ConfirmationRequest(
|
||||
request_id="test-2",
|
||||
action="send_email",
|
||||
description="Test",
|
||||
risk_level="high",
|
||||
payload={},
|
||||
expires_at=time.time() + 300,
|
||||
)
|
||||
assert req.is_pending is True
|
||||
|
||||
def test_is_expired(self):
|
||||
req = ConfirmationRequest(
|
||||
request_id="test-3",
|
||||
action="send_email",
|
||||
description="Test",
|
||||
risk_level="high",
|
||||
payload={},
|
||||
expires_at=time.time() - 10,
|
||||
)
|
||||
assert req.is_expired is True
|
||||
assert req.is_pending is False
|
||||
|
||||
def test_to_dict(self):
|
||||
req = ConfirmationRequest(
|
||||
request_id="test-4",
|
||||
action="send_email",
|
||||
description="Test",
|
||||
risk_level="medium",
|
||||
payload={"to": "a@b.com"},
|
||||
)
|
||||
d = req.to_dict()
|
||||
assert d["request_id"] == "test-4"
|
||||
assert d["action"] == "send_email"
|
||||
assert "is_pending" in d
|
||||
|
||||
|
||||
class TestConfirmationDaemon:
|
||||
"""Test the daemon logic (without HTTP layer)."""
|
||||
|
||||
def test_auto_approve_low_risk(self):
|
||||
daemon = ConfirmationDaemon()
|
||||
req = daemon.request(
|
||||
action="access_calendar",
|
||||
description="Read today's events",
|
||||
risk_level="low",
|
||||
)
|
||||
assert req.status == ConfirmationStatus.AUTO_APPROVED.value
|
||||
|
||||
def test_whitelisted_auto_approves(self):
|
||||
daemon = ConfirmationDaemon()
|
||||
daemon._whitelist = {"send_message": {"targets": ["alice"]}}
|
||||
req = daemon.request(
|
||||
action="send_message",
|
||||
description="Message alice",
|
||||
payload={"to": "alice"},
|
||||
)
|
||||
assert req.status == ConfirmationStatus.AUTO_APPROVED.value
|
||||
|
||||
def test_non_whitelisted_goes_pending(self):
|
||||
daemon = ConfirmationDaemon()
|
||||
daemon._whitelist = {}
|
||||
req = daemon.request(
|
||||
action="send_email",
|
||||
description="Email to stranger",
|
||||
payload={"to": "stranger@test.com"},
|
||||
risk_level="high",
|
||||
)
|
||||
assert req.status == ConfirmationStatus.PENDING.value
|
||||
assert req.is_pending is True
|
||||
|
||||
def test_approve_response(self):
|
||||
daemon = ConfirmationDaemon()
|
||||
daemon._whitelist = {}
|
||||
req = daemon.request(
|
||||
action="send_email",
|
||||
description="Email test",
|
||||
risk_level="high",
|
||||
)
|
||||
result = daemon.respond(req.request_id, approved=True, decided_by="human")
|
||||
assert result.status == ConfirmationStatus.APPROVED.value
|
||||
assert result.decided_by == "human"
|
||||
|
||||
def test_deny_response(self):
|
||||
daemon = ConfirmationDaemon()
|
||||
daemon._whitelist = {}
|
||||
req = daemon.request(
|
||||
action="crypto_tx",
|
||||
description="Send 1 ETH",
|
||||
risk_level="critical",
|
||||
)
|
||||
result = daemon.respond(
|
||||
req.request_id, approved=False, decided_by="human", reason="Too risky"
|
||||
)
|
||||
assert result.status == ConfirmationStatus.DENIED.value
|
||||
assert result.reason == "Too risky"
|
||||
|
||||
def test_get_pending(self):
|
||||
daemon = ConfirmationDaemon()
|
||||
daemon._whitelist = {}
|
||||
daemon.request(action="send_email", description="Test 1", risk_level="high")
|
||||
daemon.request(action="send_email", description="Test 2", risk_level="high")
|
||||
pending = daemon.get_pending()
|
||||
assert len(pending) >= 2
|
||||
|
||||
def test_get_history(self):
|
||||
daemon = ConfirmationDaemon()
|
||||
req = daemon.request(
|
||||
action="access_calendar", description="Test", risk_level="low"
|
||||
)
|
||||
history = daemon.get_history()
|
||||
assert len(history) >= 1
|
||||
assert history[0]["action"] == "access_calendar"
|
||||
@@ -121,6 +121,19 @@ DANGEROUS_PATTERNS = [
|
||||
(r'\b(cp|mv|install)\b.*\s/etc/', "copy/move file into /etc/"),
|
||||
(r'\bsed\s+-[^\s]*i.*\s/etc/', "in-place edit of system config"),
|
||||
(r'\bsed\s+--in-place\b.*\s/etc/', "in-place edit of system config (long flag)"),
|
||||
# --- Vitalik's threat model: crypto / financial ---
|
||||
(r'\b(?:bitcoin-cli|ethers\.js|web3|ether\.sendTransaction)\b', "direct crypto transaction tool usage"),
|
||||
(r'\bwget\b.*\b(?:mnemonic|seed\s*phrase|private[_-]?key)\b', "attempting to download crypto credentials"),
|
||||
(r'\bcurl\b.*\b(?:mnemonic|seed\s*phrase|private[_-]?key)\b', "attempting to exfiltrate crypto credentials"),
|
||||
# --- Vitalik's threat model: credential exfiltration ---
|
||||
(r'\b(?:curl|wget|http|nc|ncat|socat)\b.*\b(?:\.env|\.ssh|credentials|secrets|token|api[_-]?key)\b',
|
||||
"attempting to exfiltrate credentials via network"),
|
||||
(r'\bbase64\b.*\|(?:\s*curl|\s*wget)', "base64-encode then network exfiltration"),
|
||||
(r'\bcat\b.*\b(?:\.env|\.ssh/id_rsa|credentials)\b.*\|(?:\s*curl|\s*wget)',
|
||||
"reading secrets and piping to network tool"),
|
||||
# --- Vitalik's threat model: data exfiltration ---
|
||||
(r'\bcurl\b.*-d\s.*\$(?:HOME|USER)', "sending user home directory data to remote"),
|
||||
(r'\bwget\b.*--post-data\s.*\$(?:HOME|USER)', "posting user data to remote"),
|
||||
# Script execution via heredoc — bypasses the -e/-c flag patterns above.
|
||||
# `python3 << 'EOF'` feeds arbitrary code via stdin without -c/-e flags.
|
||||
(r'\b(python[23]?|perl|ruby|node)\s+<<', "script execution via heredoc"),
|
||||
|
||||
615
tools/confirmation_daemon.py
Normal file
615
tools/confirmation_daemon.py
Normal file
@@ -0,0 +1,615 @@
|
||||
"""Human Confirmation Daemon — HTTP server for two-factor action approval.
|
||||
|
||||
Implements Vitalik's Pattern 1: "The new 'two-factor confirmation' is that
|
||||
the two factors are the human and the LLM."
|
||||
|
||||
This daemon runs on localhost:6000 and provides a simple HTTP API for the
|
||||
agent to request human approval before executing high-risk actions.
|
||||
|
||||
Threat model:
|
||||
- LLM jailbreaks: Remote content "hacking" the LLM to perform malicious actions
|
||||
- LLM accidents: LLM accidentally performing dangerous operations
|
||||
- The human acts as the second factor — the agent proposes, the human disposes
|
||||
|
||||
Architecture:
|
||||
- Agent detects high-risk action → POST /confirm with action details
|
||||
- Daemon stores pending request, sends notification to user
|
||||
- User approves/denies via POST /respond (Telegram, CLI, or direct HTTP)
|
||||
- Agent receives decision and proceeds or aborts
|
||||
|
||||
Usage:
|
||||
# Start daemon (usually managed by gateway)
|
||||
from tools.confirmation_daemon import ConfirmationDaemon
|
||||
daemon = ConfirmationDaemon(port=6000)
|
||||
daemon.start()
|
||||
|
||||
# Request approval (from agent code)
|
||||
from tools.confirmation_daemon import request_confirmation
|
||||
approved = request_confirmation(
|
||||
action="send_email",
|
||||
description="Send email to alice@example.com",
|
||||
risk_level="high",
|
||||
payload={"to": "alice@example.com", "subject": "Meeting notes"},
|
||||
timeout=300,
|
||||
)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass, field, asdict
|
||||
from enum import Enum, auto
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RiskLevel(Enum):
|
||||
"""Risk classification for actions requiring confirmation."""
|
||||
LOW = "low" # Log only, no confirmation needed
|
||||
MEDIUM = "medium" # Confirm for non-whitelisted targets
|
||||
HIGH = "high" # Always confirm
|
||||
CRITICAL = "critical" # Always confirm + require explicit reason
|
||||
|
||||
|
||||
class ConfirmationStatus(Enum):
|
||||
"""Status of a pending confirmation request."""
|
||||
PENDING = "pending"
|
||||
APPROVED = "approved"
|
||||
DENIED = "denied"
|
||||
EXPIRED = "expired"
|
||||
AUTO_APPROVED = "auto_approved"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConfirmationRequest:
|
||||
"""A request for human confirmation of a high-risk action."""
|
||||
request_id: str
|
||||
action: str # Action type: send_email, send_message, crypto_tx, etc.
|
||||
description: str # Human-readable description of what will happen
|
||||
risk_level: str # low, medium, high, critical
|
||||
payload: Dict[str, Any] # Action-specific data (sanitized)
|
||||
session_key: str = "" # Session that initiated the request
|
||||
created_at: float = 0.0
|
||||
expires_at: float = 0.0
|
||||
status: str = ConfirmationStatus.PENDING.value
|
||||
decided_at: float = 0.0
|
||||
decided_by: str = "" # "human", "auto", "whitelist"
|
||||
reason: str = "" # Optional reason for denial
|
||||
|
||||
def __post_init__(self):
|
||||
if not self.created_at:
|
||||
self.created_at = time.time()
|
||||
if not self.expires_at:
|
||||
self.expires_at = self.created_at + 300 # 5 min default
|
||||
if not self.request_id:
|
||||
self.request_id = str(uuid.uuid4())[:12]
|
||||
|
||||
@property
|
||||
def is_expired(self) -> bool:
|
||||
return time.time() > self.expires_at
|
||||
|
||||
@property
|
||||
def is_pending(self) -> bool:
|
||||
return self.status == ConfirmationStatus.PENDING.value and not self.is_expired
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
d = asdict(self)
|
||||
d["is_expired"] = self.is_expired
|
||||
d["is_pending"] = self.is_pending
|
||||
return d
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Action categories (Vitalik's threat model)
|
||||
# =========================================================================
|
||||
|
||||
ACTION_CATEGORIES = {
|
||||
# Messaging — outbound communication to external parties
|
||||
"send_email": RiskLevel.HIGH,
|
||||
"send_message": RiskLevel.MEDIUM, # Depends on recipient
|
||||
"send_signal": RiskLevel.HIGH,
|
||||
"send_telegram": RiskLevel.MEDIUM,
|
||||
"send_discord": RiskLevel.MEDIUM,
|
||||
"post_social": RiskLevel.HIGH,
|
||||
|
||||
# Financial / crypto
|
||||
"crypto_tx": RiskLevel.CRITICAL,
|
||||
"sign_transaction": RiskLevel.CRITICAL,
|
||||
"access_wallet": RiskLevel.CRITICAL,
|
||||
"modify_balance": RiskLevel.CRITICAL,
|
||||
|
||||
# System modification
|
||||
"install_software": RiskLevel.HIGH,
|
||||
"modify_system_config": RiskLevel.HIGH,
|
||||
"modify_firewall": RiskLevel.CRITICAL,
|
||||
"add_ssh_key": RiskLevel.CRITICAL,
|
||||
"create_user": RiskLevel.CRITICAL,
|
||||
|
||||
# Data access
|
||||
"access_contacts": RiskLevel.MEDIUM,
|
||||
"access_calendar": RiskLevel.LOW,
|
||||
"read_private_files": RiskLevel.MEDIUM,
|
||||
"upload_data": RiskLevel.HIGH,
|
||||
"share_credentials": RiskLevel.CRITICAL,
|
||||
|
||||
# Network
|
||||
"open_port": RiskLevel.HIGH,
|
||||
"modify_dns": RiskLevel.HIGH,
|
||||
"expose_service": RiskLevel.CRITICAL,
|
||||
}
|
||||
|
||||
# Default: any unrecognized action is MEDIUM risk
|
||||
DEFAULT_RISK_LEVEL = RiskLevel.MEDIUM
|
||||
|
||||
|
||||
def classify_action(action: str) -> RiskLevel:
|
||||
"""Classify an action by its risk level."""
|
||||
return ACTION_CATEGORIES.get(action, DEFAULT_RISK_LEVEL)
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Whitelist configuration
|
||||
# =========================================================================
|
||||
|
||||
_DEFAULT_WHITELIST = {
|
||||
"send_message": {
|
||||
"targets": [], # Contact names/IDs that don't need confirmation
|
||||
},
|
||||
"send_email": {
|
||||
"targets": [], # Email addresses that don't need confirmation
|
||||
"self_only": True, # send-to-self always allowed
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _load_whitelist() -> Dict[str, Any]:
|
||||
"""Load action whitelist from config."""
|
||||
config_path = Path.home() / ".hermes" / "approval_whitelist.json"
|
||||
if config_path.exists():
|
||||
try:
|
||||
with open(config_path) as f:
|
||||
return json.load(f)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to load approval whitelist: %s", e)
|
||||
return dict(_DEFAULT_WHITELIST)
|
||||
|
||||
|
||||
def _is_whitelisted(action: str, payload: Dict[str, Any], whitelist: Dict) -> bool:
|
||||
"""Check if an action is pre-approved by the whitelist."""
|
||||
action_config = whitelist.get(action, {})
|
||||
if not action_config:
|
||||
return False
|
||||
|
||||
# Check target-based whitelist
|
||||
targets = action_config.get("targets", [])
|
||||
target = payload.get("to") or payload.get("recipient") or payload.get("target", "")
|
||||
if target and target in targets:
|
||||
return True
|
||||
|
||||
# Self-only email
|
||||
if action_config.get("self_only") and action == "send_email":
|
||||
sender = payload.get("from", "")
|
||||
recipient = payload.get("to", "")
|
||||
if sender and recipient and sender.lower() == recipient.lower():
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Confirmation daemon
|
||||
# =========================================================================
|
||||
|
||||
class ConfirmationDaemon:
|
||||
"""HTTP daemon for human confirmation of high-risk actions.
|
||||
|
||||
Runs on localhost:PORT (default 6000). Provides:
|
||||
- POST /confirm — agent requests human approval
|
||||
- POST /respond — human approves/denies
|
||||
- GET /pending — list pending requests
|
||||
- GET /health — health check
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: str = "127.0.0.1",
|
||||
port: int = 6000,
|
||||
default_timeout: int = 300,
|
||||
notify_callback: Optional[Callable] = None,
|
||||
):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.default_timeout = default_timeout
|
||||
self.notify_callback = notify_callback
|
||||
self._pending: Dict[str, ConfirmationRequest] = {}
|
||||
self._history: List[ConfirmationRequest] = []
|
||||
self._lock = threading.Lock()
|
||||
self._whitelist = _load_whitelist()
|
||||
self._app = None
|
||||
self._runner = None
|
||||
|
||||
def request(
|
||||
self,
|
||||
action: str,
|
||||
description: str,
|
||||
payload: Optional[Dict[str, Any]] = None,
|
||||
risk_level: Optional[str] = None,
|
||||
session_key: str = "",
|
||||
timeout: Optional[int] = None,
|
||||
) -> ConfirmationRequest:
|
||||
"""Create a confirmation request.
|
||||
|
||||
Returns the request. Check .status to see if it was immediately
|
||||
auto-approved (whitelisted) or is pending human review.
|
||||
"""
|
||||
payload = payload or {}
|
||||
|
||||
# Classify risk if not specified
|
||||
if risk_level is None:
|
||||
risk_level = classify_action(action).value
|
||||
|
||||
# Check whitelist
|
||||
if risk_level in ("low",) or _is_whitelisted(action, payload, self._whitelist):
|
||||
req = ConfirmationRequest(
|
||||
request_id=str(uuid.uuid4())[:12],
|
||||
action=action,
|
||||
description=description,
|
||||
risk_level=risk_level,
|
||||
payload=payload,
|
||||
session_key=session_key,
|
||||
expires_at=time.time() + (timeout or self.default_timeout),
|
||||
status=ConfirmationStatus.AUTO_APPROVED.value,
|
||||
decided_at=time.time(),
|
||||
decided_by="whitelist",
|
||||
)
|
||||
with self._lock:
|
||||
self._history.append(req)
|
||||
logger.info("Auto-approved whitelisted action: %s", action)
|
||||
return req
|
||||
|
||||
# Create pending request
|
||||
req = ConfirmationRequest(
|
||||
request_id=str(uuid.uuid4())[:12],
|
||||
action=action,
|
||||
description=description,
|
||||
risk_level=risk_level,
|
||||
payload=payload,
|
||||
session_key=session_key,
|
||||
expires_at=time.time() + (timeout or self.default_timeout),
|
||||
)
|
||||
|
||||
with self._lock:
|
||||
self._pending[req.request_id] = req
|
||||
|
||||
# Notify human
|
||||
if self.notify_callback:
|
||||
try:
|
||||
self.notify_callback(req.to_dict())
|
||||
except Exception as e:
|
||||
logger.warning("Confirmation notify callback failed: %s", e)
|
||||
|
||||
logger.info(
|
||||
"Confirmation request %s: %s (%s risk) — waiting for human",
|
||||
req.request_id, action, risk_level,
|
||||
)
|
||||
return req
|
||||
|
||||
def respond(
|
||||
self,
|
||||
request_id: str,
|
||||
approved: bool,
|
||||
decided_by: str = "human",
|
||||
reason: str = "",
|
||||
) -> Optional[ConfirmationRequest]:
|
||||
"""Record a human decision on a pending request."""
|
||||
with self._lock:
|
||||
req = self._pending.get(request_id)
|
||||
if not req:
|
||||
logger.warning("Confirmation respond: unknown request %s", request_id)
|
||||
return None
|
||||
if not req.is_pending:
|
||||
logger.warning("Confirmation respond: request %s already decided", request_id)
|
||||
return req
|
||||
|
||||
req.status = (
|
||||
ConfirmationStatus.APPROVED.value if approved
|
||||
else ConfirmationStatus.DENIED.value
|
||||
)
|
||||
req.decided_at = time.time()
|
||||
req.decided_by = decided_by
|
||||
req.reason = reason
|
||||
|
||||
# Move to history
|
||||
del self._pending[request_id]
|
||||
self._history.append(req)
|
||||
|
||||
logger.info(
|
||||
"Confirmation %s: %s by %s",
|
||||
request_id, "APPROVED" if approved else "DENIED", decided_by,
|
||||
)
|
||||
return req
|
||||
|
||||
def wait_for_decision(
|
||||
self, request_id: str, timeout: Optional[float] = None
|
||||
) -> ConfirmationRequest:
|
||||
"""Block until a decision is made or timeout expires."""
|
||||
deadline = time.time() + (timeout or self.default_timeout)
|
||||
while time.time() < deadline:
|
||||
with self._lock:
|
||||
req = self._pending.get(request_id)
|
||||
if req and not req.is_pending:
|
||||
return req
|
||||
if req and req.is_expired:
|
||||
req.status = ConfirmationStatus.EXPIRED.value
|
||||
del self._pending[request_id]
|
||||
self._history.append(req)
|
||||
return req
|
||||
time.sleep(0.5)
|
||||
|
||||
# Timeout
|
||||
with self._lock:
|
||||
req = self._pending.pop(request_id, None)
|
||||
if req:
|
||||
req.status = ConfirmationStatus.EXPIRED.value
|
||||
self._history.append(req)
|
||||
return req
|
||||
|
||||
# Shouldn't reach here
|
||||
return ConfirmationRequest(
|
||||
request_id=request_id,
|
||||
action="unknown",
|
||||
description="Request not found",
|
||||
risk_level="high",
|
||||
payload={},
|
||||
status=ConfirmationStatus.EXPIRED.value,
|
||||
)
|
||||
|
||||
def get_pending(self) -> List[Dict[str, Any]]:
|
||||
"""Return list of pending confirmation requests."""
|
||||
self._expire_old()
|
||||
with self._lock:
|
||||
return [r.to_dict() for r in self._pending.values() if r.is_pending]
|
||||
|
||||
def get_history(self, limit: int = 50) -> List[Dict[str, Any]]:
|
||||
"""Return recent confirmation history."""
|
||||
with self._lock:
|
||||
return [r.to_dict() for r in self._history[-limit:]]
|
||||
|
||||
def _expire_old(self) -> None:
|
||||
"""Move expired requests to history."""
|
||||
now = time.time()
|
||||
with self._lock:
|
||||
expired = [
|
||||
rid for rid, req in self._pending.items()
|
||||
if now > req.expires_at
|
||||
]
|
||||
for rid in expired:
|
||||
req = self._pending.pop(rid)
|
||||
req.status = ConfirmationStatus.EXPIRED.value
|
||||
self._history.append(req)
|
||||
|
||||
# --- aiohttp HTTP API ---
|
||||
|
||||
async def _handle_health(self, request):
|
||||
from aiohttp import web
|
||||
return web.json_response({
|
||||
"status": "ok",
|
||||
"service": "hermes-confirmation-daemon",
|
||||
"pending": len(self._pending),
|
||||
})
|
||||
|
||||
async def _handle_confirm(self, request):
|
||||
from aiohttp import web
|
||||
try:
|
||||
body = await request.json()
|
||||
except Exception:
|
||||
return web.json_response({"error": "invalid JSON"}, status=400)
|
||||
|
||||
action = body.get("action", "")
|
||||
description = body.get("description", "")
|
||||
if not action or not description:
|
||||
return web.json_response(
|
||||
{"error": "action and description required"}, status=400
|
||||
)
|
||||
|
||||
req = self.request(
|
||||
action=action,
|
||||
description=description,
|
||||
payload=body.get("payload", {}),
|
||||
risk_level=body.get("risk_level"),
|
||||
session_key=body.get("session_key", ""),
|
||||
timeout=body.get("timeout"),
|
||||
)
|
||||
|
||||
# If auto-approved, return immediately
|
||||
if req.status != ConfirmationStatus.PENDING.value:
|
||||
return web.json_response({
|
||||
"request_id": req.request_id,
|
||||
"status": req.status,
|
||||
"decided_by": req.decided_by,
|
||||
})
|
||||
|
||||
# Otherwise, wait for human decision (with timeout)
|
||||
timeout = min(body.get("timeout", self.default_timeout), 600)
|
||||
result = self.wait_for_decision(req.request_id, timeout=timeout)
|
||||
|
||||
return web.json_response({
|
||||
"request_id": result.request_id,
|
||||
"status": result.status,
|
||||
"decided_by": result.decided_by,
|
||||
"reason": result.reason,
|
||||
})
|
||||
|
||||
async def _handle_respond(self, request):
|
||||
from aiohttp import web
|
||||
try:
|
||||
body = await request.json()
|
||||
except Exception:
|
||||
return web.json_response({"error": "invalid JSON"}, status=400)
|
||||
|
||||
request_id = body.get("request_id", "")
|
||||
approved = body.get("approved")
|
||||
if not request_id or approved is None:
|
||||
return web.json_response(
|
||||
{"error": "request_id and approved required"}, status=400
|
||||
)
|
||||
|
||||
result = self.respond(
|
||||
request_id=request_id,
|
||||
approved=bool(approved),
|
||||
decided_by=body.get("decided_by", "human"),
|
||||
reason=body.get("reason", ""),
|
||||
)
|
||||
|
||||
if not result:
|
||||
return web.json_response({"error": "unknown request"}, status=404)
|
||||
|
||||
return web.json_response({
|
||||
"request_id": result.request_id,
|
||||
"status": result.status,
|
||||
})
|
||||
|
||||
async def _handle_pending(self, request):
|
||||
from aiohttp import web
|
||||
return web.json_response({"pending": self.get_pending()})
|
||||
|
||||
def _build_app(self):
|
||||
"""Build the aiohttp application."""
|
||||
from aiohttp import web
|
||||
|
||||
app = web.Application()
|
||||
app.router.add_get("/health", self._handle_health)
|
||||
app.router.add_post("/confirm", self._handle_confirm)
|
||||
app.router.add_post("/respond", self._handle_respond)
|
||||
app.router.add_get("/pending", self._handle_pending)
|
||||
self._app = app
|
||||
return app
|
||||
|
||||
async def start_async(self) -> None:
|
||||
"""Start the daemon as an async server."""
|
||||
from aiohttp import web
|
||||
|
||||
app = self._build_app()
|
||||
self._runner = web.AppRunner(app)
|
||||
await self._runner.setup()
|
||||
site = web.TCPSite(self._runner, self.host, self.port)
|
||||
await site.start()
|
||||
logger.info("Confirmation daemon listening on %s:%d", self.host, self.port)
|
||||
|
||||
async def stop_async(self) -> None:
|
||||
"""Stop the daemon."""
|
||||
if self._runner:
|
||||
await self._runner.cleanup()
|
||||
self._runner = None
|
||||
|
||||
def start(self) -> None:
|
||||
"""Start daemon in a background thread (blocking caller)."""
|
||||
def _run():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
loop.run_until_complete(self.start_async())
|
||||
loop.run_forever()
|
||||
|
||||
t = threading.Thread(target=_run, daemon=True, name="confirmation-daemon")
|
||||
t.start()
|
||||
logger.info("Confirmation daemon started in background thread")
|
||||
|
||||
def start_blocking(self) -> None:
|
||||
"""Start daemon and block (for standalone use)."""
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
loop.run_until_complete(self.start_async())
|
||||
try:
|
||||
loop.run_forever()
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
loop.run_until_complete(self.stop_async())
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Convenience API for agent integration
|
||||
# =========================================================================
|
||||
|
||||
# Global singleton — initialized by gateway or CLI at startup
|
||||
_daemon: Optional[ConfirmationDaemon] = None
|
||||
|
||||
|
||||
def get_daemon() -> Optional[ConfirmationDaemon]:
|
||||
"""Get the global confirmation daemon instance."""
|
||||
return _daemon
|
||||
|
||||
|
||||
def init_daemon(
|
||||
host: str = "127.0.0.1",
|
||||
port: int = 6000,
|
||||
notify_callback: Optional[Callable] = None,
|
||||
) -> ConfirmationDaemon:
|
||||
"""Initialize the global confirmation daemon."""
|
||||
global _daemon
|
||||
_daemon = ConfirmationDaemon(
|
||||
host=host, port=port, notify_callback=notify_callback
|
||||
)
|
||||
return _daemon
|
||||
|
||||
|
||||
def request_confirmation(
|
||||
action: str,
|
||||
description: str,
|
||||
payload: Optional[Dict[str, Any]] = None,
|
||||
risk_level: Optional[str] = None,
|
||||
session_key: str = "",
|
||||
timeout: int = 300,
|
||||
) -> bool:
|
||||
"""Request human confirmation for a high-risk action.
|
||||
|
||||
This is the primary integration point for agent code. It:
|
||||
1. Classifies the action risk level
|
||||
2. Checks the whitelist
|
||||
3. If confirmation needed, blocks until human responds
|
||||
4. Returns True if approved, False if denied/expired
|
||||
|
||||
Args:
|
||||
action: Action type (send_email, crypto_tx, etc.)
|
||||
description: Human-readable description
|
||||
payload: Action-specific data
|
||||
risk_level: Override auto-classification
|
||||
session_key: Session requesting approval
|
||||
timeout: Seconds to wait for human response
|
||||
|
||||
Returns:
|
||||
True if approved, False if denied or expired.
|
||||
"""
|
||||
daemon = get_daemon()
|
||||
if not daemon:
|
||||
logger.warning(
|
||||
"No confirmation daemon running — DENYING action %s by default. "
|
||||
"Start daemon with init_daemon() or --confirmation-daemon flag.",
|
||||
action,
|
||||
)
|
||||
return False
|
||||
|
||||
req = daemon.request(
|
||||
action=action,
|
||||
description=description,
|
||||
payload=payload,
|
||||
risk_level=risk_level,
|
||||
session_key=session_key,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
# Auto-approved (whitelisted)
|
||||
if req.status == ConfirmationStatus.AUTO_APPROVED.value:
|
||||
return True
|
||||
|
||||
# Wait for human
|
||||
result = daemon.wait_for_decision(req.request_id, timeout=timeout)
|
||||
return result.status == ConfirmationStatus.APPROVED.value
|
||||
Reference in New Issue
Block a user