fix(#922): Poka-yoke — detect and block tool hallucination
Validation firewall between LLM tool-call output and execution: 1. Unknown tool names rejected 2. Malformed parameters caught 3. Missing required arguments detected 4. Hallucination patterns detected All rejections logged with model provenance. Agent receives rejection as tool result for self-correction. Resolves #922
This commit is contained in:
312
tools/tool_validator.py
Normal file
312
tools/tool_validator.py
Normal file
@@ -0,0 +1,312 @@
|
||||
"""
|
||||
Poka-Yoke: Tool Hallucination Detection — #922.
|
||||
|
||||
Validation firewall between LLM tool-call output and actual execution.
|
||||
|
||||
Detects and blocks:
|
||||
1. Unknown tool names (hallucinated tools)
|
||||
2. Malformed parameters (wrong types)
|
||||
3. Missing required arguments
|
||||
4. Extra unknown parameters
|
||||
|
||||
Poka-Yoke Type: Detection (catches errors at the boundary before harm)
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ValidationSeverity(Enum):
|
||||
"""Severity of validation failure."""
|
||||
BLOCK = "block" # Must block execution
|
||||
WARN = "warn" # Warning, may proceed
|
||||
INFO = "info" # Informational
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValidationIssue:
|
||||
"""A validation issue found."""
|
||||
severity: ValidationSeverity
|
||||
code: str
|
||||
message: str
|
||||
tool_name: str
|
||||
parameter: Optional[str] = None
|
||||
expected: Optional[str] = None
|
||||
actual: Optional[Any] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValidationResult:
|
||||
"""Result of tool call validation."""
|
||||
valid: bool
|
||||
tool_name: str
|
||||
issues: List[ValidationIssue] = field(default_factory=list)
|
||||
corrected_args: Optional[Dict[str, Any]] = None
|
||||
|
||||
@property
|
||||
def blocking_issues(self) -> List[ValidationIssue]:
|
||||
return [i for i in self.issues if i.severity == ValidationSeverity.BLOCK]
|
||||
|
||||
@property
|
||||
def warnings(self) -> List[ValidationIssue]:
|
||||
return [i for i in self.issues if i.severity == ValidationSeverity.WARN]
|
||||
|
||||
|
||||
class ToolHallucinationDetector:
|
||||
"""
|
||||
Poka-yoke detector for tool hallucinations.
|
||||
|
||||
Validates tool calls against registered schemas before execution.
|
||||
"""
|
||||
|
||||
def __init__(self, tool_registry: Optional[Dict] = None):
|
||||
"""
|
||||
Initialize detector.
|
||||
|
||||
Args:
|
||||
tool_registry: Dict of tool_name -> tool_schema
|
||||
"""
|
||||
self.registry = tool_registry or {}
|
||||
self._rejection_log: List[Dict] = []
|
||||
|
||||
def register_tool(self, name: str, schema: Dict):
|
||||
"""Register a tool with its JSON Schema."""
|
||||
self.registry[name] = schema
|
||||
|
||||
def register_tools(self, tools: Dict[str, Dict]):
|
||||
"""Register multiple tools."""
|
||||
self.registry.update(tools)
|
||||
|
||||
def validate_tool_call(
|
||||
self,
|
||||
tool_name: str,
|
||||
arguments: Dict[str, Any],
|
||||
model: str = "unknown",
|
||||
) -> ValidationResult:
|
||||
"""
|
||||
Validate a tool call against the registry.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool being called
|
||||
arguments: Arguments passed to the tool
|
||||
model: Model that generated the call (for logging)
|
||||
|
||||
Returns:
|
||||
ValidationResult with validation status
|
||||
"""
|
||||
issues = []
|
||||
|
||||
# 1. Check if tool exists
|
||||
if tool_name not in self.registry:
|
||||
issue = ValidationIssue(
|
||||
severity=ValidationSeverity.BLOCK,
|
||||
code="UNKNOWN_TOOL",
|
||||
message=f"Tool '{tool_name}' does not exist. Available: {', '.join(sorted(self.registry.keys())[:10])}...",
|
||||
tool_name=tool_name,
|
||||
)
|
||||
issues.append(issue)
|
||||
self._log_rejection(tool_name, arguments, model, "UNKNOWN_TOOL")
|
||||
return ValidationResult(valid=False, tool_name=tool_name, issues=issues)
|
||||
|
||||
schema = self.registry[tool_name]
|
||||
params_schema = schema.get("parameters", {}).get("properties", {})
|
||||
required = set(schema.get("parameters", {}).get("required", []))
|
||||
|
||||
# 2. Check for missing required parameters
|
||||
for param in required:
|
||||
if param not in arguments:
|
||||
issue = ValidationIssue(
|
||||
severity=ValidationSeverity.BLOCK,
|
||||
code="MISSING_REQUIRED",
|
||||
message=f"Missing required parameter: {param}",
|
||||
tool_name=tool_name,
|
||||
parameter=param,
|
||||
)
|
||||
issues.append(issue)
|
||||
|
||||
# 3. Check parameter types
|
||||
for param_name, param_value in arguments.items():
|
||||
if param_name not in params_schema:
|
||||
# Unknown parameter
|
||||
issue = ValidationIssue(
|
||||
severity=ValidationSeverity.WARN,
|
||||
code="UNKNOWN_PARAM",
|
||||
message=f"Unknown parameter: {param_name}",
|
||||
tool_name=tool_name,
|
||||
parameter=param_name,
|
||||
)
|
||||
issues.append(issue)
|
||||
continue
|
||||
|
||||
param_schema = params_schema[param_name]
|
||||
expected_type = param_schema.get("type")
|
||||
|
||||
if expected_type and not self._check_type(param_value, expected_type):
|
||||
issue = ValidationIssue(
|
||||
severity=ValidationSeverity.BLOCK,
|
||||
code="WRONG_TYPE",
|
||||
message=f"Parameter '{param_name}' expects {expected_type}, got {type(param_value).__name__}",
|
||||
tool_name=tool_name,
|
||||
parameter=param_name,
|
||||
expected=expected_type,
|
||||
actual=type(param_value).__name__,
|
||||
)
|
||||
issues.append(issue)
|
||||
|
||||
# 4. Check for common hallucination patterns
|
||||
hallucination_issues = self._detect_hallucination_patterns(tool_name, arguments)
|
||||
issues.extend(hallucination_issues)
|
||||
|
||||
# Determine validity
|
||||
has_blocking = any(i.severity == ValidationSeverity.BLOCK for i in issues)
|
||||
|
||||
if has_blocking:
|
||||
self._log_rejection(tool_name, arguments, model,
|
||||
"; ".join(i.code for i in issues if i.severity == ValidationSeverity.BLOCK))
|
||||
|
||||
return ValidationResult(
|
||||
valid=not has_blocking,
|
||||
tool_name=tool_name,
|
||||
issues=issues,
|
||||
)
|
||||
|
||||
def _check_type(self, value: Any, expected_type: str) -> bool:
|
||||
"""Check if value matches expected JSON Schema type."""
|
||||
type_map = {
|
||||
"string": str,
|
||||
"number": (int, float),
|
||||
"integer": int,
|
||||
"boolean": bool,
|
||||
"array": list,
|
||||
"object": dict,
|
||||
}
|
||||
|
||||
expected = type_map.get(expected_type)
|
||||
if expected is None:
|
||||
return True # Unknown type, assume OK
|
||||
|
||||
return isinstance(value, expected)
|
||||
|
||||
def _detect_hallucination_patterns(self, tool_name: str, arguments: Dict) -> List[ValidationIssue]:
|
||||
"""Detect common hallucination patterns."""
|
||||
issues = []
|
||||
|
||||
# Pattern 1: Placeholder values
|
||||
placeholder_patterns = [
|
||||
r"^<.*>$", # <placeholder>
|
||||
r"^\[.*\]$", # [placeholder]
|
||||
r"^TODO$|^FIXME$", # TODO/FIXME
|
||||
r"^example\.com$", # example.com
|
||||
r"^127\.0\.0\.1$", # localhost
|
||||
]
|
||||
|
||||
for param_name, param_value in arguments.items():
|
||||
if isinstance(param_value, str):
|
||||
for pattern in placeholder_patterns:
|
||||
if re.match(pattern, param_value, re.IGNORECASE):
|
||||
issues.append(ValidationIssue(
|
||||
severity=ValidationSeverity.WARN,
|
||||
code="PLACEHOLDER_VALUE",
|
||||
message=f"Parameter '{param_name}' contains placeholder: {param_value}",
|
||||
tool_name=tool_name,
|
||||
parameter=param_name,
|
||||
))
|
||||
|
||||
# Pattern 2: Suspiciously long strings (might be hallucinated content)
|
||||
for param_name, param_value in arguments.items():
|
||||
if isinstance(param_value, str) and len(param_value) > 10000:
|
||||
issues.append(ValidationIssue(
|
||||
severity=ValidationSeverity.WARN,
|
||||
code="SUSPICIOUS_LENGTH",
|
||||
message=f"Parameter '{param_name}' is unusually long ({len(param_value)} chars)",
|
||||
tool_name=tool_name,
|
||||
parameter=param_name,
|
||||
))
|
||||
|
||||
return issues
|
||||
|
||||
def _log_rejection(self, tool_name: str, arguments: Dict, model: str, reason: str):
|
||||
"""Log a rejected tool call for analysis."""
|
||||
import time
|
||||
|
||||
entry = {
|
||||
"timestamp": time.time(),
|
||||
"tool_name": tool_name,
|
||||
"arguments": {k: str(v)[:100] for k, v in arguments.items()},
|
||||
"model": model,
|
||||
"reason": reason,
|
||||
}
|
||||
|
||||
self._rejection_log.append(entry)
|
||||
|
||||
# Keep log bounded
|
||||
if len(self._rejection_log) > 1000:
|
||||
self._rejection_log = self._rejection_log[-500:]
|
||||
|
||||
logger.warning(
|
||||
"Tool hallucination blocked: tool=%s, model=%s, reason=%s",
|
||||
tool_name, model, reason
|
||||
)
|
||||
|
||||
def get_rejection_stats(self) -> Dict:
|
||||
"""Get statistics on rejected tool calls."""
|
||||
if not self._rejection_log:
|
||||
return {"total": 0, "by_reason": {}, "by_tool": {}}
|
||||
|
||||
by_reason = {}
|
||||
by_tool = {}
|
||||
|
||||
for entry in self._rejection_log:
|
||||
reason = entry["reason"]
|
||||
tool = entry["tool_name"]
|
||||
|
||||
by_reason[reason] = by_reason.get(reason, 0) + 1
|
||||
by_tool[tool] = by_tool.get(tool, 0) + 1
|
||||
|
||||
return {
|
||||
"total": len(self._rejection_log),
|
||||
"by_reason": by_reason,
|
||||
"by_tool": by_tool,
|
||||
}
|
||||
|
||||
def format_validation_report(self, result: ValidationResult) -> str:
|
||||
"""Format validation result as human-readable report."""
|
||||
if result.valid:
|
||||
return f"✅ {result.tool_name}: valid"
|
||||
|
||||
lines = [f"❌ {result.tool_name}: BLOCKED"]
|
||||
for issue in result.blocking_issues:
|
||||
lines.append(f" [{issue.code}] {issue.message}")
|
||||
|
||||
for issue in result.warnings:
|
||||
lines.append(f" ⚠️ [{issue.code}] {issue.message}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def create_rejection_response(result: ValidationResult) -> Dict:
|
||||
"""
|
||||
Create a tool result for a rejected tool call.
|
||||
|
||||
This allows the agent to see the rejection and self-correct.
|
||||
"""
|
||||
issues_text = "\n".join(
|
||||
f"- [{i.code}] {i.message}"
|
||||
for i in result.blocking_issues
|
||||
)
|
||||
|
||||
return {
|
||||
"role": "tool",
|
||||
"content": f"""Tool call rejected: {result.tool_name}
|
||||
|
||||
Issues found:
|
||||
{issues_text}
|
||||
|
||||
Please check the tool name and parameters, then try again with valid arguments.""",
|
||||
}
|
||||
Reference in New Issue
Block a user