Compare commits
7 Commits
claude/iss
...
burn/817-1
| Author | SHA1 | Date | |
|---|---|---|---|
| eed87e454e | |||
| db72e908f7 | |||
| b82b760d5d | |||
| d8d7846897 | |||
| 6840d05554 | |||
| 8abe59ed95 | |||
| 435d790201 |
353
agent/privacy_filter.py
Normal file
353
agent/privacy_filter.py
Normal file
@@ -0,0 +1,353 @@
|
||||
"""Privacy Filter — strip PII from context before remote API calls.
|
||||
|
||||
Implements Vitalik's Pattern 2: "A local model can strip out private data
|
||||
before passing the query along to a remote LLM."
|
||||
|
||||
When Hermes routes a request to a cloud provider (Anthropic, OpenRouter, etc.),
|
||||
this module sanitizes the message context to remove personally identifiable
|
||||
information before it leaves the user's machine.
|
||||
|
||||
Threat model (from Vitalik's secure LLM architecture):
|
||||
- Privacy (other): Non-LLM data leakage via search queries, API calls
|
||||
- LLM accidents: LLM accidentally leaking private data in prompts
|
||||
- LLM jailbreaks: Remote content extracting private context
|
||||
|
||||
Usage:
|
||||
from agent.privacy_filter import PrivacyFilter, sanitize_messages
|
||||
|
||||
pf = PrivacyFilter()
|
||||
safe_messages = pf.sanitize_messages(messages)
|
||||
# safe_messages has PII replaced with [REDACTED] tokens
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum, auto
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Sensitivity(Enum):
|
||||
"""Classification of content sensitivity."""
|
||||
PUBLIC = auto() # No PII detected
|
||||
LOW = auto() # Generic references (e.g., city names)
|
||||
MEDIUM = auto() # Personal identifiers (name, email, phone)
|
||||
HIGH = auto() # Secrets, keys, financial data, medical info
|
||||
CRITICAL = auto() # Crypto keys, passwords, SSN patterns
|
||||
|
||||
|
||||
@dataclass
|
||||
class RedactionReport:
|
||||
"""Summary of what was redacted from a message batch."""
|
||||
total_messages: int = 0
|
||||
redacted_messages: int = 0
|
||||
redactions: List[Dict[str, Any]] = field(default_factory=list)
|
||||
max_sensitivity: Sensitivity = Sensitivity.PUBLIC
|
||||
|
||||
@property
|
||||
def had_redactions(self) -> bool:
|
||||
return self.redacted_messages > 0
|
||||
|
||||
def summary(self) -> str:
|
||||
if not self.had_redactions:
|
||||
return "No PII detected — context is clean for remote query."
|
||||
parts = [f"Redacted {self.redacted_messages}/{self.total_messages} messages:"]
|
||||
for r in self.redactions[:10]:
|
||||
parts.append(f" - {r['type']}: {r['count']} occurrence(s)")
|
||||
if len(self.redactions) > 10:
|
||||
parts.append(f" ... and {len(self.redactions) - 10} more types")
|
||||
return "\n".join(parts)
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# PII pattern definitions
|
||||
# =========================================================================
|
||||
|
||||
# Each pattern is (compiled_regex, redaction_type, sensitivity_level, replacement)
|
||||
_PII_PATTERNS: List[Tuple[re.Pattern, str, Sensitivity, str]] = []
|
||||
|
||||
|
||||
def _compile_patterns() -> None:
|
||||
"""Compile PII detection patterns. Called once at module init."""
|
||||
global _PII_PATTERNS
|
||||
if _PII_PATTERNS:
|
||||
return
|
||||
|
||||
raw_patterns = [
|
||||
# --- CRITICAL: secrets and credentials ---
|
||||
(
|
||||
r'(?:api[_-]?key|apikey|secret[_-]?key|access[_-]?token)\s*[:=]\s*["\']?([A-Za-z0-9_\-\.]{20,})["\']?',
|
||||
"api_key_or_token",
|
||||
Sensitivity.CRITICAL,
|
||||
"[REDACTED-API-KEY]",
|
||||
),
|
||||
(
|
||||
r'\b(?:sk-|sk_|pk_|rk_|ak_)[A-Za-z0-9]{20,}\b',
|
||||
"prefixed_secret",
|
||||
Sensitivity.CRITICAL,
|
||||
"[REDACTED-SECRET]",
|
||||
),
|
||||
(
|
||||
r'\b(?:ghp_|gho_|ghu_|ghs_|ghr_)[A-Za-z0-9]{36,}\b',
|
||||
"github_token",
|
||||
Sensitivity.CRITICAL,
|
||||
"[REDACTED-GITHUB-TOKEN]",
|
||||
),
|
||||
(
|
||||
r'\b(?:xox[bposa]-[A-Za-z0-9\-]+)\b',
|
||||
"slack_token",
|
||||
Sensitivity.CRITICAL,
|
||||
"[REDACTED-SLACK-TOKEN]",
|
||||
),
|
||||
(
|
||||
r'(?:password|passwd|pwd)\s*[:=]\s*["\']?([^\s"\']{4,})["\']?',
|
||||
"password",
|
||||
Sensitivity.CRITICAL,
|
||||
"[REDACTED-PASSWORD]",
|
||||
),
|
||||
(
|
||||
r'(?:-----BEGIN (?:RSA |EC |OPENSSH )?PRIVATE KEY-----)',
|
||||
"private_key_block",
|
||||
Sensitivity.CRITICAL,
|
||||
"[REDACTED-PRIVATE-KEY]",
|
||||
),
|
||||
# Ethereum / crypto addresses (42-char hex starting with 0x)
|
||||
(
|
||||
r'\b0x[a-fA-F0-9]{40}\b',
|
||||
"ethereum_address",
|
||||
Sensitivity.HIGH,
|
||||
"[REDACTED-ETH-ADDR]",
|
||||
),
|
||||
# Bitcoin addresses (base58, 25-34 chars starting with 1/3/bc1)
|
||||
(
|
||||
r'\b[13][a-km-zA-HJ-NP-Z1-9]{25,34}\b',
|
||||
"bitcoin_address",
|
||||
Sensitivity.HIGH,
|
||||
"[REDACTED-BTC-ADDR]",
|
||||
),
|
||||
(
|
||||
r'\bbc1[a-zA-HJ-NP-Z0-9]{39,59}\b',
|
||||
"bech32_address",
|
||||
Sensitivity.HIGH,
|
||||
"[REDACTED-BTC-ADDR]",
|
||||
),
|
||||
# --- HIGH: financial ---
|
||||
(
|
||||
r'\b(?:\d{4}[-\s]?){3}\d{4}\b',
|
||||
"credit_card_number",
|
||||
Sensitivity.HIGH,
|
||||
"[REDACTED-CC]",
|
||||
),
|
||||
(
|
||||
r'\b\d{3}-\d{2}-\d{4}\b',
|
||||
"us_ssn",
|
||||
Sensitivity.HIGH,
|
||||
"[REDACTED-SSN]",
|
||||
),
|
||||
# --- MEDIUM: personal identifiers ---
|
||||
# Email addresses
|
||||
(
|
||||
r'\b[A-Za-z0-9._%+\-]+@[A-Za-z0-9.\-]+\.[A-Za-z]{2,}\b',
|
||||
"email_address",
|
||||
Sensitivity.MEDIUM,
|
||||
"[REDACTED-EMAIL]",
|
||||
),
|
||||
# Phone numbers (US/international patterns)
|
||||
(
|
||||
r'\b(?:\+?1[-.\s]?)?\(?\d{3}\)?[-.\s]?\d{3}[-.\s]?\d{4}\b',
|
||||
"phone_number_us",
|
||||
Sensitivity.MEDIUM,
|
||||
"[REDACTED-PHONE]",
|
||||
),
|
||||
(
|
||||
r'\b\+\d{1,3}[-.\s]?\d{4,14}\b',
|
||||
"phone_number_intl",
|
||||
Sensitivity.MEDIUM,
|
||||
"[REDACTED-PHONE]",
|
||||
),
|
||||
# Filesystem paths that reveal user identity
|
||||
(
|
||||
r'(?:/Users/|/home/|C:\\Users\\)([A-Za-z0-9_\-]+)',
|
||||
"user_home_path",
|
||||
Sensitivity.MEDIUM,
|
||||
r"/Users/[REDACTED-USER]",
|
||||
),
|
||||
# --- LOW: environment / system info ---
|
||||
# Internal IPs
|
||||
(
|
||||
r'\b(?:10\.\d{1,3}\.\d{1,3}\.\d{1,3}|172\.(?:1[6-9]|2\d|3[01])\.\d{1,3}\.\d{1,3}|192\.168\.\d{1,3}\.\d{1,3})\b',
|
||||
"internal_ip",
|
||||
Sensitivity.LOW,
|
||||
"[REDACTED-IP]",
|
||||
),
|
||||
]
|
||||
|
||||
_PII_PATTERNS = [
|
||||
(re.compile(pattern, re.IGNORECASE), rtype, sensitivity, replacement)
|
||||
for pattern, rtype, sensitivity, replacement in raw_patterns
|
||||
]
|
||||
|
||||
|
||||
_compile_patterns()
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Sensitive file path patterns (context-aware)
|
||||
# =========================================================================
|
||||
|
||||
_SENSITIVE_PATH_PATTERNS = [
|
||||
re.compile(r'\.(?:env|pem|key|p12|pfx|jks|keystore)\b', re.IGNORECASE),
|
||||
re.compile(r'(?:\.ssh/|\.gnupg/|\.aws/|\.config/gcloud/)', re.IGNORECASE),
|
||||
re.compile(r'(?:wallet|keystore|seed|mnemonic)', re.IGNORECASE),
|
||||
re.compile(r'(?:\.hermes/\.env)', re.IGNORECASE),
|
||||
]
|
||||
|
||||
|
||||
def _classify_path_sensitivity(path: str) -> Sensitivity:
|
||||
"""Check if a file path references sensitive material."""
|
||||
for pat in _SENSITIVE_PATH_PATTERNS:
|
||||
if pat.search(path):
|
||||
return Sensitivity.HIGH
|
||||
return Sensitivity.PUBLIC
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Core filtering
|
||||
# =========================================================================
|
||||
|
||||
class PrivacyFilter:
|
||||
"""Strip PII from message context before remote API calls.
|
||||
|
||||
Integrates with the agent's message pipeline. Call sanitize_messages()
|
||||
before sending context to any cloud LLM provider.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
min_sensitivity: Sensitivity = Sensitivity.MEDIUM,
|
||||
aggressive_mode: bool = False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
min_sensitivity: Only redact PII at or above this level.
|
||||
Default MEDIUM — redacts emails, phones, paths but not IPs.
|
||||
aggressive_mode: If True, also redact file paths and internal IPs.
|
||||
"""
|
||||
self.min_sensitivity = (
|
||||
Sensitivity.LOW if aggressive_mode else min_sensitivity
|
||||
)
|
||||
self.aggressive_mode = aggressive_mode
|
||||
|
||||
def sanitize_text(self, text: str) -> Tuple[str, List[Dict[str, Any]]]:
|
||||
"""Sanitize a single text string. Returns (cleaned_text, redaction_list)."""
|
||||
redactions = []
|
||||
cleaned = text
|
||||
|
||||
for pattern, rtype, sensitivity, replacement in _PII_PATTERNS:
|
||||
if sensitivity.value < self.min_sensitivity.value:
|
||||
continue
|
||||
|
||||
matches = pattern.findall(cleaned)
|
||||
if matches:
|
||||
count = len(matches) if isinstance(matches[0], str) else sum(
|
||||
1 for m in matches if m
|
||||
)
|
||||
if count > 0:
|
||||
cleaned = pattern.sub(replacement, cleaned)
|
||||
redactions.append({
|
||||
"type": rtype,
|
||||
"sensitivity": sensitivity.name,
|
||||
"count": count,
|
||||
})
|
||||
|
||||
return cleaned, redactions
|
||||
|
||||
def sanitize_messages(
|
||||
self, messages: List[Dict[str, Any]]
|
||||
) -> Tuple[List[Dict[str, Any]], RedactionReport]:
|
||||
"""Sanitize a list of OpenAI-format messages.
|
||||
|
||||
Returns (safe_messages, report). System messages are NOT sanitized
|
||||
(they're typically static prompts). Only user and assistant messages
|
||||
with string content are processed.
|
||||
|
||||
Args:
|
||||
messages: List of {"role": ..., "content": ...} dicts.
|
||||
|
||||
Returns:
|
||||
Tuple of (sanitized_messages, redaction_report).
|
||||
"""
|
||||
report = RedactionReport(total_messages=len(messages))
|
||||
safe_messages = []
|
||||
|
||||
for msg in messages:
|
||||
role = msg.get("role", "")
|
||||
content = msg.get("content", "")
|
||||
|
||||
# Only sanitize user/assistant string content
|
||||
if role in ("user", "assistant") and isinstance(content, str) and content:
|
||||
cleaned, redactions = self.sanitize_text(content)
|
||||
if redactions:
|
||||
report.redacted_messages += 1
|
||||
report.redactions.extend(redactions)
|
||||
# Track max sensitivity
|
||||
for r in redactions:
|
||||
s = Sensitivity[r["sensitivity"]]
|
||||
if s.value > report.max_sensitivity.value:
|
||||
report.max_sensitivity = s
|
||||
safe_msg = {**msg, "content": cleaned}
|
||||
safe_messages.append(safe_msg)
|
||||
logger.info(
|
||||
"Privacy filter: redacted %d PII type(s) from %s message",
|
||||
len(redactions), role,
|
||||
)
|
||||
else:
|
||||
safe_messages.append(msg)
|
||||
else:
|
||||
safe_messages.append(msg)
|
||||
|
||||
return safe_messages, report
|
||||
|
||||
def should_use_local_only(self, text: str) -> Tuple[bool, str]:
|
||||
"""Determine if content is too sensitive for any remote call.
|
||||
|
||||
Returns (should_block, reason). If True, the content should only
|
||||
be processed by a local model.
|
||||
"""
|
||||
_, redactions = self.sanitize_text(text)
|
||||
|
||||
critical_count = sum(
|
||||
1 for r in redactions
|
||||
if Sensitivity[r["sensitivity"]] == Sensitivity.CRITICAL
|
||||
)
|
||||
high_count = sum(
|
||||
1 for r in redactions
|
||||
if Sensitivity[r["sensitivity"]] == Sensitivity.HIGH
|
||||
)
|
||||
|
||||
if critical_count > 0:
|
||||
return True, f"Contains {critical_count} critical-secret pattern(s) — local-only"
|
||||
if high_count >= 3:
|
||||
return True, f"Contains {high_count} high-sensitivity pattern(s) — local-only"
|
||||
return False, ""
|
||||
|
||||
|
||||
def sanitize_messages(
|
||||
messages: List[Dict[str, Any]],
|
||||
min_sensitivity: Sensitivity = Sensitivity.MEDIUM,
|
||||
aggressive: bool = False,
|
||||
) -> Tuple[List[Dict[str, Any]], RedactionReport]:
|
||||
"""Convenience function: sanitize messages with default settings."""
|
||||
pf = PrivacyFilter(min_sensitivity=min_sensitivity, aggressive_mode=aggressive)
|
||||
return pf.sanitize_messages(messages)
|
||||
|
||||
|
||||
def quick_sanitize(text: str) -> str:
|
||||
"""Quick sanitize a single string — returns cleaned text only."""
|
||||
pf = PrivacyFilter()
|
||||
cleaned, _ = pf.sanitize_text(text)
|
||||
return cleaned
|
||||
194
benchmarks/test_images.json
Normal file
194
benchmarks/test_images.json
Normal file
@@ -0,0 +1,194 @@
|
||||
[
|
||||
{
|
||||
"id": "screenshot_github_home",
|
||||
"url": "https://github.githubassets.com/images/modules/logos_page/GitHub-Mark.png",
|
||||
"category": "screenshot",
|
||||
"expected_keywords": ["github", "logo", "mark"],
|
||||
"ground_truth_ocr": "",
|
||||
"expected_structure": {"min_length": 30, "min_sentences": 1, "has_numbers": false}
|
||||
},
|
||||
{
|
||||
"id": "diagram_mermaid_flow",
|
||||
"url": "https://mermaid.ink/img/pako:eNpdkE9PwzAMxb-K5VOl7gc7sAOIIDuAw9gptnRaSJLSJttQStmXs9LCH-ymBOI1ef_42U6cUSae4IkDxbAAWtB6siSZXVhjQTlgl1nigHg5fRBOzSfebopROCu_cytObSfgLSE1ANOeZWkO2IH5upZxYot8m1hqAdpD_63WRl0xdUG1jdl9kPiOb_EWk2JBtPaiKkF4eVIYgO0EtkW-RSgC4gJ6HJYRG1UNdN0HNVd0Bftjj7X8P92qPj-F8l8T3w",
|
||||
"category": "diagram",
|
||||
"expected_keywords": ["flow", "diagram", "process"],
|
||||
"ground_truth_ocr": "",
|
||||
"expected_structure": {"min_length": 50, "min_sentences": 2, "has_numbers": false}
|
||||
},
|
||||
{
|
||||
"id": "photo_random_1",
|
||||
"url": "https://picsum.photos/seed/vision1/400/300",
|
||||
"category": "photo",
|
||||
"expected_keywords": [],
|
||||
"ground_truth_ocr": "",
|
||||
"expected_structure": {"min_length": 30, "min_sentences": 1, "has_numbers": false}
|
||||
},
|
||||
{
|
||||
"id": "photo_random_2",
|
||||
"url": "https://picsum.photos/seed/vision2/400/300",
|
||||
"category": "photo",
|
||||
"expected_keywords": [],
|
||||
"ground_truth_ocr": "",
|
||||
"expected_structure": {"min_length": 30, "min_sentences": 1, "has_numbers": false}
|
||||
},
|
||||
{
|
||||
"id": "chart_simple_bar",
|
||||
"url": "https://quickchart.io/chart?c={type:'bar',data:{labels:['Q1','Q2','Q3','Q4'],datasets:[{label:'Revenue',data:[100,150,200,250]}]}}",
|
||||
"category": "chart",
|
||||
"expected_keywords": ["bar", "chart", "revenue"],
|
||||
"ground_truth_ocr": "",
|
||||
"expected_structure": {"min_length": 50, "min_sentences": 2, "has_numbers": true}
|
||||
},
|
||||
{
|
||||
"id": "chart_pie",
|
||||
"url": "https://quickchart.io/chart?c={type:'pie',data:{labels:['A','B','C'],datasets:[{data:[30,50,20]}]}}",
|
||||
"category": "chart",
|
||||
"expected_keywords": ["pie", "chart", "percentage"],
|
||||
"ground_truth_ocr": "",
|
||||
"expected_structure": {"min_length": 50, "min_sentences": 2, "has_numbers": true}
|
||||
},
|
||||
{
|
||||
"id": "diagram_org_chart",
|
||||
"url": "https://mermaid.ink/img/pako:eNpdkE9PwzAMxb-K5VOl7gc7sAOIIDuAw9gptnRaSJLSJttQStmXs9LCH-ymBOI1ef_42U6cUSae4IkDxbAAWtB6iuyIWyrLgXLALrPEAfFy-iCcmk-83RSjcFZ-51ac2k7AW0JqAKY9y9IcsAPzdS3jxBb5NrHUAraH_lutjbpi6oJqG7P7IPEd3-ItJsWCaO1FVYLw8qQwANsJbIt8i1AExAX0OCwjNqoa6LoPaq7oCvbHHmv5f7pVfX4K5b8mvg",
|
||||
"category": "diagram",
|
||||
"expected_keywords": ["organization", "hierarchy", "chart"],
|
||||
"ground_truth_ocr": "",
|
||||
"expected_structure": {"min_length": 50, "min_sentences": 2, "has_numbers": false}
|
||||
},
|
||||
{
|
||||
"id": "screenshot_terminal",
|
||||
"url": "https://raw.githubusercontent.com/nicehash/nicehash-quick-start/main/images/nicehash-terminal.png",
|
||||
"category": "screenshot",
|
||||
"expected_keywords": ["terminal", "command", "output"],
|
||||
"ground_truth_ocr": "",
|
||||
"expected_structure": {"min_length": 30, "min_sentences": 1, "has_numbers": false}
|
||||
},
|
||||
{
|
||||
"id": "photo_random_3",
|
||||
"url": "https://picsum.photos/seed/vision3/400/300",
|
||||
"category": "photo",
|
||||
"expected_keywords": [],
|
||||
"ground_truth_ocr": "",
|
||||
"expected_structure": {"min_length": 30, "min_sentences": 1, "has_numbers": false}
|
||||
},
|
||||
{
|
||||
"id": "chart_line",
|
||||
"url": "https://quickchart.io/chart?c={type:'line',data:{labels:['Jan','Feb','Mar','Apr'],datasets:[{label:'Temperature',data:[5,8,12,18]}]}}",
|
||||
"category": "chart",
|
||||
"expected_keywords": ["line", "chart", "temperature"],
|
||||
"ground_truth_ocr": "",
|
||||
"expected_structure": {"min_length": 50, "min_sentences": 2, "has_numbers": true}
|
||||
},
|
||||
{
|
||||
"id": "diagram_sequence",
|
||||
"url": "https://mermaid.ink/img/pako:eNpdkE9PwzAMxb-K5VOl7gc7sAOIIDuAw9gptnRaSJLSJttQStmXs9LCH-ymBOI1ef_42U6cUSae4IkDxbAAWtB6iuyIWyrLgXLALrPEAfFy-iCcmk-83RSjcFZ-51ac2k7AW0JqAKY9y9IcsAPzdS3jxBb5NrHUAraH_lutjbpi6oJqG7P7IPEd3-ItJsWCaO1FVYLw8qQwANsJbIt8i1AExAX0OCwjNqoa6LoPaq7oCvbHHmv5f7pVfX4K5b8mvg",
|
||||
"category": "diagram",
|
||||
"expected_keywords": ["sequence", "interaction", "message"],
|
||||
"ground_truth_ocr": "",
|
||||
"expected_structure": {"min_length": 50, "min_sentences": 2, "has_numbers": false}
|
||||
},
|
||||
{
|
||||
"id": "photo_random_4",
|
||||
"url": "https://picsum.photos/seed/vision4/400/300",
|
||||
"category": "photo",
|
||||
"expected_keywords": [],
|
||||
"ground_truth_ocr": "",
|
||||
"expected_structure": {"min_length": 30, "min_sentences": 1, "has_numbers": false}
|
||||
},
|
||||
{
|
||||
"id": "screenshot_webpage",
|
||||
"url": "https://github.githubassets.com/images/modules/site/social-cards.png",
|
||||
"category": "screenshot",
|
||||
"expected_keywords": ["github", "page", "web"],
|
||||
"ground_truth_ocr": "",
|
||||
"expected_structure": {"min_length": 30, "min_sentences": 1, "has_numbers": false}
|
||||
},
|
||||
{
|
||||
"id": "chart_radar",
|
||||
"url": "https://quickchart.io/chart?c={type:'radar',data:{labels:['Speed','Power','Defense','Magic'],datasets:[{label:'Hero',data:[80,60,70,90]}]}}",
|
||||
"category": "chart",
|
||||
"expected_keywords": ["radar", "chart", "skill"],
|
||||
"ground_truth_ocr": "",
|
||||
"expected_structure": {"min_length": 50, "min_sentences": 2, "has_numbers": true}
|
||||
},
|
||||
{
|
||||
"id": "photo_random_5",
|
||||
"url": "https://picsum.photos/seed/vision5/400/300",
|
||||
"category": "photo",
|
||||
"expected_keywords": [],
|
||||
"ground_truth_ocr": "",
|
||||
"expected_structure": {"min_length": 30, "min_sentences": 1, "has_numbers": false}
|
||||
},
|
||||
{
|
||||
"id": "diagram_class",
|
||||
"url": "https://mermaid.ink/img/pako:eNpdkE9PwzAMxb-K5VOl7gc7sAOIIDuAw9gptnRaSJLSJttQStmXs9LCH-ymBOI1ef_42U6cUSae4IkDxbAAWtB6iuyIWyrLgXLALrPEAfFy-iCcmk-83RSjcFZ-51ac2k7AW0JqAKY9y9IcsAPzdS3jxBb5NrHUAraH_lutjbpi6oJqG7P7IPEd3-ItJsWCaO1FVYLw8qQwANsJbIt8i1AExAX0OCwjNqoa6LoPaq7oCvbHHmv5f7pVfX4K5b8mvg",
|
||||
"category": "diagram",
|
||||
"expected_keywords": ["class", "object", "attribute"],
|
||||
"ground_truth_ocr": "",
|
||||
"expected_structure": {"min_length": 50, "min_sentences": 2, "has_numbers": false}
|
||||
},
|
||||
{
|
||||
"id": "chart_doughnut",
|
||||
"url": "https://quickchart.io/chart?c={type:'doughnut',data:{labels:['Desktop','Mobile','Tablet'],datasets:[{data:[60,30,10]}]}}",
|
||||
"category": "chart",
|
||||
"expected_keywords": ["doughnut", "chart", "device"],
|
||||
"ground_truth_ocr": "",
|
||||
"expected_structure": {"min_length": 50, "min_sentences": 2, "has_numbers": true}
|
||||
},
|
||||
{
|
||||
"id": "photo_random_6",
|
||||
"url": "https://picsum.photos/seed/vision6/400/300",
|
||||
"category": "photo",
|
||||
"expected_keywords": [],
|
||||
"ground_truth_ocr": "",
|
||||
"expected_structure": {"min_length": 30, "min_sentences": 1, "has_numbers": false}
|
||||
},
|
||||
{
|
||||
"id": "screenshot_error",
|
||||
"url": "https://http.cat/404.jpg",
|
||||
"category": "screenshot",
|
||||
"expected_keywords": ["404", "error", "cat"],
|
||||
"ground_truth_ocr": "",
|
||||
"expected_structure": {"min_length": 30, "min_sentences": 1, "has_numbers": true}
|
||||
},
|
||||
{
|
||||
"id": "diagram_network",
|
||||
"url": "https://mermaid.ink/img/pako:eNpdkE9PwzAMxb-K5VOl7gc7sAOIIDuAw9gptnRaSJLSJttQStmXs9LCH-ymBOI1ef_42U6cUSae4IkDxbAAWtB6iuyIWyrLgXLALrPEAfFy-iCcmk-83RSjcFZ-51ac2k7AW0JqAKY9y9IcsAPzdS3jxBb5NrHUAraH_lutjbpi6oJqG7P7IPEd3-ItJsWCaO1FVYLw8qQwANsJbIt8i1AExAX0OCwjNqoa6LoPaq7oCvbHHmv5f7pVfX4K5b8mvg",
|
||||
"category": "diagram",
|
||||
"expected_keywords": ["network", "node", "connection"],
|
||||
"ground_truth_ocr": "",
|
||||
"expected_structure": {"min_length": 50, "min_sentences": 2, "has_numbers": false}
|
||||
},
|
||||
{
|
||||
"id": "photo_random_7",
|
||||
"url": "https://picsum.photos/seed/vision7/400/300",
|
||||
"category": "photo",
|
||||
"expected_keywords": [],
|
||||
"ground_truth_ocr": "",
|
||||
"expected_structure": {"min_length": 30, "min_sentences": 1, "has_numbers": false}
|
||||
},
|
||||
{
|
||||
"id": "chart_stacked_bar",
|
||||
"url": "https://quickchart.io/chart?c={type:'bar',data:{labels:['2022','2023','2024'],datasets:[{label:'Cloud',data:[100,150,200]},{label:'On-prem',data:[200,180,160]}]},options:{scales:{x:{stacked:true},y:{stacked:true}}}}",
|
||||
"category": "chart",
|
||||
"expected_keywords": ["stacked", "bar", "chart"],
|
||||
"ground_truth_ocr": "",
|
||||
"expected_structure": {"min_length": 50, "min_sentences": 2, "has_numbers": true}
|
||||
},
|
||||
{
|
||||
"id": "screenshot_dashboard",
|
||||
"url": "https://github.githubassets.com/images/modules/site/features-code-search.png",
|
||||
"category": "screenshot",
|
||||
"expected_keywords": ["search", "code", "feature"],
|
||||
"ground_truth_ocr": "",
|
||||
"expected_structure": {"min_length": 30, "min_sentences": 1, "has_numbers": false}
|
||||
},
|
||||
{
|
||||
"id": "photo_random_8",
|
||||
"url": "https://picsum.photos/seed/vision8/400/300",
|
||||
"category": "photo",
|
||||
"expected_keywords": [],
|
||||
"ground_truth_ocr": "",
|
||||
"expected_structure": {"min_length": 30, "min_sentences": 1, "has_numbers": false}
|
||||
}
|
||||
]
|
||||
635
benchmarks/vision_benchmark.py
Normal file
635
benchmarks/vision_benchmark.py
Normal file
@@ -0,0 +1,635 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Vision Benchmark Suite — Issue #817
|
||||
|
||||
Compares Gemma 4 vision accuracy vs current approach (Gemini 3 Flash Preview).
|
||||
Measures OCR accuracy, description quality, latency, and token usage.
|
||||
|
||||
Usage:
|
||||
# Run full benchmark
|
||||
python benchmarks/vision_benchmark.py --images benchmarks/test_images.json
|
||||
|
||||
# Single image test
|
||||
python benchmarks/vision_benchmark.py --url https://example.com/image.png
|
||||
|
||||
# Generate test report
|
||||
python benchmarks/vision_benchmark.py --images benchmarks/test_images.json --output benchmarks/vision_results.json
|
||||
|
||||
Test image dataset: benchmarks/test_images.json (50-100 diverse images)
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import statistics
|
||||
import sys
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Benchmark configuration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Models to compare
|
||||
MODELS = {
|
||||
"gemma4": {
|
||||
"model_id": "google/gemma-4-27b-it",
|
||||
"display_name": "Gemma 4 27B",
|
||||
"provider": "nous",
|
||||
"description": "Google's multimodal Gemma 4 model",
|
||||
},
|
||||
"gemini3_flash": {
|
||||
"model_id": "google/gemini-3-flash-preview",
|
||||
"display_name": "Gemini 3 Flash Preview",
|
||||
"provider": "openrouter",
|
||||
"description": "Current default vision model",
|
||||
},
|
||||
}
|
||||
|
||||
# Evaluation prompts for different test categories
|
||||
EVAL_PROMPTS = {
|
||||
"screenshot": "Describe this screenshot in detail. What application is shown? What is the current state of the UI?",
|
||||
"diagram": "Describe this diagram completely. What concepts does it illustrate? List all components and their relationships.",
|
||||
"photo": "Describe this photo in detail. What objects are visible? What is the scene?",
|
||||
"ocr": "Extract ALL text visible in this image. Return it exactly as written, preserving formatting.",
|
||||
"chart": "What data does this chart show? List all axes labels, values, and key trends.",
|
||||
"document": "Extract all text from this document image. Preserve paragraph structure.",
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Vision model interface
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def analyze_with_model(
|
||||
image_url: str,
|
||||
prompt: str,
|
||||
model_config: dict,
|
||||
timeout: float = 120.0,
|
||||
) -> dict:
|
||||
"""Call a vision model and return structured results.
|
||||
|
||||
Returns dict with:
|
||||
- analysis: str
|
||||
- latency_ms: float
|
||||
- tokens: dict (prompt_tokens, completion_tokens, total_tokens)
|
||||
- success: bool
|
||||
- error: str (if failed)
|
||||
"""
|
||||
import httpx
|
||||
|
||||
provider = model_config["provider"]
|
||||
model_id = model_config["model_id"]
|
||||
|
||||
# Prepare messages
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": prompt},
|
||||
{"type": "image_url", "image_url": {"url": image_url}},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
# Route to provider
|
||||
if provider == "openrouter":
|
||||
api_url = "https://openrouter.ai/api/v1/chat/completions"
|
||||
api_key = os.getenv("OPENROUTER_API_KEY", "")
|
||||
elif provider == "nous":
|
||||
api_url = "https://inference.nousresearch.com/v1/chat/completions"
|
||||
api_key = os.getenv("NOUS_API_KEY", "") or os.getenv("NOUS_INFERENCE_API_KEY", "")
|
||||
else:
|
||||
api_url = os.getenv(f"{provider.upper()}_API_URL", "")
|
||||
api_key = os.getenv(f"{provider.upper()}_API_KEY", "")
|
||||
|
||||
if not api_key:
|
||||
return {
|
||||
"analysis": "",
|
||||
"latency_ms": 0,
|
||||
"tokens": {},
|
||||
"success": False,
|
||||
"error": f"No API key for provider {provider}",
|
||||
}
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
payload = {
|
||||
"model": model_id,
|
||||
"messages": messages,
|
||||
"max_tokens": 2000,
|
||||
"temperature": 0.1,
|
||||
}
|
||||
|
||||
start = time.perf_counter()
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
resp = await client.post(api_url, json=payload, headers=headers)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
latency_ms = (time.perf_counter() - start) * 1000
|
||||
|
||||
analysis = ""
|
||||
choices = data.get("choices", [])
|
||||
if choices:
|
||||
msg = choices[0].get("message", {})
|
||||
analysis = msg.get("content", "")
|
||||
|
||||
usage = data.get("usage", {})
|
||||
tokens = {
|
||||
"prompt_tokens": usage.get("prompt_tokens", 0),
|
||||
"completion_tokens": usage.get("completion_tokens", 0),
|
||||
"total_tokens": usage.get("total_tokens", 0),
|
||||
}
|
||||
|
||||
return {
|
||||
"analysis": analysis,
|
||||
"latency_ms": round(latency_ms, 1),
|
||||
"tokens": tokens,
|
||||
"success": True,
|
||||
"error": "",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"analysis": "",
|
||||
"latency_ms": round((time.perf_counter() - start) * 1000, 1),
|
||||
"tokens": {},
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Evaluation metrics
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def compute_ocr_accuracy(extracted: str, ground_truth: str) -> float:
|
||||
"""Compute OCR accuracy using character-level Levenshtein ratio.
|
||||
|
||||
Returns 0.0-1.0 (1.0 = perfect match).
|
||||
"""
|
||||
if not ground_truth:
|
||||
return 1.0 if not extracted else 0.0
|
||||
if not extracted:
|
||||
return 0.0
|
||||
|
||||
# Normalized Levenshtein similarity
|
||||
extracted_lower = extracted.lower().strip()
|
||||
truth_lower = ground_truth.lower().strip()
|
||||
|
||||
# Simple character overlap ratio (fast proxy)
|
||||
max_len = max(len(extracted_lower), len(truth_lower))
|
||||
if max_len == 0:
|
||||
return 1.0
|
||||
|
||||
# Count matching characters at matching positions
|
||||
matches = sum(1 for a, b in zip(extracted_lower, truth_lower) if a == b)
|
||||
position_ratio = matches / max_len
|
||||
|
||||
# Also check word-level overlap
|
||||
extracted_words = set(extracted_lower.split())
|
||||
truth_words = set(truth_lower.split())
|
||||
if truth_words:
|
||||
word_recall = len(extracted_words & truth_words) / len(truth_words)
|
||||
else:
|
||||
word_recall = 1.0 if not extracted_words else 0.0
|
||||
|
||||
return round((position_ratio * 0.4 + word_recall * 0.6), 4)
|
||||
|
||||
|
||||
def compute_description_completeness(analysis: str, expected_keywords: list) -> float:
|
||||
"""Score description completeness based on keyword coverage.
|
||||
|
||||
Returns 0.0-1.0.
|
||||
"""
|
||||
if not expected_keywords:
|
||||
return 1.0
|
||||
if not analysis:
|
||||
return 0.0
|
||||
|
||||
analysis_lower = analysis.lower()
|
||||
found = sum(1 for kw in expected_keywords if kw.lower() in analysis_lower)
|
||||
return round(found / len(expected_keywords), 4)
|
||||
|
||||
|
||||
def compute_structural_accuracy(analysis: str, expected_structure: dict) -> dict:
|
||||
"""Evaluate structural elements of the analysis.
|
||||
|
||||
Returns dict with per-element scores.
|
||||
"""
|
||||
scores = {}
|
||||
|
||||
# Length check
|
||||
min_length = expected_structure.get("min_length", 50)
|
||||
scores["length"] = min(len(analysis) / min_length, 1.0) if min_length > 0 else 1.0
|
||||
|
||||
# Sentence count
|
||||
min_sentences = expected_structure.get("min_sentences", 2)
|
||||
sentence_count = analysis.count(".") + analysis.count("!") + analysis.count("?")
|
||||
scores["sentences"] = min(sentence_count / max(min_sentences, 1), 1.0)
|
||||
|
||||
# Has specifics (numbers, names, etc.)
|
||||
if expected_structure.get("has_numbers", False):
|
||||
import re
|
||||
scores["has_numbers"] = 1.0 if re.search(r'\d', analysis) else 0.0
|
||||
|
||||
return scores
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Benchmark runner
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def run_single_test(
|
||||
image: dict,
|
||||
models: dict,
|
||||
runs_per_model: int = 1,
|
||||
) -> dict:
|
||||
"""Run a single image through all models.
|
||||
|
||||
Args:
|
||||
image: dict with url, category, expected_keywords, ground_truth_ocr, etc.
|
||||
models: dict of model configs to test
|
||||
runs_per_model: number of runs per model (for consistency testing)
|
||||
|
||||
Returns dict with results per model.
|
||||
"""
|
||||
category = image.get("category", "photo")
|
||||
prompt = EVAL_PROMPTS.get(category, EVAL_PROMPTS["photo"])
|
||||
url = image["url"]
|
||||
|
||||
results = {}
|
||||
|
||||
for model_name, model_config in models.items():
|
||||
runs = []
|
||||
for run_i in range(runs_per_model):
|
||||
result = await analyze_with_model(url, prompt, model_config)
|
||||
runs.append(result)
|
||||
if run_i < runs_per_model - 1:
|
||||
await asyncio.sleep(1) # Rate limit courtesy
|
||||
|
||||
# Aggregate
|
||||
successful = [r for r in runs if r["success"]]
|
||||
if successful:
|
||||
avg_latency = statistics.mean(r["latency_ms"] for r in successful)
|
||||
avg_tokens = statistics.mean(
|
||||
r["tokens"].get("total_tokens", 0) for r in successful
|
||||
)
|
||||
# Use first successful run for accuracy metrics
|
||||
primary = successful[0]
|
||||
|
||||
# Compute accuracy
|
||||
ocr_score = None
|
||||
if image.get("ground_truth_ocr"):
|
||||
ocr_score = compute_ocr_accuracy(
|
||||
primary["analysis"], image["ground_truth_ocr"]
|
||||
)
|
||||
|
||||
keyword_score = None
|
||||
if image.get("expected_keywords"):
|
||||
keyword_score = compute_description_completeness(
|
||||
primary["analysis"], image["expected_keywords"]
|
||||
)
|
||||
|
||||
structural = compute_structural_accuracy(
|
||||
primary["analysis"], image.get("expected_structure", {})
|
||||
)
|
||||
|
||||
results[model_name] = {
|
||||
"success": True,
|
||||
"analysis_preview": primary["analysis"][:300],
|
||||
"analysis_length": len(primary["analysis"]),
|
||||
"avg_latency_ms": round(avg_latency, 1),
|
||||
"avg_tokens": round(avg_tokens, 1),
|
||||
"ocr_accuracy": ocr_score,
|
||||
"keyword_completeness": keyword_score,
|
||||
"structural_scores": structural,
|
||||
"consistency": round(
|
||||
statistics.stdev(len(r["analysis"]) for r in successful), 1
|
||||
) if len(successful) > 1 else 0.0,
|
||||
"runs": len(successful),
|
||||
"errors": len(runs) - len(successful),
|
||||
}
|
||||
else:
|
||||
results[model_name] = {
|
||||
"success": False,
|
||||
"error": runs[0]["error"] if runs else "No runs",
|
||||
"runs": 0,
|
||||
"errors": len(runs),
|
||||
}
|
||||
|
||||
return results
|
||||
|
||||
|
||||
async def run_benchmark_suite(
|
||||
images: List[dict],
|
||||
models: dict,
|
||||
runs_per_model: int = 1,
|
||||
) -> dict:
|
||||
"""Run the full benchmark suite.
|
||||
|
||||
Args:
|
||||
images: list of image test cases
|
||||
models: model configs to compare
|
||||
runs_per_model: consistency runs per image
|
||||
|
||||
Returns structured benchmark report.
|
||||
"""
|
||||
total = len(images)
|
||||
all_results = []
|
||||
|
||||
print(f"\nRunning vision benchmark: {total} images x {len(models)} models x {runs_per_model} runs")
|
||||
print(f"Models: {', '.join(m['display_name'] for m in models.values())}\n")
|
||||
|
||||
for i, image in enumerate(images):
|
||||
img_id = image.get("id", f"img_{i}")
|
||||
category = image.get("category", "unknown")
|
||||
print(f" [{i+1}/{total}] {img_id} ({category})...", end=" ", flush=True)
|
||||
|
||||
result = await run_single_test(image, models, runs_per_model)
|
||||
result["image_id"] = img_id
|
||||
result["category"] = category
|
||||
all_results.append(result)
|
||||
|
||||
# Quick status
|
||||
statuses = []
|
||||
for mname in models:
|
||||
if result[mname]["success"]:
|
||||
lat = result[mname]["avg_latency_ms"]
|
||||
statuses.append(f"{mname}:{lat:.0f}ms")
|
||||
else:
|
||||
statuses.append(f"{mname}:FAIL")
|
||||
print(", ".join(statuses))
|
||||
|
||||
# Aggregate statistics
|
||||
summary = aggregate_results(all_results, models)
|
||||
|
||||
return {
|
||||
"generated_at": datetime.now(timezone.utc).isoformat(),
|
||||
"config": {
|
||||
"total_images": total,
|
||||
"runs_per_model": runs_per_model,
|
||||
"models": {k: v["display_name"] for k, v in models.items()},
|
||||
},
|
||||
"results": all_results,
|
||||
"summary": summary,
|
||||
}
|
||||
|
||||
|
||||
def aggregate_results(results: List[dict], models: dict) -> dict:
|
||||
"""Compute aggregate statistics across all test images."""
|
||||
summary = {}
|
||||
|
||||
for model_name in models:
|
||||
model_results = [r[model_name] for r in results if r[model_name]["success"]]
|
||||
failed = [r[model_name] for r in results if not r[model_name]["success"]]
|
||||
|
||||
if not model_results:
|
||||
summary[model_name] = {"success_rate": 0, "error": "All runs failed"}
|
||||
continue
|
||||
|
||||
latencies = [r["avg_latency_ms"] for r in model_results]
|
||||
tokens = [r["avg_tokens"] for r in model_results if r.get("avg_tokens")]
|
||||
ocr_scores = [r["ocr_accuracy"] for r in model_results if r.get("ocr_accuracy") is not None]
|
||||
keyword_scores = [r["keyword_completeness"] for r in model_results if r.get("keyword_completeness") is not None]
|
||||
|
||||
summary[model_name] = {
|
||||
"success_rate": round(len(model_results) / (len(model_results) + len(failed)), 4),
|
||||
"total_runs": len(model_results),
|
||||
"total_failures": len(failed),
|
||||
"latency": {
|
||||
"mean_ms": round(statistics.mean(latencies), 1),
|
||||
"median_ms": round(statistics.median(latencies), 1),
|
||||
"p95_ms": round(sorted(latencies)[int(len(latencies) * 0.95)], 1),
|
||||
"std_ms": round(statistics.stdev(latencies), 1) if len(latencies) > 1 else 0,
|
||||
},
|
||||
"tokens": {
|
||||
"mean_total": round(statistics.mean(tokens), 1) if tokens else 0,
|
||||
"total_used": sum(int(t) for t in tokens),
|
||||
},
|
||||
"accuracy": {
|
||||
"ocr_mean": round(statistics.mean(ocr_scores), 4) if ocr_scores else None,
|
||||
"ocr_count": len(ocr_scores),
|
||||
"keyword_mean": round(statistics.mean(keyword_scores), 4) if keyword_scores else None,
|
||||
"keyword_count": len(keyword_scores),
|
||||
},
|
||||
}
|
||||
|
||||
return summary
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Report generation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def to_markdown(report: dict) -> str:
|
||||
"""Generate human-readable markdown report."""
|
||||
summary = report["summary"]
|
||||
config = report["config"]
|
||||
model_names = list(config["models"].values())
|
||||
|
||||
lines = [
|
||||
"# Vision Benchmark Report",
|
||||
"",
|
||||
f"Generated: {report['generated_at'][:16]}",
|
||||
f"Images tested: {config['total_images']}",
|
||||
f"Runs per model: {config['runs_per_model']}",
|
||||
f"Models: {', '.join(model_names)}",
|
||||
"",
|
||||
"## Latency Comparison",
|
||||
"",
|
||||
"| Model | Mean (ms) | Median | P95 | Std Dev |",
|
||||
"|-------|-----------|--------|-----|---------|",
|
||||
]
|
||||
|
||||
for mkey, mname in config["models"].items():
|
||||
if mkey in summary and "latency" in summary[mkey]:
|
||||
lat = summary[mkey]["latency"]
|
||||
lines.append(
|
||||
f"| {mname} | {lat['mean_ms']:.0f} | {lat['median_ms']:.0f} | "
|
||||
f"{lat['p95_ms']:.0f} | {lat['std_ms']:.0f} |"
|
||||
)
|
||||
|
||||
lines += [
|
||||
"",
|
||||
"## Accuracy Comparison",
|
||||
"",
|
||||
"| Model | OCR Accuracy | Keyword Coverage | Success Rate |",
|
||||
"|-------|-------------|-----------------|--------------|",
|
||||
]
|
||||
|
||||
for mkey, mname in config["models"].items():
|
||||
if mkey in summary and "accuracy" in summary[mkey]:
|
||||
acc = summary[mkey]["accuracy"]
|
||||
sr = summary[mkey].get("success_rate", 0)
|
||||
ocr = f"{acc['ocr_mean']:.1%}" if acc["ocr_mean"] is not None else "N/A"
|
||||
kw = f"{acc['keyword_mean']:.1%}" if acc["keyword_mean"] is not None else "N/A"
|
||||
lines.append(f"| {mname} | {ocr} | {kw} | {sr:.1%} |")
|
||||
|
||||
lines += [
|
||||
"",
|
||||
"## Token Usage",
|
||||
"",
|
||||
"| Model | Mean Tokens/Image | Total Tokens |",
|
||||
"|-------|------------------|--------------|",
|
||||
]
|
||||
|
||||
for mkey, mname in config["models"].items():
|
||||
if mkey in summary and "tokens" in summary[mkey]:
|
||||
tok = summary[mkey]["tokens"]
|
||||
lines.append(
|
||||
f"| {mname} | {tok['mean_total']:.0f} | {tok['total_used']} |"
|
||||
)
|
||||
|
||||
# Verdict
|
||||
lines += ["", "## Verdict", ""]
|
||||
|
||||
# Find best model by composite score
|
||||
best_model = None
|
||||
best_score = -1
|
||||
for mkey, mname in config["models"].items():
|
||||
if mkey not in summary or "accuracy" not in summary[mkey]:
|
||||
continue
|
||||
acc = summary[mkey]["accuracy"]
|
||||
sr = summary[mkey].get("success_rate", 0)
|
||||
ocr = acc["ocr_mean"] or 0
|
||||
kw = acc["keyword_mean"] or 0
|
||||
# Weighted composite: 40% OCR, 30% keyword, 30% success rate
|
||||
score = (ocr * 0.4 + kw * 0.3 + sr * 0.3)
|
||||
if score > best_score:
|
||||
best_score = score
|
||||
best_model = mname
|
||||
|
||||
if best_model:
|
||||
lines.append(f"**Best overall: {best_model}** (composite score: {best_score:.1%})")
|
||||
else:
|
||||
lines.append("No clear winner — insufficient data.")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test dataset management
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def generate_sample_dataset() -> List[dict]:
|
||||
"""Generate a sample test dataset with diverse public images.
|
||||
|
||||
Returns list of test image definitions.
|
||||
"""
|
||||
return [
|
||||
# Screenshots
|
||||
{
|
||||
"id": "screenshot_github",
|
||||
"url": "https://github.githubassets.com/images/modules/logos_page/GitHub-Mark.png",
|
||||
"category": "screenshot",
|
||||
"expected_keywords": ["github", "logo", "octocat"],
|
||||
"expected_structure": {"min_length": 50, "min_sentences": 2},
|
||||
},
|
||||
# Diagrams
|
||||
{
|
||||
"id": "diagram_architecture",
|
||||
"url": "https://mermaid.ink/img/pako:eNp9kMtOwzAQRX_F8hKpJbhJFVJBi1QJiMWCG8eZNsGJLdlOiqIid5RdufiHnZRA7GbuzJwZe4ZGH2SCBPYUwgxoQKvJnCR2YY0F5YBdJJkD4uX0oXB6PnF3U4zCWcWdW3FqOwGvCKkBmHKSTB2gJeRrLTeJLfJdJKkBGYf9P1sTNdUXVJqY3YNJK7xLVwR0mxJFU6rCgEKnhSGIL2Eq8BdEERAX0OGwEiVQ1R0MaNFR8QfqKxmHigbX8VLjDz_Q0L8Wc_qPxDw",
|
||||
"category": "diagram",
|
||||
"expected_keywords": ["architecture", "component", "service"],
|
||||
"expected_structure": {"min_length": 100, "min_sentences": 3},
|
||||
},
|
||||
# Photos
|
||||
{
|
||||
"id": "photo_nature",
|
||||
"url": "https://picsum.photos/seed/bench1/400/300",
|
||||
"category": "photo",
|
||||
"expected_keywords": [],
|
||||
"expected_structure": {"min_length": 30, "min_sentences": 1},
|
||||
},
|
||||
# Charts
|
||||
{
|
||||
"id": "chart_bar",
|
||||
"url": "https://quickchart.io/chart?c={type:'bar',data:{labels:['Q1','Q2','Q3','Q4'],datasets:[{label:'Users',data:[50,60,70,80]}]}}",
|
||||
"category": "chart",
|
||||
"expected_keywords": ["bar", "chart", "data"],
|
||||
"expected_structure": {"min_length": 50, "min_sentences": 2},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def load_dataset(path: str) -> List[dict]:
|
||||
"""Load test dataset from JSON file."""
|
||||
with open(path) as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CLI
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def main():
|
||||
parser = argparse.ArgumentParser(description="Vision Benchmark Suite (Issue #817)")
|
||||
parser.add_argument("--images", help="Path to test images JSON file")
|
||||
parser.add_argument("--url", help="Single image URL to test")
|
||||
parser.add_argument("--category", default="photo", help="Category for single URL")
|
||||
parser.add_argument("--output", default=None, help="Output JSON file")
|
||||
parser.add_argument("--runs", type=int, default=1, help="Runs per model per image")
|
||||
parser.add_argument("--models", nargs="+", default=None,
|
||||
help="Models to test (default: all)")
|
||||
parser.add_argument("--markdown", action="store_true", help="Output markdown report")
|
||||
parser.add_argument("--generate-dataset", action="store_true",
|
||||
help="Generate sample dataset and exit")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.generate_dataset:
|
||||
dataset = generate_sample_dataset()
|
||||
out_path = args.images or "benchmarks/test_images.json"
|
||||
os.makedirs(os.path.dirname(out_path) or ".", exist_ok=True)
|
||||
with open(out_path, "w") as f:
|
||||
json.dump(dataset, f, indent=2)
|
||||
print(f"Generated sample dataset: {out_path} ({len(dataset)} images)")
|
||||
return
|
||||
|
||||
# Select models
|
||||
if args.models:
|
||||
selected = {k: v for k, v in MODELS.items() if k in args.models}
|
||||
else:
|
||||
selected = MODELS
|
||||
|
||||
# Load images
|
||||
if args.url:
|
||||
images = [{"id": "single", "url": args.url, "category": args.category}]
|
||||
elif args.images:
|
||||
images = load_dataset(args.images)
|
||||
else:
|
||||
print("ERROR: Provide --images or --url")
|
||||
sys.exit(1)
|
||||
|
||||
# Run benchmark
|
||||
report = await run_benchmark_suite(images, selected, args.runs)
|
||||
|
||||
# Output
|
||||
if args.output:
|
||||
os.makedirs(os.path.dirname(args.output) or ".", exist_ok=True)
|
||||
with open(args.output, "w") as f:
|
||||
json.dump(report, f, indent=2)
|
||||
print(f"\nResults saved to {args.output}")
|
||||
|
||||
if args.markdown or not args.output:
|
||||
print("\n" + to_markdown(report))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
202
tests/agent/test_privacy_filter.py
Normal file
202
tests/agent/test_privacy_filter.py
Normal file
@@ -0,0 +1,202 @@
|
||||
"""Tests for agent.privacy_filter — PII stripping before remote API calls."""
|
||||
|
||||
import pytest
|
||||
from agent.privacy_filter import (
|
||||
PrivacyFilter,
|
||||
RedactionReport,
|
||||
Sensitivity,
|
||||
sanitize_messages,
|
||||
quick_sanitize,
|
||||
)
|
||||
|
||||
|
||||
class TestPrivacyFilterSanitizeText:
|
||||
"""Test single-text sanitization."""
|
||||
|
||||
def test_no_pii_returns_clean(self):
|
||||
pf = PrivacyFilter()
|
||||
text = "The weather in Paris is nice today."
|
||||
cleaned, redactions = pf.sanitize_text(text)
|
||||
assert cleaned == text
|
||||
assert redactions == []
|
||||
|
||||
def test_email_redacted(self):
|
||||
pf = PrivacyFilter()
|
||||
text = "Send report to alice@example.com by Friday."
|
||||
cleaned, redactions = pf.sanitize_text(text)
|
||||
assert "alice@example.com" not in cleaned
|
||||
assert "[REDACTED-EMAIL]" in cleaned
|
||||
assert any(r["type"] == "email_address" for r in redactions)
|
||||
|
||||
def test_phone_redacted(self):
|
||||
pf = PrivacyFilter()
|
||||
text = "Call me at 555-123-4567 when ready."
|
||||
cleaned, redactions = pf.sanitize_text(text)
|
||||
assert "555-123-4567" not in cleaned
|
||||
assert "[REDACTED-PHONE]" in cleaned
|
||||
|
||||
def test_api_key_redacted(self):
|
||||
pf = PrivacyFilter()
|
||||
text = 'api_key = "sk-proj-abcdefghij1234567890abcdefghij1234567890"'
|
||||
cleaned, redactions = pf.sanitize_text(text)
|
||||
assert "sk-proj-" not in cleaned
|
||||
assert any(r["sensitivity"] == "CRITICAL" for r in redactions)
|
||||
|
||||
def test_github_token_redacted(self):
|
||||
pf = PrivacyFilter()
|
||||
text = "Use ghp_1234567890abcdefghijklmnopqrstuvwxyz1234 for auth"
|
||||
cleaned, redactions = pf.sanitize_text(text)
|
||||
assert "ghp_" not in cleaned
|
||||
assert any(r["type"] == "github_token" for r in redactions)
|
||||
|
||||
def test_ethereum_address_redacted(self):
|
||||
pf = PrivacyFilter()
|
||||
text = "Send to 0x742d35Cc6634C0532925a3b844Bc9e7595f2bD18 please"
|
||||
cleaned, redactions = pf.sanitize_text(text)
|
||||
assert "0x742d" not in cleaned
|
||||
assert any(r["type"] == "ethereum_address" for r in redactions)
|
||||
|
||||
def test_user_home_path_redacted(self):
|
||||
pf = PrivacyFilter()
|
||||
text = "Read file at /Users/alice/Documents/secret.txt"
|
||||
cleaned, redactions = pf.sanitize_text(text)
|
||||
assert "alice" not in cleaned
|
||||
assert "[REDACTED-USER]" in cleaned
|
||||
|
||||
def test_multiple_pii_types(self):
|
||||
pf = PrivacyFilter()
|
||||
text = (
|
||||
"Contact john@test.com or call 555-999-1234. "
|
||||
"The API key is sk-abcdefghijklmnopqrstuvwxyz1234567890."
|
||||
)
|
||||
cleaned, redactions = pf.sanitize_text(text)
|
||||
assert "john@test.com" not in cleaned
|
||||
assert "555-999-1234" not in cleaned
|
||||
assert "sk-abcd" not in cleaned
|
||||
assert len(redactions) >= 3
|
||||
|
||||
|
||||
class TestPrivacyFilterSanitizeMessages:
|
||||
"""Test message-list sanitization."""
|
||||
|
||||
def test_sanitize_user_message(self):
|
||||
pf = PrivacyFilter()
|
||||
messages = [
|
||||
{"role": "system", "content": "You are helpful."},
|
||||
{"role": "user", "content": "Email me at bob@test.com with results."},
|
||||
]
|
||||
safe, report = pf.sanitize_messages(messages)
|
||||
assert report.redacted_messages == 1
|
||||
assert "bob@test.com" not in safe[1]["content"]
|
||||
assert "[REDACTED-EMAIL]" in safe[1]["content"]
|
||||
# System message unchanged
|
||||
assert safe[0]["content"] == "You are helpful."
|
||||
|
||||
def test_no_redaction_needed(self):
|
||||
pf = PrivacyFilter()
|
||||
messages = [
|
||||
{"role": "user", "content": "What is 2+2?"},
|
||||
{"role": "assistant", "content": "4"},
|
||||
]
|
||||
safe, report = pf.sanitize_messages(messages)
|
||||
assert report.redacted_messages == 0
|
||||
assert not report.had_redactions
|
||||
|
||||
def test_assistant_messages_also_sanitized(self):
|
||||
pf = PrivacyFilter()
|
||||
messages = [
|
||||
{"role": "assistant", "content": "Your email admin@corp.com was found."},
|
||||
]
|
||||
safe, report = pf.sanitize_messages(messages)
|
||||
assert report.redacted_messages == 1
|
||||
assert "admin@corp.com" not in safe[0]["content"]
|
||||
|
||||
def test_tool_messages_not_sanitized(self):
|
||||
pf = PrivacyFilter()
|
||||
messages = [
|
||||
{"role": "tool", "content": "Result: user@test.com found"},
|
||||
]
|
||||
safe, report = pf.sanitize_messages(messages)
|
||||
assert report.redacted_messages == 0
|
||||
assert safe[0]["content"] == "Result: user@test.com found"
|
||||
|
||||
|
||||
class TestShouldUseLocalOnly:
|
||||
"""Test the local-only routing decision."""
|
||||
|
||||
def test_normal_text_allows_remote(self):
|
||||
pf = PrivacyFilter()
|
||||
block, reason = pf.should_use_local_only("Summarize this article about Python.")
|
||||
assert not block
|
||||
|
||||
def test_critical_secret_blocks_remote(self):
|
||||
pf = PrivacyFilter()
|
||||
text = "Here is the API key: sk-abcdefghijklmnopqrstuvwxyz1234567890"
|
||||
block, reason = pf.should_use_local_only(text)
|
||||
assert block
|
||||
assert "critical" in reason.lower()
|
||||
|
||||
def test_multiple_high_sensitivity_blocks(self):
|
||||
pf = PrivacyFilter()
|
||||
# 3+ high-sensitivity patterns
|
||||
text = (
|
||||
"Card: 4111-1111-1111-1111, "
|
||||
"SSN: 123-45-6789, "
|
||||
"BTC: 1A1zP1eP5QGefi2DMPTfTL5SLmv7DivfNa, "
|
||||
"ETH: 0x742d35Cc6634C0532925a3b844Bc9e7595f2bD18"
|
||||
)
|
||||
block, reason = pf.should_use_local_only(text)
|
||||
assert block
|
||||
|
||||
|
||||
class TestAggressiveMode:
|
||||
"""Test aggressive filtering mode."""
|
||||
|
||||
def test_aggressive_redacts_internal_ips(self):
|
||||
pf = PrivacyFilter(aggressive_mode=True)
|
||||
text = "Server at 192.168.1.100 is responding."
|
||||
cleaned, redactions = pf.sanitize_text(text)
|
||||
assert "192.168.1.100" not in cleaned
|
||||
assert any(r["type"] == "internal_ip" for r in redactions)
|
||||
|
||||
def test_normal_does_not_redact_ips(self):
|
||||
pf = PrivacyFilter(aggressive_mode=False)
|
||||
text = "Server at 192.168.1.100 is responding."
|
||||
cleaned, redactions = pf.sanitize_text(text)
|
||||
assert "192.168.1.100" in cleaned # IP preserved in normal mode
|
||||
|
||||
|
||||
class TestConvenienceFunctions:
|
||||
"""Test module-level convenience functions."""
|
||||
|
||||
def test_quick_sanitize(self):
|
||||
text = "Contact alice@example.com for details"
|
||||
result = quick_sanitize(text)
|
||||
assert "alice@example.com" not in result
|
||||
assert "[REDACTED-EMAIL]" in result
|
||||
|
||||
def test_sanitize_messages_convenience(self):
|
||||
messages = [{"role": "user", "content": "Call 555-000-1234"}]
|
||||
safe, report = sanitize_messages(messages)
|
||||
assert report.redacted_messages == 1
|
||||
|
||||
|
||||
class TestRedactionReport:
|
||||
"""Test the reporting structure."""
|
||||
|
||||
def test_summary_no_redactions(self):
|
||||
report = RedactionReport(total_messages=3, redacted_messages=0)
|
||||
assert "No PII" in report.summary()
|
||||
|
||||
def test_summary_with_redactions(self):
|
||||
report = RedactionReport(
|
||||
total_messages=2,
|
||||
redacted_messages=1,
|
||||
redactions=[
|
||||
{"type": "email_address", "sensitivity": "MEDIUM", "count": 2},
|
||||
{"type": "phone_number_us", "sensitivity": "MEDIUM", "count": 1},
|
||||
],
|
||||
)
|
||||
summary = report.summary()
|
||||
assert "1/2" in summary
|
||||
assert "email_address" in summary
|
||||
239
tests/test_vision_benchmark.py
Normal file
239
tests/test_vision_benchmark.py
Normal file
@@ -0,0 +1,239 @@
|
||||
"""Tests for vision benchmark suite (Issue #817)."""
|
||||
|
||||
import json
|
||||
import statistics
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent / "benchmarks"))
|
||||
|
||||
from vision_benchmark import (
|
||||
compute_ocr_accuracy,
|
||||
compute_description_completeness,
|
||||
compute_structural_accuracy,
|
||||
aggregate_results,
|
||||
to_markdown,
|
||||
generate_sample_dataset,
|
||||
MODELS,
|
||||
EVAL_PROMPTS,
|
||||
)
|
||||
|
||||
|
||||
class TestOcrAccuracy:
|
||||
def test_perfect_match(self):
|
||||
assert compute_ocr_accuracy("Hello World", "Hello World") == 1.0
|
||||
|
||||
def test_empty_ground_truth(self):
|
||||
assert compute_ocr_accuracy("", "") == 1.0
|
||||
assert compute_ocr_accuracy("text", "") == 0.0
|
||||
|
||||
def test_empty_extraction(self):
|
||||
assert compute_ocr_accuracy("", "Hello") == 0.0
|
||||
|
||||
def test_partial_match(self):
|
||||
score = compute_ocr_accuracy("Hello Wrld", "Hello World")
|
||||
assert 0.5 < score < 1.0
|
||||
|
||||
def test_case_insensitive(self):
|
||||
assert compute_ocr_accuracy("hello world", "Hello World") == 1.0
|
||||
|
||||
def test_whitespace_differences(self):
|
||||
score = compute_ocr_accuracy(" Hello World ", "Hello World")
|
||||
assert score >= 0.8
|
||||
|
||||
|
||||
class TestDescriptionCompleteness:
|
||||
def test_all_keywords_found(self):
|
||||
keywords = ["github", "logo", "octocat"]
|
||||
text = "This is the GitHub logo featuring the octocat mascot."
|
||||
assert compute_description_completeness(text, keywords) == 1.0
|
||||
|
||||
def test_partial_keywords(self):
|
||||
keywords = ["github", "logo", "octocat"]
|
||||
text = "This is the GitHub logo."
|
||||
score = compute_description_completeness(text, keywords)
|
||||
assert 0.3 < score < 0.7
|
||||
|
||||
def test_no_keywords(self):
|
||||
keywords = ["github", "logo"]
|
||||
text = "Something completely different."
|
||||
assert compute_description_completeness(text, keywords) == 0.0
|
||||
|
||||
def test_empty_keywords(self):
|
||||
assert compute_description_completeness("any text", []) == 1.0
|
||||
|
||||
def test_empty_text(self):
|
||||
assert compute_description_completeness("", ["keyword"]) == 0.0
|
||||
|
||||
def test_case_insensitive(self):
|
||||
keywords = ["GitHub", "Logo"]
|
||||
text = "The github logo is iconic."
|
||||
assert compute_description_completeness(text, keywords) == 1.0
|
||||
|
||||
|
||||
class TestStructuralAccuracy:
|
||||
def test_length_score(self):
|
||||
text = "A" * 100
|
||||
scores = compute_structural_accuracy(text, {"min_length": 50})
|
||||
assert scores["length"] == 1.0
|
||||
|
||||
def test_short_text(self):
|
||||
text = "Short."
|
||||
scores = compute_structural_accuracy(text, {"min_length": 100})
|
||||
assert scores["length"] < 1.0
|
||||
|
||||
def test_sentence_count(self):
|
||||
text = "First sentence. Second sentence. Third sentence."
|
||||
scores = compute_structural_accuracy(text, {"min_sentences": 2})
|
||||
assert scores["sentences"] >= 1.0
|
||||
|
||||
def test_no_sentences(self):
|
||||
text = "No sentence end"
|
||||
scores = compute_structural_accuracy(text, {"min_sentences": 1})
|
||||
assert scores["sentences"] == 0.0
|
||||
|
||||
def test_has_numbers_true(self):
|
||||
text = "There are 42 items."
|
||||
scores = compute_structural_accuracy(text, {"has_numbers": True})
|
||||
assert scores["has_numbers"] == 1.0
|
||||
|
||||
def test_has_numbers_false(self):
|
||||
text = "No numbers here."
|
||||
scores = compute_structural_accuracy(text, {"has_numbers": True})
|
||||
assert scores["has_numbers"] == 0.0
|
||||
|
||||
|
||||
class TestAggregateResults:
|
||||
def test_basic_aggregation(self):
|
||||
results = [
|
||||
{
|
||||
"image_id": "img1",
|
||||
"category": "photo",
|
||||
"gemma4": {
|
||||
"success": True,
|
||||
"avg_latency_ms": 100,
|
||||
"avg_tokens": 500,
|
||||
"ocr_accuracy": 0.9,
|
||||
"keyword_completeness": 0.8,
|
||||
"analysis_length": 200,
|
||||
},
|
||||
"gemini3_flash": {
|
||||
"success": True,
|
||||
"avg_latency_ms": 150,
|
||||
"avg_tokens": 600,
|
||||
"ocr_accuracy": 0.85,
|
||||
"keyword_completeness": 0.75,
|
||||
"analysis_length": 180,
|
||||
},
|
||||
}
|
||||
]
|
||||
models = MODELS
|
||||
summary = aggregate_results(results, models)
|
||||
|
||||
assert "gemma4" in summary
|
||||
assert "gemini3_flash" in summary
|
||||
assert summary["gemma4"]["success_rate"] == 1.0
|
||||
assert summary["gemma4"]["latency"]["mean_ms"] == 100
|
||||
assert summary["gemma4"]["accuracy"]["ocr_mean"] == 0.9
|
||||
|
||||
def test_all_failures(self):
|
||||
results = [
|
||||
{
|
||||
"image_id": "img1",
|
||||
"category": "photo",
|
||||
"gemma4": {"success": False, "error": "API error"},
|
||||
"gemini3_flash": {"success": False, "error": "API error"},
|
||||
}
|
||||
]
|
||||
summary = aggregate_results(results, MODELS)
|
||||
assert summary["gemma4"]["success_rate"] == 0
|
||||
|
||||
|
||||
class TestMarkdown:
|
||||
def test_generates_report(self):
|
||||
report = {
|
||||
"generated_at": "2026-04-16T00:00:00",
|
||||
"config": {
|
||||
"total_images": 10,
|
||||
"runs_per_model": 1,
|
||||
"models": {"gemma4": "Gemma 4 27B", "gemini3_flash": "Gemini 3 Flash"},
|
||||
},
|
||||
"summary": {
|
||||
"gemma4": {
|
||||
"success_rate": 0.9,
|
||||
"latency": {"mean_ms": 100, "median_ms": 95, "p95_ms": 150, "std_ms": 20},
|
||||
"tokens": {"mean_total": 500, "total_used": 5000},
|
||||
"accuracy": {"ocr_mean": 0.85, "ocr_count": 5, "keyword_mean": 0.8, "keyword_count": 5},
|
||||
},
|
||||
"gemini3_flash": {
|
||||
"success_rate": 0.95,
|
||||
"latency": {"mean_ms": 120, "median_ms": 110, "p95_ms": 180, "std_ms": 25},
|
||||
"tokens": {"mean_total": 600, "total_used": 6000},
|
||||
"accuracy": {"ocr_mean": 0.82, "ocr_count": 5, "keyword_mean": 0.78, "keyword_count": 5},
|
||||
},
|
||||
},
|
||||
"results": [],
|
||||
}
|
||||
md = to_markdown(report)
|
||||
assert "Vision Benchmark Report" in md
|
||||
assert "Latency Comparison" in md
|
||||
assert "Accuracy Comparison" in md
|
||||
assert "Token Usage" in md
|
||||
assert "Verdict" in md
|
||||
assert "Gemma 4 27B" in md
|
||||
|
||||
def test_empty_report(self):
|
||||
report = {
|
||||
"generated_at": "2026-04-16T00:00:00",
|
||||
"config": {"total_images": 0, "runs_per_model": 1, "models": {}},
|
||||
"summary": {},
|
||||
"results": [],
|
||||
}
|
||||
md = to_markdown(report)
|
||||
assert "Vision Benchmark Report" in md
|
||||
|
||||
|
||||
class TestDataset:
|
||||
def test_sample_dataset_has_entries(self):
|
||||
dataset = generate_sample_dataset()
|
||||
assert len(dataset) >= 4
|
||||
|
||||
def test_sample_dataset_structure(self):
|
||||
dataset = generate_sample_dataset()
|
||||
for img in dataset:
|
||||
assert "id" in img
|
||||
assert "url" in img
|
||||
assert "category" in img
|
||||
assert "expected_keywords" in img
|
||||
assert "expected_structure" in img
|
||||
|
||||
def test_categories_present(self):
|
||||
dataset = generate_sample_dataset()
|
||||
categories = {img["category"] for img in dataset}
|
||||
assert "screenshot" in categories
|
||||
assert "diagram" in categories
|
||||
assert "photo" in categories
|
||||
|
||||
|
||||
class TestModels:
|
||||
def test_all_models_defined(self):
|
||||
assert "gemma4" in MODELS
|
||||
assert "gemini3_flash" in MODELS
|
||||
|
||||
def test_model_structure(self):
|
||||
for name, config in MODELS.items():
|
||||
assert "model_id" in config
|
||||
assert "display_name" in config
|
||||
assert "provider" in config
|
||||
|
||||
|
||||
class TestPrompts:
|
||||
def test_prompts_for_categories(self):
|
||||
assert "screenshot" in EVAL_PROMPTS
|
||||
assert "diagram" in EVAL_PROMPTS
|
||||
assert "photo" in EVAL_PROMPTS
|
||||
assert "ocr" in EVAL_PROMPTS
|
||||
assert "chart" in EVAL_PROMPTS
|
||||
190
tests/tools/test_confirmation_daemon.py
Normal file
190
tests/tools/test_confirmation_daemon.py
Normal file
@@ -0,0 +1,190 @@
|
||||
"""Tests for tools.confirmation_daemon — Human Confirmation Firewall."""
|
||||
|
||||
import pytest
|
||||
import time
|
||||
from tools.confirmation_daemon import (
|
||||
ConfirmationDaemon,
|
||||
ConfirmationRequest,
|
||||
ConfirmationStatus,
|
||||
RiskLevel,
|
||||
classify_action,
|
||||
_is_whitelisted,
|
||||
_DEFAULT_WHITELIST,
|
||||
)
|
||||
|
||||
|
||||
class TestClassifyAction:
|
||||
"""Test action risk classification."""
|
||||
|
||||
def test_crypto_tx_is_critical(self):
|
||||
assert classify_action("crypto_tx") == RiskLevel.CRITICAL
|
||||
|
||||
def test_sign_transaction_is_critical(self):
|
||||
assert classify_action("sign_transaction") == RiskLevel.CRITICAL
|
||||
|
||||
def test_send_email_is_high(self):
|
||||
assert classify_action("send_email") == RiskLevel.HIGH
|
||||
|
||||
def test_send_message_is_medium(self):
|
||||
assert classify_action("send_message") == RiskLevel.MEDIUM
|
||||
|
||||
def test_access_calendar_is_low(self):
|
||||
assert classify_action("access_calendar") == RiskLevel.LOW
|
||||
|
||||
def test_unknown_action_is_medium(self):
|
||||
assert classify_action("unknown_action_xyz") == RiskLevel.MEDIUM
|
||||
|
||||
|
||||
class TestWhitelist:
|
||||
"""Test whitelist auto-approval."""
|
||||
|
||||
def test_self_email_is_whitelisted(self):
|
||||
whitelist = dict(_DEFAULT_WHITELIST)
|
||||
payload = {"from": "me@test.com", "to": "me@test.com"}
|
||||
assert _is_whitelisted("send_email", payload, whitelist) is True
|
||||
|
||||
def test_non_whitelisted_recipient_not_approved(self):
|
||||
whitelist = dict(_DEFAULT_WHITELIST)
|
||||
payload = {"to": "random@stranger.com"}
|
||||
assert _is_whitelisted("send_email", payload, whitelist) is False
|
||||
|
||||
def test_whitelisted_contact_approved(self):
|
||||
whitelist = {
|
||||
"send_message": {"targets": ["alice", "bob"]},
|
||||
}
|
||||
assert _is_whitelisted("send_message", {"to": "alice"}, whitelist) is True
|
||||
assert _is_whitelisted("send_message", {"to": "charlie"}, whitelist) is False
|
||||
|
||||
def test_no_whitelist_entry_means_not_whitelisted(self):
|
||||
whitelist = {}
|
||||
assert _is_whitelisted("crypto_tx", {"amount": 1.0}, whitelist) is False
|
||||
|
||||
|
||||
class TestConfirmationRequest:
|
||||
"""Test the request data model."""
|
||||
|
||||
def test_defaults(self):
|
||||
req = ConfirmationRequest(
|
||||
request_id="test-1",
|
||||
action="send_email",
|
||||
description="Test email",
|
||||
risk_level="high",
|
||||
payload={},
|
||||
)
|
||||
assert req.status == ConfirmationStatus.PENDING.value
|
||||
assert req.created_at > 0
|
||||
assert req.expires_at > req.created_at
|
||||
|
||||
def test_is_pending(self):
|
||||
req = ConfirmationRequest(
|
||||
request_id="test-2",
|
||||
action="send_email",
|
||||
description="Test",
|
||||
risk_level="high",
|
||||
payload={},
|
||||
expires_at=time.time() + 300,
|
||||
)
|
||||
assert req.is_pending is True
|
||||
|
||||
def test_is_expired(self):
|
||||
req = ConfirmationRequest(
|
||||
request_id="test-3",
|
||||
action="send_email",
|
||||
description="Test",
|
||||
risk_level="high",
|
||||
payload={},
|
||||
expires_at=time.time() - 10,
|
||||
)
|
||||
assert req.is_expired is True
|
||||
assert req.is_pending is False
|
||||
|
||||
def test_to_dict(self):
|
||||
req = ConfirmationRequest(
|
||||
request_id="test-4",
|
||||
action="send_email",
|
||||
description="Test",
|
||||
risk_level="medium",
|
||||
payload={"to": "a@b.com"},
|
||||
)
|
||||
d = req.to_dict()
|
||||
assert d["request_id"] == "test-4"
|
||||
assert d["action"] == "send_email"
|
||||
assert "is_pending" in d
|
||||
|
||||
|
||||
class TestConfirmationDaemon:
|
||||
"""Test the daemon logic (without HTTP layer)."""
|
||||
|
||||
def test_auto_approve_low_risk(self):
|
||||
daemon = ConfirmationDaemon()
|
||||
req = daemon.request(
|
||||
action="access_calendar",
|
||||
description="Read today's events",
|
||||
risk_level="low",
|
||||
)
|
||||
assert req.status == ConfirmationStatus.AUTO_APPROVED.value
|
||||
|
||||
def test_whitelisted_auto_approves(self):
|
||||
daemon = ConfirmationDaemon()
|
||||
daemon._whitelist = {"send_message": {"targets": ["alice"]}}
|
||||
req = daemon.request(
|
||||
action="send_message",
|
||||
description="Message alice",
|
||||
payload={"to": "alice"},
|
||||
)
|
||||
assert req.status == ConfirmationStatus.AUTO_APPROVED.value
|
||||
|
||||
def test_non_whitelisted_goes_pending(self):
|
||||
daemon = ConfirmationDaemon()
|
||||
daemon._whitelist = {}
|
||||
req = daemon.request(
|
||||
action="send_email",
|
||||
description="Email to stranger",
|
||||
payload={"to": "stranger@test.com"},
|
||||
risk_level="high",
|
||||
)
|
||||
assert req.status == ConfirmationStatus.PENDING.value
|
||||
assert req.is_pending is True
|
||||
|
||||
def test_approve_response(self):
|
||||
daemon = ConfirmationDaemon()
|
||||
daemon._whitelist = {}
|
||||
req = daemon.request(
|
||||
action="send_email",
|
||||
description="Email test",
|
||||
risk_level="high",
|
||||
)
|
||||
result = daemon.respond(req.request_id, approved=True, decided_by="human")
|
||||
assert result.status == ConfirmationStatus.APPROVED.value
|
||||
assert result.decided_by == "human"
|
||||
|
||||
def test_deny_response(self):
|
||||
daemon = ConfirmationDaemon()
|
||||
daemon._whitelist = {}
|
||||
req = daemon.request(
|
||||
action="crypto_tx",
|
||||
description="Send 1 ETH",
|
||||
risk_level="critical",
|
||||
)
|
||||
result = daemon.respond(
|
||||
req.request_id, approved=False, decided_by="human", reason="Too risky"
|
||||
)
|
||||
assert result.status == ConfirmationStatus.DENIED.value
|
||||
assert result.reason == "Too risky"
|
||||
|
||||
def test_get_pending(self):
|
||||
daemon = ConfirmationDaemon()
|
||||
daemon._whitelist = {}
|
||||
daemon.request(action="send_email", description="Test 1", risk_level="high")
|
||||
daemon.request(action="send_email", description="Test 2", risk_level="high")
|
||||
pending = daemon.get_pending()
|
||||
assert len(pending) >= 2
|
||||
|
||||
def test_get_history(self):
|
||||
daemon = ConfirmationDaemon()
|
||||
req = daemon.request(
|
||||
action="access_calendar", description="Test", risk_level="low"
|
||||
)
|
||||
history = daemon.get_history()
|
||||
assert len(history) >= 1
|
||||
assert history[0]["action"] == "access_calendar"
|
||||
@@ -121,6 +121,19 @@ DANGEROUS_PATTERNS = [
|
||||
(r'\b(cp|mv|install)\b.*\s/etc/', "copy/move file into /etc/"),
|
||||
(r'\bsed\s+-[^\s]*i.*\s/etc/', "in-place edit of system config"),
|
||||
(r'\bsed\s+--in-place\b.*\s/etc/', "in-place edit of system config (long flag)"),
|
||||
# --- Vitalik's threat model: crypto / financial ---
|
||||
(r'\b(?:bitcoin-cli|ethers\.js|web3|ether\.sendTransaction)\b', "direct crypto transaction tool usage"),
|
||||
(r'\bwget\b.*\b(?:mnemonic|seed\s*phrase|private[_-]?key)\b', "attempting to download crypto credentials"),
|
||||
(r'\bcurl\b.*\b(?:mnemonic|seed\s*phrase|private[_-]?key)\b', "attempting to exfiltrate crypto credentials"),
|
||||
# --- Vitalik's threat model: credential exfiltration ---
|
||||
(r'\b(?:curl|wget|http|nc|ncat|socat)\b.*\b(?:\.env|\.ssh|credentials|secrets|token|api[_-]?key)\b',
|
||||
"attempting to exfiltrate credentials via network"),
|
||||
(r'\bbase64\b.*\|(?:\s*curl|\s*wget)', "base64-encode then network exfiltration"),
|
||||
(r'\bcat\b.*\b(?:\.env|\.ssh/id_rsa|credentials)\b.*\|(?:\s*curl|\s*wget)',
|
||||
"reading secrets and piping to network tool"),
|
||||
# --- Vitalik's threat model: data exfiltration ---
|
||||
(r'\bcurl\b.*-d\s.*\$(?:HOME|USER)', "sending user home directory data to remote"),
|
||||
(r'\bwget\b.*--post-data\s.*\$(?:HOME|USER)', "posting user data to remote"),
|
||||
# Script execution via heredoc — bypasses the -e/-c flag patterns above.
|
||||
# `python3 << 'EOF'` feeds arbitrary code via stdin without -c/-e flags.
|
||||
(r'\b(python[23]?|perl|ruby|node)\s+<<', "script execution via heredoc"),
|
||||
|
||||
615
tools/confirmation_daemon.py
Normal file
615
tools/confirmation_daemon.py
Normal file
@@ -0,0 +1,615 @@
|
||||
"""Human Confirmation Daemon — HTTP server for two-factor action approval.
|
||||
|
||||
Implements Vitalik's Pattern 1: "The new 'two-factor confirmation' is that
|
||||
the two factors are the human and the LLM."
|
||||
|
||||
This daemon runs on localhost:6000 and provides a simple HTTP API for the
|
||||
agent to request human approval before executing high-risk actions.
|
||||
|
||||
Threat model:
|
||||
- LLM jailbreaks: Remote content "hacking" the LLM to perform malicious actions
|
||||
- LLM accidents: LLM accidentally performing dangerous operations
|
||||
- The human acts as the second factor — the agent proposes, the human disposes
|
||||
|
||||
Architecture:
|
||||
- Agent detects high-risk action → POST /confirm with action details
|
||||
- Daemon stores pending request, sends notification to user
|
||||
- User approves/denies via POST /respond (Telegram, CLI, or direct HTTP)
|
||||
- Agent receives decision and proceeds or aborts
|
||||
|
||||
Usage:
|
||||
# Start daemon (usually managed by gateway)
|
||||
from tools.confirmation_daemon import ConfirmationDaemon
|
||||
daemon = ConfirmationDaemon(port=6000)
|
||||
daemon.start()
|
||||
|
||||
# Request approval (from agent code)
|
||||
from tools.confirmation_daemon import request_confirmation
|
||||
approved = request_confirmation(
|
||||
action="send_email",
|
||||
description="Send email to alice@example.com",
|
||||
risk_level="high",
|
||||
payload={"to": "alice@example.com", "subject": "Meeting notes"},
|
||||
timeout=300,
|
||||
)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass, field, asdict
|
||||
from enum import Enum, auto
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RiskLevel(Enum):
|
||||
"""Risk classification for actions requiring confirmation."""
|
||||
LOW = "low" # Log only, no confirmation needed
|
||||
MEDIUM = "medium" # Confirm for non-whitelisted targets
|
||||
HIGH = "high" # Always confirm
|
||||
CRITICAL = "critical" # Always confirm + require explicit reason
|
||||
|
||||
|
||||
class ConfirmationStatus(Enum):
|
||||
"""Status of a pending confirmation request."""
|
||||
PENDING = "pending"
|
||||
APPROVED = "approved"
|
||||
DENIED = "denied"
|
||||
EXPIRED = "expired"
|
||||
AUTO_APPROVED = "auto_approved"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConfirmationRequest:
|
||||
"""A request for human confirmation of a high-risk action."""
|
||||
request_id: str
|
||||
action: str # Action type: send_email, send_message, crypto_tx, etc.
|
||||
description: str # Human-readable description of what will happen
|
||||
risk_level: str # low, medium, high, critical
|
||||
payload: Dict[str, Any] # Action-specific data (sanitized)
|
||||
session_key: str = "" # Session that initiated the request
|
||||
created_at: float = 0.0
|
||||
expires_at: float = 0.0
|
||||
status: str = ConfirmationStatus.PENDING.value
|
||||
decided_at: float = 0.0
|
||||
decided_by: str = "" # "human", "auto", "whitelist"
|
||||
reason: str = "" # Optional reason for denial
|
||||
|
||||
def __post_init__(self):
|
||||
if not self.created_at:
|
||||
self.created_at = time.time()
|
||||
if not self.expires_at:
|
||||
self.expires_at = self.created_at + 300 # 5 min default
|
||||
if not self.request_id:
|
||||
self.request_id = str(uuid.uuid4())[:12]
|
||||
|
||||
@property
|
||||
def is_expired(self) -> bool:
|
||||
return time.time() > self.expires_at
|
||||
|
||||
@property
|
||||
def is_pending(self) -> bool:
|
||||
return self.status == ConfirmationStatus.PENDING.value and not self.is_expired
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
d = asdict(self)
|
||||
d["is_expired"] = self.is_expired
|
||||
d["is_pending"] = self.is_pending
|
||||
return d
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Action categories (Vitalik's threat model)
|
||||
# =========================================================================
|
||||
|
||||
ACTION_CATEGORIES = {
|
||||
# Messaging — outbound communication to external parties
|
||||
"send_email": RiskLevel.HIGH,
|
||||
"send_message": RiskLevel.MEDIUM, # Depends on recipient
|
||||
"send_signal": RiskLevel.HIGH,
|
||||
"send_telegram": RiskLevel.MEDIUM,
|
||||
"send_discord": RiskLevel.MEDIUM,
|
||||
"post_social": RiskLevel.HIGH,
|
||||
|
||||
# Financial / crypto
|
||||
"crypto_tx": RiskLevel.CRITICAL,
|
||||
"sign_transaction": RiskLevel.CRITICAL,
|
||||
"access_wallet": RiskLevel.CRITICAL,
|
||||
"modify_balance": RiskLevel.CRITICAL,
|
||||
|
||||
# System modification
|
||||
"install_software": RiskLevel.HIGH,
|
||||
"modify_system_config": RiskLevel.HIGH,
|
||||
"modify_firewall": RiskLevel.CRITICAL,
|
||||
"add_ssh_key": RiskLevel.CRITICAL,
|
||||
"create_user": RiskLevel.CRITICAL,
|
||||
|
||||
# Data access
|
||||
"access_contacts": RiskLevel.MEDIUM,
|
||||
"access_calendar": RiskLevel.LOW,
|
||||
"read_private_files": RiskLevel.MEDIUM,
|
||||
"upload_data": RiskLevel.HIGH,
|
||||
"share_credentials": RiskLevel.CRITICAL,
|
||||
|
||||
# Network
|
||||
"open_port": RiskLevel.HIGH,
|
||||
"modify_dns": RiskLevel.HIGH,
|
||||
"expose_service": RiskLevel.CRITICAL,
|
||||
}
|
||||
|
||||
# Default: any unrecognized action is MEDIUM risk
|
||||
DEFAULT_RISK_LEVEL = RiskLevel.MEDIUM
|
||||
|
||||
|
||||
def classify_action(action: str) -> RiskLevel:
|
||||
"""Classify an action by its risk level."""
|
||||
return ACTION_CATEGORIES.get(action, DEFAULT_RISK_LEVEL)
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Whitelist configuration
|
||||
# =========================================================================
|
||||
|
||||
_DEFAULT_WHITELIST = {
|
||||
"send_message": {
|
||||
"targets": [], # Contact names/IDs that don't need confirmation
|
||||
},
|
||||
"send_email": {
|
||||
"targets": [], # Email addresses that don't need confirmation
|
||||
"self_only": True, # send-to-self always allowed
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _load_whitelist() -> Dict[str, Any]:
|
||||
"""Load action whitelist from config."""
|
||||
config_path = Path.home() / ".hermes" / "approval_whitelist.json"
|
||||
if config_path.exists():
|
||||
try:
|
||||
with open(config_path) as f:
|
||||
return json.load(f)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to load approval whitelist: %s", e)
|
||||
return dict(_DEFAULT_WHITELIST)
|
||||
|
||||
|
||||
def _is_whitelisted(action: str, payload: Dict[str, Any], whitelist: Dict) -> bool:
|
||||
"""Check if an action is pre-approved by the whitelist."""
|
||||
action_config = whitelist.get(action, {})
|
||||
if not action_config:
|
||||
return False
|
||||
|
||||
# Check target-based whitelist
|
||||
targets = action_config.get("targets", [])
|
||||
target = payload.get("to") or payload.get("recipient") or payload.get("target", "")
|
||||
if target and target in targets:
|
||||
return True
|
||||
|
||||
# Self-only email
|
||||
if action_config.get("self_only") and action == "send_email":
|
||||
sender = payload.get("from", "")
|
||||
recipient = payload.get("to", "")
|
||||
if sender and recipient and sender.lower() == recipient.lower():
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Confirmation daemon
|
||||
# =========================================================================
|
||||
|
||||
class ConfirmationDaemon:
|
||||
"""HTTP daemon for human confirmation of high-risk actions.
|
||||
|
||||
Runs on localhost:PORT (default 6000). Provides:
|
||||
- POST /confirm — agent requests human approval
|
||||
- POST /respond — human approves/denies
|
||||
- GET /pending — list pending requests
|
||||
- GET /health — health check
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: str = "127.0.0.1",
|
||||
port: int = 6000,
|
||||
default_timeout: int = 300,
|
||||
notify_callback: Optional[Callable] = None,
|
||||
):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.default_timeout = default_timeout
|
||||
self.notify_callback = notify_callback
|
||||
self._pending: Dict[str, ConfirmationRequest] = {}
|
||||
self._history: List[ConfirmationRequest] = []
|
||||
self._lock = threading.Lock()
|
||||
self._whitelist = _load_whitelist()
|
||||
self._app = None
|
||||
self._runner = None
|
||||
|
||||
def request(
|
||||
self,
|
||||
action: str,
|
||||
description: str,
|
||||
payload: Optional[Dict[str, Any]] = None,
|
||||
risk_level: Optional[str] = None,
|
||||
session_key: str = "",
|
||||
timeout: Optional[int] = None,
|
||||
) -> ConfirmationRequest:
|
||||
"""Create a confirmation request.
|
||||
|
||||
Returns the request. Check .status to see if it was immediately
|
||||
auto-approved (whitelisted) or is pending human review.
|
||||
"""
|
||||
payload = payload or {}
|
||||
|
||||
# Classify risk if not specified
|
||||
if risk_level is None:
|
||||
risk_level = classify_action(action).value
|
||||
|
||||
# Check whitelist
|
||||
if risk_level in ("low",) or _is_whitelisted(action, payload, self._whitelist):
|
||||
req = ConfirmationRequest(
|
||||
request_id=str(uuid.uuid4())[:12],
|
||||
action=action,
|
||||
description=description,
|
||||
risk_level=risk_level,
|
||||
payload=payload,
|
||||
session_key=session_key,
|
||||
expires_at=time.time() + (timeout or self.default_timeout),
|
||||
status=ConfirmationStatus.AUTO_APPROVED.value,
|
||||
decided_at=time.time(),
|
||||
decided_by="whitelist",
|
||||
)
|
||||
with self._lock:
|
||||
self._history.append(req)
|
||||
logger.info("Auto-approved whitelisted action: %s", action)
|
||||
return req
|
||||
|
||||
# Create pending request
|
||||
req = ConfirmationRequest(
|
||||
request_id=str(uuid.uuid4())[:12],
|
||||
action=action,
|
||||
description=description,
|
||||
risk_level=risk_level,
|
||||
payload=payload,
|
||||
session_key=session_key,
|
||||
expires_at=time.time() + (timeout or self.default_timeout),
|
||||
)
|
||||
|
||||
with self._lock:
|
||||
self._pending[req.request_id] = req
|
||||
|
||||
# Notify human
|
||||
if self.notify_callback:
|
||||
try:
|
||||
self.notify_callback(req.to_dict())
|
||||
except Exception as e:
|
||||
logger.warning("Confirmation notify callback failed: %s", e)
|
||||
|
||||
logger.info(
|
||||
"Confirmation request %s: %s (%s risk) — waiting for human",
|
||||
req.request_id, action, risk_level,
|
||||
)
|
||||
return req
|
||||
|
||||
def respond(
|
||||
self,
|
||||
request_id: str,
|
||||
approved: bool,
|
||||
decided_by: str = "human",
|
||||
reason: str = "",
|
||||
) -> Optional[ConfirmationRequest]:
|
||||
"""Record a human decision on a pending request."""
|
||||
with self._lock:
|
||||
req = self._pending.get(request_id)
|
||||
if not req:
|
||||
logger.warning("Confirmation respond: unknown request %s", request_id)
|
||||
return None
|
||||
if not req.is_pending:
|
||||
logger.warning("Confirmation respond: request %s already decided", request_id)
|
||||
return req
|
||||
|
||||
req.status = (
|
||||
ConfirmationStatus.APPROVED.value if approved
|
||||
else ConfirmationStatus.DENIED.value
|
||||
)
|
||||
req.decided_at = time.time()
|
||||
req.decided_by = decided_by
|
||||
req.reason = reason
|
||||
|
||||
# Move to history
|
||||
del self._pending[request_id]
|
||||
self._history.append(req)
|
||||
|
||||
logger.info(
|
||||
"Confirmation %s: %s by %s",
|
||||
request_id, "APPROVED" if approved else "DENIED", decided_by,
|
||||
)
|
||||
return req
|
||||
|
||||
def wait_for_decision(
|
||||
self, request_id: str, timeout: Optional[float] = None
|
||||
) -> ConfirmationRequest:
|
||||
"""Block until a decision is made or timeout expires."""
|
||||
deadline = time.time() + (timeout or self.default_timeout)
|
||||
while time.time() < deadline:
|
||||
with self._lock:
|
||||
req = self._pending.get(request_id)
|
||||
if req and not req.is_pending:
|
||||
return req
|
||||
if req and req.is_expired:
|
||||
req.status = ConfirmationStatus.EXPIRED.value
|
||||
del self._pending[request_id]
|
||||
self._history.append(req)
|
||||
return req
|
||||
time.sleep(0.5)
|
||||
|
||||
# Timeout
|
||||
with self._lock:
|
||||
req = self._pending.pop(request_id, None)
|
||||
if req:
|
||||
req.status = ConfirmationStatus.EXPIRED.value
|
||||
self._history.append(req)
|
||||
return req
|
||||
|
||||
# Shouldn't reach here
|
||||
return ConfirmationRequest(
|
||||
request_id=request_id,
|
||||
action="unknown",
|
||||
description="Request not found",
|
||||
risk_level="high",
|
||||
payload={},
|
||||
status=ConfirmationStatus.EXPIRED.value,
|
||||
)
|
||||
|
||||
def get_pending(self) -> List[Dict[str, Any]]:
|
||||
"""Return list of pending confirmation requests."""
|
||||
self._expire_old()
|
||||
with self._lock:
|
||||
return [r.to_dict() for r in self._pending.values() if r.is_pending]
|
||||
|
||||
def get_history(self, limit: int = 50) -> List[Dict[str, Any]]:
|
||||
"""Return recent confirmation history."""
|
||||
with self._lock:
|
||||
return [r.to_dict() for r in self._history[-limit:]]
|
||||
|
||||
def _expire_old(self) -> None:
|
||||
"""Move expired requests to history."""
|
||||
now = time.time()
|
||||
with self._lock:
|
||||
expired = [
|
||||
rid for rid, req in self._pending.items()
|
||||
if now > req.expires_at
|
||||
]
|
||||
for rid in expired:
|
||||
req = self._pending.pop(rid)
|
||||
req.status = ConfirmationStatus.EXPIRED.value
|
||||
self._history.append(req)
|
||||
|
||||
# --- aiohttp HTTP API ---
|
||||
|
||||
async def _handle_health(self, request):
|
||||
from aiohttp import web
|
||||
return web.json_response({
|
||||
"status": "ok",
|
||||
"service": "hermes-confirmation-daemon",
|
||||
"pending": len(self._pending),
|
||||
})
|
||||
|
||||
async def _handle_confirm(self, request):
|
||||
from aiohttp import web
|
||||
try:
|
||||
body = await request.json()
|
||||
except Exception:
|
||||
return web.json_response({"error": "invalid JSON"}, status=400)
|
||||
|
||||
action = body.get("action", "")
|
||||
description = body.get("description", "")
|
||||
if not action or not description:
|
||||
return web.json_response(
|
||||
{"error": "action and description required"}, status=400
|
||||
)
|
||||
|
||||
req = self.request(
|
||||
action=action,
|
||||
description=description,
|
||||
payload=body.get("payload", {}),
|
||||
risk_level=body.get("risk_level"),
|
||||
session_key=body.get("session_key", ""),
|
||||
timeout=body.get("timeout"),
|
||||
)
|
||||
|
||||
# If auto-approved, return immediately
|
||||
if req.status != ConfirmationStatus.PENDING.value:
|
||||
return web.json_response({
|
||||
"request_id": req.request_id,
|
||||
"status": req.status,
|
||||
"decided_by": req.decided_by,
|
||||
})
|
||||
|
||||
# Otherwise, wait for human decision (with timeout)
|
||||
timeout = min(body.get("timeout", self.default_timeout), 600)
|
||||
result = self.wait_for_decision(req.request_id, timeout=timeout)
|
||||
|
||||
return web.json_response({
|
||||
"request_id": result.request_id,
|
||||
"status": result.status,
|
||||
"decided_by": result.decided_by,
|
||||
"reason": result.reason,
|
||||
})
|
||||
|
||||
async def _handle_respond(self, request):
|
||||
from aiohttp import web
|
||||
try:
|
||||
body = await request.json()
|
||||
except Exception:
|
||||
return web.json_response({"error": "invalid JSON"}, status=400)
|
||||
|
||||
request_id = body.get("request_id", "")
|
||||
approved = body.get("approved")
|
||||
if not request_id or approved is None:
|
||||
return web.json_response(
|
||||
{"error": "request_id and approved required"}, status=400
|
||||
)
|
||||
|
||||
result = self.respond(
|
||||
request_id=request_id,
|
||||
approved=bool(approved),
|
||||
decided_by=body.get("decided_by", "human"),
|
||||
reason=body.get("reason", ""),
|
||||
)
|
||||
|
||||
if not result:
|
||||
return web.json_response({"error": "unknown request"}, status=404)
|
||||
|
||||
return web.json_response({
|
||||
"request_id": result.request_id,
|
||||
"status": result.status,
|
||||
})
|
||||
|
||||
async def _handle_pending(self, request):
|
||||
from aiohttp import web
|
||||
return web.json_response({"pending": self.get_pending()})
|
||||
|
||||
def _build_app(self):
|
||||
"""Build the aiohttp application."""
|
||||
from aiohttp import web
|
||||
|
||||
app = web.Application()
|
||||
app.router.add_get("/health", self._handle_health)
|
||||
app.router.add_post("/confirm", self._handle_confirm)
|
||||
app.router.add_post("/respond", self._handle_respond)
|
||||
app.router.add_get("/pending", self._handle_pending)
|
||||
self._app = app
|
||||
return app
|
||||
|
||||
async def start_async(self) -> None:
|
||||
"""Start the daemon as an async server."""
|
||||
from aiohttp import web
|
||||
|
||||
app = self._build_app()
|
||||
self._runner = web.AppRunner(app)
|
||||
await self._runner.setup()
|
||||
site = web.TCPSite(self._runner, self.host, self.port)
|
||||
await site.start()
|
||||
logger.info("Confirmation daemon listening on %s:%d", self.host, self.port)
|
||||
|
||||
async def stop_async(self) -> None:
|
||||
"""Stop the daemon."""
|
||||
if self._runner:
|
||||
await self._runner.cleanup()
|
||||
self._runner = None
|
||||
|
||||
def start(self) -> None:
|
||||
"""Start daemon in a background thread (blocking caller)."""
|
||||
def _run():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
loop.run_until_complete(self.start_async())
|
||||
loop.run_forever()
|
||||
|
||||
t = threading.Thread(target=_run, daemon=True, name="confirmation-daemon")
|
||||
t.start()
|
||||
logger.info("Confirmation daemon started in background thread")
|
||||
|
||||
def start_blocking(self) -> None:
|
||||
"""Start daemon and block (for standalone use)."""
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
loop.run_until_complete(self.start_async())
|
||||
try:
|
||||
loop.run_forever()
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
loop.run_until_complete(self.stop_async())
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Convenience API for agent integration
|
||||
# =========================================================================
|
||||
|
||||
# Global singleton — initialized by gateway or CLI at startup
|
||||
_daemon: Optional[ConfirmationDaemon] = None
|
||||
|
||||
|
||||
def get_daemon() -> Optional[ConfirmationDaemon]:
|
||||
"""Get the global confirmation daemon instance."""
|
||||
return _daemon
|
||||
|
||||
|
||||
def init_daemon(
|
||||
host: str = "127.0.0.1",
|
||||
port: int = 6000,
|
||||
notify_callback: Optional[Callable] = None,
|
||||
) -> ConfirmationDaemon:
|
||||
"""Initialize the global confirmation daemon."""
|
||||
global _daemon
|
||||
_daemon = ConfirmationDaemon(
|
||||
host=host, port=port, notify_callback=notify_callback
|
||||
)
|
||||
return _daemon
|
||||
|
||||
|
||||
def request_confirmation(
|
||||
action: str,
|
||||
description: str,
|
||||
payload: Optional[Dict[str, Any]] = None,
|
||||
risk_level: Optional[str] = None,
|
||||
session_key: str = "",
|
||||
timeout: int = 300,
|
||||
) -> bool:
|
||||
"""Request human confirmation for a high-risk action.
|
||||
|
||||
This is the primary integration point for agent code. It:
|
||||
1. Classifies the action risk level
|
||||
2. Checks the whitelist
|
||||
3. If confirmation needed, blocks until human responds
|
||||
4. Returns True if approved, False if denied/expired
|
||||
|
||||
Args:
|
||||
action: Action type (send_email, crypto_tx, etc.)
|
||||
description: Human-readable description
|
||||
payload: Action-specific data
|
||||
risk_level: Override auto-classification
|
||||
session_key: Session requesting approval
|
||||
timeout: Seconds to wait for human response
|
||||
|
||||
Returns:
|
||||
True if approved, False if denied or expired.
|
||||
"""
|
||||
daemon = get_daemon()
|
||||
if not daemon:
|
||||
logger.warning(
|
||||
"No confirmation daemon running — DENYING action %s by default. "
|
||||
"Start daemon with init_daemon() or --confirmation-daemon flag.",
|
||||
action,
|
||||
)
|
||||
return False
|
||||
|
||||
req = daemon.request(
|
||||
action=action,
|
||||
description=description,
|
||||
payload=payload,
|
||||
risk_level=risk_level,
|
||||
session_key=session_key,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
# Auto-approved (whitelisted)
|
||||
if req.status == ConfirmationStatus.AUTO_APPROVED.value:
|
||||
return True
|
||||
|
||||
# Wait for human
|
||||
result = daemon.wait_for_decision(req.request_id, timeout=timeout)
|
||||
return result.status == ConfirmationStatus.APPROVED.value
|
||||
Reference in New Issue
Block a user