- Validates tool names against registered tools - Auto-corrects parameter names within Levenshtein distance 1 - Circuit breaker for consecutive failures (threshold: 3) - Structured error messages with suggestions Closes #836
277 lines
10 KiB
Python
277 lines
10 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Poka-Yoke: Tool Hallucination Prevention
|
|
|
|
Detects and blocks tool hallucination before API calls:
|
|
1. Validates tool names against registered tools
|
|
2. Auto-corrects parameter names within Levenshtein distance 1
|
|
3. Circuit breaker for consecutive failures
|
|
|
|
Usage:
|
|
from tools.tool_pokayoke import validate_tool_call, ToolCallValidator
|
|
|
|
# One-shot validation
|
|
result = validate_tool_call("browser_fill", {"file_path": "/tmp/test.txt"})
|
|
|
|
# Stateful validator with circuit breaker
|
|
validator = ToolCallValidator()
|
|
result = validator.validate("browser_fill", {"file_path": "/tmp/test.txt"})
|
|
"""
|
|
|
|
import json
|
|
import logging
|
|
from typing import Dict, List, Optional, Tuple, Any
|
|
from difflib import SequenceMatcher
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def levenshtein_distance(s1: str, s2: str) -> int:
|
|
"""Calculate Levenshtein distance between two strings."""
|
|
if len(s1) < len(s2):
|
|
return levenshtein_distance(s2, s1)
|
|
|
|
if len(s2) == 0:
|
|
return len(s1)
|
|
|
|
prev_row = range(len(s2) + 1)
|
|
for i, c1 in enumerate(s1):
|
|
curr_row = [i + 1]
|
|
for j, c2 in enumerate(s2):
|
|
insertions = prev_row[j + 1] + 1
|
|
deletions = curr_row[j] + 1
|
|
substitutions = prev_row[j] + (c1 != c2)
|
|
curr_row.append(min(insertions, deletions, substitutions))
|
|
prev_row = curr_row
|
|
|
|
return prev_row[-1]
|
|
|
|
|
|
def find_similar_names(name: str, valid_names: List[str], max_distance: int = 2) -> List[Tuple[str, int]]:
|
|
"""Find similar names within edit distance."""
|
|
suggestions = []
|
|
for valid_name in valid_names:
|
|
dist = levenshtein_distance(name.lower(), valid_name.lower())
|
|
if 0 < dist <= max_distance:
|
|
suggestions.append((valid_name, dist))
|
|
return sorted(suggestions, key=lambda x: x[1])
|
|
|
|
|
|
def auto_correct_parameter(param_name: str, valid_params: List[str]) -> Optional[str]:
|
|
"""
|
|
Auto-correct parameter name if within Levenshtein distance 1.
|
|
Returns corrected name or None if no close match.
|
|
"""
|
|
for valid_param in valid_params:
|
|
dist = levenshtein_distance(param_name.lower(), valid_param.lower())
|
|
if dist == 1:
|
|
logger.info(f"Poka-yoke: Auto-corrected parameter '{param_name}' -> '{valid_param}'")
|
|
return valid_param
|
|
return None
|
|
|
|
|
|
class ToolCallValidator:
|
|
"""
|
|
Stateful validator with circuit breaker for consecutive failures.
|
|
"""
|
|
|
|
def __init__(self, failure_threshold: int = 3):
|
|
self.failure_threshold = failure_threshold
|
|
self.consecutive_failures: Dict[str, int] = {} # tool_name -> count
|
|
self.tool_schemas: Dict[str, dict] = {} # tool_name -> schema
|
|
self._initialized = False
|
|
|
|
def _ensure_initialized(self):
|
|
"""Lazy initialization - load tool schemas from registry."""
|
|
if self._initialized:
|
|
return
|
|
|
|
try:
|
|
from tools.registry import registry
|
|
for name in registry.get_all_tool_names():
|
|
schema = registry.get_schema(name)
|
|
if schema:
|
|
self.tool_schemas[name] = schema
|
|
self._initialized = True
|
|
logger.debug(f"Poka-yoke initialized with {len(self.tool_schemas)} tool schemas")
|
|
except Exception as e:
|
|
logger.warning(f"Could not initialize poka-yoke from registry: {e}")
|
|
|
|
def validate_tool_name(self, tool_name: str) -> Tuple[bool, Optional[str], List[str]]:
|
|
"""
|
|
Validate tool name against registered tools.
|
|
|
|
Returns:
|
|
(is_valid, suggested_name, error_messages)
|
|
"""
|
|
self._ensure_initialized()
|
|
|
|
if tool_name in self.tool_schemas:
|
|
return True, None, []
|
|
|
|
# Check circuit breaker
|
|
if self.consecutive_failures.get(tool_name, 0) >= self.failure_threshold:
|
|
return False, None, [
|
|
f"CIRCUIT BREAKER: Tool '{tool_name}' has failed {self.failure_threshold}+ times consecutively.",
|
|
f"This may indicate a persistent hallucination. Halt and inject diagnostic.",
|
|
f"Valid tools: {', '.join(sorted(self.tool_schemas.keys())[:20])}..."
|
|
]
|
|
|
|
# Find similar names
|
|
suggestions = find_similar_names(tool_name, list(self.tool_schemas.keys()), max_distance=2)
|
|
|
|
if suggestions:
|
|
best_match, distance = suggestions[0]
|
|
if distance == 1:
|
|
# Auto-correct
|
|
logger.info(f"Poka-yoke: Auto-corrected tool '{tool_name}' -> '{best_match}'")
|
|
return True, best_match, [f"Auto-corrected: '{tool_name}' -> '{best_match}'"]
|
|
else:
|
|
# Suggest
|
|
suggestion_list = [f"'{s}' (distance {d})" for s, d in suggestions[:3]]
|
|
return False, None, [
|
|
f"Unknown tool: '{tool_name}'",
|
|
f"Did you mean: {', '.join(suggestion_list)}?"
|
|
]
|
|
|
|
return False, None, [
|
|
f"Unknown tool: '{tool_name}'",
|
|
f"No similar tools found. Available: {', '.join(sorted(self.tool_schemas.keys())[:10])}..."
|
|
]
|
|
|
|
def validate_parameters(self, tool_name: str, params: Dict[str, Any]) -> Tuple[Dict[str, Any], List[str]]:
|
|
"""
|
|
Validate and auto-correct parameter names.
|
|
|
|
Returns:
|
|
(corrected_params, warnings)
|
|
"""
|
|
self._ensure_initialized()
|
|
|
|
if tool_name not in self.tool_schemas:
|
|
return params, []
|
|
|
|
schema = self.tool_schemas[tool_name]
|
|
valid_params = list(schema.get("parameters", {}).get("properties", {}).keys())
|
|
|
|
if not valid_params:
|
|
return params, []
|
|
|
|
corrected = dict(params)
|
|
warnings = []
|
|
|
|
for param_name in list(params.keys()):
|
|
if param_name not in valid_params:
|
|
corrected_name = auto_correct_parameter(param_name, valid_params)
|
|
if corrected_name:
|
|
corrected[corrected_name] = corrected.pop(param_name)
|
|
warnings.append(f"Auto-corrected parameter: '{param_name}' -> '{corrected_name}'")
|
|
else:
|
|
warnings.append(f"Unknown parameter: '{param_name}' (valid: {', '.join(valid_params[:10])})")
|
|
|
|
return corrected, warnings
|
|
|
|
def validate(self, tool_name: str, params: Dict[str, Any]) -> Tuple[bool, Optional[str], Dict[str, Any], List[str]]:
|
|
"""
|
|
Full validation of a tool call.
|
|
|
|
Returns:
|
|
(is_valid, corrected_tool_name, corrected_params, messages)
|
|
"""
|
|
# Validate tool name
|
|
name_valid, corrected_name, name_messages = self.validate_tool_name(tool_name)
|
|
|
|
if not name_valid:
|
|
self._record_failure(tool_name)
|
|
return False, None, params, name_messages
|
|
|
|
# Use corrected name if provided
|
|
actual_tool = corrected_name if corrected_name else tool_name
|
|
if corrected_name:
|
|
name_messages.append(f"Tool name corrected: '{tool_name}' -> '{corrected_name}'")
|
|
|
|
# Validate parameters
|
|
corrected_params, param_warnings = self.validate_parameters(actual_tool, params)
|
|
|
|
# Record success (reset failure counter)
|
|
self._record_success(actual_tool)
|
|
|
|
all_messages = name_messages + param_warnings
|
|
return True, corrected_name, corrected_params, all_messages
|
|
|
|
def _record_failure(self, tool_name: str):
|
|
"""Record a failure for circuit breaker."""
|
|
self.consecutive_failures[tool_name] = self.consecutive_failures.get(tool_name, 0) + 1
|
|
count = self.consecutive_failures[tool_name]
|
|
|
|
if count >= self.failure_threshold:
|
|
logger.warning(
|
|
f"Poka-yoke circuit breaker triggered for '{tool_name}': "
|
|
f"{count} consecutive failures"
|
|
)
|
|
|
|
def _record_success(self, tool_name: str):
|
|
"""Record a success (reset failure counter)."""
|
|
self.consecutive_failures.pop(tool_name, None)
|
|
|
|
def get_diagnostic_message(self, tool_name: str) -> str:
|
|
"""Generate diagnostic message for circuit breaker."""
|
|
self._ensure_initialized()
|
|
|
|
count = self.consecutive_failures.get(tool_name, 0)
|
|
suggestions = find_similar_names(tool_name, list(self.tool_schemas.keys()), max_distance=3)
|
|
|
|
lines = [
|
|
f"=== TOOL HALLUCINATION DETECTED ===",
|
|
f"Tool '{tool_name}' has failed {count} times consecutively.",
|
|
f"",
|
|
f"This likely means the model is hallucinating a tool name.",
|
|
f"",
|
|
f"Closest valid tools:"
|
|
]
|
|
|
|
for name, dist in suggestions[:5]:
|
|
lines.append(f" - {name} (edit distance: {dist})")
|
|
|
|
if not suggestions:
|
|
lines.append(f" (no similar tools found)")
|
|
|
|
lines.extend([
|
|
f"",
|
|
f"Action: The agent should stop retrying and use a valid tool name.",
|
|
f"If this persists, the model may need fine-tuning or prompt adjustment."
|
|
])
|
|
|
|
return "\n".join(lines)
|
|
|
|
|
|
# Global validator instance
|
|
_validator = ToolCallValidator()
|
|
|
|
|
|
def validate_tool_call(tool_name: str, params: Dict[str, Any]) -> Tuple[bool, Optional[str], Dict[str, Any], List[str]]:
|
|
"""
|
|
One-shot validation of a tool call.
|
|
|
|
Returns:
|
|
(is_valid, corrected_tool_name, corrected_params, messages)
|
|
"""
|
|
return _validator.validate(tool_name, params)
|
|
|
|
|
|
def reset_circuit_breaker(tool_name: Optional[str] = None):
|
|
"""Reset circuit breaker for a tool or all tools."""
|
|
if tool_name:
|
|
_validator.consecutive_failures.pop(tool_name, None)
|
|
else:
|
|
_validator.consecutive_failures.clear()
|
|
|
|
|
|
def get_hallucination_stats() -> Dict[str, Any]:
|
|
"""Get statistics about tool hallucinations."""
|
|
return {
|
|
"consecutive_failures": dict(_validator.consecutive_failures),
|
|
"tools_tracked": len(_validator.tool_schemas),
|
|
"threshold": _validator.failure_threshold
|
|
}
|