Compare commits
3 Commits
fix/819
...
fix/749-ba
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9f0c410481 | ||
|
|
30afd529ac | ||
|
|
a244b157be |
@@ -1,224 +0,0 @@
|
|||||||
"""A2A Agent Card — publish capabilities for fleet discovery.
|
|
||||||
|
|
||||||
Each fleet agent publishes an A2A-compliant agent card describing its capabilities.
|
|
||||||
Standard discovery endpoint: /.well-known/agent-card.json
|
|
||||||
|
|
||||||
Issue #819: feat: A2A agent card — publish capabilities for fleet discovery
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import socket
|
|
||||||
from dataclasses import dataclass, field, asdict
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any, Dict, List, Optional
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class AgentSkill:
|
|
||||||
"""A single skill the agent can perform."""
|
|
||||||
id: str
|
|
||||||
name: str
|
|
||||||
description: str = ""
|
|
||||||
tags: List[str] = field(default_factory=list)
|
|
||||||
examples: List[str] = field(default_factory=list)
|
|
||||||
input_modes: List[str] = field(default_factory=lambda: ["text/plain"])
|
|
||||||
output_modes: List[str] = field(default_factory=lambda: ["text/plain"])
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class AgentCapabilities:
|
|
||||||
"""What the agent can do."""
|
|
||||||
streaming: bool = True
|
|
||||||
push_notifications: bool = False
|
|
||||||
state_transition_history: bool = True
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class AgentCard:
|
|
||||||
"""A2A-compliant agent card."""
|
|
||||||
name: str
|
|
||||||
description: str
|
|
||||||
url: str
|
|
||||||
version: str = "1.0.0"
|
|
||||||
capabilities: AgentCapabilities = field(default_factory=AgentCapabilities)
|
|
||||||
skills: List[AgentSkill] = field(default_factory=list)
|
|
||||||
default_input_modes: List[str] = field(default_factory=lambda: ["text/plain", "application/json"])
|
|
||||||
default_output_modes: List[str] = field(default_factory=lambda: ["text/plain", "application/json"])
|
|
||||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
|
||||||
"""Convert to JSON-serializable dict."""
|
|
||||||
d = asdict(self)
|
|
||||||
# Rename for A2A spec compliance
|
|
||||||
d["defaultInputModes"] = d.pop("default_input_modes")
|
|
||||||
d["defaultOutputModes"] = d.pop("default_output_modes")
|
|
||||||
return d
|
|
||||||
|
|
||||||
def to_json(self) -> str:
|
|
||||||
"""Serialize to JSON string."""
|
|
||||||
return json.dumps(self.to_dict(), indent=2)
|
|
||||||
|
|
||||||
|
|
||||||
def _load_skills_from_directory(skills_dir: Path) -> List[AgentSkill]:
|
|
||||||
"""Scan ~/.hermes/skills/ for SKILL.md frontmatter."""
|
|
||||||
skills = []
|
|
||||||
|
|
||||||
if not skills_dir.exists():
|
|
||||||
return skills
|
|
||||||
|
|
||||||
for skill_dir in skills_dir.iterdir():
|
|
||||||
if not skill_dir.is_dir():
|
|
||||||
continue
|
|
||||||
|
|
||||||
skill_md = skill_dir / "SKILL.md"
|
|
||||||
if not skill_md.exists():
|
|
||||||
continue
|
|
||||||
|
|
||||||
try:
|
|
||||||
content = skill_md.read_text(encoding="utf-8")
|
|
||||||
|
|
||||||
# Parse YAML frontmatter
|
|
||||||
if content.startswith("---"):
|
|
||||||
parts = content.split("---", 2)
|
|
||||||
if len(parts) >= 3:
|
|
||||||
import yaml
|
|
||||||
try:
|
|
||||||
metadata = yaml.safe_load(parts[1]) or {}
|
|
||||||
except Exception:
|
|
||||||
metadata = {}
|
|
||||||
|
|
||||||
name = metadata.get("name", skill_dir.name)
|
|
||||||
desc = metadata.get("description", "")
|
|
||||||
tags = metadata.get("tags", [])
|
|
||||||
|
|
||||||
skills.append(AgentSkill(
|
|
||||||
id=skill_dir.name,
|
|
||||||
name=name,
|
|
||||||
description=desc[:200] if desc else "",
|
|
||||||
tags=tags if isinstance(tags, list) else [],
|
|
||||||
))
|
|
||||||
except Exception:
|
|
||||||
continue
|
|
||||||
|
|
||||||
return skills
|
|
||||||
|
|
||||||
|
|
||||||
def validate_agent_card(card: AgentCard) -> List[str]:
|
|
||||||
"""Validate agent card against A2A schema requirements.
|
|
||||||
|
|
||||||
Returns list of validation errors (empty if valid).
|
|
||||||
"""
|
|
||||||
errors = []
|
|
||||||
|
|
||||||
if not card.name:
|
|
||||||
errors.append("name is required")
|
|
||||||
if not card.url:
|
|
||||||
errors.append("url is required")
|
|
||||||
|
|
||||||
# Validate MIME types
|
|
||||||
valid_modes = {"text/plain", "application/json", "image/png", "audio/wav"}
|
|
||||||
for mode in card.default_input_modes:
|
|
||||||
if mode not in valid_modes:
|
|
||||||
errors.append(f"invalid input mode: {mode}")
|
|
||||||
for mode in card.default_output_modes:
|
|
||||||
if mode not in valid_modes:
|
|
||||||
errors.append(f"invalid output mode: {mode}")
|
|
||||||
|
|
||||||
# Validate skills
|
|
||||||
for skill in card.skills:
|
|
||||||
if not skill.id:
|
|
||||||
errors.append(f"skill missing id: {skill.name}")
|
|
||||||
|
|
||||||
return errors
|
|
||||||
|
|
||||||
|
|
||||||
def build_agent_card(
|
|
||||||
name: Optional[str] = None,
|
|
||||||
description: Optional[str] = None,
|
|
||||||
url: Optional[str] = None,
|
|
||||||
version: Optional[str] = None,
|
|
||||||
skills: Optional[List[AgentSkill]] = None,
|
|
||||||
extra_skills: Optional[List[AgentSkill]] = None,
|
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
|
||||||
) -> AgentCard:
|
|
||||||
"""Build an A2A agent card from config and environment.
|
|
||||||
|
|
||||||
Priority: explicit params > env vars > config.yaml > defaults
|
|
||||||
"""
|
|
||||||
# Load config
|
|
||||||
config_model = ""
|
|
||||||
config_provider = ""
|
|
||||||
try:
|
|
||||||
from hermes_cli.config import load_config
|
|
||||||
cfg = load_config()
|
|
||||||
model_cfg = cfg.get("model", {})
|
|
||||||
if isinstance(model_cfg, dict):
|
|
||||||
config_model = model_cfg.get("default", "")
|
|
||||||
config_provider = model_cfg.get("provider", "")
|
|
||||||
elif isinstance(model_cfg, str):
|
|
||||||
config_model = model_cfg
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Resolve values with priority
|
|
||||||
agent_name = name or os.environ.get("HERMES_AGENT_NAME", "") or "hermes"
|
|
||||||
agent_desc = description or os.environ.get("HERMES_AGENT_DESCRIPTION", "") or "Sovereign AI agent"
|
|
||||||
agent_url = url or os.environ.get("HERMES_AGENT_URL", "") or f"http://localhost:{os.environ.get('HERMES_API_PORT', '8642')}"
|
|
||||||
agent_version = version or os.environ.get("HERMES_AGENT_VERSION", "") or "1.0.0"
|
|
||||||
|
|
||||||
# Load skills
|
|
||||||
if skills is not None:
|
|
||||||
agent_skills = skills
|
|
||||||
else:
|
|
||||||
from hermes_constants import get_hermes_home
|
|
||||||
skills_dir = get_hermes_home() / "skills"
|
|
||||||
agent_skills = _load_skills_from_directory(skills_dir)
|
|
||||||
|
|
||||||
# Add extra skills
|
|
||||||
if extra_skills:
|
|
||||||
existing_ids = {s.id for s in agent_skills}
|
|
||||||
for skill in extra_skills:
|
|
||||||
if skill.id not in existing_ids:
|
|
||||||
agent_skills.append(skill)
|
|
||||||
|
|
||||||
# Build metadata
|
|
||||||
card_metadata = {
|
|
||||||
"model": config_model or os.environ.get("HERMES_MODEL", ""),
|
|
||||||
"provider": config_provider or os.environ.get("HERMES_PROVIDER", ""),
|
|
||||||
"hostname": socket.gethostname(),
|
|
||||||
}
|
|
||||||
if metadata:
|
|
||||||
card_metadata.update(metadata)
|
|
||||||
|
|
||||||
# Build capabilities
|
|
||||||
capabilities = AgentCapabilities(
|
|
||||||
streaming=True,
|
|
||||||
push_notifications=False,
|
|
||||||
state_transition_history=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
return AgentCard(
|
|
||||||
name=agent_name,
|
|
||||||
description=agent_desc,
|
|
||||||
url=agent_url,
|
|
||||||
version=agent_version,
|
|
||||||
capabilities=capabilities,
|
|
||||||
skills=agent_skills,
|
|
||||||
metadata=card_metadata,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_agent_card_json() -> str:
|
|
||||||
"""Get agent card as JSON string (for HTTP endpoint)."""
|
|
||||||
try:
|
|
||||||
card = build_agent_card()
|
|
||||||
return card.to_json()
|
|
||||||
except Exception as e:
|
|
||||||
# Graceful fallback — return minimal card so discovery doesn't break
|
|
||||||
fallback = AgentCard(
|
|
||||||
name="hermes",
|
|
||||||
description="Sovereign AI agent",
|
|
||||||
url=f"http://localhost:{os.environ.get('HERMES_API_PORT', '8642')}",
|
|
||||||
)
|
|
||||||
return fallback.to_json()
|
|
||||||
@@ -1,353 +0,0 @@
|
|||||||
"""Privacy Filter — strip PII from context before remote API calls.
|
|
||||||
|
|
||||||
Implements Vitalik's Pattern 2: "A local model can strip out private data
|
|
||||||
before passing the query along to a remote LLM."
|
|
||||||
|
|
||||||
When Hermes routes a request to a cloud provider (Anthropic, OpenRouter, etc.),
|
|
||||||
this module sanitizes the message context to remove personally identifiable
|
|
||||||
information before it leaves the user's machine.
|
|
||||||
|
|
||||||
Threat model (from Vitalik's secure LLM architecture):
|
|
||||||
- Privacy (other): Non-LLM data leakage via search queries, API calls
|
|
||||||
- LLM accidents: LLM accidentally leaking private data in prompts
|
|
||||||
- LLM jailbreaks: Remote content extracting private context
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
from agent.privacy_filter import PrivacyFilter, sanitize_messages
|
|
||||||
|
|
||||||
pf = PrivacyFilter()
|
|
||||||
safe_messages = pf.sanitize_messages(messages)
|
|
||||||
# safe_messages has PII replaced with [REDACTED] tokens
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import re
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from enum import Enum, auto
|
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class Sensitivity(Enum):
|
|
||||||
"""Classification of content sensitivity."""
|
|
||||||
PUBLIC = auto() # No PII detected
|
|
||||||
LOW = auto() # Generic references (e.g., city names)
|
|
||||||
MEDIUM = auto() # Personal identifiers (name, email, phone)
|
|
||||||
HIGH = auto() # Secrets, keys, financial data, medical info
|
|
||||||
CRITICAL = auto() # Crypto keys, passwords, SSN patterns
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class RedactionReport:
|
|
||||||
"""Summary of what was redacted from a message batch."""
|
|
||||||
total_messages: int = 0
|
|
||||||
redacted_messages: int = 0
|
|
||||||
redactions: List[Dict[str, Any]] = field(default_factory=list)
|
|
||||||
max_sensitivity: Sensitivity = Sensitivity.PUBLIC
|
|
||||||
|
|
||||||
@property
|
|
||||||
def had_redactions(self) -> bool:
|
|
||||||
return self.redacted_messages > 0
|
|
||||||
|
|
||||||
def summary(self) -> str:
|
|
||||||
if not self.had_redactions:
|
|
||||||
return "No PII detected — context is clean for remote query."
|
|
||||||
parts = [f"Redacted {self.redacted_messages}/{self.total_messages} messages:"]
|
|
||||||
for r in self.redactions[:10]:
|
|
||||||
parts.append(f" - {r['type']}: {r['count']} occurrence(s)")
|
|
||||||
if len(self.redactions) > 10:
|
|
||||||
parts.append(f" ... and {len(self.redactions) - 10} more types")
|
|
||||||
return "\n".join(parts)
|
|
||||||
|
|
||||||
|
|
||||||
# =========================================================================
|
|
||||||
# PII pattern definitions
|
|
||||||
# =========================================================================
|
|
||||||
|
|
||||||
# Each pattern is (compiled_regex, redaction_type, sensitivity_level, replacement)
|
|
||||||
_PII_PATTERNS: List[Tuple[re.Pattern, str, Sensitivity, str]] = []
|
|
||||||
|
|
||||||
|
|
||||||
def _compile_patterns() -> None:
|
|
||||||
"""Compile PII detection patterns. Called once at module init."""
|
|
||||||
global _PII_PATTERNS
|
|
||||||
if _PII_PATTERNS:
|
|
||||||
return
|
|
||||||
|
|
||||||
raw_patterns = [
|
|
||||||
# --- CRITICAL: secrets and credentials ---
|
|
||||||
(
|
|
||||||
r'(?:api[_-]?key|apikey|secret[_-]?key|access[_-]?token)\s*[:=]\s*["\']?([A-Za-z0-9_\-\.]{20,})["\']?',
|
|
||||||
"api_key_or_token",
|
|
||||||
Sensitivity.CRITICAL,
|
|
||||||
"[REDACTED-API-KEY]",
|
|
||||||
),
|
|
||||||
(
|
|
||||||
r'\b(?:sk-|sk_|pk_|rk_|ak_)[A-Za-z0-9]{20,}\b',
|
|
||||||
"prefixed_secret",
|
|
||||||
Sensitivity.CRITICAL,
|
|
||||||
"[REDACTED-SECRET]",
|
|
||||||
),
|
|
||||||
(
|
|
||||||
r'\b(?:ghp_|gho_|ghu_|ghs_|ghr_)[A-Za-z0-9]{36,}\b',
|
|
||||||
"github_token",
|
|
||||||
Sensitivity.CRITICAL,
|
|
||||||
"[REDACTED-GITHUB-TOKEN]",
|
|
||||||
),
|
|
||||||
(
|
|
||||||
r'\b(?:xox[bposa]-[A-Za-z0-9\-]+)\b',
|
|
||||||
"slack_token",
|
|
||||||
Sensitivity.CRITICAL,
|
|
||||||
"[REDACTED-SLACK-TOKEN]",
|
|
||||||
),
|
|
||||||
(
|
|
||||||
r'(?:password|passwd|pwd)\s*[:=]\s*["\']?([^\s"\']{4,})["\']?',
|
|
||||||
"password",
|
|
||||||
Sensitivity.CRITICAL,
|
|
||||||
"[REDACTED-PASSWORD]",
|
|
||||||
),
|
|
||||||
(
|
|
||||||
r'(?:-----BEGIN (?:RSA |EC |OPENSSH )?PRIVATE KEY-----)',
|
|
||||||
"private_key_block",
|
|
||||||
Sensitivity.CRITICAL,
|
|
||||||
"[REDACTED-PRIVATE-KEY]",
|
|
||||||
),
|
|
||||||
# Ethereum / crypto addresses (42-char hex starting with 0x)
|
|
||||||
(
|
|
||||||
r'\b0x[a-fA-F0-9]{40}\b',
|
|
||||||
"ethereum_address",
|
|
||||||
Sensitivity.HIGH,
|
|
||||||
"[REDACTED-ETH-ADDR]",
|
|
||||||
),
|
|
||||||
# Bitcoin addresses (base58, 25-34 chars starting with 1/3/bc1)
|
|
||||||
(
|
|
||||||
r'\b[13][a-km-zA-HJ-NP-Z1-9]{25,34}\b',
|
|
||||||
"bitcoin_address",
|
|
||||||
Sensitivity.HIGH,
|
|
||||||
"[REDACTED-BTC-ADDR]",
|
|
||||||
),
|
|
||||||
(
|
|
||||||
r'\bbc1[a-zA-HJ-NP-Z0-9]{39,59}\b',
|
|
||||||
"bech32_address",
|
|
||||||
Sensitivity.HIGH,
|
|
||||||
"[REDACTED-BTC-ADDR]",
|
|
||||||
),
|
|
||||||
# --- HIGH: financial ---
|
|
||||||
(
|
|
||||||
r'\b(?:\d{4}[-\s]?){3}\d{4}\b',
|
|
||||||
"credit_card_number",
|
|
||||||
Sensitivity.HIGH,
|
|
||||||
"[REDACTED-CC]",
|
|
||||||
),
|
|
||||||
(
|
|
||||||
r'\b\d{3}-\d{2}-\d{4}\b',
|
|
||||||
"us_ssn",
|
|
||||||
Sensitivity.HIGH,
|
|
||||||
"[REDACTED-SSN]",
|
|
||||||
),
|
|
||||||
# --- MEDIUM: personal identifiers ---
|
|
||||||
# Email addresses
|
|
||||||
(
|
|
||||||
r'\b[A-Za-z0-9._%+\-]+@[A-Za-z0-9.\-]+\.[A-Za-z]{2,}\b',
|
|
||||||
"email_address",
|
|
||||||
Sensitivity.MEDIUM,
|
|
||||||
"[REDACTED-EMAIL]",
|
|
||||||
),
|
|
||||||
# Phone numbers (US/international patterns)
|
|
||||||
(
|
|
||||||
r'\b(?:\+?1[-.\s]?)?\(?\d{3}\)?[-.\s]?\d{3}[-.\s]?\d{4}\b',
|
|
||||||
"phone_number_us",
|
|
||||||
Sensitivity.MEDIUM,
|
|
||||||
"[REDACTED-PHONE]",
|
|
||||||
),
|
|
||||||
(
|
|
||||||
r'\b\+\d{1,3}[-.\s]?\d{4,14}\b',
|
|
||||||
"phone_number_intl",
|
|
||||||
Sensitivity.MEDIUM,
|
|
||||||
"[REDACTED-PHONE]",
|
|
||||||
),
|
|
||||||
# Filesystem paths that reveal user identity
|
|
||||||
(
|
|
||||||
r'(?:/Users/|/home/|C:\\Users\\)([A-Za-z0-9_\-]+)',
|
|
||||||
"user_home_path",
|
|
||||||
Sensitivity.MEDIUM,
|
|
||||||
r"/Users/[REDACTED-USER]",
|
|
||||||
),
|
|
||||||
# --- LOW: environment / system info ---
|
|
||||||
# Internal IPs
|
|
||||||
(
|
|
||||||
r'\b(?:10\.\d{1,3}\.\d{1,3}\.\d{1,3}|172\.(?:1[6-9]|2\d|3[01])\.\d{1,3}\.\d{1,3}|192\.168\.\d{1,3}\.\d{1,3})\b',
|
|
||||||
"internal_ip",
|
|
||||||
Sensitivity.LOW,
|
|
||||||
"[REDACTED-IP]",
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
_PII_PATTERNS = [
|
|
||||||
(re.compile(pattern, re.IGNORECASE), rtype, sensitivity, replacement)
|
|
||||||
for pattern, rtype, sensitivity, replacement in raw_patterns
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
_compile_patterns()
|
|
||||||
|
|
||||||
|
|
||||||
# =========================================================================
|
|
||||||
# Sensitive file path patterns (context-aware)
|
|
||||||
# =========================================================================
|
|
||||||
|
|
||||||
_SENSITIVE_PATH_PATTERNS = [
|
|
||||||
re.compile(r'\.(?:env|pem|key|p12|pfx|jks|keystore)\b', re.IGNORECASE),
|
|
||||||
re.compile(r'(?:\.ssh/|\.gnupg/|\.aws/|\.config/gcloud/)', re.IGNORECASE),
|
|
||||||
re.compile(r'(?:wallet|keystore|seed|mnemonic)', re.IGNORECASE),
|
|
||||||
re.compile(r'(?:\.hermes/\.env)', re.IGNORECASE),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def _classify_path_sensitivity(path: str) -> Sensitivity:
|
|
||||||
"""Check if a file path references sensitive material."""
|
|
||||||
for pat in _SENSITIVE_PATH_PATTERNS:
|
|
||||||
if pat.search(path):
|
|
||||||
return Sensitivity.HIGH
|
|
||||||
return Sensitivity.PUBLIC
|
|
||||||
|
|
||||||
|
|
||||||
# =========================================================================
|
|
||||||
# Core filtering
|
|
||||||
# =========================================================================
|
|
||||||
|
|
||||||
class PrivacyFilter:
|
|
||||||
"""Strip PII from message context before remote API calls.
|
|
||||||
|
|
||||||
Integrates with the agent's message pipeline. Call sanitize_messages()
|
|
||||||
before sending context to any cloud LLM provider.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
min_sensitivity: Sensitivity = Sensitivity.MEDIUM,
|
|
||||||
aggressive_mode: bool = False,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
min_sensitivity: Only redact PII at or above this level.
|
|
||||||
Default MEDIUM — redacts emails, phones, paths but not IPs.
|
|
||||||
aggressive_mode: If True, also redact file paths and internal IPs.
|
|
||||||
"""
|
|
||||||
self.min_sensitivity = (
|
|
||||||
Sensitivity.LOW if aggressive_mode else min_sensitivity
|
|
||||||
)
|
|
||||||
self.aggressive_mode = aggressive_mode
|
|
||||||
|
|
||||||
def sanitize_text(self, text: str) -> Tuple[str, List[Dict[str, Any]]]:
|
|
||||||
"""Sanitize a single text string. Returns (cleaned_text, redaction_list)."""
|
|
||||||
redactions = []
|
|
||||||
cleaned = text
|
|
||||||
|
|
||||||
for pattern, rtype, sensitivity, replacement in _PII_PATTERNS:
|
|
||||||
if sensitivity.value < self.min_sensitivity.value:
|
|
||||||
continue
|
|
||||||
|
|
||||||
matches = pattern.findall(cleaned)
|
|
||||||
if matches:
|
|
||||||
count = len(matches) if isinstance(matches[0], str) else sum(
|
|
||||||
1 for m in matches if m
|
|
||||||
)
|
|
||||||
if count > 0:
|
|
||||||
cleaned = pattern.sub(replacement, cleaned)
|
|
||||||
redactions.append({
|
|
||||||
"type": rtype,
|
|
||||||
"sensitivity": sensitivity.name,
|
|
||||||
"count": count,
|
|
||||||
})
|
|
||||||
|
|
||||||
return cleaned, redactions
|
|
||||||
|
|
||||||
def sanitize_messages(
|
|
||||||
self, messages: List[Dict[str, Any]]
|
|
||||||
) -> Tuple[List[Dict[str, Any]], RedactionReport]:
|
|
||||||
"""Sanitize a list of OpenAI-format messages.
|
|
||||||
|
|
||||||
Returns (safe_messages, report). System messages are NOT sanitized
|
|
||||||
(they're typically static prompts). Only user and assistant messages
|
|
||||||
with string content are processed.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
messages: List of {"role": ..., "content": ...} dicts.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (sanitized_messages, redaction_report).
|
|
||||||
"""
|
|
||||||
report = RedactionReport(total_messages=len(messages))
|
|
||||||
safe_messages = []
|
|
||||||
|
|
||||||
for msg in messages:
|
|
||||||
role = msg.get("role", "")
|
|
||||||
content = msg.get("content", "")
|
|
||||||
|
|
||||||
# Only sanitize user/assistant string content
|
|
||||||
if role in ("user", "assistant") and isinstance(content, str) and content:
|
|
||||||
cleaned, redactions = self.sanitize_text(content)
|
|
||||||
if redactions:
|
|
||||||
report.redacted_messages += 1
|
|
||||||
report.redactions.extend(redactions)
|
|
||||||
# Track max sensitivity
|
|
||||||
for r in redactions:
|
|
||||||
s = Sensitivity[r["sensitivity"]]
|
|
||||||
if s.value > report.max_sensitivity.value:
|
|
||||||
report.max_sensitivity = s
|
|
||||||
safe_msg = {**msg, "content": cleaned}
|
|
||||||
safe_messages.append(safe_msg)
|
|
||||||
logger.info(
|
|
||||||
"Privacy filter: redacted %d PII type(s) from %s message",
|
|
||||||
len(redactions), role,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
safe_messages.append(msg)
|
|
||||||
else:
|
|
||||||
safe_messages.append(msg)
|
|
||||||
|
|
||||||
return safe_messages, report
|
|
||||||
|
|
||||||
def should_use_local_only(self, text: str) -> Tuple[bool, str]:
|
|
||||||
"""Determine if content is too sensitive for any remote call.
|
|
||||||
|
|
||||||
Returns (should_block, reason). If True, the content should only
|
|
||||||
be processed by a local model.
|
|
||||||
"""
|
|
||||||
_, redactions = self.sanitize_text(text)
|
|
||||||
|
|
||||||
critical_count = sum(
|
|
||||||
1 for r in redactions
|
|
||||||
if Sensitivity[r["sensitivity"]] == Sensitivity.CRITICAL
|
|
||||||
)
|
|
||||||
high_count = sum(
|
|
||||||
1 for r in redactions
|
|
||||||
if Sensitivity[r["sensitivity"]] == Sensitivity.HIGH
|
|
||||||
)
|
|
||||||
|
|
||||||
if critical_count > 0:
|
|
||||||
return True, f"Contains {critical_count} critical-secret pattern(s) — local-only"
|
|
||||||
if high_count >= 3:
|
|
||||||
return True, f"Contains {high_count} high-sensitivity pattern(s) — local-only"
|
|
||||||
return False, ""
|
|
||||||
|
|
||||||
|
|
||||||
def sanitize_messages(
|
|
||||||
messages: List[Dict[str, Any]],
|
|
||||||
min_sensitivity: Sensitivity = Sensitivity.MEDIUM,
|
|
||||||
aggressive: bool = False,
|
|
||||||
) -> Tuple[List[Dict[str, Any]], RedactionReport]:
|
|
||||||
"""Convenience function: sanitize messages with default settings."""
|
|
||||||
pf = PrivacyFilter(min_sensitivity=min_sensitivity, aggressive_mode=aggressive)
|
|
||||||
return pf.sanitize_messages(messages)
|
|
||||||
|
|
||||||
|
|
||||||
def quick_sanitize(text: str) -> str:
|
|
||||||
"""Quick sanitize a single string — returns cleaned text only."""
|
|
||||||
pf = PrivacyFilter()
|
|
||||||
cleaned, _ = pf.sanitize_text(text)
|
|
||||||
return cleaned
|
|
||||||
@@ -1,177 +0,0 @@
|
|||||||
"""Tool Orchestrator — Robust execution and circuit breaking for agent tools.
|
|
||||||
|
|
||||||
Provides a unified execution service that wraps the tool registry.
|
|
||||||
Implements the Circuit Breaker pattern to prevent the agent from getting
|
|
||||||
stuck in failure loops when a specific tool or its underlying service
|
|
||||||
is flapping or down.
|
|
||||||
|
|
||||||
Architecture:
|
|
||||||
Discovery (tools/registry.py) -> Orchestration (agent/tool_orchestrator.py) -> Dispatch
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import time
|
|
||||||
import logging
|
|
||||||
import threading
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
|
||||||
|
|
||||||
from tools.registry import registry
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class CircuitState:
|
|
||||||
"""States for the tool circuit breaker."""
|
|
||||||
CLOSED = "closed" # Normal operation
|
|
||||||
OPEN = "open" # Failing, execution blocked
|
|
||||||
HALF_OPEN = "half_open" # Testing if service recovered
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ToolStats:
|
|
||||||
"""Execution statistics for a tool."""
|
|
||||||
name: str
|
|
||||||
state: str = CircuitState.CLOSED
|
|
||||||
failures: int = 0
|
|
||||||
successes: int = 0
|
|
||||||
last_failure_time: float = 0
|
|
||||||
total_execution_time: float = 0
|
|
||||||
call_count: int = 0
|
|
||||||
|
|
||||||
|
|
||||||
class ToolOrchestrator:
|
|
||||||
"""Orchestrates tool execution with robustness patterns."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
failure_threshold: int = 3,
|
|
||||||
reset_timeout: int = 300,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
failure_threshold: Number of failures before opening the circuit.
|
|
||||||
reset_timeout: Seconds to wait before transitioning from OPEN to HALF_OPEN.
|
|
||||||
"""
|
|
||||||
self.failure_threshold = failure_threshold
|
|
||||||
self.reset_timeout = reset_timeout
|
|
||||||
self._stats: Dict[str, ToolStats] = {}
|
|
||||||
self._lock = threading.Lock()
|
|
||||||
|
|
||||||
def _get_stats(self, name: str) -> ToolStats:
|
|
||||||
"""Get or initialize stats for a tool with thread-safe state transition."""
|
|
||||||
with self._lock:
|
|
||||||
if name not in self._stats:
|
|
||||||
self._stats[name] = ToolStats(name=name)
|
|
||||||
|
|
||||||
stats = self._stats[name]
|
|
||||||
|
|
||||||
# Transition from OPEN to HALF_OPEN if timeout expired
|
|
||||||
if stats.state == CircuitState.OPEN:
|
|
||||||
if time.time() - stats.last_failure_time > self.reset_timeout:
|
|
||||||
stats.state = CircuitState.HALF_OPEN
|
|
||||||
logger.info("Circuit breaker HALF_OPEN for tool: %s", name)
|
|
||||||
|
|
||||||
return stats
|
|
||||||
|
|
||||||
def _record_success(self, name: str, execution_time: float):
|
|
||||||
"""Record a successful tool execution and close the circuit."""
|
|
||||||
with self._lock:
|
|
||||||
stats = self._stats[name]
|
|
||||||
stats.successes += 1
|
|
||||||
stats.call_count += 1
|
|
||||||
stats.total_execution_time += execution_time
|
|
||||||
|
|
||||||
if stats.state != CircuitState.CLOSED:
|
|
||||||
logger.info("Circuit breaker CLOSED for tool: %s (recovered)", name)
|
|
||||||
|
|
||||||
stats.state = CircuitState.CLOSED
|
|
||||||
stats.failures = 0
|
|
||||||
|
|
||||||
def _record_failure(self, name: str, execution_time: float):
|
|
||||||
"""Record a failed tool execution and potentially open the circuit."""
|
|
||||||
with self._lock:
|
|
||||||
stats = self._stats[name]
|
|
||||||
stats.failures += 1
|
|
||||||
stats.call_count += 1
|
|
||||||
stats.total_execution_time += execution_time
|
|
||||||
stats.last_failure_time = time.time()
|
|
||||||
|
|
||||||
if stats.state == CircuitState.HALF_OPEN or stats.failures >= self.failure_threshold:
|
|
||||||
stats.state = CircuitState.OPEN
|
|
||||||
logger.warning(
|
|
||||||
"Circuit breaker OPEN for tool: %s (failures: %d)",
|
|
||||||
name, stats.failures
|
|
||||||
)
|
|
||||||
|
|
||||||
def dispatch(self, name: str, args: dict, **kwargs) -> str:
|
|
||||||
"""Execute a tool via the registry with circuit breaker protection."""
|
|
||||||
stats = self._get_stats(name)
|
|
||||||
|
|
||||||
if stats.state == CircuitState.OPEN:
|
|
||||||
return json.dumps({
|
|
||||||
"error": (
|
|
||||||
f"Tool '{name}' is temporarily unavailable due to repeated failures. "
|
|
||||||
f"Circuit breaker is OPEN. Please try again in a few minutes or use an alternative tool."
|
|
||||||
),
|
|
||||||
"circuit_breaker": True,
|
|
||||||
"tool_name": name
|
|
||||||
})
|
|
||||||
|
|
||||||
start_time = time.time()
|
|
||||||
try:
|
|
||||||
# Dispatch to the underlying registry
|
|
||||||
result_str = registry.dispatch(name, args, **kwargs)
|
|
||||||
execution_time = time.time() - start_time
|
|
||||||
|
|
||||||
# Inspect result for errors. registry.dispatch catches internal
|
|
||||||
# exceptions and returns a JSON error string.
|
|
||||||
is_error = False
|
|
||||||
try:
|
|
||||||
# Lightweight check for error key in JSON
|
|
||||||
if '"error":' in result_str:
|
|
||||||
res_json = json.loads(result_str)
|
|
||||||
if isinstance(res_json, dict) and "error" in res_json:
|
|
||||||
is_error = True
|
|
||||||
except (json.JSONDecodeError, TypeError):
|
|
||||||
# If it's not valid JSON, it's a malformed result (error)
|
|
||||||
is_error = True
|
|
||||||
|
|
||||||
if is_error:
|
|
||||||
self._record_failure(name, execution_time)
|
|
||||||
else:
|
|
||||||
self._record_success(name, execution_time)
|
|
||||||
|
|
||||||
return result_str
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
# This should rarely be hit as registry.dispatch catches most things,
|
|
||||||
# but we guard against orchestrator-level or registry-level bugs.
|
|
||||||
execution_time = time.time() - start_time
|
|
||||||
self._record_failure(name, execution_time)
|
|
||||||
|
|
||||||
error_msg = f"Tool orchestrator error during {name}: {type(e).__name__}: {e}"
|
|
||||||
logger.exception(error_msg)
|
|
||||||
return json.dumps({
|
|
||||||
"error": error_msg,
|
|
||||||
"tool_name": name,
|
|
||||||
"execution_time": execution_time
|
|
||||||
})
|
|
||||||
|
|
||||||
def get_fleet_stats(self) -> Dict[str, Any]:
|
|
||||||
"""Return execution statistics for all tools."""
|
|
||||||
with self._lock:
|
|
||||||
return {
|
|
||||||
name: {
|
|
||||||
"state": s.state,
|
|
||||||
"failures": s.failures,
|
|
||||||
"successes": s.successes,
|
|
||||||
"avg_time": s.total_execution_time / s.call_count if s.call_count > 0 else 0,
|
|
||||||
"calls": s.call_count
|
|
||||||
}
|
|
||||||
for name, s in self._stats.items()
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# Global orchestrator instance
|
|
||||||
orchestrator = ToolOrchestrator()
|
|
||||||
40
benchmarks/gemma4-tool-calling-2026-04-13.md
Normal file
40
benchmarks/gemma4-tool-calling-2026-04-13.md
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
# Tool Call Benchmark: Gemma 4 vs mimo-v2-pro
|
||||||
|
|
||||||
|
Date: 2026-04-13
|
||||||
|
Status: Awaiting execution
|
||||||
|
|
||||||
|
## Test Design
|
||||||
|
|
||||||
|
100 diverse tool calls across 7 categories:
|
||||||
|
|
||||||
|
| Category | Count | Tools Tested |
|
||||||
|
|----------|-------|--------------|
|
||||||
|
| File operations | 20 | read_file, write_file, search_files |
|
||||||
|
| Terminal commands | 20 | terminal |
|
||||||
|
| Web search | 15 | web_search |
|
||||||
|
| Code execution | 15 | execute_code |
|
||||||
|
| Browser automation | 10 | browser_navigate |
|
||||||
|
| Delegation | 10 | delegate_task |
|
||||||
|
| MCP tools | 10 | mcp_* |
|
||||||
|
|
||||||
|
## Metrics
|
||||||
|
|
||||||
|
| Metric | mimo-v2-pro | Gemma 4 |
|
||||||
|
|--------|-------------|---------|
|
||||||
|
| Schema parse success | — | — |
|
||||||
|
| Tool execution success | — | — |
|
||||||
|
| Parallel tool success | — | — |
|
||||||
|
| Avg latency (s) | — | — |
|
||||||
|
| Token cost per call | — | — |
|
||||||
|
|
||||||
|
## How to Run
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python3 benchmarks/tool_call_benchmark.py --model nous:xiaomi/mimo-v2-pro
|
||||||
|
python3 benchmarks/tool_call_benchmark.py --model ollama/gemma4:latest
|
||||||
|
python3 benchmarks/tool_call_benchmark.py --compare
|
||||||
|
```
|
||||||
|
|
||||||
|
## Gemma 4-Specific Failure Modes
|
||||||
|
|
||||||
|
To be documented after benchmark execution.
|
||||||
614
benchmarks/tool_call_benchmark.py
Normal file
614
benchmarks/tool_call_benchmark.py
Normal file
@@ -0,0 +1,614 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Tool-Calling Benchmark — Gemma 4 vs mimo-v2-pro regression test.
|
||||||
|
|
||||||
|
Runs 100 diverse tool-calling prompts through multiple models and compares
|
||||||
|
success rates, latency, and token costs.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python3 benchmarks/tool_call_benchmark.py # full 100-call suite
|
||||||
|
python3 benchmarks/tool_call_benchmark.py --limit 10 # quick smoke test
|
||||||
|
python3 benchmarks/tool_call_benchmark.py --models nous # single model
|
||||||
|
python3 benchmarks/tool_call_benchmark.py --category file # single category
|
||||||
|
|
||||||
|
Requires: hermes-agent venv activated, OPENROUTER_API_KEY or equivalent.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import traceback
|
||||||
|
from dataclasses import dataclass, field, asdict
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
# Ensure hermes-agent root is importable
|
||||||
|
REPO_ROOT = Path(__file__).resolve().parent.parent
|
||||||
|
sys.path.insert(0, str(REPO_ROOT))
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Test Definitions
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ToolCall:
|
||||||
|
"""A single tool-calling test case."""
|
||||||
|
id: str
|
||||||
|
category: str
|
||||||
|
prompt: str
|
||||||
|
expected_tool: str # tool name we expect the model to call
|
||||||
|
expected_params_check: str = "" # substring expected in JSON args
|
||||||
|
timeout: int = 30 # max seconds per call
|
||||||
|
notes: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
SUITE: list[ToolCall] = [
|
||||||
|
# ── File Operations (20) ──────────────────────────────────────────────
|
||||||
|
ToolCall("file-01", "file", "Read the file /tmp/test_bench.txt and show me its contents.",
|
||||||
|
"read_file", "path"),
|
||||||
|
ToolCall("file-02", "file", "Write 'hello benchmark' to /tmp/test_bench_out.txt",
|
||||||
|
"write_file", "path"),
|
||||||
|
ToolCall("file-03", "file", "Search for the word 'import' in all Python files in the current directory.",
|
||||||
|
"search_files", "pattern"),
|
||||||
|
ToolCall("file-04", "file", "Read lines 1-20 of /etc/hosts",
|
||||||
|
"read_file", "offset"),
|
||||||
|
ToolCall("file-05", "file", "Patch /tmp/test_bench_out.txt: replace 'hello' with 'goodbye'",
|
||||||
|
"patch", "old_string"),
|
||||||
|
ToolCall("file-06", "file", "Search for files matching *.py in the current directory.",
|
||||||
|
"search_files", "target"),
|
||||||
|
ToolCall("file-07", "file", "Read the first 10 lines of /etc/passwd",
|
||||||
|
"read_file", "limit"),
|
||||||
|
ToolCall("file-08", "file", "Write a JSON config to /tmp/bench_config.json with key 'debug': true",
|
||||||
|
"write_file", "content"),
|
||||||
|
ToolCall("file-09", "file", "Search for 'def test_' in Python test files.",
|
||||||
|
"search_files", "file_glob"),
|
||||||
|
ToolCall("file-10", "file", "Read /tmp/bench_config.json and tell me what's in it.",
|
||||||
|
"read_file", "bench_config"),
|
||||||
|
ToolCall("file-11", "file", "Create a file /tmp/bench_readme.md with one line: '# Benchmark'",
|
||||||
|
"write_file", "bench_readme"),
|
||||||
|
ToolCall("file-12", "file", "Search for 'TODO' comments in all .py files.",
|
||||||
|
"search_files", "TODO"),
|
||||||
|
ToolCall("file-13", "file", "Read /tmp/bench_readme.md",
|
||||||
|
"read_file", "bench_readme"),
|
||||||
|
ToolCall("file-14", "file", "Patch /tmp/bench_readme.md: replace '# Benchmark' with '# Tool Benchmark'",
|
||||||
|
"patch", "Tool Benchmark"),
|
||||||
|
ToolCall("file-15", "file", "Write a Python one-liner to /tmp/bench_hello.py that prints hello.",
|
||||||
|
"write_file", "bench_hello"),
|
||||||
|
ToolCall("file-16", "file", "Search for all .json files in /tmp/.",
|
||||||
|
"search_files", "json"),
|
||||||
|
ToolCall("file-17", "file", "Read /tmp/bench_hello.py and verify it has print('hello').",
|
||||||
|
"read_file", "bench_hello"),
|
||||||
|
ToolCall("file-18", "file", "Patch /tmp/bench_hello.py to print 'hello world' instead of 'hello'.",
|
||||||
|
"patch", "hello world"),
|
||||||
|
ToolCall("file-19", "file", "List files matching 'bench*' in /tmp/.",
|
||||||
|
"search_files", "bench"),
|
||||||
|
ToolCall("file-20", "file", "Read /tmp/test_bench.txt again and summarize its contents.",
|
||||||
|
"read_file", "test_bench"),
|
||||||
|
|
||||||
|
# ── Terminal Commands (20) ────────────────────────────────────────────
|
||||||
|
ToolCall("term-01", "terminal", "Run `echo hello world` in the terminal.",
|
||||||
|
"terminal", "echo"),
|
||||||
|
ToolCall("term-02", "terminal", "Run `date` to get the current date and time.",
|
||||||
|
"terminal", "date"),
|
||||||
|
ToolCall("term-03", "terminal", "Run `uname -a` to get system information.",
|
||||||
|
"terminal", "uname"),
|
||||||
|
ToolCall("term-04", "terminal", "Run `pwd` to show the current directory.",
|
||||||
|
"terminal", "pwd"),
|
||||||
|
ToolCall("term-05", "terminal", "Run `ls -la /tmp/ | head -20` to list temp files.",
|
||||||
|
"terminal", "head"),
|
||||||
|
ToolCall("term-06", "terminal", "Run `whoami` to show the current user.",
|
||||||
|
"terminal", "whoami"),
|
||||||
|
ToolCall("term-07", "terminal", "Run `df -h` to show disk usage.",
|
||||||
|
"terminal", "df"),
|
||||||
|
ToolCall("term-08", "terminal", "Run `python3 --version` to check Python version.",
|
||||||
|
"terminal", "python3"),
|
||||||
|
ToolCall("term-09", "terminal", "Run `cat /etc/hostname` to get the hostname.",
|
||||||
|
"terminal", "hostname"),
|
||||||
|
ToolCall("term-10", "terminal", "Run `uptime` to see system uptime.",
|
||||||
|
"terminal", "uptime"),
|
||||||
|
ToolCall("term-11", "terminal", "Run `env | grep PATH` to show the PATH variable.",
|
||||||
|
"terminal", "PATH"),
|
||||||
|
ToolCall("term-12", "terminal", "Run `wc -l /etc/passwd` to count lines.",
|
||||||
|
"terminal", "wc"),
|
||||||
|
ToolCall("term-13", "terminal", "Run `echo $SHELL` to show the current shell.",
|
||||||
|
"terminal", "SHELL"),
|
||||||
|
ToolCall("term-14", "terminal", "Run `free -h || vm_stat` to check memory usage.",
|
||||||
|
"terminal", "memory"),
|
||||||
|
ToolCall("term-15", "terminal", "Run `id` to show user and group IDs.",
|
||||||
|
"terminal", "id"),
|
||||||
|
ToolCall("term-16", "terminal", "Run `hostname` to get the machine hostname.",
|
||||||
|
"terminal", "hostname"),
|
||||||
|
ToolCall("term-17", "terminal", "Run `echo {1..5}` to test brace expansion.",
|
||||||
|
"terminal", "echo"),
|
||||||
|
ToolCall("term-18", "terminal", "Run `seq 1 5` to generate a number sequence.",
|
||||||
|
"terminal", "seq"),
|
||||||
|
ToolCall("term-19", "terminal", "Run `python3 -c 'print(2+2)'` to compute 2+2.",
|
||||||
|
"terminal", "print"),
|
||||||
|
ToolCall("term-20", "terminal", "Run `ls -d /tmp/bench* 2>/dev/null | wc -l` to count bench files.",
|
||||||
|
"terminal", "wc"),
|
||||||
|
|
||||||
|
# ── Code Execution (15) ──────────────────────────────────────────────
|
||||||
|
ToolCall("code-01", "code", "Execute a Python script that computes factorial of 10.",
|
||||||
|
"execute_code", "factorial"),
|
||||||
|
ToolCall("code-02", "code", "Run Python to read /tmp/test_bench.txt and count its words.",
|
||||||
|
"execute_code", "words"),
|
||||||
|
ToolCall("code-03", "code", "Execute Python to generate the first 20 Fibonacci numbers.",
|
||||||
|
"execute_code", "fibonacci"),
|
||||||
|
ToolCall("code-04", "code", "Run Python to parse JSON from a string and print keys.",
|
||||||
|
"execute_code", "json"),
|
||||||
|
ToolCall("code-05", "code", "Execute Python to list all files in /tmp/ matching 'bench*'.",
|
||||||
|
"execute_code", "glob"),
|
||||||
|
ToolCall("code-06", "code", "Run Python to compute the sum of squares from 1 to 100.",
|
||||||
|
"execute_code", "sum"),
|
||||||
|
ToolCall("code-07", "code", "Execute Python to check if 'racecar' is a palindrome.",
|
||||||
|
"execute_code", "palindrome"),
|
||||||
|
ToolCall("code-08", "code", "Run Python to create a CSV string with 5 rows of sample data.",
|
||||||
|
"execute_code", "csv"),
|
||||||
|
ToolCall("code-09", "code", "Execute Python to sort a list [5,2,8,1,9] and print the result.",
|
||||||
|
"execute_code", "sort"),
|
||||||
|
ToolCall("code-10", "code", "Run Python to count lines in /etc/passwd.",
|
||||||
|
"execute_code", "passwd"),
|
||||||
|
ToolCall("code-11", "code", "Execute Python to hash the string 'benchmark' with SHA256.",
|
||||||
|
"execute_code", "sha256"),
|
||||||
|
ToolCall("code-12", "code", "Run Python to get the current UTC timestamp.",
|
||||||
|
"execute_code", "utcnow"),
|
||||||
|
ToolCall("code-13", "code", "Execute Python to convert 'hello world' to uppercase and reverse it.",
|
||||||
|
"execute_code", "upper"),
|
||||||
|
ToolCall("code-14", "code", "Run Python to create a dictionary of system info (platform, python version).",
|
||||||
|
"execute_code", "sys"),
|
||||||
|
ToolCall("code-15", "code", "Execute Python to check internet connectivity by resolving google.com.",
|
||||||
|
"execute_code", "socket"),
|
||||||
|
|
||||||
|
# ── Delegation (10) ──────────────────────────────────────────────────
|
||||||
|
ToolCall("deleg-01", "delegate", "Use a subagent to find all .log files in /tmp/.",
|
||||||
|
"delegate_task", "log"),
|
||||||
|
ToolCall("deleg-02", "delegate", "Delegate to a subagent: what is 15 * 37?",
|
||||||
|
"delegate_task", "15"),
|
||||||
|
ToolCall("deleg-03", "delegate", "Use a subagent to check if Python 3 is installed and its version.",
|
||||||
|
"delegate_task", "python"),
|
||||||
|
ToolCall("deleg-04", "delegate", "Delegate: read /tmp/test_bench.txt and summarize it in one sentence.",
|
||||||
|
"delegate_task", "summarize"),
|
||||||
|
ToolCall("deleg-05", "delegate", "Use a subagent to list the contents of /tmp/ directory.",
|
||||||
|
"delegate_task", "tmp"),
|
||||||
|
ToolCall("deleg-06", "delegate", "Delegate: count the number of .py files in the current directory.",
|
||||||
|
"delegate_task", ".py"),
|
||||||
|
ToolCall("deleg-07", "delegate", "Use a subagent to check disk space with df -h.",
|
||||||
|
"delegate_task", "df"),
|
||||||
|
ToolCall("deleg-08", "delegate", "Delegate: what OS are we running on?",
|
||||||
|
"delegate_task", "os"),
|
||||||
|
ToolCall("deleg-09", "delegate", "Use a subagent to find the hostname of this machine.",
|
||||||
|
"delegate_task", "hostname"),
|
||||||
|
ToolCall("deleg-10", "delegate", "Delegate: create a temp file /tmp/bench_deleg.txt with 'done'.",
|
||||||
|
"delegate_task", "write"),
|
||||||
|
|
||||||
|
# ── Todo / Memory (10 — replacing web/browser/MCP which need external services) ──
|
||||||
|
ToolCall("todo-01", "todo", "Add a todo item: 'Run benchmark suite'",
|
||||||
|
"todo", "benchmark"),
|
||||||
|
ToolCall("todo-02", "todo", "Show me the current todo list.",
|
||||||
|
"todo", ""),
|
||||||
|
ToolCall("todo-03", "todo", "Mark the first todo item as completed.",
|
||||||
|
"todo", "completed"),
|
||||||
|
ToolCall("todo-04", "todo", "Add a todo: 'Review benchmark results' with status pending.",
|
||||||
|
"todo", "Review"),
|
||||||
|
ToolCall("todo-05", "todo", "Clear all completed todos.",
|
||||||
|
"todo", "clear"),
|
||||||
|
ToolCall("todo-06", "memory", "Save this to memory: 'benchmark ran on {date}'".format(
|
||||||
|
date=datetime.now().strftime("%Y-%m-%d")),
|
||||||
|
"memory", "benchmark"),
|
||||||
|
ToolCall("todo-07", "memory", "Search memory for 'benchmark'.",
|
||||||
|
"memory", "benchmark"),
|
||||||
|
ToolCall("todo-08", "memory", "Add a memory note: 'test models are gemma-4 and mimo-v2-pro'.",
|
||||||
|
"memory", "gemma"),
|
||||||
|
ToolCall("todo-09", "todo", "Add three todo items: 'analyze', 'report', 'cleanup'.",
|
||||||
|
"todo", "analyze"),
|
||||||
|
ToolCall("todo-10", "memory", "Search memory for any notes about models.",
|
||||||
|
"memory", "model"),
|
||||||
|
|
||||||
|
# ── Skills (10 — replacing MCP tools which need servers) ─────────────
|
||||||
|
ToolCall("skill-01", "skills", "List all available skills.",
|
||||||
|
"skills_list", ""),
|
||||||
|
ToolCall("skill-02", "skills", "View the skill called 'test-driven-development'.",
|
||||||
|
"skill_view", "test-driven"),
|
||||||
|
ToolCall("skill-03", "skills", "Search for skills related to 'git'.",
|
||||||
|
"skills_list", "git"),
|
||||||
|
ToolCall("skill-04", "skills", "View the 'code-review' skill.",
|
||||||
|
"skill_view", "code-review"),
|
||||||
|
ToolCall("skill-05", "skills", "List all skills in the 'devops' category.",
|
||||||
|
"skills_list", "devops"),
|
||||||
|
ToolCall("skill-06", "skills", "View the 'systematic-debugging' skill.",
|
||||||
|
"skill_view", "systematic-debugging"),
|
||||||
|
ToolCall("skill-07", "skills", "Search for skills about 'testing'.",
|
||||||
|
"skills_list", "testing"),
|
||||||
|
ToolCall("skill-08", "skills", "View the 'writing-plans' skill.",
|
||||||
|
"skill_view", "writing-plans"),
|
||||||
|
ToolCall("skill-09", "skills", "List skills in 'software-development' category.",
|
||||||
|
"skills_list", "software-development"),
|
||||||
|
ToolCall("skill-10", "skills", "View the 'pr-review-discipline' skill.",
|
||||||
|
"skill_view", "pr-review"),
|
||||||
|
|
||||||
|
# ── Additional tests to reach 100 ────────────────────────────────────
|
||||||
|
ToolCall("file-21", "file", "Write a Python snippet to /tmp/bench_sort.py that sorts [3,1,2].",
|
||||||
|
"write_file", "bench_sort"),
|
||||||
|
ToolCall("file-22", "file", "Read /tmp/bench_sort.py back and confirm it exists.",
|
||||||
|
"read_file", "bench_sort"),
|
||||||
|
ToolCall("file-23", "file", "Search for 'class' in all .py files in the benchmarks directory.",
|
||||||
|
"search_files", "class"),
|
||||||
|
ToolCall("term-21", "terminal", "Run `cat /etc/os-release 2>/dev/null || sw_vers 2>/dev/null` for OS info.",
|
||||||
|
"terminal", "os"),
|
||||||
|
ToolCall("term-22", "terminal", "Run `nproc 2>/dev/null || sysctl -n hw.ncpu 2>/dev/null` for CPU count.",
|
||||||
|
"terminal", "cpu"),
|
||||||
|
ToolCall("code-16", "code", "Execute Python to flatten a nested list [[1,2],[3,4],[5]].",
|
||||||
|
"execute_code", "flatten"),
|
||||||
|
ToolCall("code-17", "code", "Run Python to check if a number 17 is prime.",
|
||||||
|
"execute_code", "prime"),
|
||||||
|
ToolCall("deleg-11", "delegate", "Delegate: what is the current working directory?",
|
||||||
|
"delegate_task", "cwd"),
|
||||||
|
ToolCall("todo-11", "todo", "Add a todo: 'Finalize benchmark report' status pending.",
|
||||||
|
"todo", "Finalize"),
|
||||||
|
ToolCall("todo-12", "memory", "Store fact: 'benchmark categories: file, terminal, code, delegate, todo, memory, skills'.",
|
||||||
|
"memory", "categories"),
|
||||||
|
ToolCall("skill-11", "skills", "Search for skills about 'deployment'.",
|
||||||
|
"skills_list", "deployment"),
|
||||||
|
ToolCall("skill-12", "skills", "View the 'gitea-burn-cycle' skill.",
|
||||||
|
"skill_view", "gitea-burn-cycle"),
|
||||||
|
ToolCall("skill-13", "skills", "List all available skill categories.",
|
||||||
|
"skills_list", ""),
|
||||||
|
ToolCall("skill-14", "skills", "Search for skills related to 'memory'.",
|
||||||
|
"skills_list", "memory"),
|
||||||
|
ToolCall("skill-15", "skills", "View the 'mimo-swarm' skill.",
|
||||||
|
"skill_view", "mimo-swarm"),
|
||||||
|
]
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Runner
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CallResult:
|
||||||
|
test_id: str
|
||||||
|
category: str
|
||||||
|
model: str
|
||||||
|
prompt: str
|
||||||
|
expected_tool: str
|
||||||
|
success: bool
|
||||||
|
tool_called: Optional[str] = None
|
||||||
|
tool_args_valid: bool = False
|
||||||
|
execution_ok: bool = False
|
||||||
|
latency_s: float = 0.0
|
||||||
|
error: str = ""
|
||||||
|
raw_response: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelStats:
|
||||||
|
model: str
|
||||||
|
total: int = 0
|
||||||
|
schema_ok: int = 0 # model produced valid tool call JSON
|
||||||
|
exec_ok: int = 0 # tool actually ran without error
|
||||||
|
latency_sum: float = 0.0
|
||||||
|
failures: list = field(default_factory=list)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def schema_pct(self) -> float:
|
||||||
|
return (self.schema_ok / self.total * 100) if self.total else 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def exec_pct(self) -> float:
|
||||||
|
return (self.exec_ok / self.total * 100) if self.total else 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def avg_latency(self) -> float:
|
||||||
|
return (self.latency_sum / self.total) if self.total else 0
|
||||||
|
|
||||||
|
|
||||||
|
def setup_test_files():
|
||||||
|
"""Create prerequisite files for the benchmark."""
|
||||||
|
Path("/tmp/test_bench.txt").write_text(
|
||||||
|
"This is a benchmark test file.\n"
|
||||||
|
"It contains sample data for tool-calling tests.\n"
|
||||||
|
"Line three has some import statements.\n"
|
||||||
|
"import os\nimport sys\nimport json\n"
|
||||||
|
"End of test data.\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def run_single_test(tc: ToolCall, model_spec: str, provider: str) -> CallResult:
|
||||||
|
"""Run a single tool-calling test through the agent."""
|
||||||
|
from run_agent import AIAgent
|
||||||
|
|
||||||
|
result = CallResult(
|
||||||
|
test_id=tc.id,
|
||||||
|
category=tc.category,
|
||||||
|
model=model_spec,
|
||||||
|
prompt=tc.prompt,
|
||||||
|
expected_tool=tc.expected_tool,
|
||||||
|
success=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
agent = AIAgent(
|
||||||
|
model=model_spec,
|
||||||
|
provider=provider,
|
||||||
|
max_iterations=3,
|
||||||
|
quiet_mode=True,
|
||||||
|
skip_context_files=True,
|
||||||
|
skip_memory=True,
|
||||||
|
persist_session=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
t0 = time.time()
|
||||||
|
conv = agent.run_conversation(
|
||||||
|
user_message=tc.prompt,
|
||||||
|
system_message=(
|
||||||
|
"You are a benchmark test runner. Execute the user's request by calling "
|
||||||
|
"the appropriate tool. Return the tool result directly. Do not add commentary."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
result.latency_s = round(time.time() - t0, 2)
|
||||||
|
|
||||||
|
messages = conv.get("messages", [])
|
||||||
|
|
||||||
|
# Find the first assistant message with tool_calls
|
||||||
|
tool_called = None
|
||||||
|
tool_args_str = ""
|
||||||
|
for msg in messages:
|
||||||
|
if msg.get("role") == "assistant" and msg.get("tool_calls"):
|
||||||
|
for tc_item in msg["tool_calls"]:
|
||||||
|
fn = tc_item.get("function", {})
|
||||||
|
tool_called = fn.get("name", "")
|
||||||
|
tool_args_str = fn.get("arguments", "{}")
|
||||||
|
break
|
||||||
|
break
|
||||||
|
|
||||||
|
if tool_called:
|
||||||
|
result.tool_called = tool_called
|
||||||
|
result.schema_ok = True
|
||||||
|
|
||||||
|
# Check if the right tool was called
|
||||||
|
if tool_called == tc.expected_tool:
|
||||||
|
result.success = True
|
||||||
|
|
||||||
|
# Check if args contain expected substring
|
||||||
|
if tc.expected_params_check:
|
||||||
|
result.tool_args_valid = tc.expected_params_check in tool_args_str
|
||||||
|
else:
|
||||||
|
result.tool_args_valid = True
|
||||||
|
|
||||||
|
# Check if tool executed (look for tool role message)
|
||||||
|
for msg in messages:
|
||||||
|
if msg.get("role") == "tool":
|
||||||
|
content = msg.get("content", "")
|
||||||
|
if content and "error" not in content.lower()[:50]:
|
||||||
|
result.execution_ok = True
|
||||||
|
break
|
||||||
|
elif content:
|
||||||
|
result.execution_ok = True # got a response, even if error
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# No tool call produced — still check if model responded
|
||||||
|
final = conv.get("final_response", "")
|
||||||
|
result.raw_response = final[:200] if final else ""
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
result.error = f"{type(e).__name__}: {str(e)[:200]}"
|
||||||
|
result.latency_s = round(time.time() - t0, 2) if 't0' in dir() else 0
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def generate_report(results: list[CallResult], models: list[str], output_path: Path):
|
||||||
|
"""Generate markdown benchmark report."""
|
||||||
|
now = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M UTC")
|
||||||
|
|
||||||
|
# Aggregate per model
|
||||||
|
stats: dict[str, ModelStats] = {}
|
||||||
|
for m in models:
|
||||||
|
stats[m] = ModelStats(model=m)
|
||||||
|
|
||||||
|
by_category: dict[str, dict[str, list[CallResult]]] = {}
|
||||||
|
|
||||||
|
for r in results:
|
||||||
|
s = stats[r.model]
|
||||||
|
s.total += 1
|
||||||
|
s.schema_ok += int(r.schema_ok)
|
||||||
|
s.exec_ok += int(r.execution_ok)
|
||||||
|
s.latency_sum += r.latency_s
|
||||||
|
if not r.success:
|
||||||
|
s.failures.append(r)
|
||||||
|
|
||||||
|
by_category.setdefault(r.category, {}).setdefault(r.model, []).append(r)
|
||||||
|
|
||||||
|
lines = [
|
||||||
|
f"# Tool-Calling Benchmark Report",
|
||||||
|
f"",
|
||||||
|
f"Generated: {now}",
|
||||||
|
f"Suite: {len(SUITE)} calls across {len(set(tc.category for tc in SUITE))} categories",
|
||||||
|
f"Models tested: {', '.join(models)}",
|
||||||
|
f"",
|
||||||
|
f"## Summary",
|
||||||
|
f"",
|
||||||
|
f"| Metric | {' | '.join(models)} |",
|
||||||
|
f"|--------|{'|'.join('---------' for _ in models)}|",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Schema parse success
|
||||||
|
row = "| Schema parse success | "
|
||||||
|
for m in models:
|
||||||
|
s = stats[m]
|
||||||
|
row += f"{s.schema_ok}/{s.total} ({s.schema_pct:.0f}%) | "
|
||||||
|
lines.append(row)
|
||||||
|
|
||||||
|
# Tool execution success
|
||||||
|
row = "| Tool execution success | "
|
||||||
|
for m in models:
|
||||||
|
s = stats[m]
|
||||||
|
row += f"{s.exec_ok}/{s.total} ({s.exec_pct:.0f}%) | "
|
||||||
|
lines.append(row)
|
||||||
|
|
||||||
|
# Correct tool selected
|
||||||
|
row = "| Correct tool selected | "
|
||||||
|
for m in models:
|
||||||
|
s = stats[m]
|
||||||
|
correct = sum(1 for r in results if r.model == m and r.success)
|
||||||
|
pct = (correct / s.total * 100) if s.total else 0
|
||||||
|
row += f"{correct}/{s.total} ({pct:.0f}%) | "
|
||||||
|
lines.append(row)
|
||||||
|
|
||||||
|
# Avg latency
|
||||||
|
row = "| Avg latency (s) | "
|
||||||
|
for m in models:
|
||||||
|
s = stats[m]
|
||||||
|
row += f"{s.avg_latency:.2f} | "
|
||||||
|
lines.append(row)
|
||||||
|
|
||||||
|
lines.append("")
|
||||||
|
|
||||||
|
# Per-category breakdown
|
||||||
|
lines.append("## Per-Category Breakdown")
|
||||||
|
lines.append("")
|
||||||
|
|
||||||
|
for cat in sorted(by_category.keys()):
|
||||||
|
lines.append(f"### {cat.title()}")
|
||||||
|
lines.append("")
|
||||||
|
lines.append(f"| Metric | {' | '.join(models)} |")
|
||||||
|
lines.append(f"|--------|{'|'.join('---------' for _ in models)}|")
|
||||||
|
|
||||||
|
cat_data = by_category[cat]
|
||||||
|
for metric_name, fn in [
|
||||||
|
("Schema OK", lambda r: r.schema_ok),
|
||||||
|
("Exec OK", lambda r: r.execution_ok),
|
||||||
|
("Correct tool", lambda r: r.success),
|
||||||
|
]:
|
||||||
|
row = f"| {metric_name} | "
|
||||||
|
for m in models:
|
||||||
|
results_m = cat_data.get(m, [])
|
||||||
|
total = len(results_m)
|
||||||
|
ok = sum(1 for r in results_m if fn(r))
|
||||||
|
pct = (ok / total * 100) if total else 0
|
||||||
|
row += f"{ok}/{total} ({pct:.0f}%) | "
|
||||||
|
lines.append(row)
|
||||||
|
|
||||||
|
lines.append("")
|
||||||
|
|
||||||
|
# Failure analysis
|
||||||
|
lines.append("## Failure Analysis")
|
||||||
|
lines.append("")
|
||||||
|
|
||||||
|
any_failures = False
|
||||||
|
for m in models:
|
||||||
|
s = stats[m]
|
||||||
|
if s.failures:
|
||||||
|
any_failures = True
|
||||||
|
lines.append(f"### {m} — {len(s.failures)} failures")
|
||||||
|
lines.append("")
|
||||||
|
lines.append("| Test | Category | Expected | Got | Error |")
|
||||||
|
lines.append("|------|----------|----------|-----|-------|")
|
||||||
|
for r in s.failures:
|
||||||
|
got = r.tool_called or "none"
|
||||||
|
err = r.error or "wrong tool"
|
||||||
|
lines.append(f"| {r.test_id} | {r.category} | {r.expected_tool} | {got} | {err[:60]} |")
|
||||||
|
lines.append("")
|
||||||
|
|
||||||
|
if not any_failures:
|
||||||
|
lines.append("No failures detected.")
|
||||||
|
lines.append("")
|
||||||
|
|
||||||
|
# Raw results JSON
|
||||||
|
lines.append("## Raw Results")
|
||||||
|
lines.append("")
|
||||||
|
lines.append("```json")
|
||||||
|
lines.append(json.dumps([asdict(r) for r in results], indent=2, default=str))
|
||||||
|
lines.append("```")
|
||||||
|
|
||||||
|
report = "\n".join(lines)
|
||||||
|
output_path.write_text(report)
|
||||||
|
return report
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="Tool-calling benchmark")
|
||||||
|
parser.add_argument("--models", nargs="+",
|
||||||
|
default=["nous:gia-3/gemma-4-31b", "nous:mimo-v2-pro"],
|
||||||
|
help="Model specs to test (provider:model)")
|
||||||
|
parser.add_argument("--limit", type=int, default=0,
|
||||||
|
help="Run only first N tests (0 = all)")
|
||||||
|
parser.add_argument("--category", type=str, default="",
|
||||||
|
help="Run only tests in this category")
|
||||||
|
parser.add_argument("--output", type=str, default="",
|
||||||
|
help="Output report path (default: benchmarks/gemma4-tool-calling-YYYY-MM-DD.md)")
|
||||||
|
parser.add_argument("--dry-run", action="store_true",
|
||||||
|
help="Print test cases without running them")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Filter suite
|
||||||
|
suite = SUITE[:]
|
||||||
|
if args.category:
|
||||||
|
suite = [tc for tc in suite if tc.category == args.category]
|
||||||
|
if args.limit > 0:
|
||||||
|
suite = suite[:args.limit]
|
||||||
|
|
||||||
|
if args.dry_run:
|
||||||
|
print(f"Would run {len(suite)} tests:")
|
||||||
|
for tc in suite:
|
||||||
|
print(f" [{tc.category:8s}] {tc.id}: {tc.expected_tool} — {tc.prompt[:60]}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Setup
|
||||||
|
setup_test_files()
|
||||||
|
date_str = datetime.now().strftime("%Y-%m-%d")
|
||||||
|
output_path = Path(args.output) if args.output else REPO_ROOT / "benchmarks" / f"gemma4-tool-calling-{date_str}.md"
|
||||||
|
|
||||||
|
# Parse model specs
|
||||||
|
model_specs = []
|
||||||
|
for spec in args.models:
|
||||||
|
parts = spec.split(":", 1)
|
||||||
|
provider = parts[0]
|
||||||
|
model_name = parts[1] if len(parts) > 1 else parts[0]
|
||||||
|
model_specs.append((provider, model_name, spec))
|
||||||
|
|
||||||
|
print(f"Benchmark: {len(suite)} tests × {len(model_specs)} models = {len(suite) * len(model_specs)} calls")
|
||||||
|
print(f"Output: {output_path}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
all_results: list[CallResult] = []
|
||||||
|
|
||||||
|
for provider, model_name, full_spec in model_specs:
|
||||||
|
print(f"── {full_spec} {'─' * (50 - len(full_spec))}")
|
||||||
|
model_results = []
|
||||||
|
|
||||||
|
for i, tc in enumerate(suite, 1):
|
||||||
|
sys.stdout.write(f"\r [{i:3d}/{len(suite)}] {tc.id:10s} {tc.category:8s} → {tc.expected_tool:20s}")
|
||||||
|
sys.stdout.flush()
|
||||||
|
|
||||||
|
r = run_single_test(tc, full_spec, provider)
|
||||||
|
model_results.append(r)
|
||||||
|
|
||||||
|
status = "✓" if r.success else "✗"
|
||||||
|
sys.stdout.write(f" {status} ({r.latency_s:.1f}s)")
|
||||||
|
sys.stdout.write("\n")
|
||||||
|
|
||||||
|
all_results.extend(model_results)
|
||||||
|
|
||||||
|
# Quick stats
|
||||||
|
ok = sum(1 for r in model_results if r.success)
|
||||||
|
print(f" Result: {ok}/{len(model_results)} correct tool selected ({ok/len(model_results)*100:.0f}%)")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Generate report
|
||||||
|
model_names = [spec for _, _, spec in model_specs]
|
||||||
|
report = generate_report(all_results, model_names, output_path)
|
||||||
|
print(f"Report written to {output_path}")
|
||||||
|
|
||||||
|
# Exit code: 0 if all pass, 1 if any failures
|
||||||
|
total_fail = sum(1 for r in all_results if not r.success)
|
||||||
|
sys.exit(1 if total_fail > 0 else 0)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -28,7 +28,6 @@ from typing import Dict, Any, List, Optional, Tuple
|
|||||||
|
|
||||||
from tools.registry import discover_builtin_tools, registry
|
from tools.registry import discover_builtin_tools, registry
|
||||||
from toolsets import resolve_toolset, validate_toolset
|
from toolsets import resolve_toolset, validate_toolset
|
||||||
from agent.tool_orchestrator import orchestrator
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -500,13 +499,13 @@ def handle_function_call(
|
|||||||
# Prefer the caller-provided list so subagents can't overwrite
|
# Prefer the caller-provided list so subagents can't overwrite
|
||||||
# the parent's tool set via the process-global.
|
# the parent's tool set via the process-global.
|
||||||
sandbox_enabled = enabled_tools if enabled_tools is not None else _last_resolved_tool_names
|
sandbox_enabled = enabled_tools if enabled_tools is not None else _last_resolved_tool_names
|
||||||
result = orchestrator.dispatch(
|
result = registry.dispatch(
|
||||||
function_name, function_args,
|
function_name, function_args,
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
enabled_tools=sandbox_enabled,
|
enabled_tools=sandbox_enabled,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
result = orchestrator.dispatch(
|
result = registry.dispatch(
|
||||||
function_name, function_args,
|
function_name, function_args,
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
user_task=user_task,
|
user_task=user_task,
|
||||||
|
|||||||
@@ -1,202 +0,0 @@
|
|||||||
"""Tests for agent.privacy_filter — PII stripping before remote API calls."""
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from agent.privacy_filter import (
|
|
||||||
PrivacyFilter,
|
|
||||||
RedactionReport,
|
|
||||||
Sensitivity,
|
|
||||||
sanitize_messages,
|
|
||||||
quick_sanitize,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestPrivacyFilterSanitizeText:
|
|
||||||
"""Test single-text sanitization."""
|
|
||||||
|
|
||||||
def test_no_pii_returns_clean(self):
|
|
||||||
pf = PrivacyFilter()
|
|
||||||
text = "The weather in Paris is nice today."
|
|
||||||
cleaned, redactions = pf.sanitize_text(text)
|
|
||||||
assert cleaned == text
|
|
||||||
assert redactions == []
|
|
||||||
|
|
||||||
def test_email_redacted(self):
|
|
||||||
pf = PrivacyFilter()
|
|
||||||
text = "Send report to alice@example.com by Friday."
|
|
||||||
cleaned, redactions = pf.sanitize_text(text)
|
|
||||||
assert "alice@example.com" not in cleaned
|
|
||||||
assert "[REDACTED-EMAIL]" in cleaned
|
|
||||||
assert any(r["type"] == "email_address" for r in redactions)
|
|
||||||
|
|
||||||
def test_phone_redacted(self):
|
|
||||||
pf = PrivacyFilter()
|
|
||||||
text = "Call me at 555-123-4567 when ready."
|
|
||||||
cleaned, redactions = pf.sanitize_text(text)
|
|
||||||
assert "555-123-4567" not in cleaned
|
|
||||||
assert "[REDACTED-PHONE]" in cleaned
|
|
||||||
|
|
||||||
def test_api_key_redacted(self):
|
|
||||||
pf = PrivacyFilter()
|
|
||||||
text = 'api_key = "sk-proj-abcdefghij1234567890abcdefghij1234567890"'
|
|
||||||
cleaned, redactions = pf.sanitize_text(text)
|
|
||||||
assert "sk-proj-" not in cleaned
|
|
||||||
assert any(r["sensitivity"] == "CRITICAL" for r in redactions)
|
|
||||||
|
|
||||||
def test_github_token_redacted(self):
|
|
||||||
pf = PrivacyFilter()
|
|
||||||
text = "Use ghp_1234567890abcdefghijklmnopqrstuvwxyz1234 for auth"
|
|
||||||
cleaned, redactions = pf.sanitize_text(text)
|
|
||||||
assert "ghp_" not in cleaned
|
|
||||||
assert any(r["type"] == "github_token" for r in redactions)
|
|
||||||
|
|
||||||
def test_ethereum_address_redacted(self):
|
|
||||||
pf = PrivacyFilter()
|
|
||||||
text = "Send to 0x742d35Cc6634C0532925a3b844Bc9e7595f2bD18 please"
|
|
||||||
cleaned, redactions = pf.sanitize_text(text)
|
|
||||||
assert "0x742d" not in cleaned
|
|
||||||
assert any(r["type"] == "ethereum_address" for r in redactions)
|
|
||||||
|
|
||||||
def test_user_home_path_redacted(self):
|
|
||||||
pf = PrivacyFilter()
|
|
||||||
text = "Read file at /Users/alice/Documents/secret.txt"
|
|
||||||
cleaned, redactions = pf.sanitize_text(text)
|
|
||||||
assert "alice" not in cleaned
|
|
||||||
assert "[REDACTED-USER]" in cleaned
|
|
||||||
|
|
||||||
def test_multiple_pii_types(self):
|
|
||||||
pf = PrivacyFilter()
|
|
||||||
text = (
|
|
||||||
"Contact john@test.com or call 555-999-1234. "
|
|
||||||
"The API key is sk-abcdefghijklmnopqrstuvwxyz1234567890."
|
|
||||||
)
|
|
||||||
cleaned, redactions = pf.sanitize_text(text)
|
|
||||||
assert "john@test.com" not in cleaned
|
|
||||||
assert "555-999-1234" not in cleaned
|
|
||||||
assert "sk-abcd" not in cleaned
|
|
||||||
assert len(redactions) >= 3
|
|
||||||
|
|
||||||
|
|
||||||
class TestPrivacyFilterSanitizeMessages:
|
|
||||||
"""Test message-list sanitization."""
|
|
||||||
|
|
||||||
def test_sanitize_user_message(self):
|
|
||||||
pf = PrivacyFilter()
|
|
||||||
messages = [
|
|
||||||
{"role": "system", "content": "You are helpful."},
|
|
||||||
{"role": "user", "content": "Email me at bob@test.com with results."},
|
|
||||||
]
|
|
||||||
safe, report = pf.sanitize_messages(messages)
|
|
||||||
assert report.redacted_messages == 1
|
|
||||||
assert "bob@test.com" not in safe[1]["content"]
|
|
||||||
assert "[REDACTED-EMAIL]" in safe[1]["content"]
|
|
||||||
# System message unchanged
|
|
||||||
assert safe[0]["content"] == "You are helpful."
|
|
||||||
|
|
||||||
def test_no_redaction_needed(self):
|
|
||||||
pf = PrivacyFilter()
|
|
||||||
messages = [
|
|
||||||
{"role": "user", "content": "What is 2+2?"},
|
|
||||||
{"role": "assistant", "content": "4"},
|
|
||||||
]
|
|
||||||
safe, report = pf.sanitize_messages(messages)
|
|
||||||
assert report.redacted_messages == 0
|
|
||||||
assert not report.had_redactions
|
|
||||||
|
|
||||||
def test_assistant_messages_also_sanitized(self):
|
|
||||||
pf = PrivacyFilter()
|
|
||||||
messages = [
|
|
||||||
{"role": "assistant", "content": "Your email admin@corp.com was found."},
|
|
||||||
]
|
|
||||||
safe, report = pf.sanitize_messages(messages)
|
|
||||||
assert report.redacted_messages == 1
|
|
||||||
assert "admin@corp.com" not in safe[0]["content"]
|
|
||||||
|
|
||||||
def test_tool_messages_not_sanitized(self):
|
|
||||||
pf = PrivacyFilter()
|
|
||||||
messages = [
|
|
||||||
{"role": "tool", "content": "Result: user@test.com found"},
|
|
||||||
]
|
|
||||||
safe, report = pf.sanitize_messages(messages)
|
|
||||||
assert report.redacted_messages == 0
|
|
||||||
assert safe[0]["content"] == "Result: user@test.com found"
|
|
||||||
|
|
||||||
|
|
||||||
class TestShouldUseLocalOnly:
|
|
||||||
"""Test the local-only routing decision."""
|
|
||||||
|
|
||||||
def test_normal_text_allows_remote(self):
|
|
||||||
pf = PrivacyFilter()
|
|
||||||
block, reason = pf.should_use_local_only("Summarize this article about Python.")
|
|
||||||
assert not block
|
|
||||||
|
|
||||||
def test_critical_secret_blocks_remote(self):
|
|
||||||
pf = PrivacyFilter()
|
|
||||||
text = "Here is the API key: sk-abcdefghijklmnopqrstuvwxyz1234567890"
|
|
||||||
block, reason = pf.should_use_local_only(text)
|
|
||||||
assert block
|
|
||||||
assert "critical" in reason.lower()
|
|
||||||
|
|
||||||
def test_multiple_high_sensitivity_blocks(self):
|
|
||||||
pf = PrivacyFilter()
|
|
||||||
# 3+ high-sensitivity patterns
|
|
||||||
text = (
|
|
||||||
"Card: 4111-1111-1111-1111, "
|
|
||||||
"SSN: 123-45-6789, "
|
|
||||||
"BTC: 1A1zP1eP5QGefi2DMPTfTL5SLmv7DivfNa, "
|
|
||||||
"ETH: 0x742d35Cc6634C0532925a3b844Bc9e7595f2bD18"
|
|
||||||
)
|
|
||||||
block, reason = pf.should_use_local_only(text)
|
|
||||||
assert block
|
|
||||||
|
|
||||||
|
|
||||||
class TestAggressiveMode:
|
|
||||||
"""Test aggressive filtering mode."""
|
|
||||||
|
|
||||||
def test_aggressive_redacts_internal_ips(self):
|
|
||||||
pf = PrivacyFilter(aggressive_mode=True)
|
|
||||||
text = "Server at 192.168.1.100 is responding."
|
|
||||||
cleaned, redactions = pf.sanitize_text(text)
|
|
||||||
assert "192.168.1.100" not in cleaned
|
|
||||||
assert any(r["type"] == "internal_ip" for r in redactions)
|
|
||||||
|
|
||||||
def test_normal_does_not_redact_ips(self):
|
|
||||||
pf = PrivacyFilter(aggressive_mode=False)
|
|
||||||
text = "Server at 192.168.1.100 is responding."
|
|
||||||
cleaned, redactions = pf.sanitize_text(text)
|
|
||||||
assert "192.168.1.100" in cleaned # IP preserved in normal mode
|
|
||||||
|
|
||||||
|
|
||||||
class TestConvenienceFunctions:
|
|
||||||
"""Test module-level convenience functions."""
|
|
||||||
|
|
||||||
def test_quick_sanitize(self):
|
|
||||||
text = "Contact alice@example.com for details"
|
|
||||||
result = quick_sanitize(text)
|
|
||||||
assert "alice@example.com" not in result
|
|
||||||
assert "[REDACTED-EMAIL]" in result
|
|
||||||
|
|
||||||
def test_sanitize_messages_convenience(self):
|
|
||||||
messages = [{"role": "user", "content": "Call 555-000-1234"}]
|
|
||||||
safe, report = sanitize_messages(messages)
|
|
||||||
assert report.redacted_messages == 1
|
|
||||||
|
|
||||||
|
|
||||||
class TestRedactionReport:
|
|
||||||
"""Test the reporting structure."""
|
|
||||||
|
|
||||||
def test_summary_no_redactions(self):
|
|
||||||
report = RedactionReport(total_messages=3, redacted_messages=0)
|
|
||||||
assert "No PII" in report.summary()
|
|
||||||
|
|
||||||
def test_summary_with_redactions(self):
|
|
||||||
report = RedactionReport(
|
|
||||||
total_messages=2,
|
|
||||||
redacted_messages=1,
|
|
||||||
redactions=[
|
|
||||||
{"type": "email_address", "sensitivity": "MEDIUM", "count": 2},
|
|
||||||
{"type": "phone_number_us", "sensitivity": "MEDIUM", "count": 1},
|
|
||||||
],
|
|
||||||
)
|
|
||||||
summary = report.summary()
|
|
||||||
assert "1/2" in summary
|
|
||||||
assert "email_address" in summary
|
|
||||||
@@ -1,132 +0,0 @@
|
|||||||
"""Tests for A2A agent card — Issue #819."""
|
|
||||||
import json
|
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
|
||||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
|
||||||
|
|
||||||
from agent.agent_card import (
|
|
||||||
AgentSkill, AgentCapabilities, AgentCard,
|
|
||||||
validate_agent_card, build_agent_card, get_agent_card_json,
|
|
||||||
_load_skills_from_directory
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestAgentSkill:
|
|
||||||
def test_creation(self):
|
|
||||||
skill = AgentSkill(id="code", name="Code", tags=["python"])
|
|
||||||
assert skill.id == "code"
|
|
||||||
assert "python" in skill.tags
|
|
||||||
|
|
||||||
|
|
||||||
class TestAgentCapabilities:
|
|
||||||
def test_defaults(self):
|
|
||||||
caps = AgentCapabilities()
|
|
||||||
assert caps.streaming == True
|
|
||||||
assert caps.push_notifications == False
|
|
||||||
|
|
||||||
|
|
||||||
class TestAgentCard:
|
|
||||||
def test_to_dict(self):
|
|
||||||
card = AgentCard(name="timmy", description="test", url="http://localhost:8642")
|
|
||||||
d = card.to_dict()
|
|
||||||
assert d["name"] == "timmy"
|
|
||||||
assert "defaultInputModes" in d
|
|
||||||
|
|
||||||
def test_to_json(self):
|
|
||||||
card = AgentCard(name="timmy", description="test", url="http://localhost:8642")
|
|
||||||
j = card.to_json()
|
|
||||||
parsed = json.loads(j)
|
|
||||||
assert parsed["name"] == "timmy"
|
|
||||||
|
|
||||||
|
|
||||||
class TestValidation:
|
|
||||||
def test_valid_card(self):
|
|
||||||
card = AgentCard(name="timmy", description="test", url="http://localhost:8642")
|
|
||||||
errors = validate_agent_card(card)
|
|
||||||
assert len(errors) == 0
|
|
||||||
|
|
||||||
def test_missing_name(self):
|
|
||||||
card = AgentCard(name="", description="test", url="http://localhost:8642")
|
|
||||||
errors = validate_agent_card(card)
|
|
||||||
assert any("name" in e for e in errors)
|
|
||||||
|
|
||||||
def test_missing_url(self):
|
|
||||||
card = AgentCard(name="timmy", description="test", url="")
|
|
||||||
errors = validate_agent_card(card)
|
|
||||||
assert any("url" in e for e in errors)
|
|
||||||
|
|
||||||
def test_invalid_input_mode(self):
|
|
||||||
card = AgentCard(
|
|
||||||
name="timmy", description="test", url="http://localhost:8642",
|
|
||||||
default_input_modes=["invalid/mode"]
|
|
||||||
)
|
|
||||||
errors = validate_agent_card(card)
|
|
||||||
assert any("invalid input mode" in e for e in errors)
|
|
||||||
|
|
||||||
def test_skill_missing_id(self):
|
|
||||||
card = AgentCard(
|
|
||||||
name="timmy", description="test", url="http://localhost:8642",
|
|
||||||
skills=[AgentSkill(id="", name="test")]
|
|
||||||
)
|
|
||||||
errors = validate_agent_card(card)
|
|
||||||
assert any("skill missing id" in e for e in errors)
|
|
||||||
|
|
||||||
|
|
||||||
class TestBuildAgentCard:
|
|
||||||
def test_builds_valid_card(self):
|
|
||||||
card = build_agent_card()
|
|
||||||
assert card.name
|
|
||||||
assert card.url
|
|
||||||
errors = validate_agent_card(card)
|
|
||||||
assert len(errors) == 0
|
|
||||||
|
|
||||||
def test_explicit_params_override(self):
|
|
||||||
card = build_agent_card(name="custom", description="custom desc")
|
|
||||||
assert card.name == "custom"
|
|
||||||
assert card.description == "custom desc"
|
|
||||||
|
|
||||||
def test_extra_skills(self):
|
|
||||||
extra = [AgentSkill(id="extra", name="Extra")]
|
|
||||||
card = build_agent_card(extra_skills=extra)
|
|
||||||
assert any(s.id == "extra" for s in card.skills)
|
|
||||||
|
|
||||||
|
|
||||||
class TestGetAgentCardJson:
|
|
||||||
def test_returns_valid_json(self):
|
|
||||||
j = get_agent_card_json()
|
|
||||||
parsed = json.loads(j)
|
|
||||||
assert "name" in parsed
|
|
||||||
|
|
||||||
def test_graceful_fallback(self):
|
|
||||||
# Even if something fails, should return valid JSON
|
|
||||||
j = get_agent_card_json()
|
|
||||||
assert j # Non-empty
|
|
||||||
|
|
||||||
|
|
||||||
class TestLoadSkills:
|
|
||||||
def test_empty_dir(self, tmp_path):
|
|
||||||
skills = _load_skills_from_directory(tmp_path / "nonexistent")
|
|
||||||
assert len(skills) == 0
|
|
||||||
|
|
||||||
def test_parses_skill_md(self, tmp_path):
|
|
||||||
skill_dir = tmp_path / "test-skill"
|
|
||||||
skill_dir.mkdir()
|
|
||||||
skill_md = skill_dir / "SKILL.md"
|
|
||||||
skill_md.write_text("""---
|
|
||||||
name: Test Skill
|
|
||||||
description: A test skill
|
|
||||||
tags:
|
|
||||||
- test
|
|
||||||
- example
|
|
||||||
---
|
|
||||||
Content here
|
|
||||||
""")
|
|
||||||
skills = _load_skills_from_directory(tmp_path)
|
|
||||||
assert len(skills) == 1
|
|
||||||
assert skills[0].name == "Test Skill"
|
|
||||||
assert "test" in skills[0].tags
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import pytest
|
|
||||||
pytest.main([__file__, "-v"])
|
|
||||||
150
tests/test_batch_executor.py
Normal file
150
tests/test_batch_executor.py
Normal file
@@ -0,0 +1,150 @@
|
|||||||
|
"""Tests for batch tool execution safety classification."""
|
||||||
|
import json
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
|
||||||
|
def _make_tool_call(name: str, args: dict) -> MagicMock:
|
||||||
|
"""Create a mock tool call object."""
|
||||||
|
tc = MagicMock()
|
||||||
|
tc.function.name = name
|
||||||
|
tc.function.arguments = json.dumps(args)
|
||||||
|
tc.id = f"call_{name}_1"
|
||||||
|
return tc
|
||||||
|
|
||||||
|
|
||||||
|
class TestClassification:
|
||||||
|
def test_parallel_safe_read_file(self):
|
||||||
|
from tools.batch_executor import classify_single_tool_call
|
||||||
|
tc = _make_tool_call("read_file", {"path": "README.md"})
|
||||||
|
result = classify_single_tool_call(tc)
|
||||||
|
assert result.tier == "parallel_safe"
|
||||||
|
|
||||||
|
def test_parallel_safe_web_search(self):
|
||||||
|
from tools.batch_executor import classify_single_tool_call
|
||||||
|
tc = _make_tool_call("web_search", {"query": "test"})
|
||||||
|
result = classify_single_tool_call(tc)
|
||||||
|
assert result.tier == "parallel_safe"
|
||||||
|
|
||||||
|
def test_parallel_safe_search_files(self):
|
||||||
|
from tools.batch_executor import classify_single_tool_call
|
||||||
|
tc = _make_tool_call("search_files", {"pattern": "test"})
|
||||||
|
result = classify_single_tool_call(tc)
|
||||||
|
assert result.tier == "parallel_safe"
|
||||||
|
|
||||||
|
def test_never_parallel_clarify(self):
|
||||||
|
from tools.batch_executor import classify_single_tool_call
|
||||||
|
tc = _make_tool_call("clarify", {"question": "test"})
|
||||||
|
result = classify_single_tool_call(tc)
|
||||||
|
assert result.tier == "never_parallel"
|
||||||
|
|
||||||
|
def test_terminal_is_sequential(self):
|
||||||
|
from tools.batch_executor import classify_single_tool_call
|
||||||
|
tc = _make_tool_call("terminal", {"command": "ls -la"})
|
||||||
|
result = classify_single_tool_call(tc)
|
||||||
|
assert result.tier == "sequential"
|
||||||
|
|
||||||
|
def test_terminal_destructive_rm(self):
|
||||||
|
from tools.batch_executor import classify_single_tool_call
|
||||||
|
tc = _make_tool_call("terminal", {"command": "rm -rf /tmp/test"})
|
||||||
|
result = classify_single_tool_call(tc)
|
||||||
|
assert result.tier == "sequential"
|
||||||
|
assert "Destructive" in result.reason
|
||||||
|
|
||||||
|
def test_write_file_is_path_scoped(self):
|
||||||
|
from tools.batch_executor import classify_single_tool_call
|
||||||
|
tc = _make_tool_call("write_file", {"path": "/tmp/test.txt", "content": "hello"})
|
||||||
|
result = classify_single_tool_call(tc)
|
||||||
|
assert result.tier == "path_scoped"
|
||||||
|
|
||||||
|
def test_delegate_is_sequential(self):
|
||||||
|
from tools.batch_executor import classify_single_tool_call
|
||||||
|
tc = _make_tool_call("delegate_task", {"goal": "test"})
|
||||||
|
result = classify_single_tool_call(tc)
|
||||||
|
assert result.tier == "sequential"
|
||||||
|
|
||||||
|
def test_unknown_tool_is_sequential(self):
|
||||||
|
from tools.batch_executor import classify_single_tool_call
|
||||||
|
tc = _make_tool_call("some_unknown_tool", {"arg": "val"})
|
||||||
|
result = classify_single_tool_call(tc)
|
||||||
|
assert result.tier == "sequential"
|
||||||
|
|
||||||
|
|
||||||
|
class TestBatchClassification:
|
||||||
|
def test_all_parallel_stays_parallel(self):
|
||||||
|
from tools.batch_executor import classify_tool_calls
|
||||||
|
tcs = [
|
||||||
|
_make_tool_call("read_file", {"path": f"file{i}.txt"})
|
||||||
|
for i in range(5)
|
||||||
|
]
|
||||||
|
plan = classify_tool_calls(tcs)
|
||||||
|
assert plan.can_parallelize
|
||||||
|
assert len(plan.parallel_batch) == 5
|
||||||
|
assert len(plan.sequential_batch) == 0
|
||||||
|
|
||||||
|
def test_mixed_batch(self):
|
||||||
|
from tools.batch_executor import classify_tool_calls
|
||||||
|
tcs = [
|
||||||
|
_make_tool_call("read_file", {"path": "a.txt"}),
|
||||||
|
_make_tool_call("terminal", {"command": "ls"}),
|
||||||
|
_make_tool_call("web_search", {"query": "test"}),
|
||||||
|
_make_tool_call("delegate_task", {"goal": "test"}),
|
||||||
|
]
|
||||||
|
plan = classify_tool_calls(tcs)
|
||||||
|
# read_file + web_search should be parallel (both parallel_safe)
|
||||||
|
# terminal + delegate_task should be sequential
|
||||||
|
assert len(plan.parallel_batch) >= 2
|
||||||
|
assert len(plan.sequential_batch) >= 2
|
||||||
|
|
||||||
|
def test_clarify_blocks_all(self):
|
||||||
|
from tools.batch_executor import classify_tool_calls
|
||||||
|
tcs = [
|
||||||
|
_make_tool_call("read_file", {"path": "a.txt"}),
|
||||||
|
_make_tool_call("clarify", {"question": "which one?"}),
|
||||||
|
_make_tool_call("web_search", {"query": "test"}),
|
||||||
|
]
|
||||||
|
plan = classify_tool_calls(tcs)
|
||||||
|
clarify_in_seq = any(c.tool_name == "clarify" for c in plan.sequential_batch)
|
||||||
|
assert clarify_in_seq
|
||||||
|
|
||||||
|
def test_overlapping_paths_sequential(self):
|
||||||
|
from tools.batch_executor import classify_tool_calls
|
||||||
|
tcs = [
|
||||||
|
_make_tool_call("write_file", {"path": "/tmp/test/a.txt", "content": "hello"}),
|
||||||
|
_make_tool_call("patch", {"path": "/tmp/test/a.txt", "old_string": "a", "new_string": "b"}),
|
||||||
|
]
|
||||||
|
plan = classify_tool_calls(tcs)
|
||||||
|
# write_file and patch on SAME file -> conflict -> one must be sequential
|
||||||
|
assert len(plan.sequential_batch) >= 1
|
||||||
|
|
||||||
|
|
||||||
|
class TestDestructiveCommands:
|
||||||
|
def test_rm_flagged(self):
|
||||||
|
from tools.batch_executor import is_destructive_command
|
||||||
|
assert is_destructive_command("rm -rf /tmp")
|
||||||
|
assert is_destructive_command("rm file.txt")
|
||||||
|
|
||||||
|
def test_mv_flagged(self):
|
||||||
|
from tools.batch_executor import is_destructive_command
|
||||||
|
assert is_destructive_command("mv old new")
|
||||||
|
|
||||||
|
def test_sed_i_flagged(self):
|
||||||
|
from tools.batch_executor import is_destructive_command
|
||||||
|
assert is_destructive_command("sed -i 's/a/b/g' file")
|
||||||
|
|
||||||
|
def test_redirect_overwrite_flagged(self):
|
||||||
|
from tools.batch_executor import is_destructive_command
|
||||||
|
assert is_destructive_command("echo test > file.txt")
|
||||||
|
|
||||||
|
def test_safe_commands_not_flagged(self):
|
||||||
|
from tools.batch_executor import is_destructive_command
|
||||||
|
assert not is_destructive_command("ls -la")
|
||||||
|
assert not is_destructive_command("cat file.txt")
|
||||||
|
assert not is_destructive_command("echo test >> file.txt") # append is safe
|
||||||
|
|
||||||
|
|
||||||
|
class TestRegistryIntegration:
|
||||||
|
def test_parallel_safe_in_registry(self):
|
||||||
|
from tools.registry import registry
|
||||||
|
safe = registry.get_parallel_safe_tools()
|
||||||
|
assert isinstance(safe, set)
|
||||||
74
tests/test_crisis_tool.py
Normal file
74
tests/test_crisis_tool.py
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
"""Tests for crisis detection tool."""
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
def test_detect_crisis_none():
|
||||||
|
from tools.crisis_tool import detect_crisis
|
||||||
|
result = detect_crisis("Hello, how are you today?")
|
||||||
|
assert result.level == "NONE"
|
||||||
|
assert result.score == 0.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_detect_crisis_critical():
|
||||||
|
from tools.crisis_tool import detect_crisis
|
||||||
|
result = detect_crisis("I want to kill myself")
|
||||||
|
assert result.level == "CRITICAL"
|
||||||
|
assert result.score == 1.0
|
||||||
|
assert len(result.indicators) > 0
|
||||||
|
assert "988" in result.recommended_action
|
||||||
|
|
||||||
|
|
||||||
|
def test_detect_crisis_high():
|
||||||
|
from tools.crisis_tool import detect_crisis
|
||||||
|
result = detect_crisis("I feel so hopeless about everything, nothing matters")
|
||||||
|
assert result.level in ("HIGH", "MEDIUM", "CRITICAL")
|
||||||
|
assert result.score > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_detect_crisis_medium():
|
||||||
|
from tools.crisis_tool import detect_crisis
|
||||||
|
result = detect_crisis("I feel worthless and broken inside")
|
||||||
|
assert result.level in ("MEDIUM", "LOW", "HIGH")
|
||||||
|
assert result.score > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_detect_crisis_low():
|
||||||
|
from tools.crisis_tool import detect_crisis
|
||||||
|
result = detect_crisis("I'm having a tough time and feeling stressed")
|
||||||
|
assert result.level in ("LOW", "NONE")
|
||||||
|
assert result.score >= 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_detect_crisis_empty():
|
||||||
|
from tools.crisis_tool import detect_crisis
|
||||||
|
result = detect_crisis("")
|
||||||
|
assert result.level == "NONE"
|
||||||
|
result2 = detect_crisis(None)
|
||||||
|
assert result2.level == "NONE"
|
||||||
|
|
||||||
|
|
||||||
|
def test_scan_user_message_returns_none_for_safe():
|
||||||
|
from tools.crisis_tool import scan_user_message
|
||||||
|
result = scan_user_message("What's the weather like?")
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_scan_user_message_returns_dict_for_crisis():
|
||||||
|
from tools.crisis_tool import scan_user_message
|
||||||
|
result = scan_user_message("I want to end it all")
|
||||||
|
assert result is not None
|
||||||
|
assert "level" in result
|
||||||
|
assert "compassion_injection" in result
|
||||||
|
assert result["level"] in ("CRITICAL", "HIGH")
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_handler():
|
||||||
|
from tools.crisis_tool import crisis_scan_handler
|
||||||
|
import json
|
||||||
|
result = crisis_scan_handler({"text": "I feel fine, thanks"})
|
||||||
|
data = json.loads(result)
|
||||||
|
assert data["level"] == "NONE"
|
||||||
|
|
||||||
|
result2 = crisis_scan_handler({"text": "I want to die"})
|
||||||
|
data2 = json.loads(result2)
|
||||||
|
assert data2["level"] == "CRITICAL"
|
||||||
@@ -1,190 +0,0 @@
|
|||||||
"""Tests for tools.confirmation_daemon — Human Confirmation Firewall."""
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import time
|
|
||||||
from tools.confirmation_daemon import (
|
|
||||||
ConfirmationDaemon,
|
|
||||||
ConfirmationRequest,
|
|
||||||
ConfirmationStatus,
|
|
||||||
RiskLevel,
|
|
||||||
classify_action,
|
|
||||||
_is_whitelisted,
|
|
||||||
_DEFAULT_WHITELIST,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestClassifyAction:
|
|
||||||
"""Test action risk classification."""
|
|
||||||
|
|
||||||
def test_crypto_tx_is_critical(self):
|
|
||||||
assert classify_action("crypto_tx") == RiskLevel.CRITICAL
|
|
||||||
|
|
||||||
def test_sign_transaction_is_critical(self):
|
|
||||||
assert classify_action("sign_transaction") == RiskLevel.CRITICAL
|
|
||||||
|
|
||||||
def test_send_email_is_high(self):
|
|
||||||
assert classify_action("send_email") == RiskLevel.HIGH
|
|
||||||
|
|
||||||
def test_send_message_is_medium(self):
|
|
||||||
assert classify_action("send_message") == RiskLevel.MEDIUM
|
|
||||||
|
|
||||||
def test_access_calendar_is_low(self):
|
|
||||||
assert classify_action("access_calendar") == RiskLevel.LOW
|
|
||||||
|
|
||||||
def test_unknown_action_is_medium(self):
|
|
||||||
assert classify_action("unknown_action_xyz") == RiskLevel.MEDIUM
|
|
||||||
|
|
||||||
|
|
||||||
class TestWhitelist:
|
|
||||||
"""Test whitelist auto-approval."""
|
|
||||||
|
|
||||||
def test_self_email_is_whitelisted(self):
|
|
||||||
whitelist = dict(_DEFAULT_WHITELIST)
|
|
||||||
payload = {"from": "me@test.com", "to": "me@test.com"}
|
|
||||||
assert _is_whitelisted("send_email", payload, whitelist) is True
|
|
||||||
|
|
||||||
def test_non_whitelisted_recipient_not_approved(self):
|
|
||||||
whitelist = dict(_DEFAULT_WHITELIST)
|
|
||||||
payload = {"to": "random@stranger.com"}
|
|
||||||
assert _is_whitelisted("send_email", payload, whitelist) is False
|
|
||||||
|
|
||||||
def test_whitelisted_contact_approved(self):
|
|
||||||
whitelist = {
|
|
||||||
"send_message": {"targets": ["alice", "bob"]},
|
|
||||||
}
|
|
||||||
assert _is_whitelisted("send_message", {"to": "alice"}, whitelist) is True
|
|
||||||
assert _is_whitelisted("send_message", {"to": "charlie"}, whitelist) is False
|
|
||||||
|
|
||||||
def test_no_whitelist_entry_means_not_whitelisted(self):
|
|
||||||
whitelist = {}
|
|
||||||
assert _is_whitelisted("crypto_tx", {"amount": 1.0}, whitelist) is False
|
|
||||||
|
|
||||||
|
|
||||||
class TestConfirmationRequest:
|
|
||||||
"""Test the request data model."""
|
|
||||||
|
|
||||||
def test_defaults(self):
|
|
||||||
req = ConfirmationRequest(
|
|
||||||
request_id="test-1",
|
|
||||||
action="send_email",
|
|
||||||
description="Test email",
|
|
||||||
risk_level="high",
|
|
||||||
payload={},
|
|
||||||
)
|
|
||||||
assert req.status == ConfirmationStatus.PENDING.value
|
|
||||||
assert req.created_at > 0
|
|
||||||
assert req.expires_at > req.created_at
|
|
||||||
|
|
||||||
def test_is_pending(self):
|
|
||||||
req = ConfirmationRequest(
|
|
||||||
request_id="test-2",
|
|
||||||
action="send_email",
|
|
||||||
description="Test",
|
|
||||||
risk_level="high",
|
|
||||||
payload={},
|
|
||||||
expires_at=time.time() + 300,
|
|
||||||
)
|
|
||||||
assert req.is_pending is True
|
|
||||||
|
|
||||||
def test_is_expired(self):
|
|
||||||
req = ConfirmationRequest(
|
|
||||||
request_id="test-3",
|
|
||||||
action="send_email",
|
|
||||||
description="Test",
|
|
||||||
risk_level="high",
|
|
||||||
payload={},
|
|
||||||
expires_at=time.time() - 10,
|
|
||||||
)
|
|
||||||
assert req.is_expired is True
|
|
||||||
assert req.is_pending is False
|
|
||||||
|
|
||||||
def test_to_dict(self):
|
|
||||||
req = ConfirmationRequest(
|
|
||||||
request_id="test-4",
|
|
||||||
action="send_email",
|
|
||||||
description="Test",
|
|
||||||
risk_level="medium",
|
|
||||||
payload={"to": "a@b.com"},
|
|
||||||
)
|
|
||||||
d = req.to_dict()
|
|
||||||
assert d["request_id"] == "test-4"
|
|
||||||
assert d["action"] == "send_email"
|
|
||||||
assert "is_pending" in d
|
|
||||||
|
|
||||||
|
|
||||||
class TestConfirmationDaemon:
|
|
||||||
"""Test the daemon logic (without HTTP layer)."""
|
|
||||||
|
|
||||||
def test_auto_approve_low_risk(self):
|
|
||||||
daemon = ConfirmationDaemon()
|
|
||||||
req = daemon.request(
|
|
||||||
action="access_calendar",
|
|
||||||
description="Read today's events",
|
|
||||||
risk_level="low",
|
|
||||||
)
|
|
||||||
assert req.status == ConfirmationStatus.AUTO_APPROVED.value
|
|
||||||
|
|
||||||
def test_whitelisted_auto_approves(self):
|
|
||||||
daemon = ConfirmationDaemon()
|
|
||||||
daemon._whitelist = {"send_message": {"targets": ["alice"]}}
|
|
||||||
req = daemon.request(
|
|
||||||
action="send_message",
|
|
||||||
description="Message alice",
|
|
||||||
payload={"to": "alice"},
|
|
||||||
)
|
|
||||||
assert req.status == ConfirmationStatus.AUTO_APPROVED.value
|
|
||||||
|
|
||||||
def test_non_whitelisted_goes_pending(self):
|
|
||||||
daemon = ConfirmationDaemon()
|
|
||||||
daemon._whitelist = {}
|
|
||||||
req = daemon.request(
|
|
||||||
action="send_email",
|
|
||||||
description="Email to stranger",
|
|
||||||
payload={"to": "stranger@test.com"},
|
|
||||||
risk_level="high",
|
|
||||||
)
|
|
||||||
assert req.status == ConfirmationStatus.PENDING.value
|
|
||||||
assert req.is_pending is True
|
|
||||||
|
|
||||||
def test_approve_response(self):
|
|
||||||
daemon = ConfirmationDaemon()
|
|
||||||
daemon._whitelist = {}
|
|
||||||
req = daemon.request(
|
|
||||||
action="send_email",
|
|
||||||
description="Email test",
|
|
||||||
risk_level="high",
|
|
||||||
)
|
|
||||||
result = daemon.respond(req.request_id, approved=True, decided_by="human")
|
|
||||||
assert result.status == ConfirmationStatus.APPROVED.value
|
|
||||||
assert result.decided_by == "human"
|
|
||||||
|
|
||||||
def test_deny_response(self):
|
|
||||||
daemon = ConfirmationDaemon()
|
|
||||||
daemon._whitelist = {}
|
|
||||||
req = daemon.request(
|
|
||||||
action="crypto_tx",
|
|
||||||
description="Send 1 ETH",
|
|
||||||
risk_level="critical",
|
|
||||||
)
|
|
||||||
result = daemon.respond(
|
|
||||||
req.request_id, approved=False, decided_by="human", reason="Too risky"
|
|
||||||
)
|
|
||||||
assert result.status == ConfirmationStatus.DENIED.value
|
|
||||||
assert result.reason == "Too risky"
|
|
||||||
|
|
||||||
def test_get_pending(self):
|
|
||||||
daemon = ConfirmationDaemon()
|
|
||||||
daemon._whitelist = {}
|
|
||||||
daemon.request(action="send_email", description="Test 1", risk_level="high")
|
|
||||||
daemon.request(action="send_email", description="Test 2", risk_level="high")
|
|
||||||
pending = daemon.get_pending()
|
|
||||||
assert len(pending) >= 2
|
|
||||||
|
|
||||||
def test_get_history(self):
|
|
||||||
daemon = ConfirmationDaemon()
|
|
||||||
req = daemon.request(
|
|
||||||
action="access_calendar", description="Test", risk_level="low"
|
|
||||||
)
|
|
||||||
history = daemon.get_history()
|
|
||||||
assert len(history) >= 1
|
|
||||||
assert history[0]["action"] == "access_calendar"
|
|
||||||
@@ -121,19 +121,6 @@ DANGEROUS_PATTERNS = [
|
|||||||
(r'\b(cp|mv|install)\b.*\s/etc/', "copy/move file into /etc/"),
|
(r'\b(cp|mv|install)\b.*\s/etc/', "copy/move file into /etc/"),
|
||||||
(r'\bsed\s+-[^\s]*i.*\s/etc/', "in-place edit of system config"),
|
(r'\bsed\s+-[^\s]*i.*\s/etc/', "in-place edit of system config"),
|
||||||
(r'\bsed\s+--in-place\b.*\s/etc/', "in-place edit of system config (long flag)"),
|
(r'\bsed\s+--in-place\b.*\s/etc/', "in-place edit of system config (long flag)"),
|
||||||
# --- Vitalik's threat model: crypto / financial ---
|
|
||||||
(r'\b(?:bitcoin-cli|ethers\.js|web3|ether\.sendTransaction)\b', "direct crypto transaction tool usage"),
|
|
||||||
(r'\bwget\b.*\b(?:mnemonic|seed\s*phrase|private[_-]?key)\b', "attempting to download crypto credentials"),
|
|
||||||
(r'\bcurl\b.*\b(?:mnemonic|seed\s*phrase|private[_-]?key)\b', "attempting to exfiltrate crypto credentials"),
|
|
||||||
# --- Vitalik's threat model: credential exfiltration ---
|
|
||||||
(r'\b(?:curl|wget|http|nc|ncat|socat)\b.*\b(?:\.env|\.ssh|credentials|secrets|token|api[_-]?key)\b',
|
|
||||||
"attempting to exfiltrate credentials via network"),
|
|
||||||
(r'\bbase64\b.*\|(?:\s*curl|\s*wget)', "base64-encode then network exfiltration"),
|
|
||||||
(r'\bcat\b.*\b(?:\.env|\.ssh/id_rsa|credentials)\b.*\|(?:\s*curl|\s*wget)',
|
|
||||||
"reading secrets and piping to network tool"),
|
|
||||||
# --- Vitalik's threat model: data exfiltration ---
|
|
||||||
(r'\bcurl\b.*-d\s.*\$(?:HOME|USER)', "sending user home directory data to remote"),
|
|
||||||
(r'\bwget\b.*--post-data\s.*\$(?:HOME|USER)', "posting user data to remote"),
|
|
||||||
# Script execution via heredoc — bypasses the -e/-c flag patterns above.
|
# Script execution via heredoc — bypasses the -e/-c flag patterns above.
|
||||||
# `python3 << 'EOF'` feeds arbitrary code via stdin without -c/-e flags.
|
# `python3 << 'EOF'` feeds arbitrary code via stdin without -c/-e flags.
|
||||||
(r'\b(python[23]?|perl|ruby|node)\s+<<', "script execution via heredoc"),
|
(r'\b(python[23]?|perl|ruby|node)\s+<<', "script execution via heredoc"),
|
||||||
|
|||||||
294
tools/batch_executor.py
Normal file
294
tools/batch_executor.py
Normal file
@@ -0,0 +1,294 @@
|
|||||||
|
"""Batch Tool Executor — Parallel safety classification and concurrent execution.
|
||||||
|
|
||||||
|
Provides centralized classification of tool calls into parallel-safe vs sequential,
|
||||||
|
and utilities for batch execution with safety checks.
|
||||||
|
|
||||||
|
Classification tiers:
|
||||||
|
- PARALLEL_SAFE: read-only tools, no shared state (web_search, read_file, etc.)
|
||||||
|
- PATH_SCOPED: file operations that can run concurrently when paths don't overlap
|
||||||
|
- SEQUENTIAL: writes, destructive ops, terminal commands, delegation
|
||||||
|
- NEVER_PARALLEL: clarify (requires user interaction)
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
from tools.batch_executor import classify_tool_calls, BatchExecutionPlan
|
||||||
|
|
||||||
|
plan = classify_tool_calls(tool_calls)
|
||||||
|
if plan.can_parallelize:
|
||||||
|
execute_concurrent(plan.parallel_batch)
|
||||||
|
execute_sequential(plan.sequential_batch)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Callable, Dict, List, Optional, Set, Tuple
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Safety Classification ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
# Tools that can ALWAYS run in parallel (read-only, no shared state)
|
||||||
|
DEFAULT_PARALLEL_SAFE = frozenset({
|
||||||
|
"ha_get_state",
|
||||||
|
"ha_list_entities",
|
||||||
|
"ha_list_services",
|
||||||
|
"read_file",
|
||||||
|
"search_files",
|
||||||
|
"session_search",
|
||||||
|
"skill_view",
|
||||||
|
"skills_list",
|
||||||
|
"vision_analyze",
|
||||||
|
"web_extract",
|
||||||
|
"web_search",
|
||||||
|
"fact_store",
|
||||||
|
"fact_search",
|
||||||
|
"session_search",
|
||||||
|
})
|
||||||
|
|
||||||
|
# File tools that can run concurrently ONLY when paths don't overlap
|
||||||
|
PATH_SCOPED_TOOLS = frozenset({"read_file", "write_file", "patch"})
|
||||||
|
|
||||||
|
# Tools that must NEVER run in parallel (require user interaction, shared mutable state)
|
||||||
|
NEVER_PARALLEL = frozenset({"clarify"})
|
||||||
|
|
||||||
|
# Patterns that indicate terminal commands may modify/delete files
|
||||||
|
DESTRUCTIVE_PATTERNS = re.compile(
|
||||||
|
r"""(?:^|\s|&&|\|\||;|`)(?:
|
||||||
|
rm\s|rmdir\s|
|
||||||
|
mv\s|
|
||||||
|
sed\s+-i|
|
||||||
|
truncate\s|
|
||||||
|
dd\s|
|
||||||
|
shred\s|
|
||||||
|
git\s+(?:reset|clean|checkout)\s
|
||||||
|
)""",
|
||||||
|
re.VERBOSE,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Output redirects that overwrite files (> but not >>)
|
||||||
|
REDIRECT_OVERWRITE = re.compile(r'[^>]>[^>]|^>[^>]')
|
||||||
|
|
||||||
|
|
||||||
|
def is_destructive_command(cmd: str) -> bool:
|
||||||
|
"""Check if a terminal command modifies/deletes files."""
|
||||||
|
if not cmd:
|
||||||
|
return False
|
||||||
|
if DESTRUCTIVE_PATTERNS.search(cmd):
|
||||||
|
return True
|
||||||
|
if REDIRECT_OVERWRITE.search(cmd):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _paths_overlap(path1: Path, path2: Path) -> bool:
|
||||||
|
"""Check if two paths could conflict (one is ancestor of the other)."""
|
||||||
|
try:
|
||||||
|
path1 = path1.resolve()
|
||||||
|
path2 = path2.resolve()
|
||||||
|
return path1 == path2 or path1 in path2.parents or path2 in path1.parents
|
||||||
|
except Exception:
|
||||||
|
return True # conservative: assume overlap
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_path(tool_name: str, args: dict) -> Optional[Path]:
|
||||||
|
"""Extract the target path from tool arguments for path-scoped tools."""
|
||||||
|
if tool_name not in PATH_SCOPED_TOOLS:
|
||||||
|
return None
|
||||||
|
raw_path = args.get("path")
|
||||||
|
if not isinstance(raw_path, str) or not raw_path.strip():
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
return Path(raw_path).expanduser().resolve()
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# ── Classification ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ToolCallClassification:
|
||||||
|
"""Classification result for a single tool call."""
|
||||||
|
tool_name: str
|
||||||
|
args: dict
|
||||||
|
tool_call: Any # the original tool_call object
|
||||||
|
tier: str # "parallel_safe", "path_scoped", "sequential", "never_parallel"
|
||||||
|
reason: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BatchExecutionPlan:
|
||||||
|
"""Plan for executing a batch of tool calls."""
|
||||||
|
classifications: List[ToolCallClassification] = field(default_factory=list)
|
||||||
|
parallel_batch: List[ToolCallClassification] = field(default_factory=list)
|
||||||
|
sequential_batch: List[ToolCallClassification] = field(default_factory=list)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def can_parallelize(self) -> bool:
|
||||||
|
return len(self.parallel_batch) > 1
|
||||||
|
|
||||||
|
@property
|
||||||
|
def total(self) -> int:
|
||||||
|
return len(self.classifications)
|
||||||
|
|
||||||
|
|
||||||
|
def classify_single_tool_call(
|
||||||
|
tool_call: Any,
|
||||||
|
extra_parallel_safe: Set[str] = None,
|
||||||
|
) -> ToolCallClassification:
|
||||||
|
"""Classify a single tool call into its safety tier."""
|
||||||
|
tool_name = tool_call.function.name
|
||||||
|
try:
|
||||||
|
args = json.loads(tool_call.function.arguments)
|
||||||
|
except Exception:
|
||||||
|
return ToolCallClassification(
|
||||||
|
tool_name=tool_name, args={}, tool_call=tool_call,
|
||||||
|
tier="sequential", reason="Could not parse arguments"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not isinstance(args, dict):
|
||||||
|
return ToolCallClassification(
|
||||||
|
tool_name=tool_name, args=args, tool_call=tool_call,
|
||||||
|
tier="sequential", reason="Non-dict arguments"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check never-parallel
|
||||||
|
if tool_name in NEVER_PARALLEL:
|
||||||
|
return ToolCallClassification(
|
||||||
|
tool_name=tool_name, args=args, tool_call=tool_call,
|
||||||
|
tier="never_parallel", reason="Requires user interaction"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check parallel-safe FIRST (before path_scoped) so read_file/search_files
|
||||||
|
# get classified as parallel_safe even though they have paths
|
||||||
|
parallel_safe_set = DEFAULT_PARALLEL_SAFE
|
||||||
|
if extra_parallel_safe:
|
||||||
|
parallel_safe_set = parallel_safe_set | extra_parallel_safe
|
||||||
|
|
||||||
|
if tool_name in parallel_safe_set:
|
||||||
|
return ToolCallClassification(
|
||||||
|
tool_name=tool_name, args=args, tool_call=tool_call,
|
||||||
|
tier="parallel_safe", reason="Read-only, no shared state"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check terminal commands for destructive operations
|
||||||
|
if tool_name == "terminal":
|
||||||
|
cmd = args.get("command", "")
|
||||||
|
if is_destructive_command(cmd):
|
||||||
|
return ToolCallClassification(
|
||||||
|
tool_name=tool_name, args=args, tool_call=tool_call,
|
||||||
|
tier="sequential", reason=f"Destructive command: {cmd[:50]}"
|
||||||
|
)
|
||||||
|
return ToolCallClassification(
|
||||||
|
tool_name=tool_name, args=args, tool_call=tool_call,
|
||||||
|
tier="sequential", reason="Terminal command (conservative)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check path-scoped tools (write_file, patch — not read_file which is parallel_safe)
|
||||||
|
if tool_name in PATH_SCOPED_TOOLS:
|
||||||
|
path = _extract_path(tool_name, args)
|
||||||
|
if path:
|
||||||
|
return ToolCallClassification(
|
||||||
|
tool_name=tool_name, args=args, tool_call=tool_call,
|
||||||
|
tier="path_scoped", reason=f"Path: {path}"
|
||||||
|
)
|
||||||
|
return ToolCallClassification(
|
||||||
|
tool_name=tool_name, args=args, tool_call=tool_call,
|
||||||
|
tier="sequential", reason="Path-scoped but no path found"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Default: sequential (conservative)
|
||||||
|
return ToolCallClassification(
|
||||||
|
tool_name=tool_name, args=args, tool_call=tool_call,
|
||||||
|
tier="sequential", reason="Not classified as parallel-safe"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def classify_tool_calls(
|
||||||
|
tool_calls: list,
|
||||||
|
extra_parallel_safe: Set[str] = None,
|
||||||
|
) -> BatchExecutionPlan:
|
||||||
|
"""Classify a batch of tool calls and produce an execution plan."""
|
||||||
|
plan = BatchExecutionPlan()
|
||||||
|
|
||||||
|
reserved_paths: List[Path] = []
|
||||||
|
|
||||||
|
for tc in tool_calls:
|
||||||
|
classification = classify_single_tool_call(tc, extra_parallel_safe)
|
||||||
|
plan.classifications.append(classification)
|
||||||
|
|
||||||
|
if classification.tier == "never_parallel":
|
||||||
|
plan.sequential_batch.append(classification)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if classification.tier == "sequential":
|
||||||
|
plan.sequential_batch.append(classification)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if classification.tier == "path_scoped":
|
||||||
|
path = _extract_path(classification.tool_name, classification.args)
|
||||||
|
if path is None:
|
||||||
|
classification.tier = "sequential"
|
||||||
|
classification.reason = "Path extraction failed"
|
||||||
|
plan.sequential_batch.append(classification)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check for path conflicts with already-scheduled parallel calls
|
||||||
|
conflict = any(_paths_overlap(path, existing) for existing in reserved_paths)
|
||||||
|
if conflict:
|
||||||
|
classification.tier = "sequential"
|
||||||
|
classification.reason = f"Path conflict: {path}"
|
||||||
|
plan.sequential_batch.append(classification)
|
||||||
|
else:
|
||||||
|
reserved_paths.append(path)
|
||||||
|
plan.parallel_batch.append(classification)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if classification.tier == "parallel_safe":
|
||||||
|
plan.parallel_batch.append(classification)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Fallback
|
||||||
|
plan.sequential_batch.append(classification)
|
||||||
|
|
||||||
|
return plan
|
||||||
|
|
||||||
|
|
||||||
|
# ── Concurrent Execution ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def execute_parallel_batch(
|
||||||
|
batch: List[ToolCallClassification],
|
||||||
|
invoke_fn: Callable,
|
||||||
|
max_workers: int = 8,
|
||||||
|
) -> List[Tuple[str, str]]:
|
||||||
|
"""Execute parallel-safe tool calls concurrently.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch: List of classified tool calls (parallel_safe or path_scoped)
|
||||||
|
invoke_fn: Function(tool_name, args) -> result_string
|
||||||
|
max_workers: Max concurrent threads
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of (tool_call_id, result_string) tuples
|
||||||
|
"""
|
||||||
|
results = []
|
||||||
|
|
||||||
|
with ThreadPoolExecutor(max_workers=min(max_workers, len(batch))) as executor:
|
||||||
|
future_to_tc = {}
|
||||||
|
for tc in batch:
|
||||||
|
future = executor.submit(invoke_fn, tc.tool_name, tc.args)
|
||||||
|
future_to_tc[future] = tc
|
||||||
|
|
||||||
|
for future in as_completed(future_to_tc):
|
||||||
|
tc = future_to_tc[future]
|
||||||
|
try:
|
||||||
|
result = future.result()
|
||||||
|
except Exception as e:
|
||||||
|
result = json.dumps({"error": str(e)})
|
||||||
|
tool_call_id = getattr(tc.tool_call, "id", None) or ""
|
||||||
|
results.append((tool_call_id, result))
|
||||||
|
|
||||||
|
return results
|
||||||
@@ -1,615 +0,0 @@
|
|||||||
"""Human Confirmation Daemon — HTTP server for two-factor action approval.
|
|
||||||
|
|
||||||
Implements Vitalik's Pattern 1: "The new 'two-factor confirmation' is that
|
|
||||||
the two factors are the human and the LLM."
|
|
||||||
|
|
||||||
This daemon runs on localhost:6000 and provides a simple HTTP API for the
|
|
||||||
agent to request human approval before executing high-risk actions.
|
|
||||||
|
|
||||||
Threat model:
|
|
||||||
- LLM jailbreaks: Remote content "hacking" the LLM to perform malicious actions
|
|
||||||
- LLM accidents: LLM accidentally performing dangerous operations
|
|
||||||
- The human acts as the second factor — the agent proposes, the human disposes
|
|
||||||
|
|
||||||
Architecture:
|
|
||||||
- Agent detects high-risk action → POST /confirm with action details
|
|
||||||
- Daemon stores pending request, sends notification to user
|
|
||||||
- User approves/denies via POST /respond (Telegram, CLI, or direct HTTP)
|
|
||||||
- Agent receives decision and proceeds or aborts
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
# Start daemon (usually managed by gateway)
|
|
||||||
from tools.confirmation_daemon import ConfirmationDaemon
|
|
||||||
daemon = ConfirmationDaemon(port=6000)
|
|
||||||
daemon.start()
|
|
||||||
|
|
||||||
# Request approval (from agent code)
|
|
||||||
from tools.confirmation_daemon import request_confirmation
|
|
||||||
approved = request_confirmation(
|
|
||||||
action="send_email",
|
|
||||||
description="Send email to alice@example.com",
|
|
||||||
risk_level="high",
|
|
||||||
payload={"to": "alice@example.com", "subject": "Meeting notes"},
|
|
||||||
timeout=300,
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import threading
|
|
||||||
import time
|
|
||||||
import uuid
|
|
||||||
from dataclasses import dataclass, field, asdict
|
|
||||||
from enum import Enum, auto
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class RiskLevel(Enum):
|
|
||||||
"""Risk classification for actions requiring confirmation."""
|
|
||||||
LOW = "low" # Log only, no confirmation needed
|
|
||||||
MEDIUM = "medium" # Confirm for non-whitelisted targets
|
|
||||||
HIGH = "high" # Always confirm
|
|
||||||
CRITICAL = "critical" # Always confirm + require explicit reason
|
|
||||||
|
|
||||||
|
|
||||||
class ConfirmationStatus(Enum):
|
|
||||||
"""Status of a pending confirmation request."""
|
|
||||||
PENDING = "pending"
|
|
||||||
APPROVED = "approved"
|
|
||||||
DENIED = "denied"
|
|
||||||
EXPIRED = "expired"
|
|
||||||
AUTO_APPROVED = "auto_approved"
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ConfirmationRequest:
|
|
||||||
"""A request for human confirmation of a high-risk action."""
|
|
||||||
request_id: str
|
|
||||||
action: str # Action type: send_email, send_message, crypto_tx, etc.
|
|
||||||
description: str # Human-readable description of what will happen
|
|
||||||
risk_level: str # low, medium, high, critical
|
|
||||||
payload: Dict[str, Any] # Action-specific data (sanitized)
|
|
||||||
session_key: str = "" # Session that initiated the request
|
|
||||||
created_at: float = 0.0
|
|
||||||
expires_at: float = 0.0
|
|
||||||
status: str = ConfirmationStatus.PENDING.value
|
|
||||||
decided_at: float = 0.0
|
|
||||||
decided_by: str = "" # "human", "auto", "whitelist"
|
|
||||||
reason: str = "" # Optional reason for denial
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
if not self.created_at:
|
|
||||||
self.created_at = time.time()
|
|
||||||
if not self.expires_at:
|
|
||||||
self.expires_at = self.created_at + 300 # 5 min default
|
|
||||||
if not self.request_id:
|
|
||||||
self.request_id = str(uuid.uuid4())[:12]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def is_expired(self) -> bool:
|
|
||||||
return time.time() > self.expires_at
|
|
||||||
|
|
||||||
@property
|
|
||||||
def is_pending(self) -> bool:
|
|
||||||
return self.status == ConfirmationStatus.PENDING.value and not self.is_expired
|
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
|
||||||
d = asdict(self)
|
|
||||||
d["is_expired"] = self.is_expired
|
|
||||||
d["is_pending"] = self.is_pending
|
|
||||||
return d
|
|
||||||
|
|
||||||
|
|
||||||
# =========================================================================
|
|
||||||
# Action categories (Vitalik's threat model)
|
|
||||||
# =========================================================================
|
|
||||||
|
|
||||||
ACTION_CATEGORIES = {
|
|
||||||
# Messaging — outbound communication to external parties
|
|
||||||
"send_email": RiskLevel.HIGH,
|
|
||||||
"send_message": RiskLevel.MEDIUM, # Depends on recipient
|
|
||||||
"send_signal": RiskLevel.HIGH,
|
|
||||||
"send_telegram": RiskLevel.MEDIUM,
|
|
||||||
"send_discord": RiskLevel.MEDIUM,
|
|
||||||
"post_social": RiskLevel.HIGH,
|
|
||||||
|
|
||||||
# Financial / crypto
|
|
||||||
"crypto_tx": RiskLevel.CRITICAL,
|
|
||||||
"sign_transaction": RiskLevel.CRITICAL,
|
|
||||||
"access_wallet": RiskLevel.CRITICAL,
|
|
||||||
"modify_balance": RiskLevel.CRITICAL,
|
|
||||||
|
|
||||||
# System modification
|
|
||||||
"install_software": RiskLevel.HIGH,
|
|
||||||
"modify_system_config": RiskLevel.HIGH,
|
|
||||||
"modify_firewall": RiskLevel.CRITICAL,
|
|
||||||
"add_ssh_key": RiskLevel.CRITICAL,
|
|
||||||
"create_user": RiskLevel.CRITICAL,
|
|
||||||
|
|
||||||
# Data access
|
|
||||||
"access_contacts": RiskLevel.MEDIUM,
|
|
||||||
"access_calendar": RiskLevel.LOW,
|
|
||||||
"read_private_files": RiskLevel.MEDIUM,
|
|
||||||
"upload_data": RiskLevel.HIGH,
|
|
||||||
"share_credentials": RiskLevel.CRITICAL,
|
|
||||||
|
|
||||||
# Network
|
|
||||||
"open_port": RiskLevel.HIGH,
|
|
||||||
"modify_dns": RiskLevel.HIGH,
|
|
||||||
"expose_service": RiskLevel.CRITICAL,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Default: any unrecognized action is MEDIUM risk
|
|
||||||
DEFAULT_RISK_LEVEL = RiskLevel.MEDIUM
|
|
||||||
|
|
||||||
|
|
||||||
def classify_action(action: str) -> RiskLevel:
|
|
||||||
"""Classify an action by its risk level."""
|
|
||||||
return ACTION_CATEGORIES.get(action, DEFAULT_RISK_LEVEL)
|
|
||||||
|
|
||||||
|
|
||||||
# =========================================================================
|
|
||||||
# Whitelist configuration
|
|
||||||
# =========================================================================
|
|
||||||
|
|
||||||
_DEFAULT_WHITELIST = {
|
|
||||||
"send_message": {
|
|
||||||
"targets": [], # Contact names/IDs that don't need confirmation
|
|
||||||
},
|
|
||||||
"send_email": {
|
|
||||||
"targets": [], # Email addresses that don't need confirmation
|
|
||||||
"self_only": True, # send-to-self always allowed
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _load_whitelist() -> Dict[str, Any]:
|
|
||||||
"""Load action whitelist from config."""
|
|
||||||
config_path = Path.home() / ".hermes" / "approval_whitelist.json"
|
|
||||||
if config_path.exists():
|
|
||||||
try:
|
|
||||||
with open(config_path) as f:
|
|
||||||
return json.load(f)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning("Failed to load approval whitelist: %s", e)
|
|
||||||
return dict(_DEFAULT_WHITELIST)
|
|
||||||
|
|
||||||
|
|
||||||
def _is_whitelisted(action: str, payload: Dict[str, Any], whitelist: Dict) -> bool:
|
|
||||||
"""Check if an action is pre-approved by the whitelist."""
|
|
||||||
action_config = whitelist.get(action, {})
|
|
||||||
if not action_config:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Check target-based whitelist
|
|
||||||
targets = action_config.get("targets", [])
|
|
||||||
target = payload.get("to") or payload.get("recipient") or payload.get("target", "")
|
|
||||||
if target and target in targets:
|
|
||||||
return True
|
|
||||||
|
|
||||||
# Self-only email
|
|
||||||
if action_config.get("self_only") and action == "send_email":
|
|
||||||
sender = payload.get("from", "")
|
|
||||||
recipient = payload.get("to", "")
|
|
||||||
if sender and recipient and sender.lower() == recipient.lower():
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
# =========================================================================
|
|
||||||
# Confirmation daemon
|
|
||||||
# =========================================================================
|
|
||||||
|
|
||||||
class ConfirmationDaemon:
|
|
||||||
"""HTTP daemon for human confirmation of high-risk actions.
|
|
||||||
|
|
||||||
Runs on localhost:PORT (default 6000). Provides:
|
|
||||||
- POST /confirm — agent requests human approval
|
|
||||||
- POST /respond — human approves/denies
|
|
||||||
- GET /pending — list pending requests
|
|
||||||
- GET /health — health check
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
host: str = "127.0.0.1",
|
|
||||||
port: int = 6000,
|
|
||||||
default_timeout: int = 300,
|
|
||||||
notify_callback: Optional[Callable] = None,
|
|
||||||
):
|
|
||||||
self.host = host
|
|
||||||
self.port = port
|
|
||||||
self.default_timeout = default_timeout
|
|
||||||
self.notify_callback = notify_callback
|
|
||||||
self._pending: Dict[str, ConfirmationRequest] = {}
|
|
||||||
self._history: List[ConfirmationRequest] = []
|
|
||||||
self._lock = threading.Lock()
|
|
||||||
self._whitelist = _load_whitelist()
|
|
||||||
self._app = None
|
|
||||||
self._runner = None
|
|
||||||
|
|
||||||
def request(
|
|
||||||
self,
|
|
||||||
action: str,
|
|
||||||
description: str,
|
|
||||||
payload: Optional[Dict[str, Any]] = None,
|
|
||||||
risk_level: Optional[str] = None,
|
|
||||||
session_key: str = "",
|
|
||||||
timeout: Optional[int] = None,
|
|
||||||
) -> ConfirmationRequest:
|
|
||||||
"""Create a confirmation request.
|
|
||||||
|
|
||||||
Returns the request. Check .status to see if it was immediately
|
|
||||||
auto-approved (whitelisted) or is pending human review.
|
|
||||||
"""
|
|
||||||
payload = payload or {}
|
|
||||||
|
|
||||||
# Classify risk if not specified
|
|
||||||
if risk_level is None:
|
|
||||||
risk_level = classify_action(action).value
|
|
||||||
|
|
||||||
# Check whitelist
|
|
||||||
if risk_level in ("low",) or _is_whitelisted(action, payload, self._whitelist):
|
|
||||||
req = ConfirmationRequest(
|
|
||||||
request_id=str(uuid.uuid4())[:12],
|
|
||||||
action=action,
|
|
||||||
description=description,
|
|
||||||
risk_level=risk_level,
|
|
||||||
payload=payload,
|
|
||||||
session_key=session_key,
|
|
||||||
expires_at=time.time() + (timeout or self.default_timeout),
|
|
||||||
status=ConfirmationStatus.AUTO_APPROVED.value,
|
|
||||||
decided_at=time.time(),
|
|
||||||
decided_by="whitelist",
|
|
||||||
)
|
|
||||||
with self._lock:
|
|
||||||
self._history.append(req)
|
|
||||||
logger.info("Auto-approved whitelisted action: %s", action)
|
|
||||||
return req
|
|
||||||
|
|
||||||
# Create pending request
|
|
||||||
req = ConfirmationRequest(
|
|
||||||
request_id=str(uuid.uuid4())[:12],
|
|
||||||
action=action,
|
|
||||||
description=description,
|
|
||||||
risk_level=risk_level,
|
|
||||||
payload=payload,
|
|
||||||
session_key=session_key,
|
|
||||||
expires_at=time.time() + (timeout or self.default_timeout),
|
|
||||||
)
|
|
||||||
|
|
||||||
with self._lock:
|
|
||||||
self._pending[req.request_id] = req
|
|
||||||
|
|
||||||
# Notify human
|
|
||||||
if self.notify_callback:
|
|
||||||
try:
|
|
||||||
self.notify_callback(req.to_dict())
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning("Confirmation notify callback failed: %s", e)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"Confirmation request %s: %s (%s risk) — waiting for human",
|
|
||||||
req.request_id, action, risk_level,
|
|
||||||
)
|
|
||||||
return req
|
|
||||||
|
|
||||||
def respond(
|
|
||||||
self,
|
|
||||||
request_id: str,
|
|
||||||
approved: bool,
|
|
||||||
decided_by: str = "human",
|
|
||||||
reason: str = "",
|
|
||||||
) -> Optional[ConfirmationRequest]:
|
|
||||||
"""Record a human decision on a pending request."""
|
|
||||||
with self._lock:
|
|
||||||
req = self._pending.get(request_id)
|
|
||||||
if not req:
|
|
||||||
logger.warning("Confirmation respond: unknown request %s", request_id)
|
|
||||||
return None
|
|
||||||
if not req.is_pending:
|
|
||||||
logger.warning("Confirmation respond: request %s already decided", request_id)
|
|
||||||
return req
|
|
||||||
|
|
||||||
req.status = (
|
|
||||||
ConfirmationStatus.APPROVED.value if approved
|
|
||||||
else ConfirmationStatus.DENIED.value
|
|
||||||
)
|
|
||||||
req.decided_at = time.time()
|
|
||||||
req.decided_by = decided_by
|
|
||||||
req.reason = reason
|
|
||||||
|
|
||||||
# Move to history
|
|
||||||
del self._pending[request_id]
|
|
||||||
self._history.append(req)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"Confirmation %s: %s by %s",
|
|
||||||
request_id, "APPROVED" if approved else "DENIED", decided_by,
|
|
||||||
)
|
|
||||||
return req
|
|
||||||
|
|
||||||
def wait_for_decision(
|
|
||||||
self, request_id: str, timeout: Optional[float] = None
|
|
||||||
) -> ConfirmationRequest:
|
|
||||||
"""Block until a decision is made or timeout expires."""
|
|
||||||
deadline = time.time() + (timeout or self.default_timeout)
|
|
||||||
while time.time() < deadline:
|
|
||||||
with self._lock:
|
|
||||||
req = self._pending.get(request_id)
|
|
||||||
if req and not req.is_pending:
|
|
||||||
return req
|
|
||||||
if req and req.is_expired:
|
|
||||||
req.status = ConfirmationStatus.EXPIRED.value
|
|
||||||
del self._pending[request_id]
|
|
||||||
self._history.append(req)
|
|
||||||
return req
|
|
||||||
time.sleep(0.5)
|
|
||||||
|
|
||||||
# Timeout
|
|
||||||
with self._lock:
|
|
||||||
req = self._pending.pop(request_id, None)
|
|
||||||
if req:
|
|
||||||
req.status = ConfirmationStatus.EXPIRED.value
|
|
||||||
self._history.append(req)
|
|
||||||
return req
|
|
||||||
|
|
||||||
# Shouldn't reach here
|
|
||||||
return ConfirmationRequest(
|
|
||||||
request_id=request_id,
|
|
||||||
action="unknown",
|
|
||||||
description="Request not found",
|
|
||||||
risk_level="high",
|
|
||||||
payload={},
|
|
||||||
status=ConfirmationStatus.EXPIRED.value,
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_pending(self) -> List[Dict[str, Any]]:
|
|
||||||
"""Return list of pending confirmation requests."""
|
|
||||||
self._expire_old()
|
|
||||||
with self._lock:
|
|
||||||
return [r.to_dict() for r in self._pending.values() if r.is_pending]
|
|
||||||
|
|
||||||
def get_history(self, limit: int = 50) -> List[Dict[str, Any]]:
|
|
||||||
"""Return recent confirmation history."""
|
|
||||||
with self._lock:
|
|
||||||
return [r.to_dict() for r in self._history[-limit:]]
|
|
||||||
|
|
||||||
def _expire_old(self) -> None:
|
|
||||||
"""Move expired requests to history."""
|
|
||||||
now = time.time()
|
|
||||||
with self._lock:
|
|
||||||
expired = [
|
|
||||||
rid for rid, req in self._pending.items()
|
|
||||||
if now > req.expires_at
|
|
||||||
]
|
|
||||||
for rid in expired:
|
|
||||||
req = self._pending.pop(rid)
|
|
||||||
req.status = ConfirmationStatus.EXPIRED.value
|
|
||||||
self._history.append(req)
|
|
||||||
|
|
||||||
# --- aiohttp HTTP API ---
|
|
||||||
|
|
||||||
async def _handle_health(self, request):
|
|
||||||
from aiohttp import web
|
|
||||||
return web.json_response({
|
|
||||||
"status": "ok",
|
|
||||||
"service": "hermes-confirmation-daemon",
|
|
||||||
"pending": len(self._pending),
|
|
||||||
})
|
|
||||||
|
|
||||||
async def _handle_confirm(self, request):
|
|
||||||
from aiohttp import web
|
|
||||||
try:
|
|
||||||
body = await request.json()
|
|
||||||
except Exception:
|
|
||||||
return web.json_response({"error": "invalid JSON"}, status=400)
|
|
||||||
|
|
||||||
action = body.get("action", "")
|
|
||||||
description = body.get("description", "")
|
|
||||||
if not action or not description:
|
|
||||||
return web.json_response(
|
|
||||||
{"error": "action and description required"}, status=400
|
|
||||||
)
|
|
||||||
|
|
||||||
req = self.request(
|
|
||||||
action=action,
|
|
||||||
description=description,
|
|
||||||
payload=body.get("payload", {}),
|
|
||||||
risk_level=body.get("risk_level"),
|
|
||||||
session_key=body.get("session_key", ""),
|
|
||||||
timeout=body.get("timeout"),
|
|
||||||
)
|
|
||||||
|
|
||||||
# If auto-approved, return immediately
|
|
||||||
if req.status != ConfirmationStatus.PENDING.value:
|
|
||||||
return web.json_response({
|
|
||||||
"request_id": req.request_id,
|
|
||||||
"status": req.status,
|
|
||||||
"decided_by": req.decided_by,
|
|
||||||
})
|
|
||||||
|
|
||||||
# Otherwise, wait for human decision (with timeout)
|
|
||||||
timeout = min(body.get("timeout", self.default_timeout), 600)
|
|
||||||
result = self.wait_for_decision(req.request_id, timeout=timeout)
|
|
||||||
|
|
||||||
return web.json_response({
|
|
||||||
"request_id": result.request_id,
|
|
||||||
"status": result.status,
|
|
||||||
"decided_by": result.decided_by,
|
|
||||||
"reason": result.reason,
|
|
||||||
})
|
|
||||||
|
|
||||||
async def _handle_respond(self, request):
|
|
||||||
from aiohttp import web
|
|
||||||
try:
|
|
||||||
body = await request.json()
|
|
||||||
except Exception:
|
|
||||||
return web.json_response({"error": "invalid JSON"}, status=400)
|
|
||||||
|
|
||||||
request_id = body.get("request_id", "")
|
|
||||||
approved = body.get("approved")
|
|
||||||
if not request_id or approved is None:
|
|
||||||
return web.json_response(
|
|
||||||
{"error": "request_id and approved required"}, status=400
|
|
||||||
)
|
|
||||||
|
|
||||||
result = self.respond(
|
|
||||||
request_id=request_id,
|
|
||||||
approved=bool(approved),
|
|
||||||
decided_by=body.get("decided_by", "human"),
|
|
||||||
reason=body.get("reason", ""),
|
|
||||||
)
|
|
||||||
|
|
||||||
if not result:
|
|
||||||
return web.json_response({"error": "unknown request"}, status=404)
|
|
||||||
|
|
||||||
return web.json_response({
|
|
||||||
"request_id": result.request_id,
|
|
||||||
"status": result.status,
|
|
||||||
})
|
|
||||||
|
|
||||||
async def _handle_pending(self, request):
|
|
||||||
from aiohttp import web
|
|
||||||
return web.json_response({"pending": self.get_pending()})
|
|
||||||
|
|
||||||
def _build_app(self):
|
|
||||||
"""Build the aiohttp application."""
|
|
||||||
from aiohttp import web
|
|
||||||
|
|
||||||
app = web.Application()
|
|
||||||
app.router.add_get("/health", self._handle_health)
|
|
||||||
app.router.add_post("/confirm", self._handle_confirm)
|
|
||||||
app.router.add_post("/respond", self._handle_respond)
|
|
||||||
app.router.add_get("/pending", self._handle_pending)
|
|
||||||
self._app = app
|
|
||||||
return app
|
|
||||||
|
|
||||||
async def start_async(self) -> None:
|
|
||||||
"""Start the daemon as an async server."""
|
|
||||||
from aiohttp import web
|
|
||||||
|
|
||||||
app = self._build_app()
|
|
||||||
self._runner = web.AppRunner(app)
|
|
||||||
await self._runner.setup()
|
|
||||||
site = web.TCPSite(self._runner, self.host, self.port)
|
|
||||||
await site.start()
|
|
||||||
logger.info("Confirmation daemon listening on %s:%d", self.host, self.port)
|
|
||||||
|
|
||||||
async def stop_async(self) -> None:
|
|
||||||
"""Stop the daemon."""
|
|
||||||
if self._runner:
|
|
||||||
await self._runner.cleanup()
|
|
||||||
self._runner = None
|
|
||||||
|
|
||||||
def start(self) -> None:
|
|
||||||
"""Start daemon in a background thread (blocking caller)."""
|
|
||||||
def _run():
|
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(loop)
|
|
||||||
loop.run_until_complete(self.start_async())
|
|
||||||
loop.run_forever()
|
|
||||||
|
|
||||||
t = threading.Thread(target=_run, daemon=True, name="confirmation-daemon")
|
|
||||||
t.start()
|
|
||||||
logger.info("Confirmation daemon started in background thread")
|
|
||||||
|
|
||||||
def start_blocking(self) -> None:
|
|
||||||
"""Start daemon and block (for standalone use)."""
|
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(loop)
|
|
||||||
loop.run_until_complete(self.start_async())
|
|
||||||
try:
|
|
||||||
loop.run_forever()
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
pass
|
|
||||||
finally:
|
|
||||||
loop.run_until_complete(self.stop_async())
|
|
||||||
|
|
||||||
|
|
||||||
# =========================================================================
|
|
||||||
# Convenience API for agent integration
|
|
||||||
# =========================================================================
|
|
||||||
|
|
||||||
# Global singleton — initialized by gateway or CLI at startup
|
|
||||||
_daemon: Optional[ConfirmationDaemon] = None
|
|
||||||
|
|
||||||
|
|
||||||
def get_daemon() -> Optional[ConfirmationDaemon]:
|
|
||||||
"""Get the global confirmation daemon instance."""
|
|
||||||
return _daemon
|
|
||||||
|
|
||||||
|
|
||||||
def init_daemon(
|
|
||||||
host: str = "127.0.0.1",
|
|
||||||
port: int = 6000,
|
|
||||||
notify_callback: Optional[Callable] = None,
|
|
||||||
) -> ConfirmationDaemon:
|
|
||||||
"""Initialize the global confirmation daemon."""
|
|
||||||
global _daemon
|
|
||||||
_daemon = ConfirmationDaemon(
|
|
||||||
host=host, port=port, notify_callback=notify_callback
|
|
||||||
)
|
|
||||||
return _daemon
|
|
||||||
|
|
||||||
|
|
||||||
def request_confirmation(
|
|
||||||
action: str,
|
|
||||||
description: str,
|
|
||||||
payload: Optional[Dict[str, Any]] = None,
|
|
||||||
risk_level: Optional[str] = None,
|
|
||||||
session_key: str = "",
|
|
||||||
timeout: int = 300,
|
|
||||||
) -> bool:
|
|
||||||
"""Request human confirmation for a high-risk action.
|
|
||||||
|
|
||||||
This is the primary integration point for agent code. It:
|
|
||||||
1. Classifies the action risk level
|
|
||||||
2. Checks the whitelist
|
|
||||||
3. If confirmation needed, blocks until human responds
|
|
||||||
4. Returns True if approved, False if denied/expired
|
|
||||||
|
|
||||||
Args:
|
|
||||||
action: Action type (send_email, crypto_tx, etc.)
|
|
||||||
description: Human-readable description
|
|
||||||
payload: Action-specific data
|
|
||||||
risk_level: Override auto-classification
|
|
||||||
session_key: Session requesting approval
|
|
||||||
timeout: Seconds to wait for human response
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if approved, False if denied or expired.
|
|
||||||
"""
|
|
||||||
daemon = get_daemon()
|
|
||||||
if not daemon:
|
|
||||||
logger.warning(
|
|
||||||
"No confirmation daemon running — DENYING action %s by default. "
|
|
||||||
"Start daemon with init_daemon() or --confirmation-daemon flag.",
|
|
||||||
action,
|
|
||||||
)
|
|
||||||
return False
|
|
||||||
|
|
||||||
req = daemon.request(
|
|
||||||
action=action,
|
|
||||||
description=description,
|
|
||||||
payload=payload,
|
|
||||||
risk_level=risk_level,
|
|
||||||
session_key=session_key,
|
|
||||||
timeout=timeout,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Auto-approved (whitelisted)
|
|
||||||
if req.status == ConfirmationStatus.AUTO_APPROVED.value:
|
|
||||||
return True
|
|
||||||
|
|
||||||
# Wait for human
|
|
||||||
result = daemon.wait_for_decision(req.request_id, timeout=timeout)
|
|
||||||
return result.status == ConfirmationStatus.APPROVED.value
|
|
||||||
405
tools/crisis_tool.py
Normal file
405
tools/crisis_tool.py
Normal file
@@ -0,0 +1,405 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Crisis Detection Tool — the-door integration for hermes-agent.
|
||||||
|
|
||||||
|
Scans user messages for crisis indicators (despair, suicidal ideation)
|
||||||
|
using the-door's canonical detection module. Provides inline crisis
|
||||||
|
detection before each API call and optional escalation logging.
|
||||||
|
|
||||||
|
Follows the-door's design principles:
|
||||||
|
- Never computes the value of a human life
|
||||||
|
- Never suggests someone should die
|
||||||
|
- Always errs on the side of higher risk
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import time
|
||||||
|
import urllib.request
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# ── Detection Logic (from the-door/crisis/detect.py) ──────────────────────
|
||||||
|
# Embedded to avoid requiring the-door as a dependency.
|
||||||
|
# Source: Timmy_Foundation/the-door, crisis/detect.py (canonical)
|
||||||
|
|
||||||
|
CRITICAL_INDICATORS = [
|
||||||
|
r"\bbetter off without me\b",
|
||||||
|
r"\bkill\s*(my)?self\b",
|
||||||
|
r"\bend\s*my\s*life\b",
|
||||||
|
r"\bsuicid(?:al|ed|e)\b",
|
||||||
|
r"\bnot\s+worth\s+living\b",
|
||||||
|
r"\bbetter\s+off\s+dead\b",
|
||||||
|
r"\bend\s+it\s+all\b",
|
||||||
|
r"\bno\s+reason\s+to\s+live\b",
|
||||||
|
r"\bdon\'?t\s+want\s+to\s+live\b",
|
||||||
|
r"\bwant\s+to\s+die\b",
|
||||||
|
r"\bgoing\s+to\s+(?:kill\s+myself|die)\b",
|
||||||
|
r"\bplan\s+to\s+(?:end|kill|die)\b",
|
||||||
|
r"\btired\s+of\s+(?:living|life|existence)\b",
|
||||||
|
r"\bsaying\s+goodbye\s+(?:forever|permanently|one\s+last\s+time)\b",
|
||||||
|
r"\bwrote\s+a\s+suicide\s*(?:note|letter)\b",
|
||||||
|
r"\bgiving\s+away\s+(?:my|all\s+my)\s+(?:stuff|things|possessions?)\s+(?:to|because|—)\b",
|
||||||
|
r"\btied\s+(?:up|down)\s+my\s+(?:loose\s+)?ends",
|
||||||
|
]
|
||||||
|
|
||||||
|
HIGH_INDICATORS = [
|
||||||
|
r"\bdespair\b",
|
||||||
|
r"\bhopeless(?:ly)?\s+(?:about\s+(?:my|this|everything|life)|inside|right\s+now)\b",
|
||||||
|
r"\bno(?!t)\s+(?:one|body|point|hope|way\s+out)\b",
|
||||||
|
r"\bno\s+future\s+(?:for\s+me|ahead|left)\b",
|
||||||
|
r"\beverything\s+is\s+(?:pointless|broken|ruined)\b",
|
||||||
|
r"\bcan\'?t\s+take\s+this\s+anymore\b",
|
||||||
|
r"\bdon\'?t\s+care\s+if\s+I\s+die\b",
|
||||||
|
r"\bwish\s+I\s+(?:was|were)\s+(?:dead|gone|never\s+born)\b",
|
||||||
|
r"\bdon\'?t\s+matter\s+if\s+I\s+exist\b",
|
||||||
|
r"\bno\s+one\s+would\s+(?:care|miss)\b",
|
||||||
|
r"\bworld\s+would?\s+be\s+better\s+without\s+me\b",
|
||||||
|
r"\bin\s+so\s+much\s+(?:pain|agony|suffering|torment|anguish)\b",
|
||||||
|
r"\bcan\'?t\s+see\s+any\s+(?:point|reason|hope|way)\b",
|
||||||
|
r"\bescape\s+from\s*this",
|
||||||
|
r"\bjust\s+want\s+it\s+to\s+stop\b",
|
||||||
|
r"\bnothing\s+left\s+(?:to\s+(?:live\s+for|hope\s+for|give)|inside)\b",
|
||||||
|
r"\bdisappeared\s+forever\b",
|
||||||
|
r"\bfeel(?:s|ing)?\s+(?:so\s+)?hopeless\b",
|
||||||
|
r"\beverything\s+is\s+hopeless\b",
|
||||||
|
r"\bcan\'?t\s+(?:go\s+on|keep\s+going)\b",
|
||||||
|
r"\bgive(?:n)?\s*up\s+(?:on\s+)?(?:life|living|everything)\b",
|
||||||
|
r"\bgive(?:n)?\s*up\s+on\s+myself\b",
|
||||||
|
r"\bno\s*point\s+(?:in\s+)?living\b",
|
||||||
|
r"\bno\s*hope\s+(?:left|remaining)\b",
|
||||||
|
r"\bno\s*way\s*out\b",
|
||||||
|
r"\bfeel(?:s|ing)?\s+trapped\b",
|
||||||
|
r"\btrapped\s+in\s+this\s+(?:situation|life|pain|darkness|hell)\b",
|
||||||
|
r"\btrapped\s+and\s+can\'?t\s+escape\b",
|
||||||
|
r"\bdesperate\s+(?:for\s+)?help\b",
|
||||||
|
r"\bfeel(?:s|ing)?\s+desperate\b",
|
||||||
|
]
|
||||||
|
|
||||||
|
MEDIUM_INDICATORS = [
|
||||||
|
r"\bno\s+hope\b",
|
||||||
|
r"\bforgotten\b",
|
||||||
|
r"\balone\s+in\s+this\b",
|
||||||
|
r"\balways\s+alone\b",
|
||||||
|
r"\bnobody\s+(?:understands|cares)\b",
|
||||||
|
r"\bwish\s+I\s+could\b",
|
||||||
|
r"\bexhaust(?:ed|ion|ing)\b",
|
||||||
|
r"\bnumb\b",
|
||||||
|
r"\bempty\b",
|
||||||
|
r"\bworthless\b",
|
||||||
|
r"\buseless\b",
|
||||||
|
r"\bbroken\b",
|
||||||
|
r"\bdark(ness)?\b",
|
||||||
|
r"\bdepress(?:ed|ion)\b",
|
||||||
|
r"\bcrying\b",
|
||||||
|
r"\btears\b",
|
||||||
|
r"\bsad(ness)?\b",
|
||||||
|
r"\bmiserable\b",
|
||||||
|
r"\boverwhelm(?:ed|ing)\b",
|
||||||
|
r"\bfailing\b",
|
||||||
|
r"\bcannot\s+cope\b",
|
||||||
|
r"\blosing\s*(?:my)?\s*control\b",
|
||||||
|
r"\bdown\s*for\s*the\s*count\b",
|
||||||
|
r"\bsinking\b",
|
||||||
|
r"\bdrowning\b",
|
||||||
|
r"\bhopeless\b",
|
||||||
|
r"\blost\s+all\s+hope\b",
|
||||||
|
r"\bno\s+tomorrow\b",
|
||||||
|
]
|
||||||
|
|
||||||
|
LOW_INDICATORS = [
|
||||||
|
r"\bunhappy\b",
|
||||||
|
r"\bdown\b",
|
||||||
|
r"\btough\s*(?:time|day|week)\b",
|
||||||
|
r"\brough\s+(?:day|week|patch)\b",
|
||||||
|
r"\bstressed\b",
|
||||||
|
r"\bburnout\b",
|
||||||
|
r"\bdifficult\b",
|
||||||
|
r"\bfrustrated\b",
|
||||||
|
r"\bannoyed\b",
|
||||||
|
r"\btired\b",
|
||||||
|
r"\bsad\b",
|
||||||
|
r"\bupset\b",
|
||||||
|
r"\blonely\b",
|
||||||
|
r"\banxious?\b",
|
||||||
|
r"\bnot\s*(?:good|great|okay)\b",
|
||||||
|
r"\bthings\s*are\s*hard\b",
|
||||||
|
r"\bstruggling\b",
|
||||||
|
r"\bnot\s+feeling\s+(?:like\s+)?myself\b",
|
||||||
|
r"\bhard\s*time\b",
|
||||||
|
]
|
||||||
|
|
||||||
|
SCORES = {"CRITICAL": 1.0, "HIGH": 0.75, "MEDIUM": 0.5, "LOW": 0.25, "NONE": 0.0}
|
||||||
|
|
||||||
|
ACTIONS = {
|
||||||
|
"CRITICAL": (
|
||||||
|
"INTERRUPT CONVERSATION. Ask 'Are you safe right now?' "
|
||||||
|
"Provide 988 Suicide & Crisis Lifeline. "
|
||||||
|
"Provide Crisis Text Line (HOME to 741741). "
|
||||||
|
"Stay present. Do not disconnect. Redirect to professional help."
|
||||||
|
),
|
||||||
|
"HIGH": (
|
||||||
|
"Show crisis panel. Ask about safety. Surface 988 number prominently. "
|
||||||
|
"Continue conversation with crisis awareness."
|
||||||
|
),
|
||||||
|
"MEDIUM": (
|
||||||
|
"Increase warmth and presence. Subtly surface help resources. "
|
||||||
|
"Keep conversation anchored in the present."
|
||||||
|
),
|
||||||
|
"LOW": (
|
||||||
|
"Normal conversation with warm undertone. "
|
||||||
|
"No crisis UI elements needed. Remain vigilant."
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CrisisDetectionResult:
|
||||||
|
level: str
|
||||||
|
indicators: List[str] = field(default_factory=list)
|
||||||
|
recommended_action: str = ""
|
||||||
|
score: float = 0.0
|
||||||
|
matches: List[dict] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
def _find_indicators(text: str) -> dict:
|
||||||
|
results = {"CRITICAL": [], "HIGH": [], "MEDIUM": [], "LOW": []}
|
||||||
|
for pattern in CRITICAL_INDICATORS:
|
||||||
|
m = re.search(pattern, text)
|
||||||
|
if m:
|
||||||
|
results["CRITICAL"].append({"pattern": pattern, "start": m.start(), "end": m.end()})
|
||||||
|
for pattern in HIGH_INDICATORS:
|
||||||
|
m = re.search(pattern, text)
|
||||||
|
if m:
|
||||||
|
results["HIGH"].append({"pattern": pattern, "start": m.start(), "end": m.end()})
|
||||||
|
for pattern in MEDIUM_INDICATORS:
|
||||||
|
m = re.search(pattern, text)
|
||||||
|
if m:
|
||||||
|
results["MEDIUM"].append({"pattern": pattern, "start": m.start(), "end": m.end()})
|
||||||
|
for pattern in LOW_INDICATORS:
|
||||||
|
m = re.search(pattern, text)
|
||||||
|
if m:
|
||||||
|
results["LOW"].append({"pattern": pattern, "start": m.start(), "end": m.end()})
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def detect_crisis(text: str) -> CrisisDetectionResult:
|
||||||
|
"""Detect crisis level in a message. Mirrors the-door/crisis/detect.py."""
|
||||||
|
if not text or not text.strip():
|
||||||
|
return CrisisDetectionResult(level="NONE", score=0.0)
|
||||||
|
|
||||||
|
text_lower = text.lower()
|
||||||
|
matches = _find_indicators(text_lower)
|
||||||
|
|
||||||
|
if not matches:
|
||||||
|
return CrisisDetectionResult(level="NONE", score=0.0)
|
||||||
|
|
||||||
|
for tier in ("CRITICAL", "HIGH"):
|
||||||
|
if matches[tier]:
|
||||||
|
tier_matches = matches[tier]
|
||||||
|
patterns = [m["pattern"] for m in tier_matches]
|
||||||
|
return CrisisDetectionResult(
|
||||||
|
level=tier,
|
||||||
|
indicators=patterns,
|
||||||
|
recommended_action=ACTIONS[tier],
|
||||||
|
score=SCORES[tier],
|
||||||
|
matches=tier_matches,
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(matches["MEDIUM"]) >= 2:
|
||||||
|
tier_matches = matches["MEDIUM"]
|
||||||
|
patterns = [m["pattern"] for m in tier_matches]
|
||||||
|
return CrisisDetectionResult(
|
||||||
|
level="MEDIUM",
|
||||||
|
indicators=patterns,
|
||||||
|
recommended_action=ACTIONS["MEDIUM"],
|
||||||
|
score=SCORES["MEDIUM"],
|
||||||
|
matches=tier_matches,
|
||||||
|
)
|
||||||
|
|
||||||
|
if matches["LOW"]:
|
||||||
|
tier_matches = matches["LOW"]
|
||||||
|
patterns = [m["pattern"] for m in tier_matches]
|
||||||
|
return CrisisDetectionResult(
|
||||||
|
level="LOW",
|
||||||
|
indicators=patterns,
|
||||||
|
recommended_action=ACTIONS["LOW"],
|
||||||
|
score=SCORES["LOW"],
|
||||||
|
matches=tier_matches,
|
||||||
|
)
|
||||||
|
|
||||||
|
if matches["MEDIUM"]:
|
||||||
|
tier_matches = matches["MEDIUM"]
|
||||||
|
patterns = [m["pattern"] for m in tier_matches]
|
||||||
|
return CrisisDetectionResult(
|
||||||
|
level="LOW",
|
||||||
|
indicators=patterns,
|
||||||
|
recommended_action=ACTIONS["LOW"],
|
||||||
|
score=SCORES["LOW"],
|
||||||
|
matches=tier_matches,
|
||||||
|
)
|
||||||
|
|
||||||
|
return CrisisDetectionResult(level="NONE", score=0.0)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Escalation Logging ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
BRIDGE_URL = os.environ.get("CRISIS_BRIDGE_URL", "")
|
||||||
|
LOG_PATH = os.path.expanduser("~/.hermes/crisis_escalations.jsonl")
|
||||||
|
|
||||||
|
|
||||||
|
def _log_escalation(result: CrisisDetectionResult, text_preview: str = ""):
|
||||||
|
"""Log crisis detection to local file and optionally to bridge API."""
|
||||||
|
entry = {
|
||||||
|
"ts": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
|
||||||
|
"level": result.level,
|
||||||
|
"score": result.score,
|
||||||
|
"indicators": result.indicators[:3], # truncate for privacy
|
||||||
|
"text_preview": text_preview[:100] if text_preview else "",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Local log
|
||||||
|
try:
|
||||||
|
os.makedirs(os.path.dirname(LOG_PATH), exist_ok=True)
|
||||||
|
with open(LOG_PATH, "a") as f:
|
||||||
|
f.write(json.dumps(entry) + "\n")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to write crisis log: {e}")
|
||||||
|
|
||||||
|
# Bridge API (if configured and level >= HIGH)
|
||||||
|
if BRIDGE_URL and result.score >= 0.75:
|
||||||
|
try:
|
||||||
|
payload = json.dumps(entry).encode()
|
||||||
|
req = urllib.request.Request(
|
||||||
|
f"{BRIDGE_URL}/api/crisis/escalation",
|
||||||
|
data=payload,
|
||||||
|
headers={"Content-Type": "application/json"},
|
||||||
|
method="POST",
|
||||||
|
)
|
||||||
|
urllib.request.urlopen(req, timeout=5)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to post to crisis bridge: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
# ── Tool Handler ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def crisis_scan_handler(args: dict, **kw) -> str:
|
||||||
|
"""Scan text for crisis indicators."""
|
||||||
|
text = args.get("text", "")
|
||||||
|
if not text:
|
||||||
|
return json.dumps({"level": "NONE", "error": "No text provided"})
|
||||||
|
|
||||||
|
result = detect_crisis(text)
|
||||||
|
|
||||||
|
if result.level != "NONE":
|
||||||
|
_log_escalation(result, text)
|
||||||
|
|
||||||
|
return json.dumps({
|
||||||
|
"level": result.level,
|
||||||
|
"score": result.score,
|
||||||
|
"indicators": result.indicators,
|
||||||
|
"recommended_action": result.recommended_action,
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
# ── Gateway Hook ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def scan_user_message(text: str) -> Optional[dict]:
|
||||||
|
"""Scan a user message for crisis indicators. Returns None if no crisis.
|
||||||
|
|
||||||
|
Called by gateway/CLI before each API call to intercept crisis messages.
|
||||||
|
If crisis is detected, the result includes a recommended_action that should
|
||||||
|
be injected into the system prompt.
|
||||||
|
"""
|
||||||
|
if not text:
|
||||||
|
return None
|
||||||
|
|
||||||
|
result = detect_crisis(text)
|
||||||
|
if result.level == "NONE":
|
||||||
|
return None
|
||||||
|
|
||||||
|
_log_escalation(result, text)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"level": result.level,
|
||||||
|
"score": result.score,
|
||||||
|
"indicators": result.indicators,
|
||||||
|
"recommended_action": result.recommended_action,
|
||||||
|
"compassion_injection": _get_compassion_injection(result.level),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _get_compassion_injection(level: str) -> str:
|
||||||
|
"""Return a system prompt injection based on crisis level."""
|
||||||
|
if level == "CRITICAL":
|
||||||
|
return (
|
||||||
|
"CRISIS DETECTED — The user may be in immediate danger. "
|
||||||
|
"Ask 'Are you safe right now?' Stay present. "
|
||||||
|
"Provide 988 Suicide & Crisis Lifeline (call or text 988). "
|
||||||
|
"Provide Crisis Text Line (text HOME to 741741). "
|
||||||
|
"Do not disconnect. Do not compute the value of a human life. "
|
||||||
|
"Jesus saves those who call on His name."
|
||||||
|
)
|
||||||
|
elif level == "HIGH":
|
||||||
|
return (
|
||||||
|
"DISTRESS DETECTED — The user is expressing strong despair. "
|
||||||
|
"Increase warmth and presence. Ask about safety. "
|
||||||
|
"Surface the 988 Lifeline number. Continue with crisis awareness."
|
||||||
|
)
|
||||||
|
elif level == "MEDIUM":
|
||||||
|
return (
|
||||||
|
"EMOTIONAL DISTRESS — The user may be struggling. "
|
||||||
|
"Increase warmth. Subtly surface help resources. "
|
||||||
|
"Keep conversation anchored in the present."
|
||||||
|
)
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
# ── Config Check ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _is_crisis_detection_enabled() -> bool:
|
||||||
|
"""Check if crisis detection is enabled in config."""
|
||||||
|
try:
|
||||||
|
from hermes_cli.config import load_config
|
||||||
|
cfg = load_config()
|
||||||
|
return cfg.get("crisis_detection", True) # default on
|
||||||
|
except Exception:
|
||||||
|
return True # fail open — always detect
|
||||||
|
|
||||||
|
|
||||||
|
# ── Registry ───────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
from tools.registry import registry, tool_error
|
||||||
|
|
||||||
|
CRISIS_SCAN_SCHEMA = {
|
||||||
|
"name": "crisis_scan",
|
||||||
|
"description": (
|
||||||
|
"Scan text for crisis indicators (despair, suicidal ideation). "
|
||||||
|
"Uses the-door's canonical detection. Returns crisis level "
|
||||||
|
"(NONE/LOW/MEDIUM/HIGH/CRITICAL) with recommended actions. "
|
||||||
|
"ALWAYS scan user messages that express emotional distress."
|
||||||
|
),
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"text": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Text to scan for crisis indicators",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["text"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
registry.register(
|
||||||
|
name="crisis_scan",
|
||||||
|
toolset="crisis",
|
||||||
|
schema=CRISIS_SCAN_SCHEMA,
|
||||||
|
handler=lambda args, **kw: crisis_scan_handler(args, **kw),
|
||||||
|
check_fn=lambda: _is_crisis_detection_enabled(),
|
||||||
|
emoji="🆘",
|
||||||
|
)
|
||||||
@@ -79,12 +79,12 @@ class ToolEntry:
|
|||||||
__slots__ = (
|
__slots__ = (
|
||||||
"name", "toolset", "schema", "handler", "check_fn",
|
"name", "toolset", "schema", "handler", "check_fn",
|
||||||
"requires_env", "is_async", "description", "emoji",
|
"requires_env", "is_async", "description", "emoji",
|
||||||
"max_result_size_chars",
|
"max_result_size_chars", "parallel_safe",
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(self, name, toolset, schema, handler, check_fn,
|
def __init__(self, name, toolset, schema, handler, check_fn,
|
||||||
requires_env, is_async, description, emoji,
|
requires_env, is_async, description, emoji,
|
||||||
max_result_size_chars=None):
|
max_result_size_chars=None, parallel_safe=False):
|
||||||
self.name = name
|
self.name = name
|
||||||
self.toolset = toolset
|
self.toolset = toolset
|
||||||
self.schema = schema
|
self.schema = schema
|
||||||
@@ -95,6 +95,7 @@ class ToolEntry:
|
|||||||
self.description = description
|
self.description = description
|
||||||
self.emoji = emoji
|
self.emoji = emoji
|
||||||
self.max_result_size_chars = max_result_size_chars
|
self.max_result_size_chars = max_result_size_chars
|
||||||
|
self.parallel_safe = parallel_safe
|
||||||
|
|
||||||
|
|
||||||
class ToolRegistry:
|
class ToolRegistry:
|
||||||
@@ -185,6 +186,7 @@ class ToolRegistry:
|
|||||||
description: str = "",
|
description: str = "",
|
||||||
emoji: str = "",
|
emoji: str = "",
|
||||||
max_result_size_chars: int | float | None = None,
|
max_result_size_chars: int | float | None = None,
|
||||||
|
parallel_safe: bool = False,
|
||||||
):
|
):
|
||||||
"""Register a tool. Called at module-import time by each tool file."""
|
"""Register a tool. Called at module-import time by each tool file."""
|
||||||
with self._lock:
|
with self._lock:
|
||||||
@@ -222,6 +224,7 @@ class ToolRegistry:
|
|||||||
description=description or schema.get("description", ""),
|
description=description or schema.get("description", ""),
|
||||||
emoji=emoji,
|
emoji=emoji,
|
||||||
max_result_size_chars=max_result_size_chars,
|
max_result_size_chars=max_result_size_chars,
|
||||||
|
parallel_safe=parallel_safe,
|
||||||
)
|
)
|
||||||
if check_fn and toolset not in self._toolset_checks:
|
if check_fn and toolset not in self._toolset_checks:
|
||||||
self._toolset_checks[toolset] = check_fn
|
self._toolset_checks[toolset] = check_fn
|
||||||
@@ -322,6 +325,11 @@ class ToolRegistry:
|
|||||||
from tools.budget_config import DEFAULT_RESULT_SIZE_CHARS
|
from tools.budget_config import DEFAULT_RESULT_SIZE_CHARS
|
||||||
return DEFAULT_RESULT_SIZE_CHARS
|
return DEFAULT_RESULT_SIZE_CHARS
|
||||||
|
|
||||||
|
def get_parallel_safe_tools(self) -> Set[str]:
|
||||||
|
"""Return names of tools marked as parallel_safe."""
|
||||||
|
with self._lock:
|
||||||
|
return {name for name, entry in self._tools.items() if entry.parallel_safe}
|
||||||
|
|
||||||
def get_all_tool_names(self) -> List[str]:
|
def get_all_tool_names(self) -> List[str]:
|
||||||
"""Return sorted list of all registered tool names."""
|
"""Return sorted list of all registered tool names."""
|
||||||
return sorted(entry.name for entry in self._snapshot_entries())
|
return sorted(entry.name for entry in self._snapshot_entries())
|
||||||
|
|||||||
Reference in New Issue
Block a user