Compare commits
2 Commits
fix/701
...
feat/a2a-h
| Author | SHA1 | Date | |
|---|---|---|---|
| 2fba6dade4 | |||
| d7670b98cf |
@@ -1,238 +0,0 @@
|
||||
"""
|
||||
hybrid_search.py — Hybrid search combining FTS5, vector, and HRR.
|
||||
|
||||
Three-backend search router:
|
||||
1. FTS5 (SQLite full-text) — fast keyword matching, always available
|
||||
2. Vector search (Qdrant/ChromaDB) — semantic similarity, optional
|
||||
3. HRR (Holographic Reduced Representations) — compositional recall, optional
|
||||
|
||||
Graceful degradation: if vector or HRR backends are unavailable,
|
||||
falls back to FTS5-only.
|
||||
|
||||
Usage:
|
||||
from agent.hybrid_search import hybrid_search
|
||||
|
||||
results = hybrid_search(query, db=session_db, limit=10)
|
||||
# Returns merged, deduplicated, ranked results from all available backends
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Dict, Any, Optional, Callable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SearchResult:
|
||||
"""Single search result from any backend."""
|
||||
session_id: str
|
||||
message_content: str
|
||||
score: float
|
||||
source: str # "fts5", "vector", "hrr"
|
||||
role: str = ""
|
||||
timestamp: str = ""
|
||||
metadata: dict = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class HybridSearchConfig:
|
||||
"""Configuration for hybrid search."""
|
||||
fts5_enabled: bool = True
|
||||
vector_enabled: bool = False
|
||||
hrr_enabled: bool = False
|
||||
vector_weight: float = 0.4
|
||||
fts5_weight: float = 0.4
|
||||
hrr_weight: float = 0.2
|
||||
dedup_threshold: float = 0.9 # similarity threshold for dedup
|
||||
|
||||
|
||||
def search_fts5(query: str, db, limit: int = 50, role_filter: list = None) -> List[SearchResult]:
|
||||
"""Search using FTS5 full-text search."""
|
||||
try:
|
||||
raw = db.search_messages(
|
||||
query=query,
|
||||
role_filter=role_filter,
|
||||
limit=limit,
|
||||
offset=0,
|
||||
)
|
||||
results = []
|
||||
for r in raw:
|
||||
results.append(SearchResult(
|
||||
session_id=r.get("session_id", ""),
|
||||
message_content=r.get("content", ""),
|
||||
score=r.get("rank", 0.0),
|
||||
source="fts5",
|
||||
role=r.get("role", ""),
|
||||
timestamp=str(r.get("timestamp", "")),
|
||||
))
|
||||
return results
|
||||
except Exception as e:
|
||||
logger.warning(f"FTS5 search failed: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def search_vector(query: str, limit: int = 50) -> List[SearchResult]:
|
||||
"""Search using vector similarity (Qdrant/ChromaDB).
|
||||
|
||||
Returns empty list if vector backend unavailable.
|
||||
"""
|
||||
try:
|
||||
# Try ChromaDB first
|
||||
import chromadb
|
||||
client = chromadb.PersistentClient(path="~/.hermes/memory/chroma")
|
||||
collection = client.get_or_create_collection("sessions")
|
||||
results = collection.query(
|
||||
query_texts=[query],
|
||||
n_results=limit,
|
||||
)
|
||||
search_results = []
|
||||
for i, doc in enumerate(results.get("documents", [[]])[0]):
|
||||
metadata = results.get("metadatas", [[]])[0]
|
||||
meta = metadata[i] if i < len(metadata) else {}
|
||||
distance = results.get("distances", [[]])[0]
|
||||
score = 1.0 - (distance[i] if i < len(distance) else 1.0)
|
||||
search_results.append(SearchResult(
|
||||
session_id=meta.get("session_id", ""),
|
||||
message_content=doc,
|
||||
score=score,
|
||||
source="vector",
|
||||
role=meta.get("role", ""),
|
||||
timestamp=meta.get("timestamp", ""),
|
||||
))
|
||||
return search_results
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
# Try Qdrant
|
||||
from qdrant_client import QdrantClient
|
||||
client = QdrantClient(host="localhost", port=6333)
|
||||
results = client.query_points(
|
||||
collection_name="sessions",
|
||||
query_text=query,
|
||||
limit=limit,
|
||||
)
|
||||
return [
|
||||
SearchResult(
|
||||
session_id=pt.payload.get("session_id", ""),
|
||||
message_content=pt.payload.get("content", ""),
|
||||
score=pt.score,
|
||||
source="vector",
|
||||
role=pt.payload.get("role", ""),
|
||||
)
|
||||
for pt in results.points
|
||||
]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return []
|
||||
|
||||
|
||||
def search_hrr(query: str, limit: int = 50) -> List[SearchResult]:
|
||||
"""Search using Holographic Reduced Representations.
|
||||
|
||||
Returns empty list if HRR backend unavailable.
|
||||
"""
|
||||
try:
|
||||
from agent.holographic_memory import holographic_recall
|
||||
results = holographic_recall(query, limit=limit)
|
||||
return [
|
||||
SearchResult(
|
||||
session_id=r.get("session_id", ""),
|
||||
message_content=r.get("content", ""),
|
||||
score=r.get("binding_score", 0.0),
|
||||
source="hrr",
|
||||
role=r.get("role", ""),
|
||||
)
|
||||
for r in results
|
||||
]
|
||||
except Exception:
|
||||
pass
|
||||
return []
|
||||
|
||||
|
||||
def merge_results(
|
||||
fts5_results: List[SearchResult],
|
||||
vector_results: List[SearchResult],
|
||||
hrr_results: List[SearchResult],
|
||||
config: HybridSearchConfig,
|
||||
limit: int = 10,
|
||||
) -> List[SearchResult]:
|
||||
"""Merge results from multiple backends with weighted scoring."""
|
||||
all_results = []
|
||||
|
||||
# Apply weights
|
||||
for r in fts5_results:
|
||||
r.score *= config.fts5_weight
|
||||
all_results.append(r)
|
||||
for r in vector_results:
|
||||
r.score *= config.vector_weight
|
||||
all_results.append(r)
|
||||
for r in hrr_results:
|
||||
r.score *= config.hrr_weight
|
||||
all_results.append(r)
|
||||
|
||||
# Sort by weighted score
|
||||
all_results.sort(key=lambda r: r.score, reverse=True)
|
||||
|
||||
# Deduplicate by session_id + content similarity
|
||||
seen = set()
|
||||
deduped = []
|
||||
for r in all_results:
|
||||
key = f"{r.session_id}:{r.message_content[:100]}"
|
||||
if key not in seen:
|
||||
seen.add(key)
|
||||
deduped.append(r)
|
||||
|
||||
return deduped[:limit]
|
||||
|
||||
|
||||
def hybrid_search(
|
||||
query: str,
|
||||
db=None,
|
||||
limit: int = 10,
|
||||
role_filter: list = None,
|
||||
config: HybridSearchConfig = None,
|
||||
) -> List[SearchResult]:
|
||||
"""Hybrid search across FTS5, vector, and HRR backends.
|
||||
|
||||
Args:
|
||||
query: Search query
|
||||
db: Session database (for FTS5)
|
||||
limit: Max results
|
||||
role_filter: Filter by message role
|
||||
config: Hybrid search configuration
|
||||
|
||||
Returns:
|
||||
List of SearchResult, ranked by weighted score
|
||||
"""
|
||||
if config is None:
|
||||
config = HybridSearchConfig()
|
||||
|
||||
fts5_results = []
|
||||
vector_results = []
|
||||
hrr_results = []
|
||||
|
||||
# FTS5 (always available if db provided)
|
||||
if config.fts5_enabled and db:
|
||||
fts5_results = search_fts5(query, db, limit=50, role_filter=role_filter)
|
||||
logger.debug(f"FTS5: {len(fts5_results)} results")
|
||||
|
||||
# Vector search (optional)
|
||||
if config.vector_enabled:
|
||||
vector_results = search_vector(query, limit=50)
|
||||
logger.debug(f"Vector: {len(vector_results)} results")
|
||||
|
||||
# HRR (optional)
|
||||
if config.hrr_enabled:
|
||||
hrr_results = search_hrr(query, limit=50)
|
||||
logger.debug(f"HRR: {len(hrr_results)} results")
|
||||
|
||||
# If only FTS5 available, just return those
|
||||
if not vector_results and not hrr_results:
|
||||
return fts5_results[:limit]
|
||||
|
||||
# Merge and rank
|
||||
return merge_results(fts5_results, vector_results, hrr_results, config, limit)
|
||||
@@ -1,353 +0,0 @@
|
||||
"""Privacy Filter — strip PII from context before remote API calls.
|
||||
|
||||
Implements Vitalik's Pattern 2: "A local model can strip out private data
|
||||
before passing the query along to a remote LLM."
|
||||
|
||||
When Hermes routes a request to a cloud provider (Anthropic, OpenRouter, etc.),
|
||||
this module sanitizes the message context to remove personally identifiable
|
||||
information before it leaves the user's machine.
|
||||
|
||||
Threat model (from Vitalik's secure LLM architecture):
|
||||
- Privacy (other): Non-LLM data leakage via search queries, API calls
|
||||
- LLM accidents: LLM accidentally leaking private data in prompts
|
||||
- LLM jailbreaks: Remote content extracting private context
|
||||
|
||||
Usage:
|
||||
from agent.privacy_filter import PrivacyFilter, sanitize_messages
|
||||
|
||||
pf = PrivacyFilter()
|
||||
safe_messages = pf.sanitize_messages(messages)
|
||||
# safe_messages has PII replaced with [REDACTED] tokens
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum, auto
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Sensitivity(Enum):
|
||||
"""Classification of content sensitivity."""
|
||||
PUBLIC = auto() # No PII detected
|
||||
LOW = auto() # Generic references (e.g., city names)
|
||||
MEDIUM = auto() # Personal identifiers (name, email, phone)
|
||||
HIGH = auto() # Secrets, keys, financial data, medical info
|
||||
CRITICAL = auto() # Crypto keys, passwords, SSN patterns
|
||||
|
||||
|
||||
@dataclass
|
||||
class RedactionReport:
|
||||
"""Summary of what was redacted from a message batch."""
|
||||
total_messages: int = 0
|
||||
redacted_messages: int = 0
|
||||
redactions: List[Dict[str, Any]] = field(default_factory=list)
|
||||
max_sensitivity: Sensitivity = Sensitivity.PUBLIC
|
||||
|
||||
@property
|
||||
def had_redactions(self) -> bool:
|
||||
return self.redacted_messages > 0
|
||||
|
||||
def summary(self) -> str:
|
||||
if not self.had_redactions:
|
||||
return "No PII detected — context is clean for remote query."
|
||||
parts = [f"Redacted {self.redacted_messages}/{self.total_messages} messages:"]
|
||||
for r in self.redactions[:10]:
|
||||
parts.append(f" - {r['type']}: {r['count']} occurrence(s)")
|
||||
if len(self.redactions) > 10:
|
||||
parts.append(f" ... and {len(self.redactions) - 10} more types")
|
||||
return "\n".join(parts)
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# PII pattern definitions
|
||||
# =========================================================================
|
||||
|
||||
# Each pattern is (compiled_regex, redaction_type, sensitivity_level, replacement)
|
||||
_PII_PATTERNS: List[Tuple[re.Pattern, str, Sensitivity, str]] = []
|
||||
|
||||
|
||||
def _compile_patterns() -> None:
|
||||
"""Compile PII detection patterns. Called once at module init."""
|
||||
global _PII_PATTERNS
|
||||
if _PII_PATTERNS:
|
||||
return
|
||||
|
||||
raw_patterns = [
|
||||
# --- CRITICAL: secrets and credentials ---
|
||||
(
|
||||
r'(?:api[_-]?key|apikey|secret[_-]?key|access[_-]?token)\s*[:=]\s*["\']?([A-Za-z0-9_\-\.]{20,})["\']?',
|
||||
"api_key_or_token",
|
||||
Sensitivity.CRITICAL,
|
||||
"[REDACTED-API-KEY]",
|
||||
),
|
||||
(
|
||||
r'\b(?:sk-|sk_|pk_|rk_|ak_)[A-Za-z0-9]{20,}\b',
|
||||
"prefixed_secret",
|
||||
Sensitivity.CRITICAL,
|
||||
"[REDACTED-SECRET]",
|
||||
),
|
||||
(
|
||||
r'\b(?:ghp_|gho_|ghu_|ghs_|ghr_)[A-Za-z0-9]{36,}\b',
|
||||
"github_token",
|
||||
Sensitivity.CRITICAL,
|
||||
"[REDACTED-GITHUB-TOKEN]",
|
||||
),
|
||||
(
|
||||
r'\b(?:xox[bposa]-[A-Za-z0-9\-]+)\b',
|
||||
"slack_token",
|
||||
Sensitivity.CRITICAL,
|
||||
"[REDACTED-SLACK-TOKEN]",
|
||||
),
|
||||
(
|
||||
r'(?:password|passwd|pwd)\s*[:=]\s*["\']?([^\s"\']{4,})["\']?',
|
||||
"password",
|
||||
Sensitivity.CRITICAL,
|
||||
"[REDACTED-PASSWORD]",
|
||||
),
|
||||
(
|
||||
r'(?:-----BEGIN (?:RSA |EC |OPENSSH )?PRIVATE KEY-----)',
|
||||
"private_key_block",
|
||||
Sensitivity.CRITICAL,
|
||||
"[REDACTED-PRIVATE-KEY]",
|
||||
),
|
||||
# Ethereum / crypto addresses (42-char hex starting with 0x)
|
||||
(
|
||||
r'\b0x[a-fA-F0-9]{40}\b',
|
||||
"ethereum_address",
|
||||
Sensitivity.HIGH,
|
||||
"[REDACTED-ETH-ADDR]",
|
||||
),
|
||||
# Bitcoin addresses (base58, 25-34 chars starting with 1/3/bc1)
|
||||
(
|
||||
r'\b[13][a-km-zA-HJ-NP-Z1-9]{25,34}\b',
|
||||
"bitcoin_address",
|
||||
Sensitivity.HIGH,
|
||||
"[REDACTED-BTC-ADDR]",
|
||||
),
|
||||
(
|
||||
r'\bbc1[a-zA-HJ-NP-Z0-9]{39,59}\b',
|
||||
"bech32_address",
|
||||
Sensitivity.HIGH,
|
||||
"[REDACTED-BTC-ADDR]",
|
||||
),
|
||||
# --- HIGH: financial ---
|
||||
(
|
||||
r'\b(?:\d{4}[-\s]?){3}\d{4}\b',
|
||||
"credit_card_number",
|
||||
Sensitivity.HIGH,
|
||||
"[REDACTED-CC]",
|
||||
),
|
||||
(
|
||||
r'\b\d{3}-\d{2}-\d{4}\b',
|
||||
"us_ssn",
|
||||
Sensitivity.HIGH,
|
||||
"[REDACTED-SSN]",
|
||||
),
|
||||
# --- MEDIUM: personal identifiers ---
|
||||
# Email addresses
|
||||
(
|
||||
r'\b[A-Za-z0-9._%+\-]+@[A-Za-z0-9.\-]+\.[A-Za-z]{2,}\b',
|
||||
"email_address",
|
||||
Sensitivity.MEDIUM,
|
||||
"[REDACTED-EMAIL]",
|
||||
),
|
||||
# Phone numbers (US/international patterns)
|
||||
(
|
||||
r'\b(?:\+?1[-.\s]?)?\(?\d{3}\)?[-.\s]?\d{3}[-.\s]?\d{4}\b',
|
||||
"phone_number_us",
|
||||
Sensitivity.MEDIUM,
|
||||
"[REDACTED-PHONE]",
|
||||
),
|
||||
(
|
||||
r'\b\+\d{1,3}[-.\s]?\d{4,14}\b',
|
||||
"phone_number_intl",
|
||||
Sensitivity.MEDIUM,
|
||||
"[REDACTED-PHONE]",
|
||||
),
|
||||
# Filesystem paths that reveal user identity
|
||||
(
|
||||
r'(?:/Users/|/home/|C:\\Users\\)([A-Za-z0-9_\-]+)',
|
||||
"user_home_path",
|
||||
Sensitivity.MEDIUM,
|
||||
r"/Users/[REDACTED-USER]",
|
||||
),
|
||||
# --- LOW: environment / system info ---
|
||||
# Internal IPs
|
||||
(
|
||||
r'\b(?:10\.\d{1,3}\.\d{1,3}\.\d{1,3}|172\.(?:1[6-9]|2\d|3[01])\.\d{1,3}\.\d{1,3}|192\.168\.\d{1,3}\.\d{1,3})\b',
|
||||
"internal_ip",
|
||||
Sensitivity.LOW,
|
||||
"[REDACTED-IP]",
|
||||
),
|
||||
]
|
||||
|
||||
_PII_PATTERNS = [
|
||||
(re.compile(pattern, re.IGNORECASE), rtype, sensitivity, replacement)
|
||||
for pattern, rtype, sensitivity, replacement in raw_patterns
|
||||
]
|
||||
|
||||
|
||||
_compile_patterns()
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Sensitive file path patterns (context-aware)
|
||||
# =========================================================================
|
||||
|
||||
_SENSITIVE_PATH_PATTERNS = [
|
||||
re.compile(r'\.(?:env|pem|key|p12|pfx|jks|keystore)\b', re.IGNORECASE),
|
||||
re.compile(r'(?:\.ssh/|\.gnupg/|\.aws/|\.config/gcloud/)', re.IGNORECASE),
|
||||
re.compile(r'(?:wallet|keystore|seed|mnemonic)', re.IGNORECASE),
|
||||
re.compile(r'(?:\.hermes/\.env)', re.IGNORECASE),
|
||||
]
|
||||
|
||||
|
||||
def _classify_path_sensitivity(path: str) -> Sensitivity:
|
||||
"""Check if a file path references sensitive material."""
|
||||
for pat in _SENSITIVE_PATH_PATTERNS:
|
||||
if pat.search(path):
|
||||
return Sensitivity.HIGH
|
||||
return Sensitivity.PUBLIC
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Core filtering
|
||||
# =========================================================================
|
||||
|
||||
class PrivacyFilter:
|
||||
"""Strip PII from message context before remote API calls.
|
||||
|
||||
Integrates with the agent's message pipeline. Call sanitize_messages()
|
||||
before sending context to any cloud LLM provider.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
min_sensitivity: Sensitivity = Sensitivity.MEDIUM,
|
||||
aggressive_mode: bool = False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
min_sensitivity: Only redact PII at or above this level.
|
||||
Default MEDIUM — redacts emails, phones, paths but not IPs.
|
||||
aggressive_mode: If True, also redact file paths and internal IPs.
|
||||
"""
|
||||
self.min_sensitivity = (
|
||||
Sensitivity.LOW if aggressive_mode else min_sensitivity
|
||||
)
|
||||
self.aggressive_mode = aggressive_mode
|
||||
|
||||
def sanitize_text(self, text: str) -> Tuple[str, List[Dict[str, Any]]]:
|
||||
"""Sanitize a single text string. Returns (cleaned_text, redaction_list)."""
|
||||
redactions = []
|
||||
cleaned = text
|
||||
|
||||
for pattern, rtype, sensitivity, replacement in _PII_PATTERNS:
|
||||
if sensitivity.value < self.min_sensitivity.value:
|
||||
continue
|
||||
|
||||
matches = pattern.findall(cleaned)
|
||||
if matches:
|
||||
count = len(matches) if isinstance(matches[0], str) else sum(
|
||||
1 for m in matches if m
|
||||
)
|
||||
if count > 0:
|
||||
cleaned = pattern.sub(replacement, cleaned)
|
||||
redactions.append({
|
||||
"type": rtype,
|
||||
"sensitivity": sensitivity.name,
|
||||
"count": count,
|
||||
})
|
||||
|
||||
return cleaned, redactions
|
||||
|
||||
def sanitize_messages(
|
||||
self, messages: List[Dict[str, Any]]
|
||||
) -> Tuple[List[Dict[str, Any]], RedactionReport]:
|
||||
"""Sanitize a list of OpenAI-format messages.
|
||||
|
||||
Returns (safe_messages, report). System messages are NOT sanitized
|
||||
(they're typically static prompts). Only user and assistant messages
|
||||
with string content are processed.
|
||||
|
||||
Args:
|
||||
messages: List of {"role": ..., "content": ...} dicts.
|
||||
|
||||
Returns:
|
||||
Tuple of (sanitized_messages, redaction_report).
|
||||
"""
|
||||
report = RedactionReport(total_messages=len(messages))
|
||||
safe_messages = []
|
||||
|
||||
for msg in messages:
|
||||
role = msg.get("role", "")
|
||||
content = msg.get("content", "")
|
||||
|
||||
# Only sanitize user/assistant string content
|
||||
if role in ("user", "assistant") and isinstance(content, str) and content:
|
||||
cleaned, redactions = self.sanitize_text(content)
|
||||
if redactions:
|
||||
report.redacted_messages += 1
|
||||
report.redactions.extend(redactions)
|
||||
# Track max sensitivity
|
||||
for r in redactions:
|
||||
s = Sensitivity[r["sensitivity"]]
|
||||
if s.value > report.max_sensitivity.value:
|
||||
report.max_sensitivity = s
|
||||
safe_msg = {**msg, "content": cleaned}
|
||||
safe_messages.append(safe_msg)
|
||||
logger.info(
|
||||
"Privacy filter: redacted %d PII type(s) from %s message",
|
||||
len(redactions), role,
|
||||
)
|
||||
else:
|
||||
safe_messages.append(msg)
|
||||
else:
|
||||
safe_messages.append(msg)
|
||||
|
||||
return safe_messages, report
|
||||
|
||||
def should_use_local_only(self, text: str) -> Tuple[bool, str]:
|
||||
"""Determine if content is too sensitive for any remote call.
|
||||
|
||||
Returns (should_block, reason). If True, the content should only
|
||||
be processed by a local model.
|
||||
"""
|
||||
_, redactions = self.sanitize_text(text)
|
||||
|
||||
critical_count = sum(
|
||||
1 for r in redactions
|
||||
if Sensitivity[r["sensitivity"]] == Sensitivity.CRITICAL
|
||||
)
|
||||
high_count = sum(
|
||||
1 for r in redactions
|
||||
if Sensitivity[r["sensitivity"]] == Sensitivity.HIGH
|
||||
)
|
||||
|
||||
if critical_count > 0:
|
||||
return True, f"Contains {critical_count} critical-secret pattern(s) — local-only"
|
||||
if high_count >= 3:
|
||||
return True, f"Contains {high_count} high-sensitivity pattern(s) — local-only"
|
||||
return False, ""
|
||||
|
||||
|
||||
def sanitize_messages(
|
||||
messages: List[Dict[str, Any]],
|
||||
min_sensitivity: Sensitivity = Sensitivity.MEDIUM,
|
||||
aggressive: bool = False,
|
||||
) -> Tuple[List[Dict[str, Any]], RedactionReport]:
|
||||
"""Convenience function: sanitize messages with default settings."""
|
||||
pf = PrivacyFilter(min_sensitivity=min_sensitivity, aggressive_mode=aggressive)
|
||||
return pf.sanitize_messages(messages)
|
||||
|
||||
|
||||
def quick_sanitize(text: str) -> str:
|
||||
"""Quick sanitize a single string — returns cleaned text only."""
|
||||
pf = PrivacyFilter()
|
||||
cleaned, _ = pf.sanitize_text(text)
|
||||
return cleaned
|
||||
292
hermes_cli/a2a_health.py
Normal file
292
hermes_cli/a2a_health.py
Normal file
@@ -0,0 +1,292 @@
|
||||
"""A2A Health Check — periodic heartbeat monitoring for fleet agents.
|
||||
|
||||
Pings each registered agent's A2A endpoint, records response time and status,
|
||||
tracks consecutive failures, and fires alerts when thresholds are breached.
|
||||
|
||||
Usage:
|
||||
from hermes_cli.a2a_health import HealthMonitor
|
||||
monitor = HealthMonitor()
|
||||
results = monitor.check_all()
|
||||
monitor.print_dashboard(results)
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from dataclasses import dataclass, field, asdict
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import urllib.request
|
||||
import urllib.error
|
||||
|
||||
|
||||
STATE_FILE = os.path.expanduser("~/.hermes/a2a_health.json")
|
||||
FLEET_FILE = os.path.expanduser("~/.hermes/fleet_agents.json")
|
||||
|
||||
# Thresholds
|
||||
CONSECUTIVE_FAILURE_THRESHOLD = 3
|
||||
RESPONSE_TIME_THRESHOLD_SEC = 10.0
|
||||
CHECK_INTERVAL_SEC = 300 # 5 minutes
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentStatus:
|
||||
name: str
|
||||
url: str
|
||||
last_check: Optional[str] = None
|
||||
last_healthy: Optional[str] = None
|
||||
response_time_ms: Optional[float] = None
|
||||
status: str = "unknown" # healthy | degraded | down | unknown
|
||||
consecutive_failures: int = 0
|
||||
available_tools: int = 0
|
||||
agent_status: str = ""
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class HealthState:
|
||||
agents: dict = field(default_factory=dict) # name -> AgentStatus dict
|
||||
last_full_check: Optional[str] = None
|
||||
alerts_sent: dict = field(default_factory=dict) # name -> timestamp of last alert
|
||||
|
||||
|
||||
def _now_iso() -> str:
|
||||
return datetime.now(timezone.utc).isoformat()
|
||||
|
||||
|
||||
def _load_state() -> HealthState:
|
||||
if os.path.exists(STATE_FILE):
|
||||
try:
|
||||
with open(STATE_FILE) as f:
|
||||
data = json.load(f)
|
||||
state = HealthState()
|
||||
state.last_full_check = data.get("last_full_check")
|
||||
state.alerts_sent = data.get("alerts_sent", {})
|
||||
for name, agent_data in data.get("agents", {}).items():
|
||||
status = AgentStatus(name=agent_data["name"], url=agent_data["url"])
|
||||
status.last_check = agent_data.get("last_check")
|
||||
status.last_healthy = agent_data.get("last_healthy")
|
||||
status.response_time_ms = agent_data.get("response_time_ms")
|
||||
status.status = agent_data.get("status", "unknown")
|
||||
status.consecutive_failures = agent_data.get("consecutive_failures", 0)
|
||||
status.available_tools = agent_data.get("available_tools", 0)
|
||||
status.agent_status = agent_data.get("agent_status", "")
|
||||
status.error = agent_data.get("error")
|
||||
state.agents[name] = status
|
||||
return state
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
pass
|
||||
return HealthState()
|
||||
|
||||
|
||||
def _save_state(state: HealthState) -> None:
|
||||
os.makedirs(os.path.dirname(STATE_FILE), exist_ok=True)
|
||||
data = {
|
||||
"last_full_check": state.last_full_check,
|
||||
"alerts_sent": state.alerts_sent,
|
||||
"agents": {},
|
||||
}
|
||||
for name, status in state.agents.items():
|
||||
data["agents"][name] = asdict(status)
|
||||
with open(STATE_FILE, "w") as f:
|
||||
json.dump(data, f, indent=2)
|
||||
|
||||
|
||||
def _load_fleet() -> list[dict]:
|
||||
"""Load fleet agent definitions from ~/.hermes/fleet_agents.json.
|
||||
|
||||
Format:
|
||||
[
|
||||
{"name": "ezra", "url": "http://143.198.27.163:8080"},
|
||||
{"name": "bezalel", "url": "http://167.99.126.228:8080"},
|
||||
{"name": "allegro", "url": "http://167.99.126.229:8080"}
|
||||
]
|
||||
"""
|
||||
if os.path.exists(FLEET_FILE):
|
||||
try:
|
||||
with open(FLEET_FILE) as f:
|
||||
return json.load(f)
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
pass
|
||||
return []
|
||||
|
||||
|
||||
def _ping_agent(url: str, timeout: float = 15.0) -> dict:
|
||||
"""Ping an A2A endpoint and return health data.
|
||||
|
||||
Tries /health first, falls back to /a2a with a minimal request.
|
||||
"""
|
||||
start = time.monotonic()
|
||||
|
||||
# Try dedicated health endpoint first
|
||||
health_url = f"{url.rstrip('/')}/health"
|
||||
try:
|
||||
req = urllib.request.Request(health_url, method="GET")
|
||||
with urllib.request.urlopen(req, timeout=timeout) as resp:
|
||||
elapsed = time.monotonic() - start
|
||||
data = json.loads(resp.read())
|
||||
return {
|
||||
"ok": True,
|
||||
"response_time_ms": round(elapsed * 1000, 1),
|
||||
"status": data.get("status", "healthy"),
|
||||
"available_tools": data.get("available_tools", 0),
|
||||
}
|
||||
except (urllib.error.URLError, urllib.error.HTTPError, TimeoutError, OSError):
|
||||
pass
|
||||
|
||||
# Fallback: try A2A agent card endpoint
|
||||
card_url = f"{url.rstrip('/')}/.well-known/agent-card.json"
|
||||
try:
|
||||
req = urllib.request.Request(card_url, method="GET")
|
||||
with urllib.request.urlopen(req, timeout=timeout) as resp:
|
||||
elapsed = time.monotonic() - start
|
||||
data = json.loads(resp.read())
|
||||
tools = data.get("capabilities", {}).get("tools", [])
|
||||
return {
|
||||
"ok": True,
|
||||
"response_time_ms": round(elapsed * 1000, 1),
|
||||
"status": "healthy",
|
||||
"available_tools": len(tools) if isinstance(tools, list) else 0,
|
||||
}
|
||||
except (urllib.error.URLError, urllib.error.HTTPError, TimeoutError, OSError) as e:
|
||||
elapsed = time.monotonic() - start
|
||||
return {
|
||||
"ok": False,
|
||||
"response_time_ms": round(elapsed * 1000, 1),
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
|
||||
class HealthMonitor:
|
||||
"""A2A fleet health monitor."""
|
||||
|
||||
def __init__(self, state_file: str = STATE_FILE, fleet_file: str = FLEET_FILE):
|
||||
self.state_file = state_file
|
||||
self.fleet_file = fleet_file
|
||||
self.state = _load_state()
|
||||
|
||||
def check_agent(self, name: str, url: str) -> AgentStatus:
|
||||
"""Check a single agent and update state."""
|
||||
now = _now_iso()
|
||||
result = _ping_agent(url)
|
||||
|
||||
if name not in self.state.agents:
|
||||
self.state.agents[name] = AgentStatus(name=name, url=url)
|
||||
|
||||
agent = self.state.agents[name]
|
||||
agent.last_check = now
|
||||
agent.response_time_ms = result.get("response_time_ms")
|
||||
|
||||
if result["ok"]:
|
||||
agent.status = "healthy"
|
||||
agent.consecutive_failures = 0
|
||||
agent.available_tools = result.get("available_tools", 0)
|
||||
agent.agent_status = result.get("status", "healthy")
|
||||
agent.last_healthy = now
|
||||
agent.error = None
|
||||
else:
|
||||
agent.consecutive_failures += 1
|
||||
agent.error = result.get("error", "unknown error")
|
||||
|
||||
if agent.consecutive_failures >= CONSECUTIVE_FAILURE_THRESHOLD:
|
||||
agent.status = "down"
|
||||
else:
|
||||
agent.status = "degraded"
|
||||
|
||||
# Check response time even on success
|
||||
if result["ok"] and result.get("response_time_ms", 0) > RESPONSE_TIME_THRESHOLD_SEC * 1000:
|
||||
if agent.status == "healthy":
|
||||
agent.status = "degraded"
|
||||
|
||||
return agent
|
||||
|
||||
def check_all(self) -> list[AgentStatus]:
|
||||
"""Check all registered fleet agents."""
|
||||
fleet = _load_fleet()
|
||||
results = []
|
||||
for agent_def in fleet:
|
||||
name = agent_def.get("name", "unknown")
|
||||
url = agent_def.get("url", "")
|
||||
if not url:
|
||||
continue
|
||||
status = self.check_agent(name, url)
|
||||
results.append(status)
|
||||
|
||||
self.state.last_full_check = _now_iso()
|
||||
_save_state(self.state)
|
||||
return results
|
||||
|
||||
def get_alerts(self, results: list[AgentStatus]) -> list[dict]:
|
||||
"""Generate alerts for agents that just went down."""
|
||||
alerts = []
|
||||
now = time.time()
|
||||
|
||||
for agent in results:
|
||||
if agent.status == "down":
|
||||
last_alert = self.state.alerts_sent.get(agent.name, 0)
|
||||
# Alert at most once per hour per agent
|
||||
if now - last_alert > 3600:
|
||||
alerts.append({
|
||||
"agent": agent.name,
|
||||
"consecutive_failures": agent.consecutive_failures,
|
||||
"last_healthy": agent.last_healthy,
|
||||
"error": agent.error,
|
||||
})
|
||||
self.state.alerts_sent[agent.name] = now
|
||||
|
||||
if alerts:
|
||||
_save_state(self.state)
|
||||
|
||||
return alerts
|
||||
|
||||
def print_dashboard(self, results: list[AgentStatus]) -> str:
|
||||
"""Format a text dashboard of fleet health."""
|
||||
lines = []
|
||||
lines.append("=== A2A Fleet Health ===")
|
||||
lines.append(f"Checked: {_now_iso()[:19]}Z")
|
||||
lines.append("")
|
||||
|
||||
if not results:
|
||||
lines.append("No agents registered.")
|
||||
lines.append(f"Add agents to {self.fleet_file}")
|
||||
return "\n".join(lines)
|
||||
|
||||
# Header
|
||||
lines.append(f"{'AGENT':<12} {'STATUS':<10} {'MS':>8} {'TOOLS':>6} {'FAIL':>5} {'LAST HEALTHY'}")
|
||||
lines.append("-" * 72)
|
||||
|
||||
for agent in sorted(results, key=lambda a: a.name):
|
||||
status_icon = {
|
||||
"healthy": "OK",
|
||||
"degraded": "WARN",
|
||||
"down": "DOWN",
|
||||
"unknown": "???",
|
||||
}.get(agent.status, "???")
|
||||
|
||||
ms = f"{agent.response_time_ms:.0f}" if agent.response_time_ms else "-"
|
||||
tools = str(agent.available_tools) if agent.available_tools else "-"
|
||||
fail = str(agent.consecutive_failures) if agent.consecutive_failures else "-"
|
||||
last_h = agent.last_healthy[:19] + "Z" if agent.last_healthy else "never"
|
||||
|
||||
lines.append(f"{agent.name:<12} {status_icon:<10} {ms:>8} {tools:>6} {fail:>5} {last_h}")
|
||||
|
||||
# Summary
|
||||
healthy = sum(1 for a in results if a.status == "healthy")
|
||||
degraded = sum(1 for a in results if a.status == "degraded")
|
||||
down = sum(1 for a in results if a.status == "down")
|
||||
lines.append("-" * 72)
|
||||
lines.append(f"Total: {len(results)} Healthy: {healthy} Degraded: {degraded} Down: {down}")
|
||||
|
||||
# Alerts
|
||||
alerts = self.get_alerts(results)
|
||||
if alerts:
|
||||
lines.append("")
|
||||
lines.append("ALERTS:")
|
||||
for alert in alerts:
|
||||
lines.append(f" !! {alert['agent']} is DOWN ({alert['consecutive_failures']} consecutive failures)")
|
||||
if alert.get("error"):
|
||||
lines.append(f" Error: {alert['error'][:100]}")
|
||||
|
||||
return "\n".join(lines)
|
||||
64
hermes_cli/a2a_health_commands.py
Normal file
64
hermes_cli/a2a_health_commands.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""CLI command: hermes a2a health
|
||||
|
||||
Fleet agent health check via A2A heartbeat.
|
||||
|
||||
Usage:
|
||||
hermes a2a health # Check all agents, show dashboard
|
||||
hermes a2a health --json # JSON output for scripting
|
||||
hermes a2a health --agent NAME # Check single agent
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
|
||||
from hermes_cli.a2a_health import HealthMonitor, _load_fleet, _ping_agent
|
||||
|
||||
|
||||
def register_subparser(subparsers):
|
||||
"""Register 'hermes a2a health' subcommand."""
|
||||
a2a_parser = subparsers.add_parser("a2a", help="A2A agent communication")
|
||||
a2a_sub = a2a_parser.add_subparsers(dest="a2a_command")
|
||||
|
||||
health_parser = a2a_sub.add_parser("health", help="Check fleet agent health")
|
||||
health_parser.add_argument("--json", action="store_true", help="JSON output")
|
||||
health_parser.add_argument("--agent", type=str, help="Check single agent by name")
|
||||
health_parser.set_defaults(func=run_health_check)
|
||||
|
||||
return a2a_parser
|
||||
|
||||
|
||||
def run_health_check(args: argparse.Namespace) -> int:
|
||||
"""Execute the health check command."""
|
||||
monitor = HealthMonitor()
|
||||
|
||||
if args.agent:
|
||||
# Single agent check
|
||||
fleet = _load_fleet()
|
||||
agent_def = next((a for a in fleet if a.get("name") == args.agent), None)
|
||||
if not agent_def:
|
||||
print(f"Error: agent '{args.agent}' not found in fleet", file=sys.stderr)
|
||||
print(f"Add it to ~/.hermes/fleet_agents.json", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
result = monitor.check_agent(args.agent, agent_def["url"])
|
||||
if args.json:
|
||||
from dataclasses import asdict
|
||||
print(json.dumps(asdict(result), indent=2))
|
||||
else:
|
||||
print(monitor.print_dashboard([result]))
|
||||
|
||||
return 0 if result.status in ("healthy", "degraded") else 1
|
||||
|
||||
# Full fleet check
|
||||
results = monitor.check_all()
|
||||
|
||||
if args.json:
|
||||
from dataclasses import asdict
|
||||
print(json.dumps([asdict(r) for r in results], indent=2))
|
||||
else:
|
||||
print(monitor.print_dashboard(results))
|
||||
|
||||
# Exit code: 0 if all healthy, 1 if any down
|
||||
has_down = any(r.status == "down" for r in results)
|
||||
return 1 if has_down else 0
|
||||
@@ -1,58 +0,0 @@
|
||||
"""Tests for hybrid search router."""
|
||||
|
||||
import pytest
|
||||
from agent.hybrid_search import (
|
||||
SearchResult,
|
||||
HybridSearchConfig,
|
||||
merge_results,
|
||||
hybrid_search,
|
||||
search_fts5,
|
||||
)
|
||||
|
||||
|
||||
class TestSearchResult:
|
||||
def test_creation(self):
|
||||
r = SearchResult(session_id="s1", message_content="hello", score=0.9, source="fts5")
|
||||
assert r.session_id == "s1"
|
||||
assert r.source == "fts5"
|
||||
|
||||
|
||||
class TestMergeResults:
|
||||
def test_merges_and_ranks(self):
|
||||
fts5 = [SearchResult("s1", "alpha content", 1.0, "fts5")]
|
||||
vec = [SearchResult("s2", "beta content", 0.9, "vector")]
|
||||
hrr = [SearchResult("s3", "gamma content", 0.5, "hrr")]
|
||||
config = HybridSearchConfig(fts5_weight=0.4, vector_weight=0.4, hrr_weight=0.2)
|
||||
results = merge_results(fts5, vec, hrr, config, limit=10)
|
||||
assert len(results) == 3
|
||||
# s1: 1.0*0.4=0.4, s2: 0.9*0.4=0.36, s3: 0.5*0.2=0.1
|
||||
assert results[0].session_id == "s1"
|
||||
|
||||
def test_deduplicates(self):
|
||||
fts5 = [SearchResult("s1", "same content", 1.0, "fts5")]
|
||||
vec = [SearchResult("s1", "same content", 0.8, "vector")]
|
||||
config = HybridSearchConfig()
|
||||
results = merge_results(fts5, vec, [], config, limit=10)
|
||||
assert len(results) == 1
|
||||
|
||||
def test_respects_limit(self):
|
||||
fts5 = [SearchResult(f"s{i}", f"content {i}", 1.0/i, "fts5") for i in range(1, 20)]
|
||||
results = merge_results(fts5, [], [], HybridSearchConfig(), limit=5)
|
||||
assert len(results) == 5
|
||||
|
||||
def test_empty_inputs(self):
|
||||
results = merge_results([], [], [], HybridSearchConfig())
|
||||
assert len(results) == 0
|
||||
|
||||
|
||||
class TestHybridSearchFallback:
|
||||
def test_falls_back_to_fts5_only(self):
|
||||
"""When vector and HRR unavailable, returns FTS5 results."""
|
||||
# Mock db
|
||||
class MockDB:
|
||||
def search_messages(self, **kwargs):
|
||||
return [{"session_id": "s1", "content": "test", "rank": 1.0, "role": "user"}]
|
||||
|
||||
results = hybrid_search("test", db=MockDB(), config=HybridSearchConfig(vector_enabled=False, hrr_enabled=False))
|
||||
assert len(results) == 1
|
||||
assert results[0].source == "fts5"
|
||||
@@ -1,202 +0,0 @@
|
||||
"""Tests for agent.privacy_filter — PII stripping before remote API calls."""
|
||||
|
||||
import pytest
|
||||
from agent.privacy_filter import (
|
||||
PrivacyFilter,
|
||||
RedactionReport,
|
||||
Sensitivity,
|
||||
sanitize_messages,
|
||||
quick_sanitize,
|
||||
)
|
||||
|
||||
|
||||
class TestPrivacyFilterSanitizeText:
|
||||
"""Test single-text sanitization."""
|
||||
|
||||
def test_no_pii_returns_clean(self):
|
||||
pf = PrivacyFilter()
|
||||
text = "The weather in Paris is nice today."
|
||||
cleaned, redactions = pf.sanitize_text(text)
|
||||
assert cleaned == text
|
||||
assert redactions == []
|
||||
|
||||
def test_email_redacted(self):
|
||||
pf = PrivacyFilter()
|
||||
text = "Send report to alice@example.com by Friday."
|
||||
cleaned, redactions = pf.sanitize_text(text)
|
||||
assert "alice@example.com" not in cleaned
|
||||
assert "[REDACTED-EMAIL]" in cleaned
|
||||
assert any(r["type"] == "email_address" for r in redactions)
|
||||
|
||||
def test_phone_redacted(self):
|
||||
pf = PrivacyFilter()
|
||||
text = "Call me at 555-123-4567 when ready."
|
||||
cleaned, redactions = pf.sanitize_text(text)
|
||||
assert "555-123-4567" not in cleaned
|
||||
assert "[REDACTED-PHONE]" in cleaned
|
||||
|
||||
def test_api_key_redacted(self):
|
||||
pf = PrivacyFilter()
|
||||
text = 'api_key = "sk-proj-abcdefghij1234567890abcdefghij1234567890"'
|
||||
cleaned, redactions = pf.sanitize_text(text)
|
||||
assert "sk-proj-" not in cleaned
|
||||
assert any(r["sensitivity"] == "CRITICAL" for r in redactions)
|
||||
|
||||
def test_github_token_redacted(self):
|
||||
pf = PrivacyFilter()
|
||||
text = "Use ghp_1234567890abcdefghijklmnopqrstuvwxyz1234 for auth"
|
||||
cleaned, redactions = pf.sanitize_text(text)
|
||||
assert "ghp_" not in cleaned
|
||||
assert any(r["type"] == "github_token" for r in redactions)
|
||||
|
||||
def test_ethereum_address_redacted(self):
|
||||
pf = PrivacyFilter()
|
||||
text = "Send to 0x742d35Cc6634C0532925a3b844Bc9e7595f2bD18 please"
|
||||
cleaned, redactions = pf.sanitize_text(text)
|
||||
assert "0x742d" not in cleaned
|
||||
assert any(r["type"] == "ethereum_address" for r in redactions)
|
||||
|
||||
def test_user_home_path_redacted(self):
|
||||
pf = PrivacyFilter()
|
||||
text = "Read file at /Users/alice/Documents/secret.txt"
|
||||
cleaned, redactions = pf.sanitize_text(text)
|
||||
assert "alice" not in cleaned
|
||||
assert "[REDACTED-USER]" in cleaned
|
||||
|
||||
def test_multiple_pii_types(self):
|
||||
pf = PrivacyFilter()
|
||||
text = (
|
||||
"Contact john@test.com or call 555-999-1234. "
|
||||
"The API key is sk-abcdefghijklmnopqrstuvwxyz1234567890."
|
||||
)
|
||||
cleaned, redactions = pf.sanitize_text(text)
|
||||
assert "john@test.com" not in cleaned
|
||||
assert "555-999-1234" not in cleaned
|
||||
assert "sk-abcd" not in cleaned
|
||||
assert len(redactions) >= 3
|
||||
|
||||
|
||||
class TestPrivacyFilterSanitizeMessages:
|
||||
"""Test message-list sanitization."""
|
||||
|
||||
def test_sanitize_user_message(self):
|
||||
pf = PrivacyFilter()
|
||||
messages = [
|
||||
{"role": "system", "content": "You are helpful."},
|
||||
{"role": "user", "content": "Email me at bob@test.com with results."},
|
||||
]
|
||||
safe, report = pf.sanitize_messages(messages)
|
||||
assert report.redacted_messages == 1
|
||||
assert "bob@test.com" not in safe[1]["content"]
|
||||
assert "[REDACTED-EMAIL]" in safe[1]["content"]
|
||||
# System message unchanged
|
||||
assert safe[0]["content"] == "You are helpful."
|
||||
|
||||
def test_no_redaction_needed(self):
|
||||
pf = PrivacyFilter()
|
||||
messages = [
|
||||
{"role": "user", "content": "What is 2+2?"},
|
||||
{"role": "assistant", "content": "4"},
|
||||
]
|
||||
safe, report = pf.sanitize_messages(messages)
|
||||
assert report.redacted_messages == 0
|
||||
assert not report.had_redactions
|
||||
|
||||
def test_assistant_messages_also_sanitized(self):
|
||||
pf = PrivacyFilter()
|
||||
messages = [
|
||||
{"role": "assistant", "content": "Your email admin@corp.com was found."},
|
||||
]
|
||||
safe, report = pf.sanitize_messages(messages)
|
||||
assert report.redacted_messages == 1
|
||||
assert "admin@corp.com" not in safe[0]["content"]
|
||||
|
||||
def test_tool_messages_not_sanitized(self):
|
||||
pf = PrivacyFilter()
|
||||
messages = [
|
||||
{"role": "tool", "content": "Result: user@test.com found"},
|
||||
]
|
||||
safe, report = pf.sanitize_messages(messages)
|
||||
assert report.redacted_messages == 0
|
||||
assert safe[0]["content"] == "Result: user@test.com found"
|
||||
|
||||
|
||||
class TestShouldUseLocalOnly:
|
||||
"""Test the local-only routing decision."""
|
||||
|
||||
def test_normal_text_allows_remote(self):
|
||||
pf = PrivacyFilter()
|
||||
block, reason = pf.should_use_local_only("Summarize this article about Python.")
|
||||
assert not block
|
||||
|
||||
def test_critical_secret_blocks_remote(self):
|
||||
pf = PrivacyFilter()
|
||||
text = "Here is the API key: sk-abcdefghijklmnopqrstuvwxyz1234567890"
|
||||
block, reason = pf.should_use_local_only(text)
|
||||
assert block
|
||||
assert "critical" in reason.lower()
|
||||
|
||||
def test_multiple_high_sensitivity_blocks(self):
|
||||
pf = PrivacyFilter()
|
||||
# 3+ high-sensitivity patterns
|
||||
text = (
|
||||
"Card: 4111-1111-1111-1111, "
|
||||
"SSN: 123-45-6789, "
|
||||
"BTC: 1A1zP1eP5QGefi2DMPTfTL5SLmv7DivfNa, "
|
||||
"ETH: 0x742d35Cc6634C0532925a3b844Bc9e7595f2bD18"
|
||||
)
|
||||
block, reason = pf.should_use_local_only(text)
|
||||
assert block
|
||||
|
||||
|
||||
class TestAggressiveMode:
|
||||
"""Test aggressive filtering mode."""
|
||||
|
||||
def test_aggressive_redacts_internal_ips(self):
|
||||
pf = PrivacyFilter(aggressive_mode=True)
|
||||
text = "Server at 192.168.1.100 is responding."
|
||||
cleaned, redactions = pf.sanitize_text(text)
|
||||
assert "192.168.1.100" not in cleaned
|
||||
assert any(r["type"] == "internal_ip" for r in redactions)
|
||||
|
||||
def test_normal_does_not_redact_ips(self):
|
||||
pf = PrivacyFilter(aggressive_mode=False)
|
||||
text = "Server at 192.168.1.100 is responding."
|
||||
cleaned, redactions = pf.sanitize_text(text)
|
||||
assert "192.168.1.100" in cleaned # IP preserved in normal mode
|
||||
|
||||
|
||||
class TestConvenienceFunctions:
|
||||
"""Test module-level convenience functions."""
|
||||
|
||||
def test_quick_sanitize(self):
|
||||
text = "Contact alice@example.com for details"
|
||||
result = quick_sanitize(text)
|
||||
assert "alice@example.com" not in result
|
||||
assert "[REDACTED-EMAIL]" in result
|
||||
|
||||
def test_sanitize_messages_convenience(self):
|
||||
messages = [{"role": "user", "content": "Call 555-000-1234"}]
|
||||
safe, report = sanitize_messages(messages)
|
||||
assert report.redacted_messages == 1
|
||||
|
||||
|
||||
class TestRedactionReport:
|
||||
"""Test the reporting structure."""
|
||||
|
||||
def test_summary_no_redactions(self):
|
||||
report = RedactionReport(total_messages=3, redacted_messages=0)
|
||||
assert "No PII" in report.summary()
|
||||
|
||||
def test_summary_with_redactions(self):
|
||||
report = RedactionReport(
|
||||
total_messages=2,
|
||||
redacted_messages=1,
|
||||
redactions=[
|
||||
{"type": "email_address", "sensitivity": "MEDIUM", "count": 2},
|
||||
{"type": "phone_number_us", "sensitivity": "MEDIUM", "count": 1},
|
||||
],
|
||||
)
|
||||
summary = report.summary()
|
||||
assert "1/2" in summary
|
||||
assert "email_address" in summary
|
||||
@@ -1,190 +0,0 @@
|
||||
"""Tests for tools.confirmation_daemon — Human Confirmation Firewall."""
|
||||
|
||||
import pytest
|
||||
import time
|
||||
from tools.confirmation_daemon import (
|
||||
ConfirmationDaemon,
|
||||
ConfirmationRequest,
|
||||
ConfirmationStatus,
|
||||
RiskLevel,
|
||||
classify_action,
|
||||
_is_whitelisted,
|
||||
_DEFAULT_WHITELIST,
|
||||
)
|
||||
|
||||
|
||||
class TestClassifyAction:
|
||||
"""Test action risk classification."""
|
||||
|
||||
def test_crypto_tx_is_critical(self):
|
||||
assert classify_action("crypto_tx") == RiskLevel.CRITICAL
|
||||
|
||||
def test_sign_transaction_is_critical(self):
|
||||
assert classify_action("sign_transaction") == RiskLevel.CRITICAL
|
||||
|
||||
def test_send_email_is_high(self):
|
||||
assert classify_action("send_email") == RiskLevel.HIGH
|
||||
|
||||
def test_send_message_is_medium(self):
|
||||
assert classify_action("send_message") == RiskLevel.MEDIUM
|
||||
|
||||
def test_access_calendar_is_low(self):
|
||||
assert classify_action("access_calendar") == RiskLevel.LOW
|
||||
|
||||
def test_unknown_action_is_medium(self):
|
||||
assert classify_action("unknown_action_xyz") == RiskLevel.MEDIUM
|
||||
|
||||
|
||||
class TestWhitelist:
|
||||
"""Test whitelist auto-approval."""
|
||||
|
||||
def test_self_email_is_whitelisted(self):
|
||||
whitelist = dict(_DEFAULT_WHITELIST)
|
||||
payload = {"from": "me@test.com", "to": "me@test.com"}
|
||||
assert _is_whitelisted("send_email", payload, whitelist) is True
|
||||
|
||||
def test_non_whitelisted_recipient_not_approved(self):
|
||||
whitelist = dict(_DEFAULT_WHITELIST)
|
||||
payload = {"to": "random@stranger.com"}
|
||||
assert _is_whitelisted("send_email", payload, whitelist) is False
|
||||
|
||||
def test_whitelisted_contact_approved(self):
|
||||
whitelist = {
|
||||
"send_message": {"targets": ["alice", "bob"]},
|
||||
}
|
||||
assert _is_whitelisted("send_message", {"to": "alice"}, whitelist) is True
|
||||
assert _is_whitelisted("send_message", {"to": "charlie"}, whitelist) is False
|
||||
|
||||
def test_no_whitelist_entry_means_not_whitelisted(self):
|
||||
whitelist = {}
|
||||
assert _is_whitelisted("crypto_tx", {"amount": 1.0}, whitelist) is False
|
||||
|
||||
|
||||
class TestConfirmationRequest:
|
||||
"""Test the request data model."""
|
||||
|
||||
def test_defaults(self):
|
||||
req = ConfirmationRequest(
|
||||
request_id="test-1",
|
||||
action="send_email",
|
||||
description="Test email",
|
||||
risk_level="high",
|
||||
payload={},
|
||||
)
|
||||
assert req.status == ConfirmationStatus.PENDING.value
|
||||
assert req.created_at > 0
|
||||
assert req.expires_at > req.created_at
|
||||
|
||||
def test_is_pending(self):
|
||||
req = ConfirmationRequest(
|
||||
request_id="test-2",
|
||||
action="send_email",
|
||||
description="Test",
|
||||
risk_level="high",
|
||||
payload={},
|
||||
expires_at=time.time() + 300,
|
||||
)
|
||||
assert req.is_pending is True
|
||||
|
||||
def test_is_expired(self):
|
||||
req = ConfirmationRequest(
|
||||
request_id="test-3",
|
||||
action="send_email",
|
||||
description="Test",
|
||||
risk_level="high",
|
||||
payload={},
|
||||
expires_at=time.time() - 10,
|
||||
)
|
||||
assert req.is_expired is True
|
||||
assert req.is_pending is False
|
||||
|
||||
def test_to_dict(self):
|
||||
req = ConfirmationRequest(
|
||||
request_id="test-4",
|
||||
action="send_email",
|
||||
description="Test",
|
||||
risk_level="medium",
|
||||
payload={"to": "a@b.com"},
|
||||
)
|
||||
d = req.to_dict()
|
||||
assert d["request_id"] == "test-4"
|
||||
assert d["action"] == "send_email"
|
||||
assert "is_pending" in d
|
||||
|
||||
|
||||
class TestConfirmationDaemon:
|
||||
"""Test the daemon logic (without HTTP layer)."""
|
||||
|
||||
def test_auto_approve_low_risk(self):
|
||||
daemon = ConfirmationDaemon()
|
||||
req = daemon.request(
|
||||
action="access_calendar",
|
||||
description="Read today's events",
|
||||
risk_level="low",
|
||||
)
|
||||
assert req.status == ConfirmationStatus.AUTO_APPROVED.value
|
||||
|
||||
def test_whitelisted_auto_approves(self):
|
||||
daemon = ConfirmationDaemon()
|
||||
daemon._whitelist = {"send_message": {"targets": ["alice"]}}
|
||||
req = daemon.request(
|
||||
action="send_message",
|
||||
description="Message alice",
|
||||
payload={"to": "alice"},
|
||||
)
|
||||
assert req.status == ConfirmationStatus.AUTO_APPROVED.value
|
||||
|
||||
def test_non_whitelisted_goes_pending(self):
|
||||
daemon = ConfirmationDaemon()
|
||||
daemon._whitelist = {}
|
||||
req = daemon.request(
|
||||
action="send_email",
|
||||
description="Email to stranger",
|
||||
payload={"to": "stranger@test.com"},
|
||||
risk_level="high",
|
||||
)
|
||||
assert req.status == ConfirmationStatus.PENDING.value
|
||||
assert req.is_pending is True
|
||||
|
||||
def test_approve_response(self):
|
||||
daemon = ConfirmationDaemon()
|
||||
daemon._whitelist = {}
|
||||
req = daemon.request(
|
||||
action="send_email",
|
||||
description="Email test",
|
||||
risk_level="high",
|
||||
)
|
||||
result = daemon.respond(req.request_id, approved=True, decided_by="human")
|
||||
assert result.status == ConfirmationStatus.APPROVED.value
|
||||
assert result.decided_by == "human"
|
||||
|
||||
def test_deny_response(self):
|
||||
daemon = ConfirmationDaemon()
|
||||
daemon._whitelist = {}
|
||||
req = daemon.request(
|
||||
action="crypto_tx",
|
||||
description="Send 1 ETH",
|
||||
risk_level="critical",
|
||||
)
|
||||
result = daemon.respond(
|
||||
req.request_id, approved=False, decided_by="human", reason="Too risky"
|
||||
)
|
||||
assert result.status == ConfirmationStatus.DENIED.value
|
||||
assert result.reason == "Too risky"
|
||||
|
||||
def test_get_pending(self):
|
||||
daemon = ConfirmationDaemon()
|
||||
daemon._whitelist = {}
|
||||
daemon.request(action="send_email", description="Test 1", risk_level="high")
|
||||
daemon.request(action="send_email", description="Test 2", risk_level="high")
|
||||
pending = daemon.get_pending()
|
||||
assert len(pending) >= 2
|
||||
|
||||
def test_get_history(self):
|
||||
daemon = ConfirmationDaemon()
|
||||
req = daemon.request(
|
||||
action="access_calendar", description="Test", risk_level="low"
|
||||
)
|
||||
history = daemon.get_history()
|
||||
assert len(history) >= 1
|
||||
assert history[0]["action"] == "access_calendar"
|
||||
@@ -121,19 +121,6 @@ DANGEROUS_PATTERNS = [
|
||||
(r'\b(cp|mv|install)\b.*\s/etc/', "copy/move file into /etc/"),
|
||||
(r'\bsed\s+-[^\s]*i.*\s/etc/', "in-place edit of system config"),
|
||||
(r'\bsed\s+--in-place\b.*\s/etc/', "in-place edit of system config (long flag)"),
|
||||
# --- Vitalik's threat model: crypto / financial ---
|
||||
(r'\b(?:bitcoin-cli|ethers\.js|web3|ether\.sendTransaction)\b', "direct crypto transaction tool usage"),
|
||||
(r'\bwget\b.*\b(?:mnemonic|seed\s*phrase|private[_-]?key)\b', "attempting to download crypto credentials"),
|
||||
(r'\bcurl\b.*\b(?:mnemonic|seed\s*phrase|private[_-]?key)\b', "attempting to exfiltrate crypto credentials"),
|
||||
# --- Vitalik's threat model: credential exfiltration ---
|
||||
(r'\b(?:curl|wget|http|nc|ncat|socat)\b.*\b(?:\.env|\.ssh|credentials|secrets|token|api[_-]?key)\b',
|
||||
"attempting to exfiltrate credentials via network"),
|
||||
(r'\bbase64\b.*\|(?:\s*curl|\s*wget)', "base64-encode then network exfiltration"),
|
||||
(r'\bcat\b.*\b(?:\.env|\.ssh/id_rsa|credentials)\b.*\|(?:\s*curl|\s*wget)',
|
||||
"reading secrets and piping to network tool"),
|
||||
# --- Vitalik's threat model: data exfiltration ---
|
||||
(r'\bcurl\b.*-d\s.*\$(?:HOME|USER)', "sending user home directory data to remote"),
|
||||
(r'\bwget\b.*--post-data\s.*\$(?:HOME|USER)', "posting user data to remote"),
|
||||
# Script execution via heredoc — bypasses the -e/-c flag patterns above.
|
||||
# `python3 << 'EOF'` feeds arbitrary code via stdin without -c/-e flags.
|
||||
(r'\b(python[23]?|perl|ruby|node)\s+<<', "script execution via heredoc"),
|
||||
|
||||
@@ -1,615 +0,0 @@
|
||||
"""Human Confirmation Daemon — HTTP server for two-factor action approval.
|
||||
|
||||
Implements Vitalik's Pattern 1: "The new 'two-factor confirmation' is that
|
||||
the two factors are the human and the LLM."
|
||||
|
||||
This daemon runs on localhost:6000 and provides a simple HTTP API for the
|
||||
agent to request human approval before executing high-risk actions.
|
||||
|
||||
Threat model:
|
||||
- LLM jailbreaks: Remote content "hacking" the LLM to perform malicious actions
|
||||
- LLM accidents: LLM accidentally performing dangerous operations
|
||||
- The human acts as the second factor — the agent proposes, the human disposes
|
||||
|
||||
Architecture:
|
||||
- Agent detects high-risk action → POST /confirm with action details
|
||||
- Daemon stores pending request, sends notification to user
|
||||
- User approves/denies via POST /respond (Telegram, CLI, or direct HTTP)
|
||||
- Agent receives decision and proceeds or aborts
|
||||
|
||||
Usage:
|
||||
# Start daemon (usually managed by gateway)
|
||||
from tools.confirmation_daemon import ConfirmationDaemon
|
||||
daemon = ConfirmationDaemon(port=6000)
|
||||
daemon.start()
|
||||
|
||||
# Request approval (from agent code)
|
||||
from tools.confirmation_daemon import request_confirmation
|
||||
approved = request_confirmation(
|
||||
action="send_email",
|
||||
description="Send email to alice@example.com",
|
||||
risk_level="high",
|
||||
payload={"to": "alice@example.com", "subject": "Meeting notes"},
|
||||
timeout=300,
|
||||
)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass, field, asdict
|
||||
from enum import Enum, auto
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RiskLevel(Enum):
|
||||
"""Risk classification for actions requiring confirmation."""
|
||||
LOW = "low" # Log only, no confirmation needed
|
||||
MEDIUM = "medium" # Confirm for non-whitelisted targets
|
||||
HIGH = "high" # Always confirm
|
||||
CRITICAL = "critical" # Always confirm + require explicit reason
|
||||
|
||||
|
||||
class ConfirmationStatus(Enum):
|
||||
"""Status of a pending confirmation request."""
|
||||
PENDING = "pending"
|
||||
APPROVED = "approved"
|
||||
DENIED = "denied"
|
||||
EXPIRED = "expired"
|
||||
AUTO_APPROVED = "auto_approved"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConfirmationRequest:
|
||||
"""A request for human confirmation of a high-risk action."""
|
||||
request_id: str
|
||||
action: str # Action type: send_email, send_message, crypto_tx, etc.
|
||||
description: str # Human-readable description of what will happen
|
||||
risk_level: str # low, medium, high, critical
|
||||
payload: Dict[str, Any] # Action-specific data (sanitized)
|
||||
session_key: str = "" # Session that initiated the request
|
||||
created_at: float = 0.0
|
||||
expires_at: float = 0.0
|
||||
status: str = ConfirmationStatus.PENDING.value
|
||||
decided_at: float = 0.0
|
||||
decided_by: str = "" # "human", "auto", "whitelist"
|
||||
reason: str = "" # Optional reason for denial
|
||||
|
||||
def __post_init__(self):
|
||||
if not self.created_at:
|
||||
self.created_at = time.time()
|
||||
if not self.expires_at:
|
||||
self.expires_at = self.created_at + 300 # 5 min default
|
||||
if not self.request_id:
|
||||
self.request_id = str(uuid.uuid4())[:12]
|
||||
|
||||
@property
|
||||
def is_expired(self) -> bool:
|
||||
return time.time() > self.expires_at
|
||||
|
||||
@property
|
||||
def is_pending(self) -> bool:
|
||||
return self.status == ConfirmationStatus.PENDING.value and not self.is_expired
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
d = asdict(self)
|
||||
d["is_expired"] = self.is_expired
|
||||
d["is_pending"] = self.is_pending
|
||||
return d
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Action categories (Vitalik's threat model)
|
||||
# =========================================================================
|
||||
|
||||
ACTION_CATEGORIES = {
|
||||
# Messaging — outbound communication to external parties
|
||||
"send_email": RiskLevel.HIGH,
|
||||
"send_message": RiskLevel.MEDIUM, # Depends on recipient
|
||||
"send_signal": RiskLevel.HIGH,
|
||||
"send_telegram": RiskLevel.MEDIUM,
|
||||
"send_discord": RiskLevel.MEDIUM,
|
||||
"post_social": RiskLevel.HIGH,
|
||||
|
||||
# Financial / crypto
|
||||
"crypto_tx": RiskLevel.CRITICAL,
|
||||
"sign_transaction": RiskLevel.CRITICAL,
|
||||
"access_wallet": RiskLevel.CRITICAL,
|
||||
"modify_balance": RiskLevel.CRITICAL,
|
||||
|
||||
# System modification
|
||||
"install_software": RiskLevel.HIGH,
|
||||
"modify_system_config": RiskLevel.HIGH,
|
||||
"modify_firewall": RiskLevel.CRITICAL,
|
||||
"add_ssh_key": RiskLevel.CRITICAL,
|
||||
"create_user": RiskLevel.CRITICAL,
|
||||
|
||||
# Data access
|
||||
"access_contacts": RiskLevel.MEDIUM,
|
||||
"access_calendar": RiskLevel.LOW,
|
||||
"read_private_files": RiskLevel.MEDIUM,
|
||||
"upload_data": RiskLevel.HIGH,
|
||||
"share_credentials": RiskLevel.CRITICAL,
|
||||
|
||||
# Network
|
||||
"open_port": RiskLevel.HIGH,
|
||||
"modify_dns": RiskLevel.HIGH,
|
||||
"expose_service": RiskLevel.CRITICAL,
|
||||
}
|
||||
|
||||
# Default: any unrecognized action is MEDIUM risk
|
||||
DEFAULT_RISK_LEVEL = RiskLevel.MEDIUM
|
||||
|
||||
|
||||
def classify_action(action: str) -> RiskLevel:
|
||||
"""Classify an action by its risk level."""
|
||||
return ACTION_CATEGORIES.get(action, DEFAULT_RISK_LEVEL)
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Whitelist configuration
|
||||
# =========================================================================
|
||||
|
||||
_DEFAULT_WHITELIST = {
|
||||
"send_message": {
|
||||
"targets": [], # Contact names/IDs that don't need confirmation
|
||||
},
|
||||
"send_email": {
|
||||
"targets": [], # Email addresses that don't need confirmation
|
||||
"self_only": True, # send-to-self always allowed
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _load_whitelist() -> Dict[str, Any]:
|
||||
"""Load action whitelist from config."""
|
||||
config_path = Path.home() / ".hermes" / "approval_whitelist.json"
|
||||
if config_path.exists():
|
||||
try:
|
||||
with open(config_path) as f:
|
||||
return json.load(f)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to load approval whitelist: %s", e)
|
||||
return dict(_DEFAULT_WHITELIST)
|
||||
|
||||
|
||||
def _is_whitelisted(action: str, payload: Dict[str, Any], whitelist: Dict) -> bool:
|
||||
"""Check if an action is pre-approved by the whitelist."""
|
||||
action_config = whitelist.get(action, {})
|
||||
if not action_config:
|
||||
return False
|
||||
|
||||
# Check target-based whitelist
|
||||
targets = action_config.get("targets", [])
|
||||
target = payload.get("to") or payload.get("recipient") or payload.get("target", "")
|
||||
if target and target in targets:
|
||||
return True
|
||||
|
||||
# Self-only email
|
||||
if action_config.get("self_only") and action == "send_email":
|
||||
sender = payload.get("from", "")
|
||||
recipient = payload.get("to", "")
|
||||
if sender and recipient and sender.lower() == recipient.lower():
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Confirmation daemon
|
||||
# =========================================================================
|
||||
|
||||
class ConfirmationDaemon:
|
||||
"""HTTP daemon for human confirmation of high-risk actions.
|
||||
|
||||
Runs on localhost:PORT (default 6000). Provides:
|
||||
- POST /confirm — agent requests human approval
|
||||
- POST /respond — human approves/denies
|
||||
- GET /pending — list pending requests
|
||||
- GET /health — health check
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: str = "127.0.0.1",
|
||||
port: int = 6000,
|
||||
default_timeout: int = 300,
|
||||
notify_callback: Optional[Callable] = None,
|
||||
):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.default_timeout = default_timeout
|
||||
self.notify_callback = notify_callback
|
||||
self._pending: Dict[str, ConfirmationRequest] = {}
|
||||
self._history: List[ConfirmationRequest] = []
|
||||
self._lock = threading.Lock()
|
||||
self._whitelist = _load_whitelist()
|
||||
self._app = None
|
||||
self._runner = None
|
||||
|
||||
def request(
|
||||
self,
|
||||
action: str,
|
||||
description: str,
|
||||
payload: Optional[Dict[str, Any]] = None,
|
||||
risk_level: Optional[str] = None,
|
||||
session_key: str = "",
|
||||
timeout: Optional[int] = None,
|
||||
) -> ConfirmationRequest:
|
||||
"""Create a confirmation request.
|
||||
|
||||
Returns the request. Check .status to see if it was immediately
|
||||
auto-approved (whitelisted) or is pending human review.
|
||||
"""
|
||||
payload = payload or {}
|
||||
|
||||
# Classify risk if not specified
|
||||
if risk_level is None:
|
||||
risk_level = classify_action(action).value
|
||||
|
||||
# Check whitelist
|
||||
if risk_level in ("low",) or _is_whitelisted(action, payload, self._whitelist):
|
||||
req = ConfirmationRequest(
|
||||
request_id=str(uuid.uuid4())[:12],
|
||||
action=action,
|
||||
description=description,
|
||||
risk_level=risk_level,
|
||||
payload=payload,
|
||||
session_key=session_key,
|
||||
expires_at=time.time() + (timeout or self.default_timeout),
|
||||
status=ConfirmationStatus.AUTO_APPROVED.value,
|
||||
decided_at=time.time(),
|
||||
decided_by="whitelist",
|
||||
)
|
||||
with self._lock:
|
||||
self._history.append(req)
|
||||
logger.info("Auto-approved whitelisted action: %s", action)
|
||||
return req
|
||||
|
||||
# Create pending request
|
||||
req = ConfirmationRequest(
|
||||
request_id=str(uuid.uuid4())[:12],
|
||||
action=action,
|
||||
description=description,
|
||||
risk_level=risk_level,
|
||||
payload=payload,
|
||||
session_key=session_key,
|
||||
expires_at=time.time() + (timeout or self.default_timeout),
|
||||
)
|
||||
|
||||
with self._lock:
|
||||
self._pending[req.request_id] = req
|
||||
|
||||
# Notify human
|
||||
if self.notify_callback:
|
||||
try:
|
||||
self.notify_callback(req.to_dict())
|
||||
except Exception as e:
|
||||
logger.warning("Confirmation notify callback failed: %s", e)
|
||||
|
||||
logger.info(
|
||||
"Confirmation request %s: %s (%s risk) — waiting for human",
|
||||
req.request_id, action, risk_level,
|
||||
)
|
||||
return req
|
||||
|
||||
def respond(
|
||||
self,
|
||||
request_id: str,
|
||||
approved: bool,
|
||||
decided_by: str = "human",
|
||||
reason: str = "",
|
||||
) -> Optional[ConfirmationRequest]:
|
||||
"""Record a human decision on a pending request."""
|
||||
with self._lock:
|
||||
req = self._pending.get(request_id)
|
||||
if not req:
|
||||
logger.warning("Confirmation respond: unknown request %s", request_id)
|
||||
return None
|
||||
if not req.is_pending:
|
||||
logger.warning("Confirmation respond: request %s already decided", request_id)
|
||||
return req
|
||||
|
||||
req.status = (
|
||||
ConfirmationStatus.APPROVED.value if approved
|
||||
else ConfirmationStatus.DENIED.value
|
||||
)
|
||||
req.decided_at = time.time()
|
||||
req.decided_by = decided_by
|
||||
req.reason = reason
|
||||
|
||||
# Move to history
|
||||
del self._pending[request_id]
|
||||
self._history.append(req)
|
||||
|
||||
logger.info(
|
||||
"Confirmation %s: %s by %s",
|
||||
request_id, "APPROVED" if approved else "DENIED", decided_by,
|
||||
)
|
||||
return req
|
||||
|
||||
def wait_for_decision(
|
||||
self, request_id: str, timeout: Optional[float] = None
|
||||
) -> ConfirmationRequest:
|
||||
"""Block until a decision is made or timeout expires."""
|
||||
deadline = time.time() + (timeout or self.default_timeout)
|
||||
while time.time() < deadline:
|
||||
with self._lock:
|
||||
req = self._pending.get(request_id)
|
||||
if req and not req.is_pending:
|
||||
return req
|
||||
if req and req.is_expired:
|
||||
req.status = ConfirmationStatus.EXPIRED.value
|
||||
del self._pending[request_id]
|
||||
self._history.append(req)
|
||||
return req
|
||||
time.sleep(0.5)
|
||||
|
||||
# Timeout
|
||||
with self._lock:
|
||||
req = self._pending.pop(request_id, None)
|
||||
if req:
|
||||
req.status = ConfirmationStatus.EXPIRED.value
|
||||
self._history.append(req)
|
||||
return req
|
||||
|
||||
# Shouldn't reach here
|
||||
return ConfirmationRequest(
|
||||
request_id=request_id,
|
||||
action="unknown",
|
||||
description="Request not found",
|
||||
risk_level="high",
|
||||
payload={},
|
||||
status=ConfirmationStatus.EXPIRED.value,
|
||||
)
|
||||
|
||||
def get_pending(self) -> List[Dict[str, Any]]:
|
||||
"""Return list of pending confirmation requests."""
|
||||
self._expire_old()
|
||||
with self._lock:
|
||||
return [r.to_dict() for r in self._pending.values() if r.is_pending]
|
||||
|
||||
def get_history(self, limit: int = 50) -> List[Dict[str, Any]]:
|
||||
"""Return recent confirmation history."""
|
||||
with self._lock:
|
||||
return [r.to_dict() for r in self._history[-limit:]]
|
||||
|
||||
def _expire_old(self) -> None:
|
||||
"""Move expired requests to history."""
|
||||
now = time.time()
|
||||
with self._lock:
|
||||
expired = [
|
||||
rid for rid, req in self._pending.items()
|
||||
if now > req.expires_at
|
||||
]
|
||||
for rid in expired:
|
||||
req = self._pending.pop(rid)
|
||||
req.status = ConfirmationStatus.EXPIRED.value
|
||||
self._history.append(req)
|
||||
|
||||
# --- aiohttp HTTP API ---
|
||||
|
||||
async def _handle_health(self, request):
|
||||
from aiohttp import web
|
||||
return web.json_response({
|
||||
"status": "ok",
|
||||
"service": "hermes-confirmation-daemon",
|
||||
"pending": len(self._pending),
|
||||
})
|
||||
|
||||
async def _handle_confirm(self, request):
|
||||
from aiohttp import web
|
||||
try:
|
||||
body = await request.json()
|
||||
except Exception:
|
||||
return web.json_response({"error": "invalid JSON"}, status=400)
|
||||
|
||||
action = body.get("action", "")
|
||||
description = body.get("description", "")
|
||||
if not action or not description:
|
||||
return web.json_response(
|
||||
{"error": "action and description required"}, status=400
|
||||
)
|
||||
|
||||
req = self.request(
|
||||
action=action,
|
||||
description=description,
|
||||
payload=body.get("payload", {}),
|
||||
risk_level=body.get("risk_level"),
|
||||
session_key=body.get("session_key", ""),
|
||||
timeout=body.get("timeout"),
|
||||
)
|
||||
|
||||
# If auto-approved, return immediately
|
||||
if req.status != ConfirmationStatus.PENDING.value:
|
||||
return web.json_response({
|
||||
"request_id": req.request_id,
|
||||
"status": req.status,
|
||||
"decided_by": req.decided_by,
|
||||
})
|
||||
|
||||
# Otherwise, wait for human decision (with timeout)
|
||||
timeout = min(body.get("timeout", self.default_timeout), 600)
|
||||
result = self.wait_for_decision(req.request_id, timeout=timeout)
|
||||
|
||||
return web.json_response({
|
||||
"request_id": result.request_id,
|
||||
"status": result.status,
|
||||
"decided_by": result.decided_by,
|
||||
"reason": result.reason,
|
||||
})
|
||||
|
||||
async def _handle_respond(self, request):
|
||||
from aiohttp import web
|
||||
try:
|
||||
body = await request.json()
|
||||
except Exception:
|
||||
return web.json_response({"error": "invalid JSON"}, status=400)
|
||||
|
||||
request_id = body.get("request_id", "")
|
||||
approved = body.get("approved")
|
||||
if not request_id or approved is None:
|
||||
return web.json_response(
|
||||
{"error": "request_id and approved required"}, status=400
|
||||
)
|
||||
|
||||
result = self.respond(
|
||||
request_id=request_id,
|
||||
approved=bool(approved),
|
||||
decided_by=body.get("decided_by", "human"),
|
||||
reason=body.get("reason", ""),
|
||||
)
|
||||
|
||||
if not result:
|
||||
return web.json_response({"error": "unknown request"}, status=404)
|
||||
|
||||
return web.json_response({
|
||||
"request_id": result.request_id,
|
||||
"status": result.status,
|
||||
})
|
||||
|
||||
async def _handle_pending(self, request):
|
||||
from aiohttp import web
|
||||
return web.json_response({"pending": self.get_pending()})
|
||||
|
||||
def _build_app(self):
|
||||
"""Build the aiohttp application."""
|
||||
from aiohttp import web
|
||||
|
||||
app = web.Application()
|
||||
app.router.add_get("/health", self._handle_health)
|
||||
app.router.add_post("/confirm", self._handle_confirm)
|
||||
app.router.add_post("/respond", self._handle_respond)
|
||||
app.router.add_get("/pending", self._handle_pending)
|
||||
self._app = app
|
||||
return app
|
||||
|
||||
async def start_async(self) -> None:
|
||||
"""Start the daemon as an async server."""
|
||||
from aiohttp import web
|
||||
|
||||
app = self._build_app()
|
||||
self._runner = web.AppRunner(app)
|
||||
await self._runner.setup()
|
||||
site = web.TCPSite(self._runner, self.host, self.port)
|
||||
await site.start()
|
||||
logger.info("Confirmation daemon listening on %s:%d", self.host, self.port)
|
||||
|
||||
async def stop_async(self) -> None:
|
||||
"""Stop the daemon."""
|
||||
if self._runner:
|
||||
await self._runner.cleanup()
|
||||
self._runner = None
|
||||
|
||||
def start(self) -> None:
|
||||
"""Start daemon in a background thread (blocking caller)."""
|
||||
def _run():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
loop.run_until_complete(self.start_async())
|
||||
loop.run_forever()
|
||||
|
||||
t = threading.Thread(target=_run, daemon=True, name="confirmation-daemon")
|
||||
t.start()
|
||||
logger.info("Confirmation daemon started in background thread")
|
||||
|
||||
def start_blocking(self) -> None:
|
||||
"""Start daemon and block (for standalone use)."""
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
loop.run_until_complete(self.start_async())
|
||||
try:
|
||||
loop.run_forever()
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
loop.run_until_complete(self.stop_async())
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Convenience API for agent integration
|
||||
# =========================================================================
|
||||
|
||||
# Global singleton — initialized by gateway or CLI at startup
|
||||
_daemon: Optional[ConfirmationDaemon] = None
|
||||
|
||||
|
||||
def get_daemon() -> Optional[ConfirmationDaemon]:
|
||||
"""Get the global confirmation daemon instance."""
|
||||
return _daemon
|
||||
|
||||
|
||||
def init_daemon(
|
||||
host: str = "127.0.0.1",
|
||||
port: int = 6000,
|
||||
notify_callback: Optional[Callable] = None,
|
||||
) -> ConfirmationDaemon:
|
||||
"""Initialize the global confirmation daemon."""
|
||||
global _daemon
|
||||
_daemon = ConfirmationDaemon(
|
||||
host=host, port=port, notify_callback=notify_callback
|
||||
)
|
||||
return _daemon
|
||||
|
||||
|
||||
def request_confirmation(
|
||||
action: str,
|
||||
description: str,
|
||||
payload: Optional[Dict[str, Any]] = None,
|
||||
risk_level: Optional[str] = None,
|
||||
session_key: str = "",
|
||||
timeout: int = 300,
|
||||
) -> bool:
|
||||
"""Request human confirmation for a high-risk action.
|
||||
|
||||
This is the primary integration point for agent code. It:
|
||||
1. Classifies the action risk level
|
||||
2. Checks the whitelist
|
||||
3. If confirmation needed, blocks until human responds
|
||||
4. Returns True if approved, False if denied/expired
|
||||
|
||||
Args:
|
||||
action: Action type (send_email, crypto_tx, etc.)
|
||||
description: Human-readable description
|
||||
payload: Action-specific data
|
||||
risk_level: Override auto-classification
|
||||
session_key: Session requesting approval
|
||||
timeout: Seconds to wait for human response
|
||||
|
||||
Returns:
|
||||
True if approved, False if denied or expired.
|
||||
"""
|
||||
daemon = get_daemon()
|
||||
if not daemon:
|
||||
logger.warning(
|
||||
"No confirmation daemon running — DENYING action %s by default. "
|
||||
"Start daemon with init_daemon() or --confirmation-daemon flag.",
|
||||
action,
|
||||
)
|
||||
return False
|
||||
|
||||
req = daemon.request(
|
||||
action=action,
|
||||
description=description,
|
||||
payload=payload,
|
||||
risk_level=risk_level,
|
||||
session_key=session_key,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
# Auto-approved (whitelisted)
|
||||
if req.status == ConfirmationStatus.AUTO_APPROVED.value:
|
||||
return True
|
||||
|
||||
# Wait for human
|
||||
result = daemon.wait_for_decision(req.request_id, timeout=timeout)
|
||||
return result.status == ConfirmationStatus.APPROVED.value
|
||||
Reference in New Issue
Block a user