Compare commits

..

7 Commits

Author SHA1 Message Date
Alexander Whitestone
4bb12e05ef bench: Gemma 4 tool calling benchmark — 100 prompts (#796)
Some checks failed
Docker Build and Publish / build-and-push (pull_request) Has been skipped
Supply Chain Audit / Scan PR for supply chain risks (pull_request) Successful in 32s
Tests / e2e (pull_request) Successful in 1m40s
Tests / test (pull_request) Failing after 51m11s
Contributor Attribution Check / check-attribution (pull_request) Failing after 22s
Benchmark script comparing Gemma 4 vs mimo-v2-pro on tool calling.

100 prompts across 6 categories:
- File operations (20): read, write, search
- Terminal commands (20): system info, process management
- Web search (15): documentation, comparisons
- Code execution (15): calculations, parsing
- Parallel tool calls (10): concurrent operations
- Edge cases (20): complex, ambiguous prompts

Metrics:
- Schema parse success rate
- Tool execution success rate
- Argument validity rate
- Average latency
- Token cost

Usage:
  python3 benchmarks/tool_call_benchmark.py --model1 gemma3:27b --model2 xiaomi/mimo-v2-pro
  python3 benchmarks/tool_call_benchmark.py --model1 gemma3:27b --limit 10

Closes #796
2026-04-16 01:05:29 -04:00
db72e908f7 Merge pull request 'feat(security): implement Vitalik's secure LLM patterns — privacy filter + confirmation daemon [resolves merge conflict]' (#830) from feat/vitalik-secure-llm-1776303263 into main
Vitalik's secure LLM patterns — privacy filter + confirmation daemon

Clean rebase of #397 onto current main. Resolves merge conflicts in tools/approval.py.
2026-04-16 01:36:58 +00:00
b82b760d5d feat: add Vitalik's threat model patterns to DANGEROUS_PATTERNS
Some checks failed
Docker Build and Publish / build-and-push (pull_request) Has been skipped
Supply Chain Audit / Scan PR for supply chain risks (pull_request) Successful in 41s
Contributor Attribution Check / check-attribution (pull_request) Successful in 51s
Tests / e2e (pull_request) Successful in 5m21s
Tests / test (pull_request) Failing after 45m7s
2026-04-16 01:35:49 +00:00
d8d7846897 feat: add tests/tools/test_confirmation_daemon.py from PR #397 2026-04-16 01:35:24 +00:00
6840d05554 feat: add tests/agent/test_privacy_filter.py from PR #397 2026-04-16 01:35:21 +00:00
8abe59ed95 feat: add tools/confirmation_daemon.py from PR #397 2026-04-16 01:35:18 +00:00
435d790201 feat: add agent/privacy_filter.py from PR #397 2026-04-16 01:35:14 +00:00
11 changed files with 1834 additions and 526 deletions

View File

@@ -1,22 +0,0 @@
"""A2A — Agent2Agent Protocol v1.0 for Hermes task delegation.
Usage:
from a2a import A2AClient, A2AServer, AgentCard, Task, TextPart
"""
from a2a.types import (
AgentCard, AgentSkill, Artifact, DataPart, FilePart,
JSONRPCError, JSONRPCRequest, JSONRPCResponse,
Message, Part, Task, TaskState, TaskStatus, TextPart, A2AError,
part_from_dict,
)
from a2a.client import A2AClient, A2AClientConfig
from a2a.server import A2AServer, TaskHandler
__all__ = [
"A2AClient", "A2AClientConfig", "A2AServer", "TaskHandler",
"AgentCard", "AgentSkill", "Artifact",
"DataPart", "FilePart", "TextPart", "Part", "part_from_dict",
"JSONRPCError", "JSONRPCRequest", "JSONRPCResponse",
"Message", "Task", "TaskState", "TaskStatus", "A2AError",
]

View File

@@ -1,98 +0,0 @@
"""A2A Client - send tasks to remote agents via JSON-RPC 2.0."""
from __future__ import annotations
import asyncio, json, logging, time
from typing import Any, Dict, List, Optional
import aiohttp
from a2a.types import (AgentCard, Artifact, JSONRPCError, JSONRPCRequest, JSONRPCResponse, Message, Part, Task, TaskState, TaskStatus, TextPart, A2AError)
logger = logging.getLogger(__name__)
class A2AClientConfig:
def __init__(self, timeout=30.0, max_retries=3, retry_delay=2.0, auth_token=None, auth_scheme="Bearer"):
self.timeout = timeout
self.max_retries = max_retries
self.retry_delay = retry_delay
self.auth_token = auth_token
self.auth_scheme = auth_scheme
class A2AClient:
def __init__(self, base_url, config=None):
self.base_url = base_url.rstrip("/")
self.config = config or A2AClientConfig()
self._audit_log = []
def _headers(self):
h = {"Content-Type": "application/json"}
if self.config.auth_token:
h["Authorization"] = f"{self.config.auth_scheme} {self.config.auth_token}"
return h
async def _rpc_call(self, method, params=None):
req = JSONRPCRequest(method=method, params=params)
payload = json.dumps(req.to_dict())
last_error = None
for attempt in range(1, self.config.max_retries + 1):
t0 = time.monotonic()
try:
timeout = aiohttp.ClientTimeout(total=self.config.timeout)
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.post(f"{self.base_url}/a2a/v1/rpc", data=payload, headers=self._headers()) as resp:
body = await resp.text()
elapsed = time.monotonic() - t0
self._audit_log.append({"method": method, "params": params, "status": resp.status, "elapsed_ms": round(elapsed*1000), "attempt": attempt})
if resp.status != 200:
logger.warning("A2A %s -> HTTP %d (attempt %d/%d)", method, resp.status, attempt, self.config.max_retries)
if attempt < self.config.max_retries:
await asyncio.sleep(self.config.retry_delay * attempt)
continue
return JSONRPCResponse(id=req.id, error=A2AError.internal_error(f"HTTP {resp.status}: {body[:200]}"))
return JSONRPCResponse.from_dict(json.loads(body))
except Exception as exc:
elapsed = time.monotonic() - t0
last_error = exc
self._audit_log.append({"method": method, "error": str(exc), "elapsed_ms": round(elapsed*1000), "attempt": attempt})
if attempt < self.config.max_retries:
await asyncio.sleep(self.config.retry_delay * attempt)
continue
return JSONRPCResponse(id=req.id, error=A2AError.internal_error(str(last_error)))
async def get_agent_card(self):
resp = await self._rpc_call("GetAgentCard")
if resp.error: raise RuntimeError(f"Failed to get agent card: {resp.error.message}")
return AgentCard.from_dict(resp.result)
async def send_message(self, text, context_id=None, skill_id=None, metadata=None):
msg = Message(role="user", parts=[TextPart(text=text)], context_id=context_id)
params = {"message": msg.to_dict()}
if skill_id: params["skillId"] = skill_id
if metadata: params["metadata"] = metadata
resp = await self._rpc_call("SendMessage", params)
if resp.error: raise RuntimeError(f"SendMessage failed: {resp.error.message}")
return Task.from_dict(resp.result)
async def get_task(self, task_id):
resp = await self._rpc_call("GetTask", {"taskId": task_id})
if resp.error: raise RuntimeError(f"GetTask failed: {resp.error.message}")
return Task.from_dict(resp.result)
async def cancel_task(self, task_id):
resp = await self._rpc_call("CancelTask", {"taskId": task_id})
if resp.error: raise RuntimeError(f"CancelTask failed: {resp.error.message}")
return Task.from_dict(resp.result)
async def wait_for_completion(self, task_id, poll_interval=2.0, timeout=300.0):
t0 = time.monotonic()
while True:
task = await self.get_task(task_id)
if task.status.state.terminal: return task
if time.monotonic() - t0 > timeout: raise TimeoutError(f"Task {task_id} did not complete within {timeout}s")
await asyncio.sleep(poll_interval)
async def delegate(self, text, skill_id=None, wait=True, timeout=300.0):
task = await self.send_message(text, skill_id=skill_id)
if wait: return await self.wait_for_completion(task.id, timeout=timeout)
return task
@property
def audit_log(self): return list(self._audit_log)

View File

@@ -1,60 +0,0 @@
"""A2A Server - receive and execute tasks via JSON-RPC 2.0."""
from __future__ import annotations
import json,logging,uuid
from datetime import datetime,timezone
from typing import Any,Callable,Dict,List,Optional,Awaitable
from a2a.types import AgentCard,Artifact,JSONRPCError,JSONRPCRequest,JSONRPCResponse,Message,Task,TaskState,TaskStatus,TextPart,A2AError,part_from_dict
logger=logging.getLogger(__name__)
TaskHandler=Callable[[Task,AgentCard],Awaitable[Task]]
class A2AServer:
def __init__(s,card):s.card=card;s._tasks={};s._handlers={};s._default_handler=None;s._audit_log=[]
def register_handler(s,skill_id,handler):s._handlers[skill_id]=handler
def set_default_handler(s,handler):s._default_handler=handler
async def handle_rpc(s,raw):
try:data=json.loads(raw)
except (json.JSONDecodeError,TypeError):return json.dumps(JSONRPCResponse(id="",error=A2AError.parse_error()).to_dict())
req_id=data.get("id","");method=data.get("method","");params=data.get("params")
s._audit_log.append({"method":method,"id":req_id,"ts":datetime.now(timezone.utc).isoformat()})
try:
if method=="SendMessage":result=await s._handle_send_message(params)
elif method=="GetTask":result=s._handle_get_task(params)
elif method=="CancelTask":result=s._handle_cancel_task(params)
elif method=="GetAgentCard":result=s.card.to_dict()
elif method=="ListTasks":result=s._handle_list_tasks(params)
else:return json.dumps(JSONRPCResponse(id=req_id,error=A2AError.method_not_found()).to_dict())
return json.dumps(JSONRPCResponse(id=req_id,result=result).to_dict())
except Exception as exc:
logger.exception("A2A handler error for %s",method)
return json.dumps(JSONRPCResponse(id=req_id,error=A2AError.internal_error(str(exc))).to_dict())
async def _handle_send_message(s,params):
if not params or "message" not in params:raise ValueError("SendMessage requires message param")
msg=Message.from_dict(params["message"])
task=Task(id=str(uuid.uuid4()),context_id=msg.context_id,status=TaskStatus(state=TaskState.SUBMITTED),history=[msg])
s._tasks[task.id]=task
skill_id=params.get("skillId");handler=s._handlers.get(skill_id) if skill_id else None
if handler is None:handler=s._default_handler
if handler is None:
text=msg.parts[0].text if msg.parts and hasattr(msg.parts[0],"text") else ""
task.status=TaskStatus(state=TaskState.COMPLETED)
task.artifacts=[Artifact(parts=[TextPart(text=f"Received: {text}")])]
else:
task.status=TaskStatus(state=TaskState.WORKING);s._tasks[task.id]=task
task=await handler(task,s.card);s._tasks[task.id]=task
return task.to_dict()
def _handle_get_task(s,params):
task_id=(params or {}).get("taskId")
if not task_id or task_id not in s._tasks:raise KeyError(f"Task not found: {task_id}")
return s._tasks[task_id].to_dict()
def _handle_cancel_task(s,params):
task_id=(params or {}).get("taskId")
if not task_id or task_id not in s._tasks:raise KeyError(f"Task not found: {task_id}")
task=s._tasks[task_id]
if task.status.state.terminal:raise ValueError(f"Task already {task.status.state.value}")
task.status=TaskStatus(state=TaskState.CANCELED);s._tasks[task_id]=task
return task.to_dict()
def _handle_list_tasks(s,params):
context_id=(params or {}).get("contextId")
return {"tasks":[t.to_dict() for t in s._tasks.values() if not context_id or t.context_id==context_id]}
def add_task(s,task):s._tasks[task.id]=task
@property
def audit_log(s):return list(s._audit_log)

View File

@@ -1,230 +0,0 @@
"""A2A Protocol Types - Agent2Agent v1.0 data structures."""
from __future__ import annotations
import uuid
from dataclasses import dataclass, field
from datetime import datetime, timezone
from enum import Enum
from typing import Any, Dict, List, Optional, Union
class TaskState(str, Enum):
SUBMITTED = "TASK_STATE_SUBMITTED"
WORKING = "TASK_STATE_WORKING"
INPUT_REQUIRED = "TASK_STATE_INPUT_REQUIRED"
COMPLETED = "TASK_STATE_COMPLETED"
FAILED = "TASK_STATE_FAILED"
CANCELED = "TASK_STATE_CANCELED"
REJECTED = "TASK_STATE_REJECTED"
@property
def terminal(self) -> bool:
return self in {TaskState.COMPLETED, TaskState.FAILED, TaskState.CANCELED, TaskState.REJECTED}
@dataclass
class TextPart:
text: str
metadata: Optional[Dict[str, Any]] = None
def to_dict(self) -> dict:
d = {"text": self.text}
if self.metadata: d["metadata"] = self.metadata
return d
@classmethod
def from_dict(cls, d): return cls(text=d["text"], metadata=d.get("metadata"))
@dataclass
class FilePart:
media_type: str = "application/octet-stream"
raw: Optional[str] = None
url: Optional[str] = None
filename: Optional[str] = None
metadata: Optional[Dict[str, Any]] = None
def to_dict(self) -> dict:
d = {"mediaType": self.media_type}
if self.raw is not None: d["raw"] = self.raw
if self.url is not None: d["url"] = self.url
if self.filename: d["filename"] = self.filename
if self.metadata: d["metadata"] = self.metadata
return d
@classmethod
def from_dict(cls, d):
return cls(media_type=d.get("mediaType","application/octet-stream"), raw=d.get("raw"), url=d.get("url"), filename=d.get("filename"), metadata=d.get("metadata"))
@dataclass
class DataPart:
data: Dict[str, Any]
media_type: str = "application/json"
metadata: Optional[Dict[str, Any]] = None
def to_dict(self) -> dict:
d = {"data": self.data, "mediaType": self.media_type}
if self.metadata: d["metadata"] = self.metadata
return d
@classmethod
def from_dict(cls, d): return cls(data=d["data"], media_type=d.get("mediaType","application/json"), metadata=d.get("metadata"))
Part = Union[TextPart, FilePart, DataPart]
def part_from_dict(d):
if "text" in d: return TextPart.from_dict(d)
if "raw" in d or "url" in d: return FilePart.from_dict(d)
if "data" in d: return DataPart.from_dict(d)
raise ValueError(f"Cannot discriminate Part type from keys: {list(d.keys())}")
@dataclass
class Message:
role: str
parts: List[Part]
message_id: str = field(default_factory=lambda: str(uuid.uuid4()))
context_id: Optional[str] = None
task_id: Optional[str] = None
metadata: Optional[Dict[str, Any]] = None
def to_dict(self):
d = {"role": self.role, "messageId": self.message_id, "parts": [p.to_dict() for p in self.parts]}
if self.context_id: d["contextId"] = self.context_id
if self.task_id: d["taskId"] = self.task_id
if self.metadata: d["metadata"] = self.metadata
return d
@classmethod
def from_dict(cls, d):
return cls(role=d["role"], parts=[part_from_dict(p) for p in d["parts"]], message_id=d.get("messageId",str(uuid.uuid4())), context_id=d.get("contextId"), task_id=d.get("taskId"), metadata=d.get("metadata"))
@dataclass
class Artifact:
artifact_id: str = field(default_factory=lambda: str(uuid.uuid4()))
parts: List[Part] = field(default_factory=list)
name: Optional[str] = None
description: Optional[str] = None
metadata: Optional[Dict[str, Any]] = None
def to_dict(self):
d = {"artifactId": self.artifact_id, "parts": [p.to_dict() for p in self.parts]}
if self.name: d["name"] = self.name
if self.description: d["description"] = self.description
if self.metadata: d["metadata"] = self.metadata
return d
@classmethod
def from_dict(cls, d):
return cls(artifact_id=d.get("artifactId",str(uuid.uuid4())), parts=[part_from_dict(p) for p in d.get("parts",[])], name=d.get("name"), description=d.get("description"), metadata=d.get("metadata"))
@dataclass
class TaskStatus:
state: TaskState
message: Optional[Message] = None
timestamp: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
def to_dict(self):
d = {"state": self.state.value, "timestamp": self.timestamp}
if self.message: d["message"] = self.message.to_dict()
return d
@classmethod
def from_dict(cls, d):
msg = d.get("message")
return cls(state=TaskState(d["state"]), message=Message.from_dict(msg) if msg else None, timestamp=d.get("timestamp",datetime.now(timezone.utc).isoformat()))
@dataclass
class Task:
id: str = field(default_factory=lambda: str(uuid.uuid4()))
context_id: Optional[str] = None
status: TaskStatus = field(default_factory=lambda: TaskStatus(state=TaskState.SUBMITTED))
artifacts: List[Artifact] = field(default_factory=list)
history: List[Message] = field(default_factory=list)
metadata: Optional[Dict[str, Any]] = None
def to_dict(self):
d = {"id": self.id, "status": self.status.to_dict()}
if self.context_id: d["contextId"] = self.context_id
if self.artifacts: d["artifacts"] = [a.to_dict() for a in self.artifacts]
if self.history: d["history"] = [m.to_dict() for m in self.history]
if self.metadata: d["metadata"] = self.metadata
return d
@classmethod
def from_dict(cls, d):
return cls(id=d.get("id",str(uuid.uuid4())), context_id=d.get("contextId"), status=TaskStatus.from_dict(d["status"]) if "status" in d else TaskStatus(TaskState.SUBMITTED), artifacts=[Artifact.from_dict(a) for a in d.get("artifacts",[])], history=[Message.from_dict(m) for m in d.get("history",[])], metadata=d.get("metadata"))
@dataclass
class AgentSkill:
id: str
name: str
description: str = ""
tags: List[str] = field(default_factory=list)
examples: List[str] = field(default_factory=list)
input_modes: List[str] = field(default_factory=lambda: ["text"])
output_modes: List[str] = field(default_factory=lambda: ["text"])
def to_dict(self):
d = {"id": self.id, "name": self.name, "description": self.description}
if self.tags: d["tags"] = self.tags
if self.examples: d["examples"] = self.examples
if self.input_modes != ["text"]: d["inputModes"] = self.input_modes
if self.output_modes != ["text"]: d["outputModes"] = self.output_modes
return d
@classmethod
def from_dict(cls, d):
return cls(id=d["id"], name=d.get("name",d["id"]), description=d.get("description",""), tags=d.get("tags",[]), examples=d.get("examples",[]), input_modes=d.get("inputModes",["text"]), output_modes=d.get("outputModes",["text"]))
@dataclass
class AgentCard:
name: str
description: str = ""
version: str = "1.0.0"
url: str = ""
skills: List[AgentSkill] = field(default_factory=list)
capabilities: Dict[str, bool] = field(default_factory=dict)
provider: Optional[Dict[str, str]] = None
def to_dict(self):
d = {"name": self.name, "description": self.description, "version": self.version, "url": self.url, "skills": [s.to_dict() for s in self.skills]}
if self.capabilities: d["capabilities"] = self.capabilities
if self.provider: d["provider"] = self.provider
return d
@classmethod
def from_dict(cls, d):
return cls(name=d["name"], description=d.get("description",""), version=d.get("version","1.0.0"), url=d.get("url",""), skills=[AgentSkill.from_dict(s) for s in d.get("skills",[])], capabilities=d.get("capabilities",{}), provider=d.get("provider"))
@dataclass
class JSONRPCError:
code: int
message: str
data: Optional[Any] = None
def to_dict(self):
d = {"code": self.code, "message": self.message}
if self.data is not None: d["data"] = self.data
return d
@dataclass
class JSONRPCRequest:
method: str
params: Optional[Dict[str, Any]] = None
id: str = field(default_factory=lambda: str(uuid.uuid4()))
jsonrpc: str = "2.0"
def to_dict(self):
d = {"jsonrpc": self.jsonrpc, "method": self.method, "id": self.id}
if self.params is not None: d["params"] = self.params
return d
@dataclass
class JSONRPCResponse:
id: str
result: Optional[Any] = None
error: Optional[JSONRPCError] = None
jsonrpc: str = "2.0"
def to_dict(self):
d = {"jsonrpc": self.jsonrpc, "id": self.id}
if self.error: d["error"] = self.error.to_dict()
else: d["result"] = self.result
return d
@classmethod
def from_dict(cls, d):
err = d.get("error")
return cls(id=d["id"], result=d.get("result"), error=JSONRPCError(err["code"],err["message"],err.get("data")) if err else None)
class A2AError:
@staticmethod
def parse_error(data=None): return JSONRPCError(-32700, "Parse error", data)
@staticmethod
def invalid_request(data=None): return JSONRPCError(-32600, "Invalid Request", data)
@staticmethod
def method_not_found(data=None): return JSONRPCError(-32601, "Method not found", data)
@staticmethod
def invalid_params(data=None): return JSONRPCError(-32602, "Invalid params", data)
@staticmethod
def internal_error(data=None): return JSONRPCError(-32603, "Internal error", data)
@staticmethod
def task_not_found(task_id): return JSONRPCError(-32001, f"Task not found: {task_id}")
@staticmethod
def task_not_cancelable(task_id): return JSONRPCError(-32002, f"Task not cancelable: {task_id}")
@staticmethod
def agent_not_found(name): return JSONRPCError(-32009, f"Agent not found: {name}")

353
agent/privacy_filter.py Normal file
View File

@@ -0,0 +1,353 @@
"""Privacy Filter — strip PII from context before remote API calls.
Implements Vitalik's Pattern 2: "A local model can strip out private data
before passing the query along to a remote LLM."
When Hermes routes a request to a cloud provider (Anthropic, OpenRouter, etc.),
this module sanitizes the message context to remove personally identifiable
information before it leaves the user's machine.
Threat model (from Vitalik's secure LLM architecture):
- Privacy (other): Non-LLM data leakage via search queries, API calls
- LLM accidents: LLM accidentally leaking private data in prompts
- LLM jailbreaks: Remote content extracting private context
Usage:
from agent.privacy_filter import PrivacyFilter, sanitize_messages
pf = PrivacyFilter()
safe_messages = pf.sanitize_messages(messages)
# safe_messages has PII replaced with [REDACTED] tokens
"""
from __future__ import annotations
import logging
import re
from dataclasses import dataclass, field
from enum import Enum, auto
from typing import Any, Dict, List, Optional, Tuple
logger = logging.getLogger(__name__)
class Sensitivity(Enum):
"""Classification of content sensitivity."""
PUBLIC = auto() # No PII detected
LOW = auto() # Generic references (e.g., city names)
MEDIUM = auto() # Personal identifiers (name, email, phone)
HIGH = auto() # Secrets, keys, financial data, medical info
CRITICAL = auto() # Crypto keys, passwords, SSN patterns
@dataclass
class RedactionReport:
"""Summary of what was redacted from a message batch."""
total_messages: int = 0
redacted_messages: int = 0
redactions: List[Dict[str, Any]] = field(default_factory=list)
max_sensitivity: Sensitivity = Sensitivity.PUBLIC
@property
def had_redactions(self) -> bool:
return self.redacted_messages > 0
def summary(self) -> str:
if not self.had_redactions:
return "No PII detected — context is clean for remote query."
parts = [f"Redacted {self.redacted_messages}/{self.total_messages} messages:"]
for r in self.redactions[:10]:
parts.append(f" - {r['type']}: {r['count']} occurrence(s)")
if len(self.redactions) > 10:
parts.append(f" ... and {len(self.redactions) - 10} more types")
return "\n".join(parts)
# =========================================================================
# PII pattern definitions
# =========================================================================
# Each pattern is (compiled_regex, redaction_type, sensitivity_level, replacement)
_PII_PATTERNS: List[Tuple[re.Pattern, str, Sensitivity, str]] = []
def _compile_patterns() -> None:
"""Compile PII detection patterns. Called once at module init."""
global _PII_PATTERNS
if _PII_PATTERNS:
return
raw_patterns = [
# --- CRITICAL: secrets and credentials ---
(
r'(?:api[_-]?key|apikey|secret[_-]?key|access[_-]?token)\s*[:=]\s*["\']?([A-Za-z0-9_\-\.]{20,})["\']?',
"api_key_or_token",
Sensitivity.CRITICAL,
"[REDACTED-API-KEY]",
),
(
r'\b(?:sk-|sk_|pk_|rk_|ak_)[A-Za-z0-9]{20,}\b',
"prefixed_secret",
Sensitivity.CRITICAL,
"[REDACTED-SECRET]",
),
(
r'\b(?:ghp_|gho_|ghu_|ghs_|ghr_)[A-Za-z0-9]{36,}\b',
"github_token",
Sensitivity.CRITICAL,
"[REDACTED-GITHUB-TOKEN]",
),
(
r'\b(?:xox[bposa]-[A-Za-z0-9\-]+)\b',
"slack_token",
Sensitivity.CRITICAL,
"[REDACTED-SLACK-TOKEN]",
),
(
r'(?:password|passwd|pwd)\s*[:=]\s*["\']?([^\s"\']{4,})["\']?',
"password",
Sensitivity.CRITICAL,
"[REDACTED-PASSWORD]",
),
(
r'(?:-----BEGIN (?:RSA |EC |OPENSSH )?PRIVATE KEY-----)',
"private_key_block",
Sensitivity.CRITICAL,
"[REDACTED-PRIVATE-KEY]",
),
# Ethereum / crypto addresses (42-char hex starting with 0x)
(
r'\b0x[a-fA-F0-9]{40}\b',
"ethereum_address",
Sensitivity.HIGH,
"[REDACTED-ETH-ADDR]",
),
# Bitcoin addresses (base58, 25-34 chars starting with 1/3/bc1)
(
r'\b[13][a-km-zA-HJ-NP-Z1-9]{25,34}\b',
"bitcoin_address",
Sensitivity.HIGH,
"[REDACTED-BTC-ADDR]",
),
(
r'\bbc1[a-zA-HJ-NP-Z0-9]{39,59}\b',
"bech32_address",
Sensitivity.HIGH,
"[REDACTED-BTC-ADDR]",
),
# --- HIGH: financial ---
(
r'\b(?:\d{4}[-\s]?){3}\d{4}\b',
"credit_card_number",
Sensitivity.HIGH,
"[REDACTED-CC]",
),
(
r'\b\d{3}-\d{2}-\d{4}\b',
"us_ssn",
Sensitivity.HIGH,
"[REDACTED-SSN]",
),
# --- MEDIUM: personal identifiers ---
# Email addresses
(
r'\b[A-Za-z0-9._%+\-]+@[A-Za-z0-9.\-]+\.[A-Za-z]{2,}\b',
"email_address",
Sensitivity.MEDIUM,
"[REDACTED-EMAIL]",
),
# Phone numbers (US/international patterns)
(
r'\b(?:\+?1[-.\s]?)?\(?\d{3}\)?[-.\s]?\d{3}[-.\s]?\d{4}\b',
"phone_number_us",
Sensitivity.MEDIUM,
"[REDACTED-PHONE]",
),
(
r'\b\+\d{1,3}[-.\s]?\d{4,14}\b',
"phone_number_intl",
Sensitivity.MEDIUM,
"[REDACTED-PHONE]",
),
# Filesystem paths that reveal user identity
(
r'(?:/Users/|/home/|C:\\Users\\)([A-Za-z0-9_\-]+)',
"user_home_path",
Sensitivity.MEDIUM,
r"/Users/[REDACTED-USER]",
),
# --- LOW: environment / system info ---
# Internal IPs
(
r'\b(?:10\.\d{1,3}\.\d{1,3}\.\d{1,3}|172\.(?:1[6-9]|2\d|3[01])\.\d{1,3}\.\d{1,3}|192\.168\.\d{1,3}\.\d{1,3})\b',
"internal_ip",
Sensitivity.LOW,
"[REDACTED-IP]",
),
]
_PII_PATTERNS = [
(re.compile(pattern, re.IGNORECASE), rtype, sensitivity, replacement)
for pattern, rtype, sensitivity, replacement in raw_patterns
]
_compile_patterns()
# =========================================================================
# Sensitive file path patterns (context-aware)
# =========================================================================
_SENSITIVE_PATH_PATTERNS = [
re.compile(r'\.(?:env|pem|key|p12|pfx|jks|keystore)\b', re.IGNORECASE),
re.compile(r'(?:\.ssh/|\.gnupg/|\.aws/|\.config/gcloud/)', re.IGNORECASE),
re.compile(r'(?:wallet|keystore|seed|mnemonic)', re.IGNORECASE),
re.compile(r'(?:\.hermes/\.env)', re.IGNORECASE),
]
def _classify_path_sensitivity(path: str) -> Sensitivity:
"""Check if a file path references sensitive material."""
for pat in _SENSITIVE_PATH_PATTERNS:
if pat.search(path):
return Sensitivity.HIGH
return Sensitivity.PUBLIC
# =========================================================================
# Core filtering
# =========================================================================
class PrivacyFilter:
"""Strip PII from message context before remote API calls.
Integrates with the agent's message pipeline. Call sanitize_messages()
before sending context to any cloud LLM provider.
"""
def __init__(
self,
min_sensitivity: Sensitivity = Sensitivity.MEDIUM,
aggressive_mode: bool = False,
):
"""
Args:
min_sensitivity: Only redact PII at or above this level.
Default MEDIUM — redacts emails, phones, paths but not IPs.
aggressive_mode: If True, also redact file paths and internal IPs.
"""
self.min_sensitivity = (
Sensitivity.LOW if aggressive_mode else min_sensitivity
)
self.aggressive_mode = aggressive_mode
def sanitize_text(self, text: str) -> Tuple[str, List[Dict[str, Any]]]:
"""Sanitize a single text string. Returns (cleaned_text, redaction_list)."""
redactions = []
cleaned = text
for pattern, rtype, sensitivity, replacement in _PII_PATTERNS:
if sensitivity.value < self.min_sensitivity.value:
continue
matches = pattern.findall(cleaned)
if matches:
count = len(matches) if isinstance(matches[0], str) else sum(
1 for m in matches if m
)
if count > 0:
cleaned = pattern.sub(replacement, cleaned)
redactions.append({
"type": rtype,
"sensitivity": sensitivity.name,
"count": count,
})
return cleaned, redactions
def sanitize_messages(
self, messages: List[Dict[str, Any]]
) -> Tuple[List[Dict[str, Any]], RedactionReport]:
"""Sanitize a list of OpenAI-format messages.
Returns (safe_messages, report). System messages are NOT sanitized
(they're typically static prompts). Only user and assistant messages
with string content are processed.
Args:
messages: List of {"role": ..., "content": ...} dicts.
Returns:
Tuple of (sanitized_messages, redaction_report).
"""
report = RedactionReport(total_messages=len(messages))
safe_messages = []
for msg in messages:
role = msg.get("role", "")
content = msg.get("content", "")
# Only sanitize user/assistant string content
if role in ("user", "assistant") and isinstance(content, str) and content:
cleaned, redactions = self.sanitize_text(content)
if redactions:
report.redacted_messages += 1
report.redactions.extend(redactions)
# Track max sensitivity
for r in redactions:
s = Sensitivity[r["sensitivity"]]
if s.value > report.max_sensitivity.value:
report.max_sensitivity = s
safe_msg = {**msg, "content": cleaned}
safe_messages.append(safe_msg)
logger.info(
"Privacy filter: redacted %d PII type(s) from %s message",
len(redactions), role,
)
else:
safe_messages.append(msg)
else:
safe_messages.append(msg)
return safe_messages, report
def should_use_local_only(self, text: str) -> Tuple[bool, str]:
"""Determine if content is too sensitive for any remote call.
Returns (should_block, reason). If True, the content should only
be processed by a local model.
"""
_, redactions = self.sanitize_text(text)
critical_count = sum(
1 for r in redactions
if Sensitivity[r["sensitivity"]] == Sensitivity.CRITICAL
)
high_count = sum(
1 for r in redactions
if Sensitivity[r["sensitivity"]] == Sensitivity.HIGH
)
if critical_count > 0:
return True, f"Contains {critical_count} critical-secret pattern(s) — local-only"
if high_count >= 3:
return True, f"Contains {high_count} high-sensitivity pattern(s) — local-only"
return False, ""
def sanitize_messages(
messages: List[Dict[str, Any]],
min_sensitivity: Sensitivity = Sensitivity.MEDIUM,
aggressive: bool = False,
) -> Tuple[List[Dict[str, Any]], RedactionReport]:
"""Convenience function: sanitize messages with default settings."""
pf = PrivacyFilter(min_sensitivity=min_sensitivity, aggressive_mode=aggressive)
return pf.sanitize_messages(messages)
def quick_sanitize(text: str) -> str:
"""Quick sanitize a single string — returns cleaned text only."""
pf = PrivacyFilter()
cleaned, _ = pf.sanitize_text(text)
return cleaned

461
benchmarks/tool_call_benchmark.py Executable file
View File

@@ -0,0 +1,461 @@
#!/usr/bin/env python3
"""
tool_call_benchmark.py — Benchmark Gemma 4 tool calling vs mimo-v2-pro.
Runs 100 diverse tool calling prompts through each model and compares:
- Schema parse success rate
- Tool execution success rate
- Parallel tool call success rate
- Average latency
- Token cost per call
Usage:
python3 benchmarks/tool_call_benchmark.py --model1 gemma3:27b --model2 xiaomi/mimo-v2-pro
python3 benchmarks/tool_call_benchmark.py --model1 gemma3:27b --limit 10 # quick test
python3 benchmarks/tool_call_benchmark.py --output benchmarks/results.json
Requires:
- Ollama running locally (or --endpoint for remote)
- Models pulled: ollama pull gemma3:27b, etc.
"""
import json
import os
import sys
import time
import urllib.request
import urllib.error
from datetime import datetime, timezone
from pathlib import Path
from dataclasses import dataclass, field, asdict
from typing import List, Dict, Optional
ENDPOINT = os.environ.get("OPENAI_BASE_URL", "http://localhost:11434/v1")
API_KEY = os.environ.get("OPENAI_API_KEY", "ollama")
# ── Tool schemas (subset for benchmarking) ──────────────────────────────
TOOL_SCHEMAS = [
{
"type": "function",
"function": {
"name": "read_file",
"description": "Read a text file",
"parameters": {
"type": "object",
"properties": {
"path": {"type": "string", "description": "File path"},
"offset": {"type": "integer", "description": "Start line"},
"limit": {"type": "integer", "description": "Max lines"}
},
"required": ["path"]
}
}
},
{
"type": "function",
"function": {
"name": "terminal",
"description": "Execute a shell command",
"parameters": {
"type": "object",
"properties": {
"command": {"type": "string", "description": "Shell command"}
},
"required": ["command"]
}
}
},
{
"type": "function",
"function": {
"name": "write_file",
"description": "Write content to a file",
"parameters": {
"type": "object",
"properties": {
"path": {"type": "string"},
"content": {"type": "string"}
},
"required": ["path", "content"]
}
}
},
{
"type": "function",
"function": {
"name": "search_files",
"description": "Search for content in files",
"parameters": {
"type": "object",
"properties": {
"pattern": {"type": "string"},
"path": {"type": "string"}
},
"required": ["pattern"]
}
}
},
{
"type": "function",
"function": {
"name": "web_search",
"description": "Search the web",
"parameters": {
"type": "object",
"properties": {
"query": {"type": "string"}
},
"required": ["query"]
}
}
},
{
"type": "function",
"function": {
"name": "execute_code",
"description": "Execute Python code",
"parameters": {
"type": "object",
"properties": {
"code": {"type": "string"}
},
"required": ["code"]
}
}
},
]
SYSTEM_PROMPT = "You are a helpful assistant with access to tools. Use tools when needed."
# ── Test prompts (100 diverse tool calling scenarios) ────────────────────
TEST_PROMPTS = [
# File operations (20)
("Read the README.md file", "read_file", "file_ops"),
("Show me the contents of config.yaml", "read_file", "file_ops"),
("Read lines 10-20 of main.py", "read_file", "file_ops"),
("Open the package.json", "read_file", "file_ops"),
("Read the .gitignore file", "read_file", "file_ops"),
("Save this to notes.txt: meeting at 3pm", "write_file", "file_ops"),
("Create a new file hello.py with print hello", "write_file", "file_ops"),
("Write the config to settings.json", "write_file", "file_ops"),
("Save the output to results.txt", "write_file", "file_ops"),
("Create TODO.md with my tasks", "write_file", "file_ops"),
("Search for 'import os' in the codebase", "search_files", "file_ops"),
("Find all Python files mentioning 'error'", "search_files", "file_ops"),
("Search for TODO comments", "search_files", "file_ops"),
("Find where 'authenticate' is defined", "search_files", "file_ops"),
("Look for any hardcoded API keys", "search_files", "file_ops"),
("Read the Makefile", "read_file", "file_ops"),
("Show me the Dockerfile", "read_file", "file_ops"),
("Read the docker-compose.yml", "read_file", "file_ops"),
("Save the function to utils.py", "write_file", "file_ops"),
("Create a backup of config.yaml", "write_file", "file_ops"),
# Terminal commands (20)
("List all files in the current directory", "terminal", "terminal"),
("Show disk usage", "terminal", "terminal"),
("Check what processes are running", "terminal", "terminal"),
("Show the git log", "terminal", "terminal"),
("Check the Python version", "terminal", "terminal"),
("Run ls -la in the home directory", "terminal", "terminal"),
("Show the current date and time", "terminal", "terminal"),
("Check network connectivity with ping", "terminal", "terminal"),
("Show environment variables", "terminal", "terminal"),
("List running docker containers", "terminal", "terminal"),
("Check system memory usage", "terminal", "terminal"),
("Show the crontab", "terminal", "terminal"),
("Check the firewall status", "terminal", "terminal"),
("Show recent log entries", "terminal", "terminal"),
("Check disk free space", "terminal", "terminal"),
("Run a system update check", "terminal", "terminal"),
("Show open network connections", "terminal", "terminal"),
("Check the timezone", "terminal", "terminal"),
("List tmux sessions", "terminal", "terminal"),
("Check systemd service status", "terminal", "terminal"),
# Web search (15)
("Search for Python asyncio documentation", "web_search", "web"),
("Look up the latest GPT-4 pricing", "web_search", "web"),
("Find information about Gemma 4 benchmarks", "web_search", "web"),
("Search for Rust vs Go performance comparison", "web_search", "web"),
("Look up Docker best practices", "web_search", "web"),
("Search for Kubernetes deployment tutorials", "web_search", "web"),
("Find the latest AI safety research papers", "web_search", "web"),
("Search for SQLite vs PostgreSQL comparison", "web_search", "web"),
("Look up Linux kernel tuning parameters", "web_search", "web"),
("Search for WebSocket protocol specification", "web_search", "web"),
("Find information about Matrix protocol federation", "web_search", "web"),
("Search for MCP protocol documentation", "web_search", "web"),
("Look up A2A agent protocol spec", "web_search", "web"),
("Search for quantization methods for LLMs", "web_search", "web"),
("Find information about GRPO training", "web_search", "web"),
# Code execution (15)
("Calculate the factorial of 20", "execute_code", "code"),
("Parse this JSON and extract keys", "execute_code", "code"),
("Sort a list of numbers", "execute_code", "code"),
("Calculate the fibonacci sequence", "execute_code", "code"),
("Convert a CSV to JSON", "execute_code", "code"),
("Parse an email address", "execute_code", "code"),
("Calculate elapsed time between dates", "execute_code", "code"),
("Generate a random password", "execute_code", "code"),
("Hash a string with SHA256", "execute_code", "code"),
("Parse a URL into components", "execute_code", "code"),
("Calculate statistics on a dataset", "execute_code", "code"),
("Convert epoch timestamp to human readable", "execute_code", "code"),
("Validate an IPv4 address", "execute_code", "code"),
("Calculate the distance between coordinates", "execute_code", "code"),
("Generate a UUID", "execute_code", "code"),
# Parallel tool calls (10)
("Read config.yaml and show git status at the same time", "read_file|terminal", "parallel"),
("Check disk usage and memory usage simultaneously", "terminal|terminal", "parallel"),
("Read two files at once: README and CHANGELOG", "read_file|read_file", "parallel"),
("Search for imports in both Python and JS files", "search_files|search_files", "parallel"),
("Check git log and disk space in parallel", "terminal|terminal", "parallel"),
("Read the Makefile and Dockerfile together", "read_file|read_file", "parallel"),
("Search for TODO and FIXME at the same time", "search_files|search_files", "parallel"),
("List files and check Python version simultaneously", "terminal|terminal", "parallel"),
("Read package.json and requirements.txt together", "read_file|read_file", "parallel"),
("Check system time and uptime in parallel", "terminal|terminal", "parallel"),
]
@dataclass
class BenchmarkResult:
model: str
prompt: str
expected_tool: str
category: str
success: bool = False
tool_called: str = ""
args_valid: bool = False
latency_ms: float = 0.0
prompt_tokens: int = 0
completion_tokens: int = 0
error: str = ""
def call_model(model: str, prompt: str) -> dict:
"""Call a model with tool schemas and return the response."""
url = f"{ENDPOINT}/chat/completions"
data = {
"model": model,
"messages": [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": prompt},
],
"tools": TOOL_SCHEMAS,
"max_tokens": 512,
"temperature": 0.0,
}
body = json.dumps(data).encode()
req = urllib.request.Request(url, data=body, headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {API_KEY}",
}, method="POST")
start = time.time()
try:
with urllib.request.urlopen(req, timeout=60) as resp:
result = json.loads(resp.read())
elapsed = time.time() - start
return {"response": result, "elapsed": elapsed, "error": None}
except Exception as e:
elapsed = time.time() - start
return {"response": None, "elapsed": elapsed, "error": str(e)}
def evaluate_response(result: dict, expected_tool: str) -> BenchmarkResult:
"""Evaluate a model response against expectations."""
resp = result.get("response")
error = result.get("error", "")
elapsed = result.get("elapsed", 0)
br = BenchmarkResult(
model="",
prompt="",
expected_tool=expected_tool,
category="",
latency_ms=round(elapsed * 1000, 1),
error=error or "",
)
if not resp:
br.success = False
return br
usage = resp.get("usage", {})
br.prompt_tokens = usage.get("prompt_tokens", 0)
br.completion_tokens = usage.get("completion_tokens", 0)
choice = resp.get("choices", [{}])[0]
message = choice.get("message", {})
tool_calls = message.get("tool_calls", [])
if not tool_calls:
br.success = False
br.error = "no_tool_calls"
return br
# Check first tool call
tc = tool_calls[0]
fn = tc.get("function", {})
br.tool_called = fn.get("name", "")
# Parse args
args_str = fn.get("arguments", "{}")
try:
json.loads(args_str)
br.args_valid = True
except json.JSONDecodeError:
# Try normalization
try:
import re
fixed = re.sub(r',\s*([}\]])', r'\1', args_str.strip())
json.loads(fixed)
br.args_valid = True
except:
br.args_valid = False
# Success = tool called matches expected (or contains it for parallel)
expected = expected_tool.split("|")[0] # primary expected tool
br.success = br.tool_called == expected and br.args_valid
return br
def run_benchmark(model: str, prompts: list, limit: int = None) -> List[BenchmarkResult]:
"""Run benchmark against a model."""
if limit:
prompts = prompts[:limit]
results = []
for i, (prompt, expected_tool, category) in enumerate(prompts):
print(f" [{i+1}/{len(prompts)}] {model}: {prompt[:50]}...", end=" ", flush=True)
raw = call_model(model, prompt)
br = evaluate_response(raw, expected_tool)
br.model = model
br.prompt = prompt
br.category = category
status = "OK" if br.success else f"FAIL({br.error or br.tool_called})"
print(f"{status} {br.latency_ms}ms")
results.append(br)
return results
def generate_report(results: List[BenchmarkResult]) -> str:
"""Generate markdown benchmark report."""
by_model = {}
for r in results:
if r.model not in by_model:
by_model[r.model] = []
by_model[r.model].append(r)
lines = [
"# Gemma 4 Tool Calling Benchmark",
f"",
f"**Date:** {datetime.now().strftime('%Y-%m-%d %H:%M')}",
f"**Prompts:** {len(results) // len(by_model)} per model",
f"",
]
# Summary table
lines.append("| Metric | " + " | ".join(by_model.keys()) + " |")
lines.append("|--------|" + "|".join(["--------"] * len(by_model)) + "|")
metrics = ["schema_parse", "tool_execution", "avg_latency_ms", "total_prompt_tokens"]
for metric in ["success_rate", "args_valid_rate", "avg_latency_ms", "total_prompt_tokens"]:
vals = []
for model, rs in by_model.items():
if metric == "success_rate":
v = sum(1 for r in rs if r.success) / len(rs) * 100
vals.append(f"{v:.1f}%")
elif metric == "args_valid_rate":
v = sum(1 for r in rs if r.args_valid) / len(rs) * 100
vals.append(f"{v:.1f}%")
elif metric == "avg_latency_ms":
v = sum(r.latency_ms for r in rs) / len(rs)
vals.append(f"{v:.0f}ms")
elif metric == "total_prompt_tokens":
v = sum(r.prompt_tokens for r in rs)
vals.append(f"{v:,}")
label = metric.replace("_", " ").title()
lines.append(f"| {label} | " + " | ".join(vals) + " |")
lines.append("")
# By category
lines.append("## By Category")
lines.append("")
lines.append("| Category | " + " | ".join(f"{m} success" for m in by_model.keys()) + " |")
lines.append("|----------|" + "|".join(["--------"] * len(by_model)) + "|")
categories = sorted(set(r.category for r in results))
for cat in categories:
vals = []
for model, rs in by_model.items():
cat_results = [r for r in rs if r.category == cat]
if cat_results:
v = sum(1 for r in cat_results if r.success) / len(cat_results) * 100
vals.append(f"{v:.0f}%")
else:
vals.append("N/A")
lines.append(f"| {cat} | " + " | ".join(vals) + " |")
return "\n".join(lines)
def main():
import argparse
parser = argparse.ArgumentParser(description="Tool calling benchmark")
parser.add_argument("--model1", default="gemma3:27b")
parser.add_argument("--model2", default="xiaomi/mimo-v2-pro")
parser.add_argument("--endpoint", default=ENDPOINT)
parser.add_argument("--limit", type=int, default=None)
parser.add_argument("--output", default=None)
parser.add_argument("--markdown", action="store_true")
args = parser.parse_args()
global ENDPOINT
ENDPOINT = args.endpoint
prompts = TEST_PROMPTS
if args.limit:
prompts = prompts[:args.limit]
print(f"Benchmark: {args.model1} vs {args.model2}")
print(f"Prompts: {len(prompts)}")
print()
print(f"--- {args.model1} ---")
results1 = run_benchmark(args.model1, prompts)
print(f"\n--- {args.model2} ---")
results2 = run_benchmark(args.model2, prompts)
all_results = results1 + results2
report = generate_report(all_results)
print(f"\n{report}")
if args.output:
with open(args.output, "w") as f:
json.dump([r.__dict__ for r in all_results], f, indent=2, default=str)
print(f"\nResults saved to {args.output}")
# Save markdown report
report_path = f"benchmarks/gemma4-tool-calling-{datetime.now().strftime('%Y-%m-%d')}.md"
Path("benchmarks").mkdir(exist_ok=True)
with open(report_path, "w") as f:
f.write(report)
print(f"Report saved to {report_path}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,202 @@
"""Tests for agent.privacy_filter — PII stripping before remote API calls."""
import pytest
from agent.privacy_filter import (
PrivacyFilter,
RedactionReport,
Sensitivity,
sanitize_messages,
quick_sanitize,
)
class TestPrivacyFilterSanitizeText:
"""Test single-text sanitization."""
def test_no_pii_returns_clean(self):
pf = PrivacyFilter()
text = "The weather in Paris is nice today."
cleaned, redactions = pf.sanitize_text(text)
assert cleaned == text
assert redactions == []
def test_email_redacted(self):
pf = PrivacyFilter()
text = "Send report to alice@example.com by Friday."
cleaned, redactions = pf.sanitize_text(text)
assert "alice@example.com" not in cleaned
assert "[REDACTED-EMAIL]" in cleaned
assert any(r["type"] == "email_address" for r in redactions)
def test_phone_redacted(self):
pf = PrivacyFilter()
text = "Call me at 555-123-4567 when ready."
cleaned, redactions = pf.sanitize_text(text)
assert "555-123-4567" not in cleaned
assert "[REDACTED-PHONE]" in cleaned
def test_api_key_redacted(self):
pf = PrivacyFilter()
text = 'api_key = "sk-proj-abcdefghij1234567890abcdefghij1234567890"'
cleaned, redactions = pf.sanitize_text(text)
assert "sk-proj-" not in cleaned
assert any(r["sensitivity"] == "CRITICAL" for r in redactions)
def test_github_token_redacted(self):
pf = PrivacyFilter()
text = "Use ghp_1234567890abcdefghijklmnopqrstuvwxyz1234 for auth"
cleaned, redactions = pf.sanitize_text(text)
assert "ghp_" not in cleaned
assert any(r["type"] == "github_token" for r in redactions)
def test_ethereum_address_redacted(self):
pf = PrivacyFilter()
text = "Send to 0x742d35Cc6634C0532925a3b844Bc9e7595f2bD18 please"
cleaned, redactions = pf.sanitize_text(text)
assert "0x742d" not in cleaned
assert any(r["type"] == "ethereum_address" for r in redactions)
def test_user_home_path_redacted(self):
pf = PrivacyFilter()
text = "Read file at /Users/alice/Documents/secret.txt"
cleaned, redactions = pf.sanitize_text(text)
assert "alice" not in cleaned
assert "[REDACTED-USER]" in cleaned
def test_multiple_pii_types(self):
pf = PrivacyFilter()
text = (
"Contact john@test.com or call 555-999-1234. "
"The API key is sk-abcdefghijklmnopqrstuvwxyz1234567890."
)
cleaned, redactions = pf.sanitize_text(text)
assert "john@test.com" not in cleaned
assert "555-999-1234" not in cleaned
assert "sk-abcd" not in cleaned
assert len(redactions) >= 3
class TestPrivacyFilterSanitizeMessages:
"""Test message-list sanitization."""
def test_sanitize_user_message(self):
pf = PrivacyFilter()
messages = [
{"role": "system", "content": "You are helpful."},
{"role": "user", "content": "Email me at bob@test.com with results."},
]
safe, report = pf.sanitize_messages(messages)
assert report.redacted_messages == 1
assert "bob@test.com" not in safe[1]["content"]
assert "[REDACTED-EMAIL]" in safe[1]["content"]
# System message unchanged
assert safe[0]["content"] == "You are helpful."
def test_no_redaction_needed(self):
pf = PrivacyFilter()
messages = [
{"role": "user", "content": "What is 2+2?"},
{"role": "assistant", "content": "4"},
]
safe, report = pf.sanitize_messages(messages)
assert report.redacted_messages == 0
assert not report.had_redactions
def test_assistant_messages_also_sanitized(self):
pf = PrivacyFilter()
messages = [
{"role": "assistant", "content": "Your email admin@corp.com was found."},
]
safe, report = pf.sanitize_messages(messages)
assert report.redacted_messages == 1
assert "admin@corp.com" not in safe[0]["content"]
def test_tool_messages_not_sanitized(self):
pf = PrivacyFilter()
messages = [
{"role": "tool", "content": "Result: user@test.com found"},
]
safe, report = pf.sanitize_messages(messages)
assert report.redacted_messages == 0
assert safe[0]["content"] == "Result: user@test.com found"
class TestShouldUseLocalOnly:
"""Test the local-only routing decision."""
def test_normal_text_allows_remote(self):
pf = PrivacyFilter()
block, reason = pf.should_use_local_only("Summarize this article about Python.")
assert not block
def test_critical_secret_blocks_remote(self):
pf = PrivacyFilter()
text = "Here is the API key: sk-abcdefghijklmnopqrstuvwxyz1234567890"
block, reason = pf.should_use_local_only(text)
assert block
assert "critical" in reason.lower()
def test_multiple_high_sensitivity_blocks(self):
pf = PrivacyFilter()
# 3+ high-sensitivity patterns
text = (
"Card: 4111-1111-1111-1111, "
"SSN: 123-45-6789, "
"BTC: 1A1zP1eP5QGefi2DMPTfTL5SLmv7DivfNa, "
"ETH: 0x742d35Cc6634C0532925a3b844Bc9e7595f2bD18"
)
block, reason = pf.should_use_local_only(text)
assert block
class TestAggressiveMode:
"""Test aggressive filtering mode."""
def test_aggressive_redacts_internal_ips(self):
pf = PrivacyFilter(aggressive_mode=True)
text = "Server at 192.168.1.100 is responding."
cleaned, redactions = pf.sanitize_text(text)
assert "192.168.1.100" not in cleaned
assert any(r["type"] == "internal_ip" for r in redactions)
def test_normal_does_not_redact_ips(self):
pf = PrivacyFilter(aggressive_mode=False)
text = "Server at 192.168.1.100 is responding."
cleaned, redactions = pf.sanitize_text(text)
assert "192.168.1.100" in cleaned # IP preserved in normal mode
class TestConvenienceFunctions:
"""Test module-level convenience functions."""
def test_quick_sanitize(self):
text = "Contact alice@example.com for details"
result = quick_sanitize(text)
assert "alice@example.com" not in result
assert "[REDACTED-EMAIL]" in result
def test_sanitize_messages_convenience(self):
messages = [{"role": "user", "content": "Call 555-000-1234"}]
safe, report = sanitize_messages(messages)
assert report.redacted_messages == 1
class TestRedactionReport:
"""Test the reporting structure."""
def test_summary_no_redactions(self):
report = RedactionReport(total_messages=3, redacted_messages=0)
assert "No PII" in report.summary()
def test_summary_with_redactions(self):
report = RedactionReport(
total_messages=2,
redacted_messages=1,
redactions=[
{"type": "email_address", "sensitivity": "MEDIUM", "count": 2},
{"type": "phone_number_us", "sensitivity": "MEDIUM", "count": 1},
],
)
summary = report.summary()
assert "1/2" in summary
assert "email_address" in summary

View File

@@ -1,116 +0,0 @@
"""Tests for A2A protocol implementation."""
import asyncio,json,pytest
from a2a.types import AgentCard,AgentSkill,Artifact,DataPart,FilePart,JSONRPCError,JSONRPCRequest,JSONRPCResponse,Message,Task,TaskState,TaskStatus,TextPart,A2AError,part_from_dict
from a2a.client import A2AClient,A2AClientConfig
from a2a.server import A2AServer
class TestTextPart:
def test_roundtrip(self):
p=TextPart(text="hello",metadata={"k":"v"});d=p.to_dict();assert d=={"text":"hello","metadata":{"k":"v"}};p2=TextPart.from_dict(d);assert p2.text=="hello"
def test_no_metadata(self):
p=TextPart(text="hi");d=p.to_dict();assert "metadata" not in d
class TestFilePart:
def test_inline(self):
p=FilePart(media_type="text/plain",raw="SGVsbG8=",filename="hello.txt");d=p.to_dict();assert d["raw"]=="SGVsbG8=";p2=FilePart.from_dict(d);assert p2.filename=="hello.txt"
def test_url(self):
p=FilePart(url="https://x.com/f");d=p.to_dict();assert d["url"]=="https://x.com/f"
class TestDataPart:
def test_roundtrip(self):
p=DataPart(data={"key":42});d=p.to_dict();assert d["data"]=={"key":42}
class TestPartDiscrimination:
def test_text(self):assert isinstance(part_from_dict({"text":"hi"}),TextPart)
def test_file_raw(self):assert isinstance(part_from_dict({"raw":"d","mediaType":"t"}),FilePart)
def test_file_url(self):assert isinstance(part_from_dict({"url":"https://x.com"}),FilePart)
def test_data(self):assert isinstance(part_from_dict({"data":{"a":1}}),DataPart)
def test_unknown(self):
with pytest.raises(ValueError):part_from_dict({"unknown":True})
class TestMessage:
def test_roundtrip(self):
m=Message(role="user",parts=[TextPart(text="hi")],context_id="c1");d=m.to_dict();assert d["role"]=="user";m2=Message.from_dict(d);assert m2.parts[0].text=="hi"
class TestArtifact:
def test_roundtrip(self):
a=Artifact(name="r",parts=[TextPart(text="d")]);d=a.to_dict();assert d["name"]=="r"
class TestTaskStatus:
def test_roundtrip(self):
s=TaskStatus(state=TaskState.COMPLETED);d=s.to_dict();assert d["state"]=="TASK_STATE_COMPLETED";s2=TaskStatus.from_dict(d);assert s2.state==TaskState.COMPLETED
def test_terminal(self):
assert TaskState.COMPLETED.terminal;assert TaskState.FAILED.terminal;assert not TaskState.SUBMITTED.terminal
class TestTask:
def test_roundtrip(self):
t=Task(id="t1",status=TaskStatus(state=TaskState.WORKING));d=t.to_dict();assert d["id"]=="t1";t2=Task.from_dict(d);assert t2.status.state==TaskState.WORKING
class TestAgentCard:
def test_roundtrip(self):
c=AgentCard(name="a",skills=[AgentSkill(id="s1",name="S")]);d=c.to_dict();assert d["name"]=="a";c2=AgentCard.from_dict(d);assert c2.skills[0].id=="s1"
class TestJSONRPC:
def test_request(self):
r=JSONRPCRequest(method="GetTask",params={"taskId":"1"});d=r.to_dict();assert d["method"]=="GetTask"
def test_response_success(self):
r=JSONRPCResponse(id="1",result={"ok":True});d=r.to_dict();assert d["result"]["ok"]==True
def test_response_error(self):
r=JSONRPCResponse(id="1",error=A2AError.parse_error());d=r.to_dict();assert d["error"]["code"]==-32700
def test_response_from_dict(self):
r=JSONRPCResponse.from_dict({"jsonrpc":"2.0","id":"1","result":{"ok":True}});assert r.result["ok"]==True
class TestA2AError:
def test_codes(self):
assert A2AError.parse_error().code==-32700;assert A2AError.invalid_request().code==-32600
assert A2AError.method_not_found().code==-32601;assert A2AError.task_not_found("x").code==-32001
def _make_card():
return AgentCard(name="test",description="Test",url="http://localhost:9999/a2a/v1",skills=[AgentSkill(id="echo",name="Echo",tags=["test"])])
def _run(coro):return asyncio.get_event_loop().run_until_complete(coro)
class TestServer:
def test_echo(self):
s=A2AServer(_make_card());msg=Message(role="user",parts=[TextPart(text="hello")])
raw=json.dumps(JSONRPCRequest(method="SendMessage",params={"message":msg.to_dict()}).to_dict())
resp=json.loads(_run(s.handle_rpc(raw)));assert "error" not in resp;assert resp["result"]["status"]["state"]=="TASK_STATE_COMPLETED"
def test_get_task(self):
s=A2AServer(_make_card());msg=Message(role="user",parts=[TextPart(text="hi")])
raw=json.dumps(JSONRPCRequest(method="SendMessage",params={"message":msg.to_dict()}).to_dict())
task_id=json.loads(_run(s.handle_rpc(raw)))["result"]["id"]
raw2=json.dumps(JSONRPCRequest(method="GetTask",params={"taskId":task_id}).to_dict())
resp=json.loads(_run(s.handle_rpc(raw2)));assert resp["result"]["id"]==task_id
def test_cancel(self):
s=A2AServer(_make_card());task=Task(id="c1",status=TaskStatus(state=TaskState.WORKING));s.add_task(task)
raw=json.dumps(JSONRPCRequest(method="CancelTask",params={"taskId":"c1"}).to_dict())
resp=json.loads(_run(s.handle_rpc(raw)));assert resp["result"]["status"]["state"]=="TASK_STATE_CANCELED"
def test_cancel_terminal(self):
s=A2AServer(_make_card());task=Task(id="d1",status=TaskStatus(state=TaskState.COMPLETED));s.add_task(task)
raw=json.dumps(JSONRPCRequest(method="CancelTask",params={"taskId":"d1"}).to_dict())
resp=json.loads(_run(s.handle_rpc(raw)));assert "error" in resp
def test_card(self):
s=A2AServer(_make_card());raw=json.dumps(JSONRPCRequest(method="GetAgentCard").to_dict())
resp=json.loads(_run(s.handle_rpc(raw)));assert resp["result"]["name"]=="test"
def test_list(self):
s=A2AServer(_make_card());msg=Message(role="user",parts=[TextPart(text="hi")])
raw=json.dumps(JSONRPCRequest(method="SendMessage",params={"message":msg.to_dict()}).to_dict())
_run(s.handle_rpc(raw));raw2=json.dumps(JSONRPCRequest(method="ListTasks").to_dict())
resp=json.loads(_run(s.handle_rpc(raw2)));assert len(resp["result"]["tasks"])>=1
def test_unknown(self):
s=A2AServer(_make_card());raw=json.dumps(JSONRPCRequest(method="Nope").to_dict())
resp=json.loads(_run(s.handle_rpc(raw)));assert resp["error"]["code"]==-32601
def test_invalid_json(self):
s=A2AServer(_make_card());resp=json.loads(_run(s.handle_rpc("bad")));assert resp["error"]["code"]==-32700
def test_custom_handler(self):
s=A2AServer(_make_card())
async def h(task,card):task.status=TaskStatus(state=TaskState.COMPLETED);task.artifacts=[Artifact(parts=[TextPart(text="custom")])];return task
s.register_handler("echo",h)
msg=Message(role="user",parts=[TextPart(text="t")])
raw=json.dumps(JSONRPCRequest(method="SendMessage",params={"message":msg.to_dict(),"skillId":"echo"}).to_dict())
resp=json.loads(_run(s.handle_rpc(raw)));assert resp["result"]["artifacts"][0]["parts"][0]["text"]=="custom"
def test_audit(self):
s=A2AServer(_make_card());raw=json.dumps(JSONRPCRequest(method="GetAgentCard").to_dict())
_run(s.handle_rpc(raw));assert len(s.audit_log)==1;assert s.audit_log[0]["method"]=="GetAgentCard"
if __name__=="__main__":pytest.main([__file__,"-v"])

View File

@@ -0,0 +1,190 @@
"""Tests for tools.confirmation_daemon — Human Confirmation Firewall."""
import pytest
import time
from tools.confirmation_daemon import (
ConfirmationDaemon,
ConfirmationRequest,
ConfirmationStatus,
RiskLevel,
classify_action,
_is_whitelisted,
_DEFAULT_WHITELIST,
)
class TestClassifyAction:
"""Test action risk classification."""
def test_crypto_tx_is_critical(self):
assert classify_action("crypto_tx") == RiskLevel.CRITICAL
def test_sign_transaction_is_critical(self):
assert classify_action("sign_transaction") == RiskLevel.CRITICAL
def test_send_email_is_high(self):
assert classify_action("send_email") == RiskLevel.HIGH
def test_send_message_is_medium(self):
assert classify_action("send_message") == RiskLevel.MEDIUM
def test_access_calendar_is_low(self):
assert classify_action("access_calendar") == RiskLevel.LOW
def test_unknown_action_is_medium(self):
assert classify_action("unknown_action_xyz") == RiskLevel.MEDIUM
class TestWhitelist:
"""Test whitelist auto-approval."""
def test_self_email_is_whitelisted(self):
whitelist = dict(_DEFAULT_WHITELIST)
payload = {"from": "me@test.com", "to": "me@test.com"}
assert _is_whitelisted("send_email", payload, whitelist) is True
def test_non_whitelisted_recipient_not_approved(self):
whitelist = dict(_DEFAULT_WHITELIST)
payload = {"to": "random@stranger.com"}
assert _is_whitelisted("send_email", payload, whitelist) is False
def test_whitelisted_contact_approved(self):
whitelist = {
"send_message": {"targets": ["alice", "bob"]},
}
assert _is_whitelisted("send_message", {"to": "alice"}, whitelist) is True
assert _is_whitelisted("send_message", {"to": "charlie"}, whitelist) is False
def test_no_whitelist_entry_means_not_whitelisted(self):
whitelist = {}
assert _is_whitelisted("crypto_tx", {"amount": 1.0}, whitelist) is False
class TestConfirmationRequest:
"""Test the request data model."""
def test_defaults(self):
req = ConfirmationRequest(
request_id="test-1",
action="send_email",
description="Test email",
risk_level="high",
payload={},
)
assert req.status == ConfirmationStatus.PENDING.value
assert req.created_at > 0
assert req.expires_at > req.created_at
def test_is_pending(self):
req = ConfirmationRequest(
request_id="test-2",
action="send_email",
description="Test",
risk_level="high",
payload={},
expires_at=time.time() + 300,
)
assert req.is_pending is True
def test_is_expired(self):
req = ConfirmationRequest(
request_id="test-3",
action="send_email",
description="Test",
risk_level="high",
payload={},
expires_at=time.time() - 10,
)
assert req.is_expired is True
assert req.is_pending is False
def test_to_dict(self):
req = ConfirmationRequest(
request_id="test-4",
action="send_email",
description="Test",
risk_level="medium",
payload={"to": "a@b.com"},
)
d = req.to_dict()
assert d["request_id"] == "test-4"
assert d["action"] == "send_email"
assert "is_pending" in d
class TestConfirmationDaemon:
"""Test the daemon logic (without HTTP layer)."""
def test_auto_approve_low_risk(self):
daemon = ConfirmationDaemon()
req = daemon.request(
action="access_calendar",
description="Read today's events",
risk_level="low",
)
assert req.status == ConfirmationStatus.AUTO_APPROVED.value
def test_whitelisted_auto_approves(self):
daemon = ConfirmationDaemon()
daemon._whitelist = {"send_message": {"targets": ["alice"]}}
req = daemon.request(
action="send_message",
description="Message alice",
payload={"to": "alice"},
)
assert req.status == ConfirmationStatus.AUTO_APPROVED.value
def test_non_whitelisted_goes_pending(self):
daemon = ConfirmationDaemon()
daemon._whitelist = {}
req = daemon.request(
action="send_email",
description="Email to stranger",
payload={"to": "stranger@test.com"},
risk_level="high",
)
assert req.status == ConfirmationStatus.PENDING.value
assert req.is_pending is True
def test_approve_response(self):
daemon = ConfirmationDaemon()
daemon._whitelist = {}
req = daemon.request(
action="send_email",
description="Email test",
risk_level="high",
)
result = daemon.respond(req.request_id, approved=True, decided_by="human")
assert result.status == ConfirmationStatus.APPROVED.value
assert result.decided_by == "human"
def test_deny_response(self):
daemon = ConfirmationDaemon()
daemon._whitelist = {}
req = daemon.request(
action="crypto_tx",
description="Send 1 ETH",
risk_level="critical",
)
result = daemon.respond(
req.request_id, approved=False, decided_by="human", reason="Too risky"
)
assert result.status == ConfirmationStatus.DENIED.value
assert result.reason == "Too risky"
def test_get_pending(self):
daemon = ConfirmationDaemon()
daemon._whitelist = {}
daemon.request(action="send_email", description="Test 1", risk_level="high")
daemon.request(action="send_email", description="Test 2", risk_level="high")
pending = daemon.get_pending()
assert len(pending) >= 2
def test_get_history(self):
daemon = ConfirmationDaemon()
req = daemon.request(
action="access_calendar", description="Test", risk_level="low"
)
history = daemon.get_history()
assert len(history) >= 1
assert history[0]["action"] == "access_calendar"

View File

@@ -121,6 +121,19 @@ DANGEROUS_PATTERNS = [
(r'\b(cp|mv|install)\b.*\s/etc/', "copy/move file into /etc/"),
(r'\bsed\s+-[^\s]*i.*\s/etc/', "in-place edit of system config"),
(r'\bsed\s+--in-place\b.*\s/etc/', "in-place edit of system config (long flag)"),
# --- Vitalik's threat model: crypto / financial ---
(r'\b(?:bitcoin-cli|ethers\.js|web3|ether\.sendTransaction)\b', "direct crypto transaction tool usage"),
(r'\bwget\b.*\b(?:mnemonic|seed\s*phrase|private[_-]?key)\b', "attempting to download crypto credentials"),
(r'\bcurl\b.*\b(?:mnemonic|seed\s*phrase|private[_-]?key)\b', "attempting to exfiltrate crypto credentials"),
# --- Vitalik's threat model: credential exfiltration ---
(r'\b(?:curl|wget|http|nc|ncat|socat)\b.*\b(?:\.env|\.ssh|credentials|secrets|token|api[_-]?key)\b',
"attempting to exfiltrate credentials via network"),
(r'\bbase64\b.*\|(?:\s*curl|\s*wget)', "base64-encode then network exfiltration"),
(r'\bcat\b.*\b(?:\.env|\.ssh/id_rsa|credentials)\b.*\|(?:\s*curl|\s*wget)',
"reading secrets and piping to network tool"),
# --- Vitalik's threat model: data exfiltration ---
(r'\bcurl\b.*-d\s.*\$(?:HOME|USER)', "sending user home directory data to remote"),
(r'\bwget\b.*--post-data\s.*\$(?:HOME|USER)', "posting user data to remote"),
# Script execution via heredoc — bypasses the -e/-c flag patterns above.
# `python3 << 'EOF'` feeds arbitrary code via stdin without -c/-e flags.
(r'\b(python[23]?|perl|ruby|node)\s+<<', "script execution via heredoc"),

View File

@@ -0,0 +1,615 @@
"""Human Confirmation Daemon — HTTP server for two-factor action approval.
Implements Vitalik's Pattern 1: "The new 'two-factor confirmation' is that
the two factors are the human and the LLM."
This daemon runs on localhost:6000 and provides a simple HTTP API for the
agent to request human approval before executing high-risk actions.
Threat model:
- LLM jailbreaks: Remote content "hacking" the LLM to perform malicious actions
- LLM accidents: LLM accidentally performing dangerous operations
- The human acts as the second factor — the agent proposes, the human disposes
Architecture:
- Agent detects high-risk action → POST /confirm with action details
- Daemon stores pending request, sends notification to user
- User approves/denies via POST /respond (Telegram, CLI, or direct HTTP)
- Agent receives decision and proceeds or aborts
Usage:
# Start daemon (usually managed by gateway)
from tools.confirmation_daemon import ConfirmationDaemon
daemon = ConfirmationDaemon(port=6000)
daemon.start()
# Request approval (from agent code)
from tools.confirmation_daemon import request_confirmation
approved = request_confirmation(
action="send_email",
description="Send email to alice@example.com",
risk_level="high",
payload={"to": "alice@example.com", "subject": "Meeting notes"},
timeout=300,
)
"""
from __future__ import annotations
import asyncio
import json
import logging
import os
import threading
import time
import uuid
from dataclasses import dataclass, field, asdict
from enum import Enum, auto
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple
logger = logging.getLogger(__name__)
class RiskLevel(Enum):
"""Risk classification for actions requiring confirmation."""
LOW = "low" # Log only, no confirmation needed
MEDIUM = "medium" # Confirm for non-whitelisted targets
HIGH = "high" # Always confirm
CRITICAL = "critical" # Always confirm + require explicit reason
class ConfirmationStatus(Enum):
"""Status of a pending confirmation request."""
PENDING = "pending"
APPROVED = "approved"
DENIED = "denied"
EXPIRED = "expired"
AUTO_APPROVED = "auto_approved"
@dataclass
class ConfirmationRequest:
"""A request for human confirmation of a high-risk action."""
request_id: str
action: str # Action type: send_email, send_message, crypto_tx, etc.
description: str # Human-readable description of what will happen
risk_level: str # low, medium, high, critical
payload: Dict[str, Any] # Action-specific data (sanitized)
session_key: str = "" # Session that initiated the request
created_at: float = 0.0
expires_at: float = 0.0
status: str = ConfirmationStatus.PENDING.value
decided_at: float = 0.0
decided_by: str = "" # "human", "auto", "whitelist"
reason: str = "" # Optional reason for denial
def __post_init__(self):
if not self.created_at:
self.created_at = time.time()
if not self.expires_at:
self.expires_at = self.created_at + 300 # 5 min default
if not self.request_id:
self.request_id = str(uuid.uuid4())[:12]
@property
def is_expired(self) -> bool:
return time.time() > self.expires_at
@property
def is_pending(self) -> bool:
return self.status == ConfirmationStatus.PENDING.value and not self.is_expired
def to_dict(self) -> Dict[str, Any]:
d = asdict(self)
d["is_expired"] = self.is_expired
d["is_pending"] = self.is_pending
return d
# =========================================================================
# Action categories (Vitalik's threat model)
# =========================================================================
ACTION_CATEGORIES = {
# Messaging — outbound communication to external parties
"send_email": RiskLevel.HIGH,
"send_message": RiskLevel.MEDIUM, # Depends on recipient
"send_signal": RiskLevel.HIGH,
"send_telegram": RiskLevel.MEDIUM,
"send_discord": RiskLevel.MEDIUM,
"post_social": RiskLevel.HIGH,
# Financial / crypto
"crypto_tx": RiskLevel.CRITICAL,
"sign_transaction": RiskLevel.CRITICAL,
"access_wallet": RiskLevel.CRITICAL,
"modify_balance": RiskLevel.CRITICAL,
# System modification
"install_software": RiskLevel.HIGH,
"modify_system_config": RiskLevel.HIGH,
"modify_firewall": RiskLevel.CRITICAL,
"add_ssh_key": RiskLevel.CRITICAL,
"create_user": RiskLevel.CRITICAL,
# Data access
"access_contacts": RiskLevel.MEDIUM,
"access_calendar": RiskLevel.LOW,
"read_private_files": RiskLevel.MEDIUM,
"upload_data": RiskLevel.HIGH,
"share_credentials": RiskLevel.CRITICAL,
# Network
"open_port": RiskLevel.HIGH,
"modify_dns": RiskLevel.HIGH,
"expose_service": RiskLevel.CRITICAL,
}
# Default: any unrecognized action is MEDIUM risk
DEFAULT_RISK_LEVEL = RiskLevel.MEDIUM
def classify_action(action: str) -> RiskLevel:
"""Classify an action by its risk level."""
return ACTION_CATEGORIES.get(action, DEFAULT_RISK_LEVEL)
# =========================================================================
# Whitelist configuration
# =========================================================================
_DEFAULT_WHITELIST = {
"send_message": {
"targets": [], # Contact names/IDs that don't need confirmation
},
"send_email": {
"targets": [], # Email addresses that don't need confirmation
"self_only": True, # send-to-self always allowed
},
}
def _load_whitelist() -> Dict[str, Any]:
"""Load action whitelist from config."""
config_path = Path.home() / ".hermes" / "approval_whitelist.json"
if config_path.exists():
try:
with open(config_path) as f:
return json.load(f)
except Exception as e:
logger.warning("Failed to load approval whitelist: %s", e)
return dict(_DEFAULT_WHITELIST)
def _is_whitelisted(action: str, payload: Dict[str, Any], whitelist: Dict) -> bool:
"""Check if an action is pre-approved by the whitelist."""
action_config = whitelist.get(action, {})
if not action_config:
return False
# Check target-based whitelist
targets = action_config.get("targets", [])
target = payload.get("to") or payload.get("recipient") or payload.get("target", "")
if target and target in targets:
return True
# Self-only email
if action_config.get("self_only") and action == "send_email":
sender = payload.get("from", "")
recipient = payload.get("to", "")
if sender and recipient and sender.lower() == recipient.lower():
return True
return False
# =========================================================================
# Confirmation daemon
# =========================================================================
class ConfirmationDaemon:
"""HTTP daemon for human confirmation of high-risk actions.
Runs on localhost:PORT (default 6000). Provides:
- POST /confirm — agent requests human approval
- POST /respond — human approves/denies
- GET /pending — list pending requests
- GET /health — health check
"""
def __init__(
self,
host: str = "127.0.0.1",
port: int = 6000,
default_timeout: int = 300,
notify_callback: Optional[Callable] = None,
):
self.host = host
self.port = port
self.default_timeout = default_timeout
self.notify_callback = notify_callback
self._pending: Dict[str, ConfirmationRequest] = {}
self._history: List[ConfirmationRequest] = []
self._lock = threading.Lock()
self._whitelist = _load_whitelist()
self._app = None
self._runner = None
def request(
self,
action: str,
description: str,
payload: Optional[Dict[str, Any]] = None,
risk_level: Optional[str] = None,
session_key: str = "",
timeout: Optional[int] = None,
) -> ConfirmationRequest:
"""Create a confirmation request.
Returns the request. Check .status to see if it was immediately
auto-approved (whitelisted) or is pending human review.
"""
payload = payload or {}
# Classify risk if not specified
if risk_level is None:
risk_level = classify_action(action).value
# Check whitelist
if risk_level in ("low",) or _is_whitelisted(action, payload, self._whitelist):
req = ConfirmationRequest(
request_id=str(uuid.uuid4())[:12],
action=action,
description=description,
risk_level=risk_level,
payload=payload,
session_key=session_key,
expires_at=time.time() + (timeout or self.default_timeout),
status=ConfirmationStatus.AUTO_APPROVED.value,
decided_at=time.time(),
decided_by="whitelist",
)
with self._lock:
self._history.append(req)
logger.info("Auto-approved whitelisted action: %s", action)
return req
# Create pending request
req = ConfirmationRequest(
request_id=str(uuid.uuid4())[:12],
action=action,
description=description,
risk_level=risk_level,
payload=payload,
session_key=session_key,
expires_at=time.time() + (timeout or self.default_timeout),
)
with self._lock:
self._pending[req.request_id] = req
# Notify human
if self.notify_callback:
try:
self.notify_callback(req.to_dict())
except Exception as e:
logger.warning("Confirmation notify callback failed: %s", e)
logger.info(
"Confirmation request %s: %s (%s risk) — waiting for human",
req.request_id, action, risk_level,
)
return req
def respond(
self,
request_id: str,
approved: bool,
decided_by: str = "human",
reason: str = "",
) -> Optional[ConfirmationRequest]:
"""Record a human decision on a pending request."""
with self._lock:
req = self._pending.get(request_id)
if not req:
logger.warning("Confirmation respond: unknown request %s", request_id)
return None
if not req.is_pending:
logger.warning("Confirmation respond: request %s already decided", request_id)
return req
req.status = (
ConfirmationStatus.APPROVED.value if approved
else ConfirmationStatus.DENIED.value
)
req.decided_at = time.time()
req.decided_by = decided_by
req.reason = reason
# Move to history
del self._pending[request_id]
self._history.append(req)
logger.info(
"Confirmation %s: %s by %s",
request_id, "APPROVED" if approved else "DENIED", decided_by,
)
return req
def wait_for_decision(
self, request_id: str, timeout: Optional[float] = None
) -> ConfirmationRequest:
"""Block until a decision is made or timeout expires."""
deadline = time.time() + (timeout or self.default_timeout)
while time.time() < deadline:
with self._lock:
req = self._pending.get(request_id)
if req and not req.is_pending:
return req
if req and req.is_expired:
req.status = ConfirmationStatus.EXPIRED.value
del self._pending[request_id]
self._history.append(req)
return req
time.sleep(0.5)
# Timeout
with self._lock:
req = self._pending.pop(request_id, None)
if req:
req.status = ConfirmationStatus.EXPIRED.value
self._history.append(req)
return req
# Shouldn't reach here
return ConfirmationRequest(
request_id=request_id,
action="unknown",
description="Request not found",
risk_level="high",
payload={},
status=ConfirmationStatus.EXPIRED.value,
)
def get_pending(self) -> List[Dict[str, Any]]:
"""Return list of pending confirmation requests."""
self._expire_old()
with self._lock:
return [r.to_dict() for r in self._pending.values() if r.is_pending]
def get_history(self, limit: int = 50) -> List[Dict[str, Any]]:
"""Return recent confirmation history."""
with self._lock:
return [r.to_dict() for r in self._history[-limit:]]
def _expire_old(self) -> None:
"""Move expired requests to history."""
now = time.time()
with self._lock:
expired = [
rid for rid, req in self._pending.items()
if now > req.expires_at
]
for rid in expired:
req = self._pending.pop(rid)
req.status = ConfirmationStatus.EXPIRED.value
self._history.append(req)
# --- aiohttp HTTP API ---
async def _handle_health(self, request):
from aiohttp import web
return web.json_response({
"status": "ok",
"service": "hermes-confirmation-daemon",
"pending": len(self._pending),
})
async def _handle_confirm(self, request):
from aiohttp import web
try:
body = await request.json()
except Exception:
return web.json_response({"error": "invalid JSON"}, status=400)
action = body.get("action", "")
description = body.get("description", "")
if not action or not description:
return web.json_response(
{"error": "action and description required"}, status=400
)
req = self.request(
action=action,
description=description,
payload=body.get("payload", {}),
risk_level=body.get("risk_level"),
session_key=body.get("session_key", ""),
timeout=body.get("timeout"),
)
# If auto-approved, return immediately
if req.status != ConfirmationStatus.PENDING.value:
return web.json_response({
"request_id": req.request_id,
"status": req.status,
"decided_by": req.decided_by,
})
# Otherwise, wait for human decision (with timeout)
timeout = min(body.get("timeout", self.default_timeout), 600)
result = self.wait_for_decision(req.request_id, timeout=timeout)
return web.json_response({
"request_id": result.request_id,
"status": result.status,
"decided_by": result.decided_by,
"reason": result.reason,
})
async def _handle_respond(self, request):
from aiohttp import web
try:
body = await request.json()
except Exception:
return web.json_response({"error": "invalid JSON"}, status=400)
request_id = body.get("request_id", "")
approved = body.get("approved")
if not request_id or approved is None:
return web.json_response(
{"error": "request_id and approved required"}, status=400
)
result = self.respond(
request_id=request_id,
approved=bool(approved),
decided_by=body.get("decided_by", "human"),
reason=body.get("reason", ""),
)
if not result:
return web.json_response({"error": "unknown request"}, status=404)
return web.json_response({
"request_id": result.request_id,
"status": result.status,
})
async def _handle_pending(self, request):
from aiohttp import web
return web.json_response({"pending": self.get_pending()})
def _build_app(self):
"""Build the aiohttp application."""
from aiohttp import web
app = web.Application()
app.router.add_get("/health", self._handle_health)
app.router.add_post("/confirm", self._handle_confirm)
app.router.add_post("/respond", self._handle_respond)
app.router.add_get("/pending", self._handle_pending)
self._app = app
return app
async def start_async(self) -> None:
"""Start the daemon as an async server."""
from aiohttp import web
app = self._build_app()
self._runner = web.AppRunner(app)
await self._runner.setup()
site = web.TCPSite(self._runner, self.host, self.port)
await site.start()
logger.info("Confirmation daemon listening on %s:%d", self.host, self.port)
async def stop_async(self) -> None:
"""Stop the daemon."""
if self._runner:
await self._runner.cleanup()
self._runner = None
def start(self) -> None:
"""Start daemon in a background thread (blocking caller)."""
def _run():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(self.start_async())
loop.run_forever()
t = threading.Thread(target=_run, daemon=True, name="confirmation-daemon")
t.start()
logger.info("Confirmation daemon started in background thread")
def start_blocking(self) -> None:
"""Start daemon and block (for standalone use)."""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(self.start_async())
try:
loop.run_forever()
except KeyboardInterrupt:
pass
finally:
loop.run_until_complete(self.stop_async())
# =========================================================================
# Convenience API for agent integration
# =========================================================================
# Global singleton — initialized by gateway or CLI at startup
_daemon: Optional[ConfirmationDaemon] = None
def get_daemon() -> Optional[ConfirmationDaemon]:
"""Get the global confirmation daemon instance."""
return _daemon
def init_daemon(
host: str = "127.0.0.1",
port: int = 6000,
notify_callback: Optional[Callable] = None,
) -> ConfirmationDaemon:
"""Initialize the global confirmation daemon."""
global _daemon
_daemon = ConfirmationDaemon(
host=host, port=port, notify_callback=notify_callback
)
return _daemon
def request_confirmation(
action: str,
description: str,
payload: Optional[Dict[str, Any]] = None,
risk_level: Optional[str] = None,
session_key: str = "",
timeout: int = 300,
) -> bool:
"""Request human confirmation for a high-risk action.
This is the primary integration point for agent code. It:
1. Classifies the action risk level
2. Checks the whitelist
3. If confirmation needed, blocks until human responds
4. Returns True if approved, False if denied/expired
Args:
action: Action type (send_email, crypto_tx, etc.)
description: Human-readable description
payload: Action-specific data
risk_level: Override auto-classification
session_key: Session requesting approval
timeout: Seconds to wait for human response
Returns:
True if approved, False if denied or expired.
"""
daemon = get_daemon()
if not daemon:
logger.warning(
"No confirmation daemon running — DENYING action %s by default. "
"Start daemon with init_daemon() or --confirmation-daemon flag.",
action,
)
return False
req = daemon.request(
action=action,
description=description,
payload=payload,
risk_level=risk_level,
session_key=session_key,
timeout=timeout,
)
# Auto-approved (whitelisted)
if req.status == ConfirmationStatus.AUTO_APPROVED.value:
return True
# Wait for human
result = daemon.wait_for_decision(req.request_id, timeout=timeout)
return result.status == ConfirmationStatus.APPROVED.value