diff --git a/tools/tool_validator.py b/tools/tool_validator.py new file mode 100644 index 000000000..7ee632d41 --- /dev/null +++ b/tools/tool_validator.py @@ -0,0 +1,312 @@ +""" +Poka-Yoke: Tool Hallucination Detection — #922. + +Validation firewall between LLM tool-call output and actual execution. + +Detects and blocks: +1. Unknown tool names (hallucinated tools) +2. Malformed parameters (wrong types) +3. Missing required arguments +4. Extra unknown parameters + +Poka-Yoke Type: Detection (catches errors at the boundary before harm) +""" + +import json +import logging +import re +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Dict, List, Optional, Set + +logger = logging.getLogger(__name__) + + +class ValidationSeverity(Enum): + """Severity of validation failure.""" + BLOCK = "block" # Must block execution + WARN = "warn" # Warning, may proceed + INFO = "info" # Informational + + +@dataclass +class ValidationIssue: + """A validation issue found.""" + severity: ValidationSeverity + code: str + message: str + tool_name: str + parameter: Optional[str] = None + expected: Optional[str] = None + actual: Optional[Any] = None + + +@dataclass +class ValidationResult: + """Result of tool call validation.""" + valid: bool + tool_name: str + issues: List[ValidationIssue] = field(default_factory=list) + corrected_args: Optional[Dict[str, Any]] = None + + @property + def blocking_issues(self) -> List[ValidationIssue]: + return [i for i in self.issues if i.severity == ValidationSeverity.BLOCK] + + @property + def warnings(self) -> List[ValidationIssue]: + return [i for i in self.issues if i.severity == ValidationSeverity.WARN] + + +class ToolHallucinationDetector: + """ + Poka-yoke detector for tool hallucinations. + + Validates tool calls against registered schemas before execution. + """ + + def __init__(self, tool_registry: Optional[Dict] = None): + """ + Initialize detector. + + Args: + tool_registry: Dict of tool_name -> tool_schema + """ + self.registry = tool_registry or {} + self._rejection_log: List[Dict] = [] + + def register_tool(self, name: str, schema: Dict): + """Register a tool with its JSON Schema.""" + self.registry[name] = schema + + def register_tools(self, tools: Dict[str, Dict]): + """Register multiple tools.""" + self.registry.update(tools) + + def validate_tool_call( + self, + tool_name: str, + arguments: Dict[str, Any], + model: str = "unknown", + ) -> ValidationResult: + """ + Validate a tool call against the registry. + + Args: + tool_name: Name of the tool being called + arguments: Arguments passed to the tool + model: Model that generated the call (for logging) + + Returns: + ValidationResult with validation status + """ + issues = [] + + # 1. Check if tool exists + if tool_name not in self.registry: + issue = ValidationIssue( + severity=ValidationSeverity.BLOCK, + code="UNKNOWN_TOOL", + message=f"Tool '{tool_name}' does not exist. Available: {', '.join(sorted(self.registry.keys())[:10])}...", + tool_name=tool_name, + ) + issues.append(issue) + self._log_rejection(tool_name, arguments, model, "UNKNOWN_TOOL") + return ValidationResult(valid=False, tool_name=tool_name, issues=issues) + + schema = self.registry[tool_name] + params_schema = schema.get("parameters", {}).get("properties", {}) + required = set(schema.get("parameters", {}).get("required", [])) + + # 2. Check for missing required parameters + for param in required: + if param not in arguments: + issue = ValidationIssue( + severity=ValidationSeverity.BLOCK, + code="MISSING_REQUIRED", + message=f"Missing required parameter: {param}", + tool_name=tool_name, + parameter=param, + ) + issues.append(issue) + + # 3. Check parameter types + for param_name, param_value in arguments.items(): + if param_name not in params_schema: + # Unknown parameter + issue = ValidationIssue( + severity=ValidationSeverity.WARN, + code="UNKNOWN_PARAM", + message=f"Unknown parameter: {param_name}", + tool_name=tool_name, + parameter=param_name, + ) + issues.append(issue) + continue + + param_schema = params_schema[param_name] + expected_type = param_schema.get("type") + + if expected_type and not self._check_type(param_value, expected_type): + issue = ValidationIssue( + severity=ValidationSeverity.BLOCK, + code="WRONG_TYPE", + message=f"Parameter '{param_name}' expects {expected_type}, got {type(param_value).__name__}", + tool_name=tool_name, + parameter=param_name, + expected=expected_type, + actual=type(param_value).__name__, + ) + issues.append(issue) + + # 4. Check for common hallucination patterns + hallucination_issues = self._detect_hallucination_patterns(tool_name, arguments) + issues.extend(hallucination_issues) + + # Determine validity + has_blocking = any(i.severity == ValidationSeverity.BLOCK for i in issues) + + if has_blocking: + self._log_rejection(tool_name, arguments, model, + "; ".join(i.code for i in issues if i.severity == ValidationSeverity.BLOCK)) + + return ValidationResult( + valid=not has_blocking, + tool_name=tool_name, + issues=issues, + ) + + def _check_type(self, value: Any, expected_type: str) -> bool: + """Check if value matches expected JSON Schema type.""" + type_map = { + "string": str, + "number": (int, float), + "integer": int, + "boolean": bool, + "array": list, + "object": dict, + } + + expected = type_map.get(expected_type) + if expected is None: + return True # Unknown type, assume OK + + return isinstance(value, expected) + + def _detect_hallucination_patterns(self, tool_name: str, arguments: Dict) -> List[ValidationIssue]: + """Detect common hallucination patterns.""" + issues = [] + + # Pattern 1: Placeholder values + placeholder_patterns = [ + r"^<.*>$", # + r"^\[.*\]$", # [placeholder] + r"^TODO$|^FIXME$", # TODO/FIXME + r"^example\.com$", # example.com + r"^127\.0\.0\.1$", # localhost + ] + + for param_name, param_value in arguments.items(): + if isinstance(param_value, str): + for pattern in placeholder_patterns: + if re.match(pattern, param_value, re.IGNORECASE): + issues.append(ValidationIssue( + severity=ValidationSeverity.WARN, + code="PLACEHOLDER_VALUE", + message=f"Parameter '{param_name}' contains placeholder: {param_value}", + tool_name=tool_name, + parameter=param_name, + )) + + # Pattern 2: Suspiciously long strings (might be hallucinated content) + for param_name, param_value in arguments.items(): + if isinstance(param_value, str) and len(param_value) > 10000: + issues.append(ValidationIssue( + severity=ValidationSeverity.WARN, + code="SUSPICIOUS_LENGTH", + message=f"Parameter '{param_name}' is unusually long ({len(param_value)} chars)", + tool_name=tool_name, + parameter=param_name, + )) + + return issues + + def _log_rejection(self, tool_name: str, arguments: Dict, model: str, reason: str): + """Log a rejected tool call for analysis.""" + import time + + entry = { + "timestamp": time.time(), + "tool_name": tool_name, + "arguments": {k: str(v)[:100] for k, v in arguments.items()}, + "model": model, + "reason": reason, + } + + self._rejection_log.append(entry) + + # Keep log bounded + if len(self._rejection_log) > 1000: + self._rejection_log = self._rejection_log[-500:] + + logger.warning( + "Tool hallucination blocked: tool=%s, model=%s, reason=%s", + tool_name, model, reason + ) + + def get_rejection_stats(self) -> Dict: + """Get statistics on rejected tool calls.""" + if not self._rejection_log: + return {"total": 0, "by_reason": {}, "by_tool": {}} + + by_reason = {} + by_tool = {} + + for entry in self._rejection_log: + reason = entry["reason"] + tool = entry["tool_name"] + + by_reason[reason] = by_reason.get(reason, 0) + 1 + by_tool[tool] = by_tool.get(tool, 0) + 1 + + return { + "total": len(self._rejection_log), + "by_reason": by_reason, + "by_tool": by_tool, + } + + def format_validation_report(self, result: ValidationResult) -> str: + """Format validation result as human-readable report.""" + if result.valid: + return f"✅ {result.tool_name}: valid" + + lines = [f"❌ {result.tool_name}: BLOCKED"] + for issue in result.blocking_issues: + lines.append(f" [{issue.code}] {issue.message}") + + for issue in result.warnings: + lines.append(f" ⚠️ [{issue.code}] {issue.message}") + + return "\n".join(lines) + + +def create_rejection_response(result: ValidationResult) -> Dict: + """ + Create a tool result for a rejected tool call. + + This allows the agent to see the rejection and self-correct. + """ + issues_text = "\n".join( + f"- [{i.code}] {i.message}" + for i in result.blocking_issues + ) + + return { + "role": "tool", + "content": f"""Tool call rejected: {result.tool_name} + +Issues found: +{issues_text} + +Please check the tool name and parameters, then try again with valid arguments.""", + }