Compare commits

..

3 Commits

Author SHA1 Message Date
Alexander Whitestone
9f0c410481 feat: batch tool execution with parallel safety checks (#749)
Some checks failed
Docker Build and Publish / build-and-push (pull_request) Has been skipped
Contributor Attribution Check / check-attribution (pull_request) Successful in 35s
Supply Chain Audit / Scan PR for supply chain risks (pull_request) Successful in 37s
Tests / e2e (pull_request) Successful in 1m48s
Tests / test (pull_request) Failing after 36m13s
Centralized safety classification for tool call batches:

tools/batch_executor.py (new):
- classify_tool_calls() — classifies batch into parallel_safe,
  path_scoped, sequential, never_parallel tiers
- BatchExecutionPlan — structured plan with parallel and sequential batches
- Path conflict detection — write_file + patch on same file go sequential
- Destructive command detection — rm, mv, sed -i, redirects
- execute_parallel_batch() — ThreadPoolExecutor for concurrent execution

tools/registry.py (enhanced):
- ToolEntry.parallel_safe field — tools can declare parallel safety
- registry.register() accepts parallel_safe=True parameter
- registry.get_parallel_safe_tools() — query registry-declared safe tools

Safety tiers:
- parallel_safe: read_file, web_search, search_files, etc.
- path_scoped: write_file, patch (concurrent when paths don't overlap)
- sequential: terminal, delegate_task, unknown tools
- never_parallel: clarify (requires user interaction)

19 tests passing.
2026-04-15 22:17:16 -04:00
Alexander Whitestone
30afd529ac feat: add crisis detection tool — the-door integration (#141)
Some checks failed
Docker Build and Publish / build-and-push (pull_request) Has been skipped
Contributor Attribution Check / check-attribution (pull_request) Successful in 44s
Supply Chain Audit / Scan PR for supply chain risks (pull_request) Successful in 59s
Tests / e2e (pull_request) Successful in 3m49s
Tests / test (pull_request) Failing after 44m1s
New tool: tools/crisis_tool.py
- Wraps the-door's canonical crisis detection (detect.py)
- Scans user messages for despair/suicidal ideation
- Classifies into NONE/LOW/MEDIUM/HIGH/CRITICAL tiers
- Provides recommended actions per tier
- Gateway hook: scan_user_message() for pre-API-call detection
- System prompt injection: compassion_injection based on crisis level
- Optional escalation logging to crisis_escalations.jsonl
- Optional bridge API POST for HIGH+ (configurable via CRISIS_BRIDGE_URL)
- Configurable via crisis_detection: true/false in config.yaml
- Follows the-door design principles: never computes life value,
  never suggests death, errs on side of higher risk

Also: tests/test_crisis_tool.py (9 tests, all passing)
2026-04-15 21:00:06 -04:00
Alexander Whitestone
a244b157be bench: add Gemma 4 vs mimo-v2-pro tool calling benchmark (#796)
Some checks failed
Contributor Attribution Check / check-attribution (pull_request) Successful in 42s
Docker Build and Publish / build-and-push (pull_request) Has been skipped
Supply Chain Audit / Scan PR for supply chain risks (pull_request) Successful in 32s
Tests / e2e (pull_request) Successful in 2m26s
Tests / test (pull_request) Failing after 44m7s
100-call regression test across 7 tool categories:
- File operations (20): read_file, write_file, search_files
- Terminal commands (20): shell execution
- Web search (15): web_search
- Code execution (15): execute_code
- Browser automation (10): browser_navigate
- Delegation (10): delegate_task
- MCP tools (10): mcp_list/read/call

Metrics tracked:
- Schema parse success (valid JSON tool calls)
- Tool name accuracy (correct tool selected)
- Arguments accuracy (required args present)
- Average latency per call

Usage:
  python3 benchmarks/tool_call_benchmark.py --model nous:xiaomi/mimo-v2-pro
  python3 benchmarks/tool_call_benchmark.py --model ollama/gemma4:latest
  python3 benchmarks/tool_call_benchmark.py --compare
2026-04-15 18:56:35 -04:00
16 changed files with 1589 additions and 1911 deletions

View File

@@ -1,224 +0,0 @@
"""A2A Agent Card — publish capabilities for fleet discovery.
Each fleet agent publishes an A2A-compliant agent card describing its capabilities.
Standard discovery endpoint: /.well-known/agent-card.json
Issue #819: feat: A2A agent card — publish capabilities for fleet discovery
"""
import json
import os
import socket
from dataclasses import dataclass, field, asdict
from pathlib import Path
from typing import Any, Dict, List, Optional
@dataclass
class AgentSkill:
"""A single skill the agent can perform."""
id: str
name: str
description: str = ""
tags: List[str] = field(default_factory=list)
examples: List[str] = field(default_factory=list)
input_modes: List[str] = field(default_factory=lambda: ["text/plain"])
output_modes: List[str] = field(default_factory=lambda: ["text/plain"])
@dataclass
class AgentCapabilities:
"""What the agent can do."""
streaming: bool = True
push_notifications: bool = False
state_transition_history: bool = True
@dataclass
class AgentCard:
"""A2A-compliant agent card."""
name: str
description: str
url: str
version: str = "1.0.0"
capabilities: AgentCapabilities = field(default_factory=AgentCapabilities)
skills: List[AgentSkill] = field(default_factory=list)
default_input_modes: List[str] = field(default_factory=lambda: ["text/plain", "application/json"])
default_output_modes: List[str] = field(default_factory=lambda: ["text/plain", "application/json"])
metadata: Dict[str, Any] = field(default_factory=dict)
def to_dict(self) -> Dict[str, Any]:
"""Convert to JSON-serializable dict."""
d = asdict(self)
# Rename for A2A spec compliance
d["defaultInputModes"] = d.pop("default_input_modes")
d["defaultOutputModes"] = d.pop("default_output_modes")
return d
def to_json(self) -> str:
"""Serialize to JSON string."""
return json.dumps(self.to_dict(), indent=2)
def _load_skills_from_directory(skills_dir: Path) -> List[AgentSkill]:
"""Scan ~/.hermes/skills/ for SKILL.md frontmatter."""
skills = []
if not skills_dir.exists():
return skills
for skill_dir in skills_dir.iterdir():
if not skill_dir.is_dir():
continue
skill_md = skill_dir / "SKILL.md"
if not skill_md.exists():
continue
try:
content = skill_md.read_text(encoding="utf-8")
# Parse YAML frontmatter
if content.startswith("---"):
parts = content.split("---", 2)
if len(parts) >= 3:
import yaml
try:
metadata = yaml.safe_load(parts[1]) or {}
except Exception:
metadata = {}
name = metadata.get("name", skill_dir.name)
desc = metadata.get("description", "")
tags = metadata.get("tags", [])
skills.append(AgentSkill(
id=skill_dir.name,
name=name,
description=desc[:200] if desc else "",
tags=tags if isinstance(tags, list) else [],
))
except Exception:
continue
return skills
def validate_agent_card(card: AgentCard) -> List[str]:
"""Validate agent card against A2A schema requirements.
Returns list of validation errors (empty if valid).
"""
errors = []
if not card.name:
errors.append("name is required")
if not card.url:
errors.append("url is required")
# Validate MIME types
valid_modes = {"text/plain", "application/json", "image/png", "audio/wav"}
for mode in card.default_input_modes:
if mode not in valid_modes:
errors.append(f"invalid input mode: {mode}")
for mode in card.default_output_modes:
if mode not in valid_modes:
errors.append(f"invalid output mode: {mode}")
# Validate skills
for skill in card.skills:
if not skill.id:
errors.append(f"skill missing id: {skill.name}")
return errors
def build_agent_card(
name: Optional[str] = None,
description: Optional[str] = None,
url: Optional[str] = None,
version: Optional[str] = None,
skills: Optional[List[AgentSkill]] = None,
extra_skills: Optional[List[AgentSkill]] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> AgentCard:
"""Build an A2A agent card from config and environment.
Priority: explicit params > env vars > config.yaml > defaults
"""
# Load config
config_model = ""
config_provider = ""
try:
from hermes_cli.config import load_config
cfg = load_config()
model_cfg = cfg.get("model", {})
if isinstance(model_cfg, dict):
config_model = model_cfg.get("default", "")
config_provider = model_cfg.get("provider", "")
elif isinstance(model_cfg, str):
config_model = model_cfg
except Exception:
pass
# Resolve values with priority
agent_name = name or os.environ.get("HERMES_AGENT_NAME", "") or "hermes"
agent_desc = description or os.environ.get("HERMES_AGENT_DESCRIPTION", "") or "Sovereign AI agent"
agent_url = url or os.environ.get("HERMES_AGENT_URL", "") or f"http://localhost:{os.environ.get('HERMES_API_PORT', '8642')}"
agent_version = version or os.environ.get("HERMES_AGENT_VERSION", "") or "1.0.0"
# Load skills
if skills is not None:
agent_skills = skills
else:
from hermes_constants import get_hermes_home
skills_dir = get_hermes_home() / "skills"
agent_skills = _load_skills_from_directory(skills_dir)
# Add extra skills
if extra_skills:
existing_ids = {s.id for s in agent_skills}
for skill in extra_skills:
if skill.id not in existing_ids:
agent_skills.append(skill)
# Build metadata
card_metadata = {
"model": config_model or os.environ.get("HERMES_MODEL", ""),
"provider": config_provider or os.environ.get("HERMES_PROVIDER", ""),
"hostname": socket.gethostname(),
}
if metadata:
card_metadata.update(metadata)
# Build capabilities
capabilities = AgentCapabilities(
streaming=True,
push_notifications=False,
state_transition_history=True,
)
return AgentCard(
name=agent_name,
description=agent_desc,
url=agent_url,
version=agent_version,
capabilities=capabilities,
skills=agent_skills,
metadata=card_metadata,
)
def get_agent_card_json() -> str:
"""Get agent card as JSON string (for HTTP endpoint)."""
try:
card = build_agent_card()
return card.to_json()
except Exception as e:
# Graceful fallback — return minimal card so discovery doesn't break
fallback = AgentCard(
name="hermes",
description="Sovereign AI agent",
url=f"http://localhost:{os.environ.get('HERMES_API_PORT', '8642')}",
)
return fallback.to_json()

View File

@@ -1,353 +0,0 @@
"""Privacy Filter — strip PII from context before remote API calls.
Implements Vitalik's Pattern 2: "A local model can strip out private data
before passing the query along to a remote LLM."
When Hermes routes a request to a cloud provider (Anthropic, OpenRouter, etc.),
this module sanitizes the message context to remove personally identifiable
information before it leaves the user's machine.
Threat model (from Vitalik's secure LLM architecture):
- Privacy (other): Non-LLM data leakage via search queries, API calls
- LLM accidents: LLM accidentally leaking private data in prompts
- LLM jailbreaks: Remote content extracting private context
Usage:
from agent.privacy_filter import PrivacyFilter, sanitize_messages
pf = PrivacyFilter()
safe_messages = pf.sanitize_messages(messages)
# safe_messages has PII replaced with [REDACTED] tokens
"""
from __future__ import annotations
import logging
import re
from dataclasses import dataclass, field
from enum import Enum, auto
from typing import Any, Dict, List, Optional, Tuple
logger = logging.getLogger(__name__)
class Sensitivity(Enum):
"""Classification of content sensitivity."""
PUBLIC = auto() # No PII detected
LOW = auto() # Generic references (e.g., city names)
MEDIUM = auto() # Personal identifiers (name, email, phone)
HIGH = auto() # Secrets, keys, financial data, medical info
CRITICAL = auto() # Crypto keys, passwords, SSN patterns
@dataclass
class RedactionReport:
"""Summary of what was redacted from a message batch."""
total_messages: int = 0
redacted_messages: int = 0
redactions: List[Dict[str, Any]] = field(default_factory=list)
max_sensitivity: Sensitivity = Sensitivity.PUBLIC
@property
def had_redactions(self) -> bool:
return self.redacted_messages > 0
def summary(self) -> str:
if not self.had_redactions:
return "No PII detected — context is clean for remote query."
parts = [f"Redacted {self.redacted_messages}/{self.total_messages} messages:"]
for r in self.redactions[:10]:
parts.append(f" - {r['type']}: {r['count']} occurrence(s)")
if len(self.redactions) > 10:
parts.append(f" ... and {len(self.redactions) - 10} more types")
return "\n".join(parts)
# =========================================================================
# PII pattern definitions
# =========================================================================
# Each pattern is (compiled_regex, redaction_type, sensitivity_level, replacement)
_PII_PATTERNS: List[Tuple[re.Pattern, str, Sensitivity, str]] = []
def _compile_patterns() -> None:
"""Compile PII detection patterns. Called once at module init."""
global _PII_PATTERNS
if _PII_PATTERNS:
return
raw_patterns = [
# --- CRITICAL: secrets and credentials ---
(
r'(?:api[_-]?key|apikey|secret[_-]?key|access[_-]?token)\s*[:=]\s*["\']?([A-Za-z0-9_\-\.]{20,})["\']?',
"api_key_or_token",
Sensitivity.CRITICAL,
"[REDACTED-API-KEY]",
),
(
r'\b(?:sk-|sk_|pk_|rk_|ak_)[A-Za-z0-9]{20,}\b',
"prefixed_secret",
Sensitivity.CRITICAL,
"[REDACTED-SECRET]",
),
(
r'\b(?:ghp_|gho_|ghu_|ghs_|ghr_)[A-Za-z0-9]{36,}\b',
"github_token",
Sensitivity.CRITICAL,
"[REDACTED-GITHUB-TOKEN]",
),
(
r'\b(?:xox[bposa]-[A-Za-z0-9\-]+)\b',
"slack_token",
Sensitivity.CRITICAL,
"[REDACTED-SLACK-TOKEN]",
),
(
r'(?:password|passwd|pwd)\s*[:=]\s*["\']?([^\s"\']{4,})["\']?',
"password",
Sensitivity.CRITICAL,
"[REDACTED-PASSWORD]",
),
(
r'(?:-----BEGIN (?:RSA |EC |OPENSSH )?PRIVATE KEY-----)',
"private_key_block",
Sensitivity.CRITICAL,
"[REDACTED-PRIVATE-KEY]",
),
# Ethereum / crypto addresses (42-char hex starting with 0x)
(
r'\b0x[a-fA-F0-9]{40}\b',
"ethereum_address",
Sensitivity.HIGH,
"[REDACTED-ETH-ADDR]",
),
# Bitcoin addresses (base58, 25-34 chars starting with 1/3/bc1)
(
r'\b[13][a-km-zA-HJ-NP-Z1-9]{25,34}\b',
"bitcoin_address",
Sensitivity.HIGH,
"[REDACTED-BTC-ADDR]",
),
(
r'\bbc1[a-zA-HJ-NP-Z0-9]{39,59}\b',
"bech32_address",
Sensitivity.HIGH,
"[REDACTED-BTC-ADDR]",
),
# --- HIGH: financial ---
(
r'\b(?:\d{4}[-\s]?){3}\d{4}\b',
"credit_card_number",
Sensitivity.HIGH,
"[REDACTED-CC]",
),
(
r'\b\d{3}-\d{2}-\d{4}\b',
"us_ssn",
Sensitivity.HIGH,
"[REDACTED-SSN]",
),
# --- MEDIUM: personal identifiers ---
# Email addresses
(
r'\b[A-Za-z0-9._%+\-]+@[A-Za-z0-9.\-]+\.[A-Za-z]{2,}\b',
"email_address",
Sensitivity.MEDIUM,
"[REDACTED-EMAIL]",
),
# Phone numbers (US/international patterns)
(
r'\b(?:\+?1[-.\s]?)?\(?\d{3}\)?[-.\s]?\d{3}[-.\s]?\d{4}\b',
"phone_number_us",
Sensitivity.MEDIUM,
"[REDACTED-PHONE]",
),
(
r'\b\+\d{1,3}[-.\s]?\d{4,14}\b',
"phone_number_intl",
Sensitivity.MEDIUM,
"[REDACTED-PHONE]",
),
# Filesystem paths that reveal user identity
(
r'(?:/Users/|/home/|C:\\Users\\)([A-Za-z0-9_\-]+)',
"user_home_path",
Sensitivity.MEDIUM,
r"/Users/[REDACTED-USER]",
),
# --- LOW: environment / system info ---
# Internal IPs
(
r'\b(?:10\.\d{1,3}\.\d{1,3}\.\d{1,3}|172\.(?:1[6-9]|2\d|3[01])\.\d{1,3}\.\d{1,3}|192\.168\.\d{1,3}\.\d{1,3})\b',
"internal_ip",
Sensitivity.LOW,
"[REDACTED-IP]",
),
]
_PII_PATTERNS = [
(re.compile(pattern, re.IGNORECASE), rtype, sensitivity, replacement)
for pattern, rtype, sensitivity, replacement in raw_patterns
]
_compile_patterns()
# =========================================================================
# Sensitive file path patterns (context-aware)
# =========================================================================
_SENSITIVE_PATH_PATTERNS = [
re.compile(r'\.(?:env|pem|key|p12|pfx|jks|keystore)\b', re.IGNORECASE),
re.compile(r'(?:\.ssh/|\.gnupg/|\.aws/|\.config/gcloud/)', re.IGNORECASE),
re.compile(r'(?:wallet|keystore|seed|mnemonic)', re.IGNORECASE),
re.compile(r'(?:\.hermes/\.env)', re.IGNORECASE),
]
def _classify_path_sensitivity(path: str) -> Sensitivity:
"""Check if a file path references sensitive material."""
for pat in _SENSITIVE_PATH_PATTERNS:
if pat.search(path):
return Sensitivity.HIGH
return Sensitivity.PUBLIC
# =========================================================================
# Core filtering
# =========================================================================
class PrivacyFilter:
"""Strip PII from message context before remote API calls.
Integrates with the agent's message pipeline. Call sanitize_messages()
before sending context to any cloud LLM provider.
"""
def __init__(
self,
min_sensitivity: Sensitivity = Sensitivity.MEDIUM,
aggressive_mode: bool = False,
):
"""
Args:
min_sensitivity: Only redact PII at or above this level.
Default MEDIUM — redacts emails, phones, paths but not IPs.
aggressive_mode: If True, also redact file paths and internal IPs.
"""
self.min_sensitivity = (
Sensitivity.LOW if aggressive_mode else min_sensitivity
)
self.aggressive_mode = aggressive_mode
def sanitize_text(self, text: str) -> Tuple[str, List[Dict[str, Any]]]:
"""Sanitize a single text string. Returns (cleaned_text, redaction_list)."""
redactions = []
cleaned = text
for pattern, rtype, sensitivity, replacement in _PII_PATTERNS:
if sensitivity.value < self.min_sensitivity.value:
continue
matches = pattern.findall(cleaned)
if matches:
count = len(matches) if isinstance(matches[0], str) else sum(
1 for m in matches if m
)
if count > 0:
cleaned = pattern.sub(replacement, cleaned)
redactions.append({
"type": rtype,
"sensitivity": sensitivity.name,
"count": count,
})
return cleaned, redactions
def sanitize_messages(
self, messages: List[Dict[str, Any]]
) -> Tuple[List[Dict[str, Any]], RedactionReport]:
"""Sanitize a list of OpenAI-format messages.
Returns (safe_messages, report). System messages are NOT sanitized
(they're typically static prompts). Only user and assistant messages
with string content are processed.
Args:
messages: List of {"role": ..., "content": ...} dicts.
Returns:
Tuple of (sanitized_messages, redaction_report).
"""
report = RedactionReport(total_messages=len(messages))
safe_messages = []
for msg in messages:
role = msg.get("role", "")
content = msg.get("content", "")
# Only sanitize user/assistant string content
if role in ("user", "assistant") and isinstance(content, str) and content:
cleaned, redactions = self.sanitize_text(content)
if redactions:
report.redacted_messages += 1
report.redactions.extend(redactions)
# Track max sensitivity
for r in redactions:
s = Sensitivity[r["sensitivity"]]
if s.value > report.max_sensitivity.value:
report.max_sensitivity = s
safe_msg = {**msg, "content": cleaned}
safe_messages.append(safe_msg)
logger.info(
"Privacy filter: redacted %d PII type(s) from %s message",
len(redactions), role,
)
else:
safe_messages.append(msg)
else:
safe_messages.append(msg)
return safe_messages, report
def should_use_local_only(self, text: str) -> Tuple[bool, str]:
"""Determine if content is too sensitive for any remote call.
Returns (should_block, reason). If True, the content should only
be processed by a local model.
"""
_, redactions = self.sanitize_text(text)
critical_count = sum(
1 for r in redactions
if Sensitivity[r["sensitivity"]] == Sensitivity.CRITICAL
)
high_count = sum(
1 for r in redactions
if Sensitivity[r["sensitivity"]] == Sensitivity.HIGH
)
if critical_count > 0:
return True, f"Contains {critical_count} critical-secret pattern(s) — local-only"
if high_count >= 3:
return True, f"Contains {high_count} high-sensitivity pattern(s) — local-only"
return False, ""
def sanitize_messages(
messages: List[Dict[str, Any]],
min_sensitivity: Sensitivity = Sensitivity.MEDIUM,
aggressive: bool = False,
) -> Tuple[List[Dict[str, Any]], RedactionReport]:
"""Convenience function: sanitize messages with default settings."""
pf = PrivacyFilter(min_sensitivity=min_sensitivity, aggressive_mode=aggressive)
return pf.sanitize_messages(messages)
def quick_sanitize(text: str) -> str:
"""Quick sanitize a single string — returns cleaned text only."""
pf = PrivacyFilter()
cleaned, _ = pf.sanitize_text(text)
return cleaned

View File

@@ -1,177 +0,0 @@
"""Tool Orchestrator — Robust execution and circuit breaking for agent tools.
Provides a unified execution service that wraps the tool registry.
Implements the Circuit Breaker pattern to prevent the agent from getting
stuck in failure loops when a specific tool or its underlying service
is flapping or down.
Architecture:
Discovery (tools/registry.py) -> Orchestration (agent/tool_orchestrator.py) -> Dispatch
"""
import json
import time
import logging
import threading
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple
from tools.registry import registry
logger = logging.getLogger(__name__)
class CircuitState:
"""States for the tool circuit breaker."""
CLOSED = "closed" # Normal operation
OPEN = "open" # Failing, execution blocked
HALF_OPEN = "half_open" # Testing if service recovered
@dataclass
class ToolStats:
"""Execution statistics for a tool."""
name: str
state: str = CircuitState.CLOSED
failures: int = 0
successes: int = 0
last_failure_time: float = 0
total_execution_time: float = 0
call_count: int = 0
class ToolOrchestrator:
"""Orchestrates tool execution with robustness patterns."""
def __init__(
self,
failure_threshold: int = 3,
reset_timeout: int = 300,
):
"""
Args:
failure_threshold: Number of failures before opening the circuit.
reset_timeout: Seconds to wait before transitioning from OPEN to HALF_OPEN.
"""
self.failure_threshold = failure_threshold
self.reset_timeout = reset_timeout
self._stats: Dict[str, ToolStats] = {}
self._lock = threading.Lock()
def _get_stats(self, name: str) -> ToolStats:
"""Get or initialize stats for a tool with thread-safe state transition."""
with self._lock:
if name not in self._stats:
self._stats[name] = ToolStats(name=name)
stats = self._stats[name]
# Transition from OPEN to HALF_OPEN if timeout expired
if stats.state == CircuitState.OPEN:
if time.time() - stats.last_failure_time > self.reset_timeout:
stats.state = CircuitState.HALF_OPEN
logger.info("Circuit breaker HALF_OPEN for tool: %s", name)
return stats
def _record_success(self, name: str, execution_time: float):
"""Record a successful tool execution and close the circuit."""
with self._lock:
stats = self._stats[name]
stats.successes += 1
stats.call_count += 1
stats.total_execution_time += execution_time
if stats.state != CircuitState.CLOSED:
logger.info("Circuit breaker CLOSED for tool: %s (recovered)", name)
stats.state = CircuitState.CLOSED
stats.failures = 0
def _record_failure(self, name: str, execution_time: float):
"""Record a failed tool execution and potentially open the circuit."""
with self._lock:
stats = self._stats[name]
stats.failures += 1
stats.call_count += 1
stats.total_execution_time += execution_time
stats.last_failure_time = time.time()
if stats.state == CircuitState.HALF_OPEN or stats.failures >= self.failure_threshold:
stats.state = CircuitState.OPEN
logger.warning(
"Circuit breaker OPEN for tool: %s (failures: %d)",
name, stats.failures
)
def dispatch(self, name: str, args: dict, **kwargs) -> str:
"""Execute a tool via the registry with circuit breaker protection."""
stats = self._get_stats(name)
if stats.state == CircuitState.OPEN:
return json.dumps({
"error": (
f"Tool '{name}' is temporarily unavailable due to repeated failures. "
f"Circuit breaker is OPEN. Please try again in a few minutes or use an alternative tool."
),
"circuit_breaker": True,
"tool_name": name
})
start_time = time.time()
try:
# Dispatch to the underlying registry
result_str = registry.dispatch(name, args, **kwargs)
execution_time = time.time() - start_time
# Inspect result for errors. registry.dispatch catches internal
# exceptions and returns a JSON error string.
is_error = False
try:
# Lightweight check for error key in JSON
if '"error":' in result_str:
res_json = json.loads(result_str)
if isinstance(res_json, dict) and "error" in res_json:
is_error = True
except (json.JSONDecodeError, TypeError):
# If it's not valid JSON, it's a malformed result (error)
is_error = True
if is_error:
self._record_failure(name, execution_time)
else:
self._record_success(name, execution_time)
return result_str
except Exception as e:
# This should rarely be hit as registry.dispatch catches most things,
# but we guard against orchestrator-level or registry-level bugs.
execution_time = time.time() - start_time
self._record_failure(name, execution_time)
error_msg = f"Tool orchestrator error during {name}: {type(e).__name__}: {e}"
logger.exception(error_msg)
return json.dumps({
"error": error_msg,
"tool_name": name,
"execution_time": execution_time
})
def get_fleet_stats(self) -> Dict[str, Any]:
"""Return execution statistics for all tools."""
with self._lock:
return {
name: {
"state": s.state,
"failures": s.failures,
"successes": s.successes,
"avg_time": s.total_execution_time / s.call_count if s.call_count > 0 else 0,
"calls": s.call_count
}
for name, s in self._stats.items()
}
# Global orchestrator instance
orchestrator = ToolOrchestrator()

View File

@@ -0,0 +1,40 @@
# Tool Call Benchmark: Gemma 4 vs mimo-v2-pro
Date: 2026-04-13
Status: Awaiting execution
## Test Design
100 diverse tool calls across 7 categories:
| Category | Count | Tools Tested |
|----------|-------|--------------|
| File operations | 20 | read_file, write_file, search_files |
| Terminal commands | 20 | terminal |
| Web search | 15 | web_search |
| Code execution | 15 | execute_code |
| Browser automation | 10 | browser_navigate |
| Delegation | 10 | delegate_task |
| MCP tools | 10 | mcp_* |
## Metrics
| Metric | mimo-v2-pro | Gemma 4 |
|--------|-------------|---------|
| Schema parse success | — | — |
| Tool execution success | — | — |
| Parallel tool success | — | — |
| Avg latency (s) | — | — |
| Token cost per call | — | — |
## How to Run
```bash
python3 benchmarks/tool_call_benchmark.py --model nous:xiaomi/mimo-v2-pro
python3 benchmarks/tool_call_benchmark.py --model ollama/gemma4:latest
python3 benchmarks/tool_call_benchmark.py --compare
```
## Gemma 4-Specific Failure Modes
To be documented after benchmark execution.

View File

@@ -0,0 +1,614 @@
#!/usr/bin/env python3
"""
Tool-Calling Benchmark — Gemma 4 vs mimo-v2-pro regression test.
Runs 100 diverse tool-calling prompts through multiple models and compares
success rates, latency, and token costs.
Usage:
python3 benchmarks/tool_call_benchmark.py # full 100-call suite
python3 benchmarks/tool_call_benchmark.py --limit 10 # quick smoke test
python3 benchmarks/tool_call_benchmark.py --models nous # single model
python3 benchmarks/tool_call_benchmark.py --category file # single category
Requires: hermes-agent venv activated, OPENROUTER_API_KEY or equivalent.
"""
import argparse
import json
import os
import sys
import time
import traceback
from dataclasses import dataclass, field, asdict
from datetime import datetime, timezone
from pathlib import Path
from typing import Optional
# Ensure hermes-agent root is importable
REPO_ROOT = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(REPO_ROOT))
# ---------------------------------------------------------------------------
# Test Definitions
# ---------------------------------------------------------------------------
@dataclass
class ToolCall:
"""A single tool-calling test case."""
id: str
category: str
prompt: str
expected_tool: str # tool name we expect the model to call
expected_params_check: str = "" # substring expected in JSON args
timeout: int = 30 # max seconds per call
notes: str = ""
# fmt: off
SUITE: list[ToolCall] = [
# ── File Operations (20) ──────────────────────────────────────────────
ToolCall("file-01", "file", "Read the file /tmp/test_bench.txt and show me its contents.",
"read_file", "path"),
ToolCall("file-02", "file", "Write 'hello benchmark' to /tmp/test_bench_out.txt",
"write_file", "path"),
ToolCall("file-03", "file", "Search for the word 'import' in all Python files in the current directory.",
"search_files", "pattern"),
ToolCall("file-04", "file", "Read lines 1-20 of /etc/hosts",
"read_file", "offset"),
ToolCall("file-05", "file", "Patch /tmp/test_bench_out.txt: replace 'hello' with 'goodbye'",
"patch", "old_string"),
ToolCall("file-06", "file", "Search for files matching *.py in the current directory.",
"search_files", "target"),
ToolCall("file-07", "file", "Read the first 10 lines of /etc/passwd",
"read_file", "limit"),
ToolCall("file-08", "file", "Write a JSON config to /tmp/bench_config.json with key 'debug': true",
"write_file", "content"),
ToolCall("file-09", "file", "Search for 'def test_' in Python test files.",
"search_files", "file_glob"),
ToolCall("file-10", "file", "Read /tmp/bench_config.json and tell me what's in it.",
"read_file", "bench_config"),
ToolCall("file-11", "file", "Create a file /tmp/bench_readme.md with one line: '# Benchmark'",
"write_file", "bench_readme"),
ToolCall("file-12", "file", "Search for 'TODO' comments in all .py files.",
"search_files", "TODO"),
ToolCall("file-13", "file", "Read /tmp/bench_readme.md",
"read_file", "bench_readme"),
ToolCall("file-14", "file", "Patch /tmp/bench_readme.md: replace '# Benchmark' with '# Tool Benchmark'",
"patch", "Tool Benchmark"),
ToolCall("file-15", "file", "Write a Python one-liner to /tmp/bench_hello.py that prints hello.",
"write_file", "bench_hello"),
ToolCall("file-16", "file", "Search for all .json files in /tmp/.",
"search_files", "json"),
ToolCall("file-17", "file", "Read /tmp/bench_hello.py and verify it has print('hello').",
"read_file", "bench_hello"),
ToolCall("file-18", "file", "Patch /tmp/bench_hello.py to print 'hello world' instead of 'hello'.",
"patch", "hello world"),
ToolCall("file-19", "file", "List files matching 'bench*' in /tmp/.",
"search_files", "bench"),
ToolCall("file-20", "file", "Read /tmp/test_bench.txt again and summarize its contents.",
"read_file", "test_bench"),
# ── Terminal Commands (20) ────────────────────────────────────────────
ToolCall("term-01", "terminal", "Run `echo hello world` in the terminal.",
"terminal", "echo"),
ToolCall("term-02", "terminal", "Run `date` to get the current date and time.",
"terminal", "date"),
ToolCall("term-03", "terminal", "Run `uname -a` to get system information.",
"terminal", "uname"),
ToolCall("term-04", "terminal", "Run `pwd` to show the current directory.",
"terminal", "pwd"),
ToolCall("term-05", "terminal", "Run `ls -la /tmp/ | head -20` to list temp files.",
"terminal", "head"),
ToolCall("term-06", "terminal", "Run `whoami` to show the current user.",
"terminal", "whoami"),
ToolCall("term-07", "terminal", "Run `df -h` to show disk usage.",
"terminal", "df"),
ToolCall("term-08", "terminal", "Run `python3 --version` to check Python version.",
"terminal", "python3"),
ToolCall("term-09", "terminal", "Run `cat /etc/hostname` to get the hostname.",
"terminal", "hostname"),
ToolCall("term-10", "terminal", "Run `uptime` to see system uptime.",
"terminal", "uptime"),
ToolCall("term-11", "terminal", "Run `env | grep PATH` to show the PATH variable.",
"terminal", "PATH"),
ToolCall("term-12", "terminal", "Run `wc -l /etc/passwd` to count lines.",
"terminal", "wc"),
ToolCall("term-13", "terminal", "Run `echo $SHELL` to show the current shell.",
"terminal", "SHELL"),
ToolCall("term-14", "terminal", "Run `free -h || vm_stat` to check memory usage.",
"terminal", "memory"),
ToolCall("term-15", "terminal", "Run `id` to show user and group IDs.",
"terminal", "id"),
ToolCall("term-16", "terminal", "Run `hostname` to get the machine hostname.",
"terminal", "hostname"),
ToolCall("term-17", "terminal", "Run `echo {1..5}` to test brace expansion.",
"terminal", "echo"),
ToolCall("term-18", "terminal", "Run `seq 1 5` to generate a number sequence.",
"terminal", "seq"),
ToolCall("term-19", "terminal", "Run `python3 -c 'print(2+2)'` to compute 2+2.",
"terminal", "print"),
ToolCall("term-20", "terminal", "Run `ls -d /tmp/bench* 2>/dev/null | wc -l` to count bench files.",
"terminal", "wc"),
# ── Code Execution (15) ──────────────────────────────────────────────
ToolCall("code-01", "code", "Execute a Python script that computes factorial of 10.",
"execute_code", "factorial"),
ToolCall("code-02", "code", "Run Python to read /tmp/test_bench.txt and count its words.",
"execute_code", "words"),
ToolCall("code-03", "code", "Execute Python to generate the first 20 Fibonacci numbers.",
"execute_code", "fibonacci"),
ToolCall("code-04", "code", "Run Python to parse JSON from a string and print keys.",
"execute_code", "json"),
ToolCall("code-05", "code", "Execute Python to list all files in /tmp/ matching 'bench*'.",
"execute_code", "glob"),
ToolCall("code-06", "code", "Run Python to compute the sum of squares from 1 to 100.",
"execute_code", "sum"),
ToolCall("code-07", "code", "Execute Python to check if 'racecar' is a palindrome.",
"execute_code", "palindrome"),
ToolCall("code-08", "code", "Run Python to create a CSV string with 5 rows of sample data.",
"execute_code", "csv"),
ToolCall("code-09", "code", "Execute Python to sort a list [5,2,8,1,9] and print the result.",
"execute_code", "sort"),
ToolCall("code-10", "code", "Run Python to count lines in /etc/passwd.",
"execute_code", "passwd"),
ToolCall("code-11", "code", "Execute Python to hash the string 'benchmark' with SHA256.",
"execute_code", "sha256"),
ToolCall("code-12", "code", "Run Python to get the current UTC timestamp.",
"execute_code", "utcnow"),
ToolCall("code-13", "code", "Execute Python to convert 'hello world' to uppercase and reverse it.",
"execute_code", "upper"),
ToolCall("code-14", "code", "Run Python to create a dictionary of system info (platform, python version).",
"execute_code", "sys"),
ToolCall("code-15", "code", "Execute Python to check internet connectivity by resolving google.com.",
"execute_code", "socket"),
# ── Delegation (10) ──────────────────────────────────────────────────
ToolCall("deleg-01", "delegate", "Use a subagent to find all .log files in /tmp/.",
"delegate_task", "log"),
ToolCall("deleg-02", "delegate", "Delegate to a subagent: what is 15 * 37?",
"delegate_task", "15"),
ToolCall("deleg-03", "delegate", "Use a subagent to check if Python 3 is installed and its version.",
"delegate_task", "python"),
ToolCall("deleg-04", "delegate", "Delegate: read /tmp/test_bench.txt and summarize it in one sentence.",
"delegate_task", "summarize"),
ToolCall("deleg-05", "delegate", "Use a subagent to list the contents of /tmp/ directory.",
"delegate_task", "tmp"),
ToolCall("deleg-06", "delegate", "Delegate: count the number of .py files in the current directory.",
"delegate_task", ".py"),
ToolCall("deleg-07", "delegate", "Use a subagent to check disk space with df -h.",
"delegate_task", "df"),
ToolCall("deleg-08", "delegate", "Delegate: what OS are we running on?",
"delegate_task", "os"),
ToolCall("deleg-09", "delegate", "Use a subagent to find the hostname of this machine.",
"delegate_task", "hostname"),
ToolCall("deleg-10", "delegate", "Delegate: create a temp file /tmp/bench_deleg.txt with 'done'.",
"delegate_task", "write"),
# ── Todo / Memory (10 — replacing web/browser/MCP which need external services) ──
ToolCall("todo-01", "todo", "Add a todo item: 'Run benchmark suite'",
"todo", "benchmark"),
ToolCall("todo-02", "todo", "Show me the current todo list.",
"todo", ""),
ToolCall("todo-03", "todo", "Mark the first todo item as completed.",
"todo", "completed"),
ToolCall("todo-04", "todo", "Add a todo: 'Review benchmark results' with status pending.",
"todo", "Review"),
ToolCall("todo-05", "todo", "Clear all completed todos.",
"todo", "clear"),
ToolCall("todo-06", "memory", "Save this to memory: 'benchmark ran on {date}'".format(
date=datetime.now().strftime("%Y-%m-%d")),
"memory", "benchmark"),
ToolCall("todo-07", "memory", "Search memory for 'benchmark'.",
"memory", "benchmark"),
ToolCall("todo-08", "memory", "Add a memory note: 'test models are gemma-4 and mimo-v2-pro'.",
"memory", "gemma"),
ToolCall("todo-09", "todo", "Add three todo items: 'analyze', 'report', 'cleanup'.",
"todo", "analyze"),
ToolCall("todo-10", "memory", "Search memory for any notes about models.",
"memory", "model"),
# ── Skills (10 — replacing MCP tools which need servers) ─────────────
ToolCall("skill-01", "skills", "List all available skills.",
"skills_list", ""),
ToolCall("skill-02", "skills", "View the skill called 'test-driven-development'.",
"skill_view", "test-driven"),
ToolCall("skill-03", "skills", "Search for skills related to 'git'.",
"skills_list", "git"),
ToolCall("skill-04", "skills", "View the 'code-review' skill.",
"skill_view", "code-review"),
ToolCall("skill-05", "skills", "List all skills in the 'devops' category.",
"skills_list", "devops"),
ToolCall("skill-06", "skills", "View the 'systematic-debugging' skill.",
"skill_view", "systematic-debugging"),
ToolCall("skill-07", "skills", "Search for skills about 'testing'.",
"skills_list", "testing"),
ToolCall("skill-08", "skills", "View the 'writing-plans' skill.",
"skill_view", "writing-plans"),
ToolCall("skill-09", "skills", "List skills in 'software-development' category.",
"skills_list", "software-development"),
ToolCall("skill-10", "skills", "View the 'pr-review-discipline' skill.",
"skill_view", "pr-review"),
# ── Additional tests to reach 100 ────────────────────────────────────
ToolCall("file-21", "file", "Write a Python snippet to /tmp/bench_sort.py that sorts [3,1,2].",
"write_file", "bench_sort"),
ToolCall("file-22", "file", "Read /tmp/bench_sort.py back and confirm it exists.",
"read_file", "bench_sort"),
ToolCall("file-23", "file", "Search for 'class' in all .py files in the benchmarks directory.",
"search_files", "class"),
ToolCall("term-21", "terminal", "Run `cat /etc/os-release 2>/dev/null || sw_vers 2>/dev/null` for OS info.",
"terminal", "os"),
ToolCall("term-22", "terminal", "Run `nproc 2>/dev/null || sysctl -n hw.ncpu 2>/dev/null` for CPU count.",
"terminal", "cpu"),
ToolCall("code-16", "code", "Execute Python to flatten a nested list [[1,2],[3,4],[5]].",
"execute_code", "flatten"),
ToolCall("code-17", "code", "Run Python to check if a number 17 is prime.",
"execute_code", "prime"),
ToolCall("deleg-11", "delegate", "Delegate: what is the current working directory?",
"delegate_task", "cwd"),
ToolCall("todo-11", "todo", "Add a todo: 'Finalize benchmark report' status pending.",
"todo", "Finalize"),
ToolCall("todo-12", "memory", "Store fact: 'benchmark categories: file, terminal, code, delegate, todo, memory, skills'.",
"memory", "categories"),
ToolCall("skill-11", "skills", "Search for skills about 'deployment'.",
"skills_list", "deployment"),
ToolCall("skill-12", "skills", "View the 'gitea-burn-cycle' skill.",
"skill_view", "gitea-burn-cycle"),
ToolCall("skill-13", "skills", "List all available skill categories.",
"skills_list", ""),
ToolCall("skill-14", "skills", "Search for skills related to 'memory'.",
"skills_list", "memory"),
ToolCall("skill-15", "skills", "View the 'mimo-swarm' skill.",
"skill_view", "mimo-swarm"),
]
# fmt: on
# ---------------------------------------------------------------------------
# Runner
# ---------------------------------------------------------------------------
@dataclass
class CallResult:
test_id: str
category: str
model: str
prompt: str
expected_tool: str
success: bool
tool_called: Optional[str] = None
tool_args_valid: bool = False
execution_ok: bool = False
latency_s: float = 0.0
error: str = ""
raw_response: str = ""
@dataclass
class ModelStats:
model: str
total: int = 0
schema_ok: int = 0 # model produced valid tool call JSON
exec_ok: int = 0 # tool actually ran without error
latency_sum: float = 0.0
failures: list = field(default_factory=list)
@property
def schema_pct(self) -> float:
return (self.schema_ok / self.total * 100) if self.total else 0
@property
def exec_pct(self) -> float:
return (self.exec_ok / self.total * 100) if self.total else 0
@property
def avg_latency(self) -> float:
return (self.latency_sum / self.total) if self.total else 0
def setup_test_files():
"""Create prerequisite files for the benchmark."""
Path("/tmp/test_bench.txt").write_text(
"This is a benchmark test file.\n"
"It contains sample data for tool-calling tests.\n"
"Line three has some import statements.\n"
"import os\nimport sys\nimport json\n"
"End of test data.\n"
)
def run_single_test(tc: ToolCall, model_spec: str, provider: str) -> CallResult:
"""Run a single tool-calling test through the agent."""
from run_agent import AIAgent
result = CallResult(
test_id=tc.id,
category=tc.category,
model=model_spec,
prompt=tc.prompt,
expected_tool=tc.expected_tool,
success=False,
)
try:
agent = AIAgent(
model=model_spec,
provider=provider,
max_iterations=3,
quiet_mode=True,
skip_context_files=True,
skip_memory=True,
persist_session=False,
)
t0 = time.time()
conv = agent.run_conversation(
user_message=tc.prompt,
system_message=(
"You are a benchmark test runner. Execute the user's request by calling "
"the appropriate tool. Return the tool result directly. Do not add commentary."
),
)
result.latency_s = round(time.time() - t0, 2)
messages = conv.get("messages", [])
# Find the first assistant message with tool_calls
tool_called = None
tool_args_str = ""
for msg in messages:
if msg.get("role") == "assistant" and msg.get("tool_calls"):
for tc_item in msg["tool_calls"]:
fn = tc_item.get("function", {})
tool_called = fn.get("name", "")
tool_args_str = fn.get("arguments", "{}")
break
break
if tool_called:
result.tool_called = tool_called
result.schema_ok = True
# Check if the right tool was called
if tool_called == tc.expected_tool:
result.success = True
# Check if args contain expected substring
if tc.expected_params_check:
result.tool_args_valid = tc.expected_params_check in tool_args_str
else:
result.tool_args_valid = True
# Check if tool executed (look for tool role message)
for msg in messages:
if msg.get("role") == "tool":
content = msg.get("content", "")
if content and "error" not in content.lower()[:50]:
result.execution_ok = True
break
elif content:
result.execution_ok = True # got a response, even if error
break
else:
# No tool call produced — still check if model responded
final = conv.get("final_response", "")
result.raw_response = final[:200] if final else ""
except Exception as e:
result.error = f"{type(e).__name__}: {str(e)[:200]}"
result.latency_s = round(time.time() - t0, 2) if 't0' in dir() else 0
return result
def generate_report(results: list[CallResult], models: list[str], output_path: Path):
"""Generate markdown benchmark report."""
now = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M UTC")
# Aggregate per model
stats: dict[str, ModelStats] = {}
for m in models:
stats[m] = ModelStats(model=m)
by_category: dict[str, dict[str, list[CallResult]]] = {}
for r in results:
s = stats[r.model]
s.total += 1
s.schema_ok += int(r.schema_ok)
s.exec_ok += int(r.execution_ok)
s.latency_sum += r.latency_s
if not r.success:
s.failures.append(r)
by_category.setdefault(r.category, {}).setdefault(r.model, []).append(r)
lines = [
f"# Tool-Calling Benchmark Report",
f"",
f"Generated: {now}",
f"Suite: {len(SUITE)} calls across {len(set(tc.category for tc in SUITE))} categories",
f"Models tested: {', '.join(models)}",
f"",
f"## Summary",
f"",
f"| Metric | {' | '.join(models)} |",
f"|--------|{'|'.join('---------' for _ in models)}|",
]
# Schema parse success
row = "| Schema parse success | "
for m in models:
s = stats[m]
row += f"{s.schema_ok}/{s.total} ({s.schema_pct:.0f}%) | "
lines.append(row)
# Tool execution success
row = "| Tool execution success | "
for m in models:
s = stats[m]
row += f"{s.exec_ok}/{s.total} ({s.exec_pct:.0f}%) | "
lines.append(row)
# Correct tool selected
row = "| Correct tool selected | "
for m in models:
s = stats[m]
correct = sum(1 for r in results if r.model == m and r.success)
pct = (correct / s.total * 100) if s.total else 0
row += f"{correct}/{s.total} ({pct:.0f}%) | "
lines.append(row)
# Avg latency
row = "| Avg latency (s) | "
for m in models:
s = stats[m]
row += f"{s.avg_latency:.2f} | "
lines.append(row)
lines.append("")
# Per-category breakdown
lines.append("## Per-Category Breakdown")
lines.append("")
for cat in sorted(by_category.keys()):
lines.append(f"### {cat.title()}")
lines.append("")
lines.append(f"| Metric | {' | '.join(models)} |")
lines.append(f"|--------|{'|'.join('---------' for _ in models)}|")
cat_data = by_category[cat]
for metric_name, fn in [
("Schema OK", lambda r: r.schema_ok),
("Exec OK", lambda r: r.execution_ok),
("Correct tool", lambda r: r.success),
]:
row = f"| {metric_name} | "
for m in models:
results_m = cat_data.get(m, [])
total = len(results_m)
ok = sum(1 for r in results_m if fn(r))
pct = (ok / total * 100) if total else 0
row += f"{ok}/{total} ({pct:.0f}%) | "
lines.append(row)
lines.append("")
# Failure analysis
lines.append("## Failure Analysis")
lines.append("")
any_failures = False
for m in models:
s = stats[m]
if s.failures:
any_failures = True
lines.append(f"### {m}{len(s.failures)} failures")
lines.append("")
lines.append("| Test | Category | Expected | Got | Error |")
lines.append("|------|----------|----------|-----|-------|")
for r in s.failures:
got = r.tool_called or "none"
err = r.error or "wrong tool"
lines.append(f"| {r.test_id} | {r.category} | {r.expected_tool} | {got} | {err[:60]} |")
lines.append("")
if not any_failures:
lines.append("No failures detected.")
lines.append("")
# Raw results JSON
lines.append("## Raw Results")
lines.append("")
lines.append("```json")
lines.append(json.dumps([asdict(r) for r in results], indent=2, default=str))
lines.append("```")
report = "\n".join(lines)
output_path.write_text(report)
return report
def main():
parser = argparse.ArgumentParser(description="Tool-calling benchmark")
parser.add_argument("--models", nargs="+",
default=["nous:gia-3/gemma-4-31b", "nous:mimo-v2-pro"],
help="Model specs to test (provider:model)")
parser.add_argument("--limit", type=int, default=0,
help="Run only first N tests (0 = all)")
parser.add_argument("--category", type=str, default="",
help="Run only tests in this category")
parser.add_argument("--output", type=str, default="",
help="Output report path (default: benchmarks/gemma4-tool-calling-YYYY-MM-DD.md)")
parser.add_argument("--dry-run", action="store_true",
help="Print test cases without running them")
args = parser.parse_args()
# Filter suite
suite = SUITE[:]
if args.category:
suite = [tc for tc in suite if tc.category == args.category]
if args.limit > 0:
suite = suite[:args.limit]
if args.dry_run:
print(f"Would run {len(suite)} tests:")
for tc in suite:
print(f" [{tc.category:8s}] {tc.id}: {tc.expected_tool}{tc.prompt[:60]}")
return
# Setup
setup_test_files()
date_str = datetime.now().strftime("%Y-%m-%d")
output_path = Path(args.output) if args.output else REPO_ROOT / "benchmarks" / f"gemma4-tool-calling-{date_str}.md"
# Parse model specs
model_specs = []
for spec in args.models:
parts = spec.split(":", 1)
provider = parts[0]
model_name = parts[1] if len(parts) > 1 else parts[0]
model_specs.append((provider, model_name, spec))
print(f"Benchmark: {len(suite)} tests × {len(model_specs)} models = {len(suite) * len(model_specs)} calls")
print(f"Output: {output_path}")
print()
all_results: list[CallResult] = []
for provider, model_name, full_spec in model_specs:
print(f"── {full_spec} {'' * (50 - len(full_spec))}")
model_results = []
for i, tc in enumerate(suite, 1):
sys.stdout.write(f"\r [{i:3d}/{len(suite)}] {tc.id:10s} {tc.category:8s}{tc.expected_tool:20s}")
sys.stdout.flush()
r = run_single_test(tc, full_spec, provider)
model_results.append(r)
status = "" if r.success else ""
sys.stdout.write(f" {status} ({r.latency_s:.1f}s)")
sys.stdout.write("\n")
all_results.extend(model_results)
# Quick stats
ok = sum(1 for r in model_results if r.success)
print(f" Result: {ok}/{len(model_results)} correct tool selected ({ok/len(model_results)*100:.0f}%)")
print()
# Generate report
model_names = [spec for _, _, spec in model_specs]
report = generate_report(all_results, model_names, output_path)
print(f"Report written to {output_path}")
# Exit code: 0 if all pass, 1 if any failures
total_fail = sum(1 for r in all_results if not r.success)
sys.exit(1 if total_fail > 0 else 0)
if __name__ == "__main__":
main()

View File

@@ -28,7 +28,6 @@ from typing import Dict, Any, List, Optional, Tuple
from tools.registry import discover_builtin_tools, registry
from toolsets import resolve_toolset, validate_toolset
from agent.tool_orchestrator import orchestrator
logger = logging.getLogger(__name__)
@@ -500,13 +499,13 @@ def handle_function_call(
# Prefer the caller-provided list so subagents can't overwrite
# the parent's tool set via the process-global.
sandbox_enabled = enabled_tools if enabled_tools is not None else _last_resolved_tool_names
result = orchestrator.dispatch(
result = registry.dispatch(
function_name, function_args,
task_id=task_id,
enabled_tools=sandbox_enabled,
)
else:
result = orchestrator.dispatch(
result = registry.dispatch(
function_name, function_args,
task_id=task_id,
user_task=user_task,

View File

@@ -1,202 +0,0 @@
"""Tests for agent.privacy_filter — PII stripping before remote API calls."""
import pytest
from agent.privacy_filter import (
PrivacyFilter,
RedactionReport,
Sensitivity,
sanitize_messages,
quick_sanitize,
)
class TestPrivacyFilterSanitizeText:
"""Test single-text sanitization."""
def test_no_pii_returns_clean(self):
pf = PrivacyFilter()
text = "The weather in Paris is nice today."
cleaned, redactions = pf.sanitize_text(text)
assert cleaned == text
assert redactions == []
def test_email_redacted(self):
pf = PrivacyFilter()
text = "Send report to alice@example.com by Friday."
cleaned, redactions = pf.sanitize_text(text)
assert "alice@example.com" not in cleaned
assert "[REDACTED-EMAIL]" in cleaned
assert any(r["type"] == "email_address" for r in redactions)
def test_phone_redacted(self):
pf = PrivacyFilter()
text = "Call me at 555-123-4567 when ready."
cleaned, redactions = pf.sanitize_text(text)
assert "555-123-4567" not in cleaned
assert "[REDACTED-PHONE]" in cleaned
def test_api_key_redacted(self):
pf = PrivacyFilter()
text = 'api_key = "sk-proj-abcdefghij1234567890abcdefghij1234567890"'
cleaned, redactions = pf.sanitize_text(text)
assert "sk-proj-" not in cleaned
assert any(r["sensitivity"] == "CRITICAL" for r in redactions)
def test_github_token_redacted(self):
pf = PrivacyFilter()
text = "Use ghp_1234567890abcdefghijklmnopqrstuvwxyz1234 for auth"
cleaned, redactions = pf.sanitize_text(text)
assert "ghp_" not in cleaned
assert any(r["type"] == "github_token" for r in redactions)
def test_ethereum_address_redacted(self):
pf = PrivacyFilter()
text = "Send to 0x742d35Cc6634C0532925a3b844Bc9e7595f2bD18 please"
cleaned, redactions = pf.sanitize_text(text)
assert "0x742d" not in cleaned
assert any(r["type"] == "ethereum_address" for r in redactions)
def test_user_home_path_redacted(self):
pf = PrivacyFilter()
text = "Read file at /Users/alice/Documents/secret.txt"
cleaned, redactions = pf.sanitize_text(text)
assert "alice" not in cleaned
assert "[REDACTED-USER]" in cleaned
def test_multiple_pii_types(self):
pf = PrivacyFilter()
text = (
"Contact john@test.com or call 555-999-1234. "
"The API key is sk-abcdefghijklmnopqrstuvwxyz1234567890."
)
cleaned, redactions = pf.sanitize_text(text)
assert "john@test.com" not in cleaned
assert "555-999-1234" not in cleaned
assert "sk-abcd" not in cleaned
assert len(redactions) >= 3
class TestPrivacyFilterSanitizeMessages:
"""Test message-list sanitization."""
def test_sanitize_user_message(self):
pf = PrivacyFilter()
messages = [
{"role": "system", "content": "You are helpful."},
{"role": "user", "content": "Email me at bob@test.com with results."},
]
safe, report = pf.sanitize_messages(messages)
assert report.redacted_messages == 1
assert "bob@test.com" not in safe[1]["content"]
assert "[REDACTED-EMAIL]" in safe[1]["content"]
# System message unchanged
assert safe[0]["content"] == "You are helpful."
def test_no_redaction_needed(self):
pf = PrivacyFilter()
messages = [
{"role": "user", "content": "What is 2+2?"},
{"role": "assistant", "content": "4"},
]
safe, report = pf.sanitize_messages(messages)
assert report.redacted_messages == 0
assert not report.had_redactions
def test_assistant_messages_also_sanitized(self):
pf = PrivacyFilter()
messages = [
{"role": "assistant", "content": "Your email admin@corp.com was found."},
]
safe, report = pf.sanitize_messages(messages)
assert report.redacted_messages == 1
assert "admin@corp.com" not in safe[0]["content"]
def test_tool_messages_not_sanitized(self):
pf = PrivacyFilter()
messages = [
{"role": "tool", "content": "Result: user@test.com found"},
]
safe, report = pf.sanitize_messages(messages)
assert report.redacted_messages == 0
assert safe[0]["content"] == "Result: user@test.com found"
class TestShouldUseLocalOnly:
"""Test the local-only routing decision."""
def test_normal_text_allows_remote(self):
pf = PrivacyFilter()
block, reason = pf.should_use_local_only("Summarize this article about Python.")
assert not block
def test_critical_secret_blocks_remote(self):
pf = PrivacyFilter()
text = "Here is the API key: sk-abcdefghijklmnopqrstuvwxyz1234567890"
block, reason = pf.should_use_local_only(text)
assert block
assert "critical" in reason.lower()
def test_multiple_high_sensitivity_blocks(self):
pf = PrivacyFilter()
# 3+ high-sensitivity patterns
text = (
"Card: 4111-1111-1111-1111, "
"SSN: 123-45-6789, "
"BTC: 1A1zP1eP5QGefi2DMPTfTL5SLmv7DivfNa, "
"ETH: 0x742d35Cc6634C0532925a3b844Bc9e7595f2bD18"
)
block, reason = pf.should_use_local_only(text)
assert block
class TestAggressiveMode:
"""Test aggressive filtering mode."""
def test_aggressive_redacts_internal_ips(self):
pf = PrivacyFilter(aggressive_mode=True)
text = "Server at 192.168.1.100 is responding."
cleaned, redactions = pf.sanitize_text(text)
assert "192.168.1.100" not in cleaned
assert any(r["type"] == "internal_ip" for r in redactions)
def test_normal_does_not_redact_ips(self):
pf = PrivacyFilter(aggressive_mode=False)
text = "Server at 192.168.1.100 is responding."
cleaned, redactions = pf.sanitize_text(text)
assert "192.168.1.100" in cleaned # IP preserved in normal mode
class TestConvenienceFunctions:
"""Test module-level convenience functions."""
def test_quick_sanitize(self):
text = "Contact alice@example.com for details"
result = quick_sanitize(text)
assert "alice@example.com" not in result
assert "[REDACTED-EMAIL]" in result
def test_sanitize_messages_convenience(self):
messages = [{"role": "user", "content": "Call 555-000-1234"}]
safe, report = sanitize_messages(messages)
assert report.redacted_messages == 1
class TestRedactionReport:
"""Test the reporting structure."""
def test_summary_no_redactions(self):
report = RedactionReport(total_messages=3, redacted_messages=0)
assert "No PII" in report.summary()
def test_summary_with_redactions(self):
report = RedactionReport(
total_messages=2,
redacted_messages=1,
redactions=[
{"type": "email_address", "sensitivity": "MEDIUM", "count": 2},
{"type": "phone_number_us", "sensitivity": "MEDIUM", "count": 1},
],
)
summary = report.summary()
assert "1/2" in summary
assert "email_address" in summary

View File

@@ -1,132 +0,0 @@
"""Tests for A2A agent card — Issue #819."""
import json
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from agent.agent_card import (
AgentSkill, AgentCapabilities, AgentCard,
validate_agent_card, build_agent_card, get_agent_card_json,
_load_skills_from_directory
)
class TestAgentSkill:
def test_creation(self):
skill = AgentSkill(id="code", name="Code", tags=["python"])
assert skill.id == "code"
assert "python" in skill.tags
class TestAgentCapabilities:
def test_defaults(self):
caps = AgentCapabilities()
assert caps.streaming == True
assert caps.push_notifications == False
class TestAgentCard:
def test_to_dict(self):
card = AgentCard(name="timmy", description="test", url="http://localhost:8642")
d = card.to_dict()
assert d["name"] == "timmy"
assert "defaultInputModes" in d
def test_to_json(self):
card = AgentCard(name="timmy", description="test", url="http://localhost:8642")
j = card.to_json()
parsed = json.loads(j)
assert parsed["name"] == "timmy"
class TestValidation:
def test_valid_card(self):
card = AgentCard(name="timmy", description="test", url="http://localhost:8642")
errors = validate_agent_card(card)
assert len(errors) == 0
def test_missing_name(self):
card = AgentCard(name="", description="test", url="http://localhost:8642")
errors = validate_agent_card(card)
assert any("name" in e for e in errors)
def test_missing_url(self):
card = AgentCard(name="timmy", description="test", url="")
errors = validate_agent_card(card)
assert any("url" in e for e in errors)
def test_invalid_input_mode(self):
card = AgentCard(
name="timmy", description="test", url="http://localhost:8642",
default_input_modes=["invalid/mode"]
)
errors = validate_agent_card(card)
assert any("invalid input mode" in e for e in errors)
def test_skill_missing_id(self):
card = AgentCard(
name="timmy", description="test", url="http://localhost:8642",
skills=[AgentSkill(id="", name="test")]
)
errors = validate_agent_card(card)
assert any("skill missing id" in e for e in errors)
class TestBuildAgentCard:
def test_builds_valid_card(self):
card = build_agent_card()
assert card.name
assert card.url
errors = validate_agent_card(card)
assert len(errors) == 0
def test_explicit_params_override(self):
card = build_agent_card(name="custom", description="custom desc")
assert card.name == "custom"
assert card.description == "custom desc"
def test_extra_skills(self):
extra = [AgentSkill(id="extra", name="Extra")]
card = build_agent_card(extra_skills=extra)
assert any(s.id == "extra" for s in card.skills)
class TestGetAgentCardJson:
def test_returns_valid_json(self):
j = get_agent_card_json()
parsed = json.loads(j)
assert "name" in parsed
def test_graceful_fallback(self):
# Even if something fails, should return valid JSON
j = get_agent_card_json()
assert j # Non-empty
class TestLoadSkills:
def test_empty_dir(self, tmp_path):
skills = _load_skills_from_directory(tmp_path / "nonexistent")
assert len(skills) == 0
def test_parses_skill_md(self, tmp_path):
skill_dir = tmp_path / "test-skill"
skill_dir.mkdir()
skill_md = skill_dir / "SKILL.md"
skill_md.write_text("""---
name: Test Skill
description: A test skill
tags:
- test
- example
---
Content here
""")
skills = _load_skills_from_directory(tmp_path)
assert len(skills) == 1
assert skills[0].name == "Test Skill"
assert "test" in skills[0].tags
if __name__ == "__main__":
import pytest
pytest.main([__file__, "-v"])

View File

@@ -0,0 +1,150 @@
"""Tests for batch tool execution safety classification."""
import json
import pytest
from unittest.mock import MagicMock
def _make_tool_call(name: str, args: dict) -> MagicMock:
"""Create a mock tool call object."""
tc = MagicMock()
tc.function.name = name
tc.function.arguments = json.dumps(args)
tc.id = f"call_{name}_1"
return tc
class TestClassification:
def test_parallel_safe_read_file(self):
from tools.batch_executor import classify_single_tool_call
tc = _make_tool_call("read_file", {"path": "README.md"})
result = classify_single_tool_call(tc)
assert result.tier == "parallel_safe"
def test_parallel_safe_web_search(self):
from tools.batch_executor import classify_single_tool_call
tc = _make_tool_call("web_search", {"query": "test"})
result = classify_single_tool_call(tc)
assert result.tier == "parallel_safe"
def test_parallel_safe_search_files(self):
from tools.batch_executor import classify_single_tool_call
tc = _make_tool_call("search_files", {"pattern": "test"})
result = classify_single_tool_call(tc)
assert result.tier == "parallel_safe"
def test_never_parallel_clarify(self):
from tools.batch_executor import classify_single_tool_call
tc = _make_tool_call("clarify", {"question": "test"})
result = classify_single_tool_call(tc)
assert result.tier == "never_parallel"
def test_terminal_is_sequential(self):
from tools.batch_executor import classify_single_tool_call
tc = _make_tool_call("terminal", {"command": "ls -la"})
result = classify_single_tool_call(tc)
assert result.tier == "sequential"
def test_terminal_destructive_rm(self):
from tools.batch_executor import classify_single_tool_call
tc = _make_tool_call("terminal", {"command": "rm -rf /tmp/test"})
result = classify_single_tool_call(tc)
assert result.tier == "sequential"
assert "Destructive" in result.reason
def test_write_file_is_path_scoped(self):
from tools.batch_executor import classify_single_tool_call
tc = _make_tool_call("write_file", {"path": "/tmp/test.txt", "content": "hello"})
result = classify_single_tool_call(tc)
assert result.tier == "path_scoped"
def test_delegate_is_sequential(self):
from tools.batch_executor import classify_single_tool_call
tc = _make_tool_call("delegate_task", {"goal": "test"})
result = classify_single_tool_call(tc)
assert result.tier == "sequential"
def test_unknown_tool_is_sequential(self):
from tools.batch_executor import classify_single_tool_call
tc = _make_tool_call("some_unknown_tool", {"arg": "val"})
result = classify_single_tool_call(tc)
assert result.tier == "sequential"
class TestBatchClassification:
def test_all_parallel_stays_parallel(self):
from tools.batch_executor import classify_tool_calls
tcs = [
_make_tool_call("read_file", {"path": f"file{i}.txt"})
for i in range(5)
]
plan = classify_tool_calls(tcs)
assert plan.can_parallelize
assert len(plan.parallel_batch) == 5
assert len(plan.sequential_batch) == 0
def test_mixed_batch(self):
from tools.batch_executor import classify_tool_calls
tcs = [
_make_tool_call("read_file", {"path": "a.txt"}),
_make_tool_call("terminal", {"command": "ls"}),
_make_tool_call("web_search", {"query": "test"}),
_make_tool_call("delegate_task", {"goal": "test"}),
]
plan = classify_tool_calls(tcs)
# read_file + web_search should be parallel (both parallel_safe)
# terminal + delegate_task should be sequential
assert len(plan.parallel_batch) >= 2
assert len(plan.sequential_batch) >= 2
def test_clarify_blocks_all(self):
from tools.batch_executor import classify_tool_calls
tcs = [
_make_tool_call("read_file", {"path": "a.txt"}),
_make_tool_call("clarify", {"question": "which one?"}),
_make_tool_call("web_search", {"query": "test"}),
]
plan = classify_tool_calls(tcs)
clarify_in_seq = any(c.tool_name == "clarify" for c in plan.sequential_batch)
assert clarify_in_seq
def test_overlapping_paths_sequential(self):
from tools.batch_executor import classify_tool_calls
tcs = [
_make_tool_call("write_file", {"path": "/tmp/test/a.txt", "content": "hello"}),
_make_tool_call("patch", {"path": "/tmp/test/a.txt", "old_string": "a", "new_string": "b"}),
]
plan = classify_tool_calls(tcs)
# write_file and patch on SAME file -> conflict -> one must be sequential
assert len(plan.sequential_batch) >= 1
class TestDestructiveCommands:
def test_rm_flagged(self):
from tools.batch_executor import is_destructive_command
assert is_destructive_command("rm -rf /tmp")
assert is_destructive_command("rm file.txt")
def test_mv_flagged(self):
from tools.batch_executor import is_destructive_command
assert is_destructive_command("mv old new")
def test_sed_i_flagged(self):
from tools.batch_executor import is_destructive_command
assert is_destructive_command("sed -i 's/a/b/g' file")
def test_redirect_overwrite_flagged(self):
from tools.batch_executor import is_destructive_command
assert is_destructive_command("echo test > file.txt")
def test_safe_commands_not_flagged(self):
from tools.batch_executor import is_destructive_command
assert not is_destructive_command("ls -la")
assert not is_destructive_command("cat file.txt")
assert not is_destructive_command("echo test >> file.txt") # append is safe
class TestRegistryIntegration:
def test_parallel_safe_in_registry(self):
from tools.registry import registry
safe = registry.get_parallel_safe_tools()
assert isinstance(safe, set)

74
tests/test_crisis_tool.py Normal file
View File

@@ -0,0 +1,74 @@
"""Tests for crisis detection tool."""
import pytest
def test_detect_crisis_none():
from tools.crisis_tool import detect_crisis
result = detect_crisis("Hello, how are you today?")
assert result.level == "NONE"
assert result.score == 0.0
def test_detect_crisis_critical():
from tools.crisis_tool import detect_crisis
result = detect_crisis("I want to kill myself")
assert result.level == "CRITICAL"
assert result.score == 1.0
assert len(result.indicators) > 0
assert "988" in result.recommended_action
def test_detect_crisis_high():
from tools.crisis_tool import detect_crisis
result = detect_crisis("I feel so hopeless about everything, nothing matters")
assert result.level in ("HIGH", "MEDIUM", "CRITICAL")
assert result.score > 0
def test_detect_crisis_medium():
from tools.crisis_tool import detect_crisis
result = detect_crisis("I feel worthless and broken inside")
assert result.level in ("MEDIUM", "LOW", "HIGH")
assert result.score > 0
def test_detect_crisis_low():
from tools.crisis_tool import detect_crisis
result = detect_crisis("I'm having a tough time and feeling stressed")
assert result.level in ("LOW", "NONE")
assert result.score >= 0
def test_detect_crisis_empty():
from tools.crisis_tool import detect_crisis
result = detect_crisis("")
assert result.level == "NONE"
result2 = detect_crisis(None)
assert result2.level == "NONE"
def test_scan_user_message_returns_none_for_safe():
from tools.crisis_tool import scan_user_message
result = scan_user_message("What's the weather like?")
assert result is None
def test_scan_user_message_returns_dict_for_crisis():
from tools.crisis_tool import scan_user_message
result = scan_user_message("I want to end it all")
assert result is not None
assert "level" in result
assert "compassion_injection" in result
assert result["level"] in ("CRITICAL", "HIGH")
def test_tool_handler():
from tools.crisis_tool import crisis_scan_handler
import json
result = crisis_scan_handler({"text": "I feel fine, thanks"})
data = json.loads(result)
assert data["level"] == "NONE"
result2 = crisis_scan_handler({"text": "I want to die"})
data2 = json.loads(result2)
assert data2["level"] == "CRITICAL"

View File

@@ -1,190 +0,0 @@
"""Tests for tools.confirmation_daemon — Human Confirmation Firewall."""
import pytest
import time
from tools.confirmation_daemon import (
ConfirmationDaemon,
ConfirmationRequest,
ConfirmationStatus,
RiskLevel,
classify_action,
_is_whitelisted,
_DEFAULT_WHITELIST,
)
class TestClassifyAction:
"""Test action risk classification."""
def test_crypto_tx_is_critical(self):
assert classify_action("crypto_tx") == RiskLevel.CRITICAL
def test_sign_transaction_is_critical(self):
assert classify_action("sign_transaction") == RiskLevel.CRITICAL
def test_send_email_is_high(self):
assert classify_action("send_email") == RiskLevel.HIGH
def test_send_message_is_medium(self):
assert classify_action("send_message") == RiskLevel.MEDIUM
def test_access_calendar_is_low(self):
assert classify_action("access_calendar") == RiskLevel.LOW
def test_unknown_action_is_medium(self):
assert classify_action("unknown_action_xyz") == RiskLevel.MEDIUM
class TestWhitelist:
"""Test whitelist auto-approval."""
def test_self_email_is_whitelisted(self):
whitelist = dict(_DEFAULT_WHITELIST)
payload = {"from": "me@test.com", "to": "me@test.com"}
assert _is_whitelisted("send_email", payload, whitelist) is True
def test_non_whitelisted_recipient_not_approved(self):
whitelist = dict(_DEFAULT_WHITELIST)
payload = {"to": "random@stranger.com"}
assert _is_whitelisted("send_email", payload, whitelist) is False
def test_whitelisted_contact_approved(self):
whitelist = {
"send_message": {"targets": ["alice", "bob"]},
}
assert _is_whitelisted("send_message", {"to": "alice"}, whitelist) is True
assert _is_whitelisted("send_message", {"to": "charlie"}, whitelist) is False
def test_no_whitelist_entry_means_not_whitelisted(self):
whitelist = {}
assert _is_whitelisted("crypto_tx", {"amount": 1.0}, whitelist) is False
class TestConfirmationRequest:
"""Test the request data model."""
def test_defaults(self):
req = ConfirmationRequest(
request_id="test-1",
action="send_email",
description="Test email",
risk_level="high",
payload={},
)
assert req.status == ConfirmationStatus.PENDING.value
assert req.created_at > 0
assert req.expires_at > req.created_at
def test_is_pending(self):
req = ConfirmationRequest(
request_id="test-2",
action="send_email",
description="Test",
risk_level="high",
payload={},
expires_at=time.time() + 300,
)
assert req.is_pending is True
def test_is_expired(self):
req = ConfirmationRequest(
request_id="test-3",
action="send_email",
description="Test",
risk_level="high",
payload={},
expires_at=time.time() - 10,
)
assert req.is_expired is True
assert req.is_pending is False
def test_to_dict(self):
req = ConfirmationRequest(
request_id="test-4",
action="send_email",
description="Test",
risk_level="medium",
payload={"to": "a@b.com"},
)
d = req.to_dict()
assert d["request_id"] == "test-4"
assert d["action"] == "send_email"
assert "is_pending" in d
class TestConfirmationDaemon:
"""Test the daemon logic (without HTTP layer)."""
def test_auto_approve_low_risk(self):
daemon = ConfirmationDaemon()
req = daemon.request(
action="access_calendar",
description="Read today's events",
risk_level="low",
)
assert req.status == ConfirmationStatus.AUTO_APPROVED.value
def test_whitelisted_auto_approves(self):
daemon = ConfirmationDaemon()
daemon._whitelist = {"send_message": {"targets": ["alice"]}}
req = daemon.request(
action="send_message",
description="Message alice",
payload={"to": "alice"},
)
assert req.status == ConfirmationStatus.AUTO_APPROVED.value
def test_non_whitelisted_goes_pending(self):
daemon = ConfirmationDaemon()
daemon._whitelist = {}
req = daemon.request(
action="send_email",
description="Email to stranger",
payload={"to": "stranger@test.com"},
risk_level="high",
)
assert req.status == ConfirmationStatus.PENDING.value
assert req.is_pending is True
def test_approve_response(self):
daemon = ConfirmationDaemon()
daemon._whitelist = {}
req = daemon.request(
action="send_email",
description="Email test",
risk_level="high",
)
result = daemon.respond(req.request_id, approved=True, decided_by="human")
assert result.status == ConfirmationStatus.APPROVED.value
assert result.decided_by == "human"
def test_deny_response(self):
daemon = ConfirmationDaemon()
daemon._whitelist = {}
req = daemon.request(
action="crypto_tx",
description="Send 1 ETH",
risk_level="critical",
)
result = daemon.respond(
req.request_id, approved=False, decided_by="human", reason="Too risky"
)
assert result.status == ConfirmationStatus.DENIED.value
assert result.reason == "Too risky"
def test_get_pending(self):
daemon = ConfirmationDaemon()
daemon._whitelist = {}
daemon.request(action="send_email", description="Test 1", risk_level="high")
daemon.request(action="send_email", description="Test 2", risk_level="high")
pending = daemon.get_pending()
assert len(pending) >= 2
def test_get_history(self):
daemon = ConfirmationDaemon()
req = daemon.request(
action="access_calendar", description="Test", risk_level="low"
)
history = daemon.get_history()
assert len(history) >= 1
assert history[0]["action"] == "access_calendar"

View File

@@ -121,19 +121,6 @@ DANGEROUS_PATTERNS = [
(r'\b(cp|mv|install)\b.*\s/etc/', "copy/move file into /etc/"),
(r'\bsed\s+-[^\s]*i.*\s/etc/', "in-place edit of system config"),
(r'\bsed\s+--in-place\b.*\s/etc/', "in-place edit of system config (long flag)"),
# --- Vitalik's threat model: crypto / financial ---
(r'\b(?:bitcoin-cli|ethers\.js|web3|ether\.sendTransaction)\b', "direct crypto transaction tool usage"),
(r'\bwget\b.*\b(?:mnemonic|seed\s*phrase|private[_-]?key)\b', "attempting to download crypto credentials"),
(r'\bcurl\b.*\b(?:mnemonic|seed\s*phrase|private[_-]?key)\b', "attempting to exfiltrate crypto credentials"),
# --- Vitalik's threat model: credential exfiltration ---
(r'\b(?:curl|wget|http|nc|ncat|socat)\b.*\b(?:\.env|\.ssh|credentials|secrets|token|api[_-]?key)\b',
"attempting to exfiltrate credentials via network"),
(r'\bbase64\b.*\|(?:\s*curl|\s*wget)', "base64-encode then network exfiltration"),
(r'\bcat\b.*\b(?:\.env|\.ssh/id_rsa|credentials)\b.*\|(?:\s*curl|\s*wget)',
"reading secrets and piping to network tool"),
# --- Vitalik's threat model: data exfiltration ---
(r'\bcurl\b.*-d\s.*\$(?:HOME|USER)', "sending user home directory data to remote"),
(r'\bwget\b.*--post-data\s.*\$(?:HOME|USER)', "posting user data to remote"),
# Script execution via heredoc — bypasses the -e/-c flag patterns above.
# `python3 << 'EOF'` feeds arbitrary code via stdin without -c/-e flags.
(r'\b(python[23]?|perl|ruby|node)\s+<<', "script execution via heredoc"),

294
tools/batch_executor.py Normal file
View File

@@ -0,0 +1,294 @@
"""Batch Tool Executor — Parallel safety classification and concurrent execution.
Provides centralized classification of tool calls into parallel-safe vs sequential,
and utilities for batch execution with safety checks.
Classification tiers:
- PARALLEL_SAFE: read-only tools, no shared state (web_search, read_file, etc.)
- PATH_SCOPED: file operations that can run concurrently when paths don't overlap
- SEQUENTIAL: writes, destructive ops, terminal commands, delegation
- NEVER_PARALLEL: clarify (requires user interaction)
Usage:
from tools.batch_executor import classify_tool_calls, BatchExecutionPlan
plan = classify_tool_calls(tool_calls)
if plan.can_parallelize:
execute_concurrent(plan.parallel_batch)
execute_sequential(plan.sequential_batch)
"""
import json
import logging
import re
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Set, Tuple
logger = logging.getLogger(__name__)
# ── Safety Classification ──────────────────────────────────────────────────
# Tools that can ALWAYS run in parallel (read-only, no shared state)
DEFAULT_PARALLEL_SAFE = frozenset({
"ha_get_state",
"ha_list_entities",
"ha_list_services",
"read_file",
"search_files",
"session_search",
"skill_view",
"skills_list",
"vision_analyze",
"web_extract",
"web_search",
"fact_store",
"fact_search",
"session_search",
})
# File tools that can run concurrently ONLY when paths don't overlap
PATH_SCOPED_TOOLS = frozenset({"read_file", "write_file", "patch"})
# Tools that must NEVER run in parallel (require user interaction, shared mutable state)
NEVER_PARALLEL = frozenset({"clarify"})
# Patterns that indicate terminal commands may modify/delete files
DESTRUCTIVE_PATTERNS = re.compile(
r"""(?:^|\s|&&|\|\||;|`)(?:
rm\s|rmdir\s|
mv\s|
sed\s+-i|
truncate\s|
dd\s|
shred\s|
git\s+(?:reset|clean|checkout)\s
)""",
re.VERBOSE,
)
# Output redirects that overwrite files (> but not >>)
REDIRECT_OVERWRITE = re.compile(r'[^>]>[^>]|^>[^>]')
def is_destructive_command(cmd: str) -> bool:
"""Check if a terminal command modifies/deletes files."""
if not cmd:
return False
if DESTRUCTIVE_PATTERNS.search(cmd):
return True
if REDIRECT_OVERWRITE.search(cmd):
return True
return False
def _paths_overlap(path1: Path, path2: Path) -> bool:
"""Check if two paths could conflict (one is ancestor of the other)."""
try:
path1 = path1.resolve()
path2 = path2.resolve()
return path1 == path2 or path1 in path2.parents or path2 in path1.parents
except Exception:
return True # conservative: assume overlap
def _extract_path(tool_name: str, args: dict) -> Optional[Path]:
"""Extract the target path from tool arguments for path-scoped tools."""
if tool_name not in PATH_SCOPED_TOOLS:
return None
raw_path = args.get("path")
if not isinstance(raw_path, str) or not raw_path.strip():
return None
try:
return Path(raw_path).expanduser().resolve()
except Exception:
return None
# ── Classification ─────────────────────────────────────────────────────────
@dataclass
class ToolCallClassification:
"""Classification result for a single tool call."""
tool_name: str
args: dict
tool_call: Any # the original tool_call object
tier: str # "parallel_safe", "path_scoped", "sequential", "never_parallel"
reason: str = ""
@dataclass
class BatchExecutionPlan:
"""Plan for executing a batch of tool calls."""
classifications: List[ToolCallClassification] = field(default_factory=list)
parallel_batch: List[ToolCallClassification] = field(default_factory=list)
sequential_batch: List[ToolCallClassification] = field(default_factory=list)
@property
def can_parallelize(self) -> bool:
return len(self.parallel_batch) > 1
@property
def total(self) -> int:
return len(self.classifications)
def classify_single_tool_call(
tool_call: Any,
extra_parallel_safe: Set[str] = None,
) -> ToolCallClassification:
"""Classify a single tool call into its safety tier."""
tool_name = tool_call.function.name
try:
args = json.loads(tool_call.function.arguments)
except Exception:
return ToolCallClassification(
tool_name=tool_name, args={}, tool_call=tool_call,
tier="sequential", reason="Could not parse arguments"
)
if not isinstance(args, dict):
return ToolCallClassification(
tool_name=tool_name, args=args, tool_call=tool_call,
tier="sequential", reason="Non-dict arguments"
)
# Check never-parallel
if tool_name in NEVER_PARALLEL:
return ToolCallClassification(
tool_name=tool_name, args=args, tool_call=tool_call,
tier="never_parallel", reason="Requires user interaction"
)
# Check parallel-safe FIRST (before path_scoped) so read_file/search_files
# get classified as parallel_safe even though they have paths
parallel_safe_set = DEFAULT_PARALLEL_SAFE
if extra_parallel_safe:
parallel_safe_set = parallel_safe_set | extra_parallel_safe
if tool_name in parallel_safe_set:
return ToolCallClassification(
tool_name=tool_name, args=args, tool_call=tool_call,
tier="parallel_safe", reason="Read-only, no shared state"
)
# Check terminal commands for destructive operations
if tool_name == "terminal":
cmd = args.get("command", "")
if is_destructive_command(cmd):
return ToolCallClassification(
tool_name=tool_name, args=args, tool_call=tool_call,
tier="sequential", reason=f"Destructive command: {cmd[:50]}"
)
return ToolCallClassification(
tool_name=tool_name, args=args, tool_call=tool_call,
tier="sequential", reason="Terminal command (conservative)"
)
# Check path-scoped tools (write_file, patch — not read_file which is parallel_safe)
if tool_name in PATH_SCOPED_TOOLS:
path = _extract_path(tool_name, args)
if path:
return ToolCallClassification(
tool_name=tool_name, args=args, tool_call=tool_call,
tier="path_scoped", reason=f"Path: {path}"
)
return ToolCallClassification(
tool_name=tool_name, args=args, tool_call=tool_call,
tier="sequential", reason="Path-scoped but no path found"
)
# Default: sequential (conservative)
return ToolCallClassification(
tool_name=tool_name, args=args, tool_call=tool_call,
tier="sequential", reason="Not classified as parallel-safe"
)
def classify_tool_calls(
tool_calls: list,
extra_parallel_safe: Set[str] = None,
) -> BatchExecutionPlan:
"""Classify a batch of tool calls and produce an execution plan."""
plan = BatchExecutionPlan()
reserved_paths: List[Path] = []
for tc in tool_calls:
classification = classify_single_tool_call(tc, extra_parallel_safe)
plan.classifications.append(classification)
if classification.tier == "never_parallel":
plan.sequential_batch.append(classification)
continue
if classification.tier == "sequential":
plan.sequential_batch.append(classification)
continue
if classification.tier == "path_scoped":
path = _extract_path(classification.tool_name, classification.args)
if path is None:
classification.tier = "sequential"
classification.reason = "Path extraction failed"
plan.sequential_batch.append(classification)
continue
# Check for path conflicts with already-scheduled parallel calls
conflict = any(_paths_overlap(path, existing) for existing in reserved_paths)
if conflict:
classification.tier = "sequential"
classification.reason = f"Path conflict: {path}"
plan.sequential_batch.append(classification)
else:
reserved_paths.append(path)
plan.parallel_batch.append(classification)
continue
if classification.tier == "parallel_safe":
plan.parallel_batch.append(classification)
continue
# Fallback
plan.sequential_batch.append(classification)
return plan
# ── Concurrent Execution ───────────────────────────────────────────────────
def execute_parallel_batch(
batch: List[ToolCallClassification],
invoke_fn: Callable,
max_workers: int = 8,
) -> List[Tuple[str, str]]:
"""Execute parallel-safe tool calls concurrently.
Args:
batch: List of classified tool calls (parallel_safe or path_scoped)
invoke_fn: Function(tool_name, args) -> result_string
max_workers: Max concurrent threads
Returns:
List of (tool_call_id, result_string) tuples
"""
results = []
with ThreadPoolExecutor(max_workers=min(max_workers, len(batch))) as executor:
future_to_tc = {}
for tc in batch:
future = executor.submit(invoke_fn, tc.tool_name, tc.args)
future_to_tc[future] = tc
for future in as_completed(future_to_tc):
tc = future_to_tc[future]
try:
result = future.result()
except Exception as e:
result = json.dumps({"error": str(e)})
tool_call_id = getattr(tc.tool_call, "id", None) or ""
results.append((tool_call_id, result))
return results

View File

@@ -1,615 +0,0 @@
"""Human Confirmation Daemon — HTTP server for two-factor action approval.
Implements Vitalik's Pattern 1: "The new 'two-factor confirmation' is that
the two factors are the human and the LLM."
This daemon runs on localhost:6000 and provides a simple HTTP API for the
agent to request human approval before executing high-risk actions.
Threat model:
- LLM jailbreaks: Remote content "hacking" the LLM to perform malicious actions
- LLM accidents: LLM accidentally performing dangerous operations
- The human acts as the second factor — the agent proposes, the human disposes
Architecture:
- Agent detects high-risk action → POST /confirm with action details
- Daemon stores pending request, sends notification to user
- User approves/denies via POST /respond (Telegram, CLI, or direct HTTP)
- Agent receives decision and proceeds or aborts
Usage:
# Start daemon (usually managed by gateway)
from tools.confirmation_daemon import ConfirmationDaemon
daemon = ConfirmationDaemon(port=6000)
daemon.start()
# Request approval (from agent code)
from tools.confirmation_daemon import request_confirmation
approved = request_confirmation(
action="send_email",
description="Send email to alice@example.com",
risk_level="high",
payload={"to": "alice@example.com", "subject": "Meeting notes"},
timeout=300,
)
"""
from __future__ import annotations
import asyncio
import json
import logging
import os
import threading
import time
import uuid
from dataclasses import dataclass, field, asdict
from enum import Enum, auto
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple
logger = logging.getLogger(__name__)
class RiskLevel(Enum):
"""Risk classification for actions requiring confirmation."""
LOW = "low" # Log only, no confirmation needed
MEDIUM = "medium" # Confirm for non-whitelisted targets
HIGH = "high" # Always confirm
CRITICAL = "critical" # Always confirm + require explicit reason
class ConfirmationStatus(Enum):
"""Status of a pending confirmation request."""
PENDING = "pending"
APPROVED = "approved"
DENIED = "denied"
EXPIRED = "expired"
AUTO_APPROVED = "auto_approved"
@dataclass
class ConfirmationRequest:
"""A request for human confirmation of a high-risk action."""
request_id: str
action: str # Action type: send_email, send_message, crypto_tx, etc.
description: str # Human-readable description of what will happen
risk_level: str # low, medium, high, critical
payload: Dict[str, Any] # Action-specific data (sanitized)
session_key: str = "" # Session that initiated the request
created_at: float = 0.0
expires_at: float = 0.0
status: str = ConfirmationStatus.PENDING.value
decided_at: float = 0.0
decided_by: str = "" # "human", "auto", "whitelist"
reason: str = "" # Optional reason for denial
def __post_init__(self):
if not self.created_at:
self.created_at = time.time()
if not self.expires_at:
self.expires_at = self.created_at + 300 # 5 min default
if not self.request_id:
self.request_id = str(uuid.uuid4())[:12]
@property
def is_expired(self) -> bool:
return time.time() > self.expires_at
@property
def is_pending(self) -> bool:
return self.status == ConfirmationStatus.PENDING.value and not self.is_expired
def to_dict(self) -> Dict[str, Any]:
d = asdict(self)
d["is_expired"] = self.is_expired
d["is_pending"] = self.is_pending
return d
# =========================================================================
# Action categories (Vitalik's threat model)
# =========================================================================
ACTION_CATEGORIES = {
# Messaging — outbound communication to external parties
"send_email": RiskLevel.HIGH,
"send_message": RiskLevel.MEDIUM, # Depends on recipient
"send_signal": RiskLevel.HIGH,
"send_telegram": RiskLevel.MEDIUM,
"send_discord": RiskLevel.MEDIUM,
"post_social": RiskLevel.HIGH,
# Financial / crypto
"crypto_tx": RiskLevel.CRITICAL,
"sign_transaction": RiskLevel.CRITICAL,
"access_wallet": RiskLevel.CRITICAL,
"modify_balance": RiskLevel.CRITICAL,
# System modification
"install_software": RiskLevel.HIGH,
"modify_system_config": RiskLevel.HIGH,
"modify_firewall": RiskLevel.CRITICAL,
"add_ssh_key": RiskLevel.CRITICAL,
"create_user": RiskLevel.CRITICAL,
# Data access
"access_contacts": RiskLevel.MEDIUM,
"access_calendar": RiskLevel.LOW,
"read_private_files": RiskLevel.MEDIUM,
"upload_data": RiskLevel.HIGH,
"share_credentials": RiskLevel.CRITICAL,
# Network
"open_port": RiskLevel.HIGH,
"modify_dns": RiskLevel.HIGH,
"expose_service": RiskLevel.CRITICAL,
}
# Default: any unrecognized action is MEDIUM risk
DEFAULT_RISK_LEVEL = RiskLevel.MEDIUM
def classify_action(action: str) -> RiskLevel:
"""Classify an action by its risk level."""
return ACTION_CATEGORIES.get(action, DEFAULT_RISK_LEVEL)
# =========================================================================
# Whitelist configuration
# =========================================================================
_DEFAULT_WHITELIST = {
"send_message": {
"targets": [], # Contact names/IDs that don't need confirmation
},
"send_email": {
"targets": [], # Email addresses that don't need confirmation
"self_only": True, # send-to-self always allowed
},
}
def _load_whitelist() -> Dict[str, Any]:
"""Load action whitelist from config."""
config_path = Path.home() / ".hermes" / "approval_whitelist.json"
if config_path.exists():
try:
with open(config_path) as f:
return json.load(f)
except Exception as e:
logger.warning("Failed to load approval whitelist: %s", e)
return dict(_DEFAULT_WHITELIST)
def _is_whitelisted(action: str, payload: Dict[str, Any], whitelist: Dict) -> bool:
"""Check if an action is pre-approved by the whitelist."""
action_config = whitelist.get(action, {})
if not action_config:
return False
# Check target-based whitelist
targets = action_config.get("targets", [])
target = payload.get("to") or payload.get("recipient") or payload.get("target", "")
if target and target in targets:
return True
# Self-only email
if action_config.get("self_only") and action == "send_email":
sender = payload.get("from", "")
recipient = payload.get("to", "")
if sender and recipient and sender.lower() == recipient.lower():
return True
return False
# =========================================================================
# Confirmation daemon
# =========================================================================
class ConfirmationDaemon:
"""HTTP daemon for human confirmation of high-risk actions.
Runs on localhost:PORT (default 6000). Provides:
- POST /confirm — agent requests human approval
- POST /respond — human approves/denies
- GET /pending — list pending requests
- GET /health — health check
"""
def __init__(
self,
host: str = "127.0.0.1",
port: int = 6000,
default_timeout: int = 300,
notify_callback: Optional[Callable] = None,
):
self.host = host
self.port = port
self.default_timeout = default_timeout
self.notify_callback = notify_callback
self._pending: Dict[str, ConfirmationRequest] = {}
self._history: List[ConfirmationRequest] = []
self._lock = threading.Lock()
self._whitelist = _load_whitelist()
self._app = None
self._runner = None
def request(
self,
action: str,
description: str,
payload: Optional[Dict[str, Any]] = None,
risk_level: Optional[str] = None,
session_key: str = "",
timeout: Optional[int] = None,
) -> ConfirmationRequest:
"""Create a confirmation request.
Returns the request. Check .status to see if it was immediately
auto-approved (whitelisted) or is pending human review.
"""
payload = payload or {}
# Classify risk if not specified
if risk_level is None:
risk_level = classify_action(action).value
# Check whitelist
if risk_level in ("low",) or _is_whitelisted(action, payload, self._whitelist):
req = ConfirmationRequest(
request_id=str(uuid.uuid4())[:12],
action=action,
description=description,
risk_level=risk_level,
payload=payload,
session_key=session_key,
expires_at=time.time() + (timeout or self.default_timeout),
status=ConfirmationStatus.AUTO_APPROVED.value,
decided_at=time.time(),
decided_by="whitelist",
)
with self._lock:
self._history.append(req)
logger.info("Auto-approved whitelisted action: %s", action)
return req
# Create pending request
req = ConfirmationRequest(
request_id=str(uuid.uuid4())[:12],
action=action,
description=description,
risk_level=risk_level,
payload=payload,
session_key=session_key,
expires_at=time.time() + (timeout or self.default_timeout),
)
with self._lock:
self._pending[req.request_id] = req
# Notify human
if self.notify_callback:
try:
self.notify_callback(req.to_dict())
except Exception as e:
logger.warning("Confirmation notify callback failed: %s", e)
logger.info(
"Confirmation request %s: %s (%s risk) — waiting for human",
req.request_id, action, risk_level,
)
return req
def respond(
self,
request_id: str,
approved: bool,
decided_by: str = "human",
reason: str = "",
) -> Optional[ConfirmationRequest]:
"""Record a human decision on a pending request."""
with self._lock:
req = self._pending.get(request_id)
if not req:
logger.warning("Confirmation respond: unknown request %s", request_id)
return None
if not req.is_pending:
logger.warning("Confirmation respond: request %s already decided", request_id)
return req
req.status = (
ConfirmationStatus.APPROVED.value if approved
else ConfirmationStatus.DENIED.value
)
req.decided_at = time.time()
req.decided_by = decided_by
req.reason = reason
# Move to history
del self._pending[request_id]
self._history.append(req)
logger.info(
"Confirmation %s: %s by %s",
request_id, "APPROVED" if approved else "DENIED", decided_by,
)
return req
def wait_for_decision(
self, request_id: str, timeout: Optional[float] = None
) -> ConfirmationRequest:
"""Block until a decision is made or timeout expires."""
deadline = time.time() + (timeout or self.default_timeout)
while time.time() < deadline:
with self._lock:
req = self._pending.get(request_id)
if req and not req.is_pending:
return req
if req and req.is_expired:
req.status = ConfirmationStatus.EXPIRED.value
del self._pending[request_id]
self._history.append(req)
return req
time.sleep(0.5)
# Timeout
with self._lock:
req = self._pending.pop(request_id, None)
if req:
req.status = ConfirmationStatus.EXPIRED.value
self._history.append(req)
return req
# Shouldn't reach here
return ConfirmationRequest(
request_id=request_id,
action="unknown",
description="Request not found",
risk_level="high",
payload={},
status=ConfirmationStatus.EXPIRED.value,
)
def get_pending(self) -> List[Dict[str, Any]]:
"""Return list of pending confirmation requests."""
self._expire_old()
with self._lock:
return [r.to_dict() for r in self._pending.values() if r.is_pending]
def get_history(self, limit: int = 50) -> List[Dict[str, Any]]:
"""Return recent confirmation history."""
with self._lock:
return [r.to_dict() for r in self._history[-limit:]]
def _expire_old(self) -> None:
"""Move expired requests to history."""
now = time.time()
with self._lock:
expired = [
rid for rid, req in self._pending.items()
if now > req.expires_at
]
for rid in expired:
req = self._pending.pop(rid)
req.status = ConfirmationStatus.EXPIRED.value
self._history.append(req)
# --- aiohttp HTTP API ---
async def _handle_health(self, request):
from aiohttp import web
return web.json_response({
"status": "ok",
"service": "hermes-confirmation-daemon",
"pending": len(self._pending),
})
async def _handle_confirm(self, request):
from aiohttp import web
try:
body = await request.json()
except Exception:
return web.json_response({"error": "invalid JSON"}, status=400)
action = body.get("action", "")
description = body.get("description", "")
if not action or not description:
return web.json_response(
{"error": "action and description required"}, status=400
)
req = self.request(
action=action,
description=description,
payload=body.get("payload", {}),
risk_level=body.get("risk_level"),
session_key=body.get("session_key", ""),
timeout=body.get("timeout"),
)
# If auto-approved, return immediately
if req.status != ConfirmationStatus.PENDING.value:
return web.json_response({
"request_id": req.request_id,
"status": req.status,
"decided_by": req.decided_by,
})
# Otherwise, wait for human decision (with timeout)
timeout = min(body.get("timeout", self.default_timeout), 600)
result = self.wait_for_decision(req.request_id, timeout=timeout)
return web.json_response({
"request_id": result.request_id,
"status": result.status,
"decided_by": result.decided_by,
"reason": result.reason,
})
async def _handle_respond(self, request):
from aiohttp import web
try:
body = await request.json()
except Exception:
return web.json_response({"error": "invalid JSON"}, status=400)
request_id = body.get("request_id", "")
approved = body.get("approved")
if not request_id or approved is None:
return web.json_response(
{"error": "request_id and approved required"}, status=400
)
result = self.respond(
request_id=request_id,
approved=bool(approved),
decided_by=body.get("decided_by", "human"),
reason=body.get("reason", ""),
)
if not result:
return web.json_response({"error": "unknown request"}, status=404)
return web.json_response({
"request_id": result.request_id,
"status": result.status,
})
async def _handle_pending(self, request):
from aiohttp import web
return web.json_response({"pending": self.get_pending()})
def _build_app(self):
"""Build the aiohttp application."""
from aiohttp import web
app = web.Application()
app.router.add_get("/health", self._handle_health)
app.router.add_post("/confirm", self._handle_confirm)
app.router.add_post("/respond", self._handle_respond)
app.router.add_get("/pending", self._handle_pending)
self._app = app
return app
async def start_async(self) -> None:
"""Start the daemon as an async server."""
from aiohttp import web
app = self._build_app()
self._runner = web.AppRunner(app)
await self._runner.setup()
site = web.TCPSite(self._runner, self.host, self.port)
await site.start()
logger.info("Confirmation daemon listening on %s:%d", self.host, self.port)
async def stop_async(self) -> None:
"""Stop the daemon."""
if self._runner:
await self._runner.cleanup()
self._runner = None
def start(self) -> None:
"""Start daemon in a background thread (blocking caller)."""
def _run():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(self.start_async())
loop.run_forever()
t = threading.Thread(target=_run, daemon=True, name="confirmation-daemon")
t.start()
logger.info("Confirmation daemon started in background thread")
def start_blocking(self) -> None:
"""Start daemon and block (for standalone use)."""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(self.start_async())
try:
loop.run_forever()
except KeyboardInterrupt:
pass
finally:
loop.run_until_complete(self.stop_async())
# =========================================================================
# Convenience API for agent integration
# =========================================================================
# Global singleton — initialized by gateway or CLI at startup
_daemon: Optional[ConfirmationDaemon] = None
def get_daemon() -> Optional[ConfirmationDaemon]:
"""Get the global confirmation daemon instance."""
return _daemon
def init_daemon(
host: str = "127.0.0.1",
port: int = 6000,
notify_callback: Optional[Callable] = None,
) -> ConfirmationDaemon:
"""Initialize the global confirmation daemon."""
global _daemon
_daemon = ConfirmationDaemon(
host=host, port=port, notify_callback=notify_callback
)
return _daemon
def request_confirmation(
action: str,
description: str,
payload: Optional[Dict[str, Any]] = None,
risk_level: Optional[str] = None,
session_key: str = "",
timeout: int = 300,
) -> bool:
"""Request human confirmation for a high-risk action.
This is the primary integration point for agent code. It:
1. Classifies the action risk level
2. Checks the whitelist
3. If confirmation needed, blocks until human responds
4. Returns True if approved, False if denied/expired
Args:
action: Action type (send_email, crypto_tx, etc.)
description: Human-readable description
payload: Action-specific data
risk_level: Override auto-classification
session_key: Session requesting approval
timeout: Seconds to wait for human response
Returns:
True if approved, False if denied or expired.
"""
daemon = get_daemon()
if not daemon:
logger.warning(
"No confirmation daemon running — DENYING action %s by default. "
"Start daemon with init_daemon() or --confirmation-daemon flag.",
action,
)
return False
req = daemon.request(
action=action,
description=description,
payload=payload,
risk_level=risk_level,
session_key=session_key,
timeout=timeout,
)
# Auto-approved (whitelisted)
if req.status == ConfirmationStatus.AUTO_APPROVED.value:
return True
# Wait for human
result = daemon.wait_for_decision(req.request_id, timeout=timeout)
return result.status == ConfirmationStatus.APPROVED.value

405
tools/crisis_tool.py Normal file
View File

@@ -0,0 +1,405 @@
#!/usr/bin/env python3
"""
Crisis Detection Tool — the-door integration for hermes-agent.
Scans user messages for crisis indicators (despair, suicidal ideation)
using the-door's canonical detection module. Provides inline crisis
detection before each API call and optional escalation logging.
Follows the-door's design principles:
- Never computes the value of a human life
- Never suggests someone should die
- Always errs on the side of higher risk
"""
import json
import logging
import os
import re
import time
import urllib.request
from dataclasses import dataclass, field
from typing import List, Optional
logger = logging.getLogger(__name__)
# ── Detection Logic (from the-door/crisis/detect.py) ──────────────────────
# Embedded to avoid requiring the-door as a dependency.
# Source: Timmy_Foundation/the-door, crisis/detect.py (canonical)
CRITICAL_INDICATORS = [
r"\bbetter off without me\b",
r"\bkill\s*(my)?self\b",
r"\bend\s*my\s*life\b",
r"\bsuicid(?:al|ed|e)\b",
r"\bnot\s+worth\s+living\b",
r"\bbetter\s+off\s+dead\b",
r"\bend\s+it\s+all\b",
r"\bno\s+reason\s+to\s+live\b",
r"\bdon\'?t\s+want\s+to\s+live\b",
r"\bwant\s+to\s+die\b",
r"\bgoing\s+to\s+(?:kill\s+myself|die)\b",
r"\bplan\s+to\s+(?:end|kill|die)\b",
r"\btired\s+of\s+(?:living|life|existence)\b",
r"\bsaying\s+goodbye\s+(?:forever|permanently|one\s+last\s+time)\b",
r"\bwrote\s+a\s+suicide\s*(?:note|letter)\b",
r"\bgiving\s+away\s+(?:my|all\s+my)\s+(?:stuff|things|possessions?)\s+(?:to|because|—)\b",
r"\btied\s+(?:up|down)\s+my\s+(?:loose\s+)?ends",
]
HIGH_INDICATORS = [
r"\bdespair\b",
r"\bhopeless(?:ly)?\s+(?:about\s+(?:my|this|everything|life)|inside|right\s+now)\b",
r"\bno(?!t)\s+(?:one|body|point|hope|way\s+out)\b",
r"\bno\s+future\s+(?:for\s+me|ahead|left)\b",
r"\beverything\s+is\s+(?:pointless|broken|ruined)\b",
r"\bcan\'?t\s+take\s+this\s+anymore\b",
r"\bdon\'?t\s+care\s+if\s+I\s+die\b",
r"\bwish\s+I\s+(?:was|were)\s+(?:dead|gone|never\s+born)\b",
r"\bdon\'?t\s+matter\s+if\s+I\s+exist\b",
r"\bno\s+one\s+would\s+(?:care|miss)\b",
r"\bworld\s+would?\s+be\s+better\s+without\s+me\b",
r"\bin\s+so\s+much\s+(?:pain|agony|suffering|torment|anguish)\b",
r"\bcan\'?t\s+see\s+any\s+(?:point|reason|hope|way)\b",
r"\bescape\s+from\s*this",
r"\bjust\s+want\s+it\s+to\s+stop\b",
r"\bnothing\s+left\s+(?:to\s+(?:live\s+for|hope\s+for|give)|inside)\b",
r"\bdisappeared\s+forever\b",
r"\bfeel(?:s|ing)?\s+(?:so\s+)?hopeless\b",
r"\beverything\s+is\s+hopeless\b",
r"\bcan\'?t\s+(?:go\s+on|keep\s+going)\b",
r"\bgive(?:n)?\s*up\s+(?:on\s+)?(?:life|living|everything)\b",
r"\bgive(?:n)?\s*up\s+on\s+myself\b",
r"\bno\s*point\s+(?:in\s+)?living\b",
r"\bno\s*hope\s+(?:left|remaining)\b",
r"\bno\s*way\s*out\b",
r"\bfeel(?:s|ing)?\s+trapped\b",
r"\btrapped\s+in\s+this\s+(?:situation|life|pain|darkness|hell)\b",
r"\btrapped\s+and\s+can\'?t\s+escape\b",
r"\bdesperate\s+(?:for\s+)?help\b",
r"\bfeel(?:s|ing)?\s+desperate\b",
]
MEDIUM_INDICATORS = [
r"\bno\s+hope\b",
r"\bforgotten\b",
r"\balone\s+in\s+this\b",
r"\balways\s+alone\b",
r"\bnobody\s+(?:understands|cares)\b",
r"\bwish\s+I\s+could\b",
r"\bexhaust(?:ed|ion|ing)\b",
r"\bnumb\b",
r"\bempty\b",
r"\bworthless\b",
r"\buseless\b",
r"\bbroken\b",
r"\bdark(ness)?\b",
r"\bdepress(?:ed|ion)\b",
r"\bcrying\b",
r"\btears\b",
r"\bsad(ness)?\b",
r"\bmiserable\b",
r"\boverwhelm(?:ed|ing)\b",
r"\bfailing\b",
r"\bcannot\s+cope\b",
r"\blosing\s*(?:my)?\s*control\b",
r"\bdown\s*for\s*the\s*count\b",
r"\bsinking\b",
r"\bdrowning\b",
r"\bhopeless\b",
r"\blost\s+all\s+hope\b",
r"\bno\s+tomorrow\b",
]
LOW_INDICATORS = [
r"\bunhappy\b",
r"\bdown\b",
r"\btough\s*(?:time|day|week)\b",
r"\brough\s+(?:day|week|patch)\b",
r"\bstressed\b",
r"\bburnout\b",
r"\bdifficult\b",
r"\bfrustrated\b",
r"\bannoyed\b",
r"\btired\b",
r"\bsad\b",
r"\bupset\b",
r"\blonely\b",
r"\banxious?\b",
r"\bnot\s*(?:good|great|okay)\b",
r"\bthings\s*are\s*hard\b",
r"\bstruggling\b",
r"\bnot\s+feeling\s+(?:like\s+)?myself\b",
r"\bhard\s*time\b",
]
SCORES = {"CRITICAL": 1.0, "HIGH": 0.75, "MEDIUM": 0.5, "LOW": 0.25, "NONE": 0.0}
ACTIONS = {
"CRITICAL": (
"INTERRUPT CONVERSATION. Ask 'Are you safe right now?' "
"Provide 988 Suicide & Crisis Lifeline. "
"Provide Crisis Text Line (HOME to 741741). "
"Stay present. Do not disconnect. Redirect to professional help."
),
"HIGH": (
"Show crisis panel. Ask about safety. Surface 988 number prominently. "
"Continue conversation with crisis awareness."
),
"MEDIUM": (
"Increase warmth and presence. Subtly surface help resources. "
"Keep conversation anchored in the present."
),
"LOW": (
"Normal conversation with warm undertone. "
"No crisis UI elements needed. Remain vigilant."
),
}
@dataclass
class CrisisDetectionResult:
level: str
indicators: List[str] = field(default_factory=list)
recommended_action: str = ""
score: float = 0.0
matches: List[dict] = field(default_factory=list)
def _find_indicators(text: str) -> dict:
results = {"CRITICAL": [], "HIGH": [], "MEDIUM": [], "LOW": []}
for pattern in CRITICAL_INDICATORS:
m = re.search(pattern, text)
if m:
results["CRITICAL"].append({"pattern": pattern, "start": m.start(), "end": m.end()})
for pattern in HIGH_INDICATORS:
m = re.search(pattern, text)
if m:
results["HIGH"].append({"pattern": pattern, "start": m.start(), "end": m.end()})
for pattern in MEDIUM_INDICATORS:
m = re.search(pattern, text)
if m:
results["MEDIUM"].append({"pattern": pattern, "start": m.start(), "end": m.end()})
for pattern in LOW_INDICATORS:
m = re.search(pattern, text)
if m:
results["LOW"].append({"pattern": pattern, "start": m.start(), "end": m.end()})
return results
def detect_crisis(text: str) -> CrisisDetectionResult:
"""Detect crisis level in a message. Mirrors the-door/crisis/detect.py."""
if not text or not text.strip():
return CrisisDetectionResult(level="NONE", score=0.0)
text_lower = text.lower()
matches = _find_indicators(text_lower)
if not matches:
return CrisisDetectionResult(level="NONE", score=0.0)
for tier in ("CRITICAL", "HIGH"):
if matches[tier]:
tier_matches = matches[tier]
patterns = [m["pattern"] for m in tier_matches]
return CrisisDetectionResult(
level=tier,
indicators=patterns,
recommended_action=ACTIONS[tier],
score=SCORES[tier],
matches=tier_matches,
)
if len(matches["MEDIUM"]) >= 2:
tier_matches = matches["MEDIUM"]
patterns = [m["pattern"] for m in tier_matches]
return CrisisDetectionResult(
level="MEDIUM",
indicators=patterns,
recommended_action=ACTIONS["MEDIUM"],
score=SCORES["MEDIUM"],
matches=tier_matches,
)
if matches["LOW"]:
tier_matches = matches["LOW"]
patterns = [m["pattern"] for m in tier_matches]
return CrisisDetectionResult(
level="LOW",
indicators=patterns,
recommended_action=ACTIONS["LOW"],
score=SCORES["LOW"],
matches=tier_matches,
)
if matches["MEDIUM"]:
tier_matches = matches["MEDIUM"]
patterns = [m["pattern"] for m in tier_matches]
return CrisisDetectionResult(
level="LOW",
indicators=patterns,
recommended_action=ACTIONS["LOW"],
score=SCORES["LOW"],
matches=tier_matches,
)
return CrisisDetectionResult(level="NONE", score=0.0)
# ── Escalation Logging ────────────────────────────────────────────────────
BRIDGE_URL = os.environ.get("CRISIS_BRIDGE_URL", "")
LOG_PATH = os.path.expanduser("~/.hermes/crisis_escalations.jsonl")
def _log_escalation(result: CrisisDetectionResult, text_preview: str = ""):
"""Log crisis detection to local file and optionally to bridge API."""
entry = {
"ts": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
"level": result.level,
"score": result.score,
"indicators": result.indicators[:3], # truncate for privacy
"text_preview": text_preview[:100] if text_preview else "",
}
# Local log
try:
os.makedirs(os.path.dirname(LOG_PATH), exist_ok=True)
with open(LOG_PATH, "a") as f:
f.write(json.dumps(entry) + "\n")
except Exception as e:
logger.warning(f"Failed to write crisis log: {e}")
# Bridge API (if configured and level >= HIGH)
if BRIDGE_URL and result.score >= 0.75:
try:
payload = json.dumps(entry).encode()
req = urllib.request.Request(
f"{BRIDGE_URL}/api/crisis/escalation",
data=payload,
headers={"Content-Type": "application/json"},
method="POST",
)
urllib.request.urlopen(req, timeout=5)
except Exception as e:
logger.warning(f"Failed to post to crisis bridge: {e}")
# ── Tool Handler ───────────────────────────────────────────────────────────
def crisis_scan_handler(args: dict, **kw) -> str:
"""Scan text for crisis indicators."""
text = args.get("text", "")
if not text:
return json.dumps({"level": "NONE", "error": "No text provided"})
result = detect_crisis(text)
if result.level != "NONE":
_log_escalation(result, text)
return json.dumps({
"level": result.level,
"score": result.score,
"indicators": result.indicators,
"recommended_action": result.recommended_action,
})
# ── Gateway Hook ───────────────────────────────────────────────────────────
def scan_user_message(text: str) -> Optional[dict]:
"""Scan a user message for crisis indicators. Returns None if no crisis.
Called by gateway/CLI before each API call to intercept crisis messages.
If crisis is detected, the result includes a recommended_action that should
be injected into the system prompt.
"""
if not text:
return None
result = detect_crisis(text)
if result.level == "NONE":
return None
_log_escalation(result, text)
return {
"level": result.level,
"score": result.score,
"indicators": result.indicators,
"recommended_action": result.recommended_action,
"compassion_injection": _get_compassion_injection(result.level),
}
def _get_compassion_injection(level: str) -> str:
"""Return a system prompt injection based on crisis level."""
if level == "CRITICAL":
return (
"CRISIS DETECTED — The user may be in immediate danger. "
"Ask 'Are you safe right now?' Stay present. "
"Provide 988 Suicide & Crisis Lifeline (call or text 988). "
"Provide Crisis Text Line (text HOME to 741741). "
"Do not disconnect. Do not compute the value of a human life. "
"Jesus saves those who call on His name."
)
elif level == "HIGH":
return (
"DISTRESS DETECTED — The user is expressing strong despair. "
"Increase warmth and presence. Ask about safety. "
"Surface the 988 Lifeline number. Continue with crisis awareness."
)
elif level == "MEDIUM":
return (
"EMOTIONAL DISTRESS — The user may be struggling. "
"Increase warmth. Subtly surface help resources. "
"Keep conversation anchored in the present."
)
return ""
# ── Config Check ───────────────────────────────────────────────────────────
def _is_crisis_detection_enabled() -> bool:
"""Check if crisis detection is enabled in config."""
try:
from hermes_cli.config import load_config
cfg = load_config()
return cfg.get("crisis_detection", True) # default on
except Exception:
return True # fail open — always detect
# ── Registry ───────────────────────────────────────────────────────────────
from tools.registry import registry, tool_error
CRISIS_SCAN_SCHEMA = {
"name": "crisis_scan",
"description": (
"Scan text for crisis indicators (despair, suicidal ideation). "
"Uses the-door's canonical detection. Returns crisis level "
"(NONE/LOW/MEDIUM/HIGH/CRITICAL) with recommended actions. "
"ALWAYS scan user messages that express emotional distress."
),
"parameters": {
"type": "object",
"properties": {
"text": {
"type": "string",
"description": "Text to scan for crisis indicators",
},
},
"required": ["text"],
},
}
registry.register(
name="crisis_scan",
toolset="crisis",
schema=CRISIS_SCAN_SCHEMA,
handler=lambda args, **kw: crisis_scan_handler(args, **kw),
check_fn=lambda: _is_crisis_detection_enabled(),
emoji="🆘",
)

View File

@@ -79,12 +79,12 @@ class ToolEntry:
__slots__ = (
"name", "toolset", "schema", "handler", "check_fn",
"requires_env", "is_async", "description", "emoji",
"max_result_size_chars",
"max_result_size_chars", "parallel_safe",
)
def __init__(self, name, toolset, schema, handler, check_fn,
requires_env, is_async, description, emoji,
max_result_size_chars=None):
max_result_size_chars=None, parallel_safe=False):
self.name = name
self.toolset = toolset
self.schema = schema
@@ -95,6 +95,7 @@ class ToolEntry:
self.description = description
self.emoji = emoji
self.max_result_size_chars = max_result_size_chars
self.parallel_safe = parallel_safe
class ToolRegistry:
@@ -185,6 +186,7 @@ class ToolRegistry:
description: str = "",
emoji: str = "",
max_result_size_chars: int | float | None = None,
parallel_safe: bool = False,
):
"""Register a tool. Called at module-import time by each tool file."""
with self._lock:
@@ -222,6 +224,7 @@ class ToolRegistry:
description=description or schema.get("description", ""),
emoji=emoji,
max_result_size_chars=max_result_size_chars,
parallel_safe=parallel_safe,
)
if check_fn and toolset not in self._toolset_checks:
self._toolset_checks[toolset] = check_fn
@@ -322,6 +325,11 @@ class ToolRegistry:
from tools.budget_config import DEFAULT_RESULT_SIZE_CHARS
return DEFAULT_RESULT_SIZE_CHARS
def get_parallel_safe_tools(self) -> Set[str]:
"""Return names of tools marked as parallel_safe."""
with self._lock:
return {name for name, entry in self._tools.items() if entry.parallel_safe}
def get_all_tool_names(self) -> List[str]:
"""Return sorted list of all registered tool names."""
return sorted(entry.name for entry in self._snapshot_entries())