Files
hermes-agent/tools/tool_pokayoke.py
Alexander Whitestone 8ef766beac feat: add tool hallucination prevention module (#836)
- Validates tool names against registered tools
- Auto-corrects parameter names within Levenshtein distance 1
- Circuit breaker for consecutive failures (threshold: 3)
- Structured error messages with suggestions

Closes #836
2026-04-16 02:10:39 +00:00

277 lines
10 KiB
Python

#!/usr/bin/env python3
"""
Poka-Yoke: Tool Hallucination Prevention
Detects and blocks tool hallucination before API calls:
1. Validates tool names against registered tools
2. Auto-corrects parameter names within Levenshtein distance 1
3. Circuit breaker for consecutive failures
Usage:
from tools.tool_pokayoke import validate_tool_call, ToolCallValidator
# One-shot validation
result = validate_tool_call("browser_fill", {"file_path": "/tmp/test.txt"})
# Stateful validator with circuit breaker
validator = ToolCallValidator()
result = validator.validate("browser_fill", {"file_path": "/tmp/test.txt"})
"""
import json
import logging
from typing import Dict, List, Optional, Tuple, Any
from difflib import SequenceMatcher
logger = logging.getLogger(__name__)
def levenshtein_distance(s1: str, s2: str) -> int:
"""Calculate Levenshtein distance between two strings."""
if len(s1) < len(s2):
return levenshtein_distance(s2, s1)
if len(s2) == 0:
return len(s1)
prev_row = range(len(s2) + 1)
for i, c1 in enumerate(s1):
curr_row = [i + 1]
for j, c2 in enumerate(s2):
insertions = prev_row[j + 1] + 1
deletions = curr_row[j] + 1
substitutions = prev_row[j] + (c1 != c2)
curr_row.append(min(insertions, deletions, substitutions))
prev_row = curr_row
return prev_row[-1]
def find_similar_names(name: str, valid_names: List[str], max_distance: int = 2) -> List[Tuple[str, int]]:
"""Find similar names within edit distance."""
suggestions = []
for valid_name in valid_names:
dist = levenshtein_distance(name.lower(), valid_name.lower())
if 0 < dist <= max_distance:
suggestions.append((valid_name, dist))
return sorted(suggestions, key=lambda x: x[1])
def auto_correct_parameter(param_name: str, valid_params: List[str]) -> Optional[str]:
"""
Auto-correct parameter name if within Levenshtein distance 1.
Returns corrected name or None if no close match.
"""
for valid_param in valid_params:
dist = levenshtein_distance(param_name.lower(), valid_param.lower())
if dist == 1:
logger.info(f"Poka-yoke: Auto-corrected parameter '{param_name}' -> '{valid_param}'")
return valid_param
return None
class ToolCallValidator:
"""
Stateful validator with circuit breaker for consecutive failures.
"""
def __init__(self, failure_threshold: int = 3):
self.failure_threshold = failure_threshold
self.consecutive_failures: Dict[str, int] = {} # tool_name -> count
self.tool_schemas: Dict[str, dict] = {} # tool_name -> schema
self._initialized = False
def _ensure_initialized(self):
"""Lazy initialization - load tool schemas from registry."""
if self._initialized:
return
try:
from tools.registry import registry
for name in registry.get_all_tool_names():
schema = registry.get_schema(name)
if schema:
self.tool_schemas[name] = schema
self._initialized = True
logger.debug(f"Poka-yoke initialized with {len(self.tool_schemas)} tool schemas")
except Exception as e:
logger.warning(f"Could not initialize poka-yoke from registry: {e}")
def validate_tool_name(self, tool_name: str) -> Tuple[bool, Optional[str], List[str]]:
"""
Validate tool name against registered tools.
Returns:
(is_valid, suggested_name, error_messages)
"""
self._ensure_initialized()
if tool_name in self.tool_schemas:
return True, None, []
# Check circuit breaker
if self.consecutive_failures.get(tool_name, 0) >= self.failure_threshold:
return False, None, [
f"CIRCUIT BREAKER: Tool '{tool_name}' has failed {self.failure_threshold}+ times consecutively.",
f"This may indicate a persistent hallucination. Halt and inject diagnostic.",
f"Valid tools: {', '.join(sorted(self.tool_schemas.keys())[:20])}..."
]
# Find similar names
suggestions = find_similar_names(tool_name, list(self.tool_schemas.keys()), max_distance=2)
if suggestions:
best_match, distance = suggestions[0]
if distance == 1:
# Auto-correct
logger.info(f"Poka-yoke: Auto-corrected tool '{tool_name}' -> '{best_match}'")
return True, best_match, [f"Auto-corrected: '{tool_name}' -> '{best_match}'"]
else:
# Suggest
suggestion_list = [f"'{s}' (distance {d})" for s, d in suggestions[:3]]
return False, None, [
f"Unknown tool: '{tool_name}'",
f"Did you mean: {', '.join(suggestion_list)}?"
]
return False, None, [
f"Unknown tool: '{tool_name}'",
f"No similar tools found. Available: {', '.join(sorted(self.tool_schemas.keys())[:10])}..."
]
def validate_parameters(self, tool_name: str, params: Dict[str, Any]) -> Tuple[Dict[str, Any], List[str]]:
"""
Validate and auto-correct parameter names.
Returns:
(corrected_params, warnings)
"""
self._ensure_initialized()
if tool_name not in self.tool_schemas:
return params, []
schema = self.tool_schemas[tool_name]
valid_params = list(schema.get("parameters", {}).get("properties", {}).keys())
if not valid_params:
return params, []
corrected = dict(params)
warnings = []
for param_name in list(params.keys()):
if param_name not in valid_params:
corrected_name = auto_correct_parameter(param_name, valid_params)
if corrected_name:
corrected[corrected_name] = corrected.pop(param_name)
warnings.append(f"Auto-corrected parameter: '{param_name}' -> '{corrected_name}'")
else:
warnings.append(f"Unknown parameter: '{param_name}' (valid: {', '.join(valid_params[:10])})")
return corrected, warnings
def validate(self, tool_name: str, params: Dict[str, Any]) -> Tuple[bool, Optional[str], Dict[str, Any], List[str]]:
"""
Full validation of a tool call.
Returns:
(is_valid, corrected_tool_name, corrected_params, messages)
"""
# Validate tool name
name_valid, corrected_name, name_messages = self.validate_tool_name(tool_name)
if not name_valid:
self._record_failure(tool_name)
return False, None, params, name_messages
# Use corrected name if provided
actual_tool = corrected_name if corrected_name else tool_name
if corrected_name:
name_messages.append(f"Tool name corrected: '{tool_name}' -> '{corrected_name}'")
# Validate parameters
corrected_params, param_warnings = self.validate_parameters(actual_tool, params)
# Record success (reset failure counter)
self._record_success(actual_tool)
all_messages = name_messages + param_warnings
return True, corrected_name, corrected_params, all_messages
def _record_failure(self, tool_name: str):
"""Record a failure for circuit breaker."""
self.consecutive_failures[tool_name] = self.consecutive_failures.get(tool_name, 0) + 1
count = self.consecutive_failures[tool_name]
if count >= self.failure_threshold:
logger.warning(
f"Poka-yoke circuit breaker triggered for '{tool_name}': "
f"{count} consecutive failures"
)
def _record_success(self, tool_name: str):
"""Record a success (reset failure counter)."""
self.consecutive_failures.pop(tool_name, None)
def get_diagnostic_message(self, tool_name: str) -> str:
"""Generate diagnostic message for circuit breaker."""
self._ensure_initialized()
count = self.consecutive_failures.get(tool_name, 0)
suggestions = find_similar_names(tool_name, list(self.tool_schemas.keys()), max_distance=3)
lines = [
f"=== TOOL HALLUCINATION DETECTED ===",
f"Tool '{tool_name}' has failed {count} times consecutively.",
f"",
f"This likely means the model is hallucinating a tool name.",
f"",
f"Closest valid tools:"
]
for name, dist in suggestions[:5]:
lines.append(f" - {name} (edit distance: {dist})")
if not suggestions:
lines.append(f" (no similar tools found)")
lines.extend([
f"",
f"Action: The agent should stop retrying and use a valid tool name.",
f"If this persists, the model may need fine-tuning or prompt adjustment."
])
return "\n".join(lines)
# Global validator instance
_validator = ToolCallValidator()
def validate_tool_call(tool_name: str, params: Dict[str, Any]) -> Tuple[bool, Optional[str], Dict[str, Any], List[str]]:
"""
One-shot validation of a tool call.
Returns:
(is_valid, corrected_tool_name, corrected_params, messages)
"""
return _validator.validate(tool_name, params)
def reset_circuit_breaker(tool_name: Optional[str] = None):
"""Reset circuit breaker for a tool or all tools."""
if tool_name:
_validator.consecutive_failures.pop(tool_name, None)
else:
_validator.consecutive_failures.clear()
def get_hallucination_stats() -> Dict[str, Any]:
"""Get statistics about tool hallucinations."""
return {
"consecutive_failures": dict(_validator.consecutive_failures),
"tools_tracked": len(_validator.tool_schemas),
"threshold": _validator.failure_threshold
}