Files
hermes-agent/tools/poka_yoke.py

299 lines
11 KiB
Python

"""
poka_yoke.py — Validation firewall for tool calls.
Poka-yoke (mistake-proofing): validates tool calls against the registry
before execution. Catches hallucinated tool names, malformed parameters,
missing required arguments, and type mismatches.
Usage:
from tools.poka_yoke import validate_tool_call
is_valid, corrected_name, corrected_params, messages = validate_tool_call(
"read_file", {"path": "test.txt"}
)
"""
import json
import logging
import re
from typing import Any, Dict, List, Optional, Tuple
logger = logging.getLogger(__name__)
def validate_tool_call(
function_name: str,
function_args: Dict[str, Any],
) -> Tuple[bool, Optional[str], Optional[Dict[str, Any]], List[str]]:
"""Validate a tool call against the registry before execution.
Args:
function_name: The tool name from the LLM's function_call.
function_args: The arguments dict from the LLM's function_call.
Returns:
Tuple of (is_valid, corrected_name, corrected_params, messages):
- is_valid: False if the call should be blocked entirely.
- corrected_name: Suggested name if a close match was found (None if OK).
- corrected_params: Corrected params if type coercion fixed issues (None if OK).
- messages: List of error/warning/info messages.
"""
from tools.registry import registry
messages: List[str] = []
corrected_name: Optional[str] = None
corrected_params: Optional[Dict[str, Any]] = None
# ── 1. Check if tool exists ───────────────────────────────────────────
entry = registry.get_entry(function_name)
if entry is None:
# Tool not found — suggest closest match
all_names = registry.get_all_tool_names()
suggestion = _find_closest_name(function_name, all_names)
if suggestion:
messages.append(
f"Unknown tool '{function_name}'. Did you mean '{suggestion}'?"
)
corrected_name = suggestion
# Re-validate with corrected name
entry = registry.get_entry(suggestion)
if entry is None:
return False, corrected_name, None, messages
else:
available = ", ".join(sorted(all_names)[:20])
messages.append(
f"Unknown tool '{function_name}'. "
f"Available tools: {available}{'...' if len(all_names) > 20 else ''}"
)
return False, None, None, messages
# ── 2. Validate parameters against schema ─────────────────────────────
schema = entry.schema
params_schema = schema.get("parameters", {})
properties = params_schema.get("properties", {})
required = set(params_schema.get("required", []))
# Check for missing required parameters
for param_name in sorted(required):
if param_name not in function_args:
param_info = properties.get(param_name, {})
param_type = param_info.get("type", "any")
messages.append(
f"Missing required parameter '{param_name}' "
f"(expected type: {param_type}). "
f"Tool: {function_name}"
)
# If required params are missing, we still return the error
# (the agent might be able to self-correct)
if any("Missing required" in m for m in messages):
# Don't block — return the error as a tool result so the agent can retry
# But mark as invalid so caller knows
return False, corrected_name, corrected_params, messages
# ── 3. Check for unknown parameters ───────────────────────────────────
if properties:
known_params = set(properties.keys())
# Allow extra params that start with _ (internal convention)
unknown = [
p for p in function_args
if p not in known_params and not p.startswith("_")
]
if unknown:
known_str = ", ".join(sorted(known_params))
unknown_str = ", ".join(sorted(unknown))
messages.append(
f"Unknown parameter(s) for '{function_name}': {unknown_str}. "
f"Known parameters: {known_str}"
)
# Remove unknown params (don't block, just clean)
corrected_params = {
k: v for k, v in function_args.items()
if k in known_params or k.startswith("_")
}
# ── 4. Type validation ────────────────────────────────────────────────
type_errors = []
coerced = dict(corrected_params or function_args)
for param_name, param_value in coerced.items():
if param_name.startswith("_"):
continue
param_schema = properties.get(param_name)
if not param_schema:
continue
expected_type = param_schema.get("type")
if not expected_type:
continue
is_valid_type, coerced_value = _validate_type(
param_name, param_value, expected_type
)
if not is_valid_type:
type_errors.append(
f"Parameter '{param_name}': expected {expected_type}, "
f"got {type(param_value).__name__} ({_truncate(str(param_value), 50)})"
)
elif coerced_value is not param_value:
coerced[param_name] = coerced_value
messages.append(
f"Parameter '{param_name}': coerced from "
f"{type(param_value).__name__} to {expected_type}"
)
if type_errors:
messages.extend(type_errors)
return False, corrected_name, corrected_params, messages
if coerced != (corrected_params or function_args):
corrected_params = coerced
# ── 5. Enum validation ────────────────────────────────────────────────
for param_name, param_value in (corrected_params or function_args).items():
param_schema = properties.get(param_name, {})
enum_values = param_schema.get("enum")
if enum_values and param_value not in enum_values:
messages.append(
f"Parameter '{param_name}': value '{param_value}' not in "
f"allowed values: {enum_values}"
)
return False, corrected_name, corrected_params, messages
# ── 6. Pattern validation ─────────────────────────────────────────────
for param_name, param_value in (corrected_params or function_args).items():
if not isinstance(param_value, str):
continue
param_schema = properties.get(param_name, {})
pattern = param_schema.get("pattern")
if pattern and not re.match(pattern, param_value):
messages.append(
f"Parameter '{param_name}': value '{_truncate(param_value, 50)}' "
f"does not match pattern '{pattern}'"
)
# ── Done ──────────────────────────────────────────────────────────────
is_valid = not any("Missing required" in m for m in messages)
if is_valid and not messages:
return True, None, None, []
return is_valid, corrected_name, corrected_params, messages
def _find_closest_name(target: str, candidates: List[str]) -> Optional[str]:
"""Find the closest tool name using simple edit distance heuristics."""
if not candidates:
return None
target_lower = target.lower()
# Exact prefix match
for name in candidates:
if name.lower().startswith(target_lower[:4]) and len(target_lower) > 3:
return name
# Substring match
for name in candidates:
if target_lower in name.lower() or name.lower() in target_lower:
return name
# Levenshtein distance (simple, for short strings)
def _levenshtein(a: str, b: str) -> int:
if len(a) < len(b):
return _levenshtein(b, a)
if len(b) == 0:
return len(a)
prev = range(len(b) + 1)
for i, ca in enumerate(a):
curr = [i + 1]
for j, cb in enumerate(b):
curr.append(min(
prev[j + 1] + 1,
curr[j] + 1,
prev[j] + (0 if ca == cb else 1),
))
prev = curr
return prev[-1]
distances = [(name, _levenshtein(target_lower, name.lower())) for name in candidates]
distances.sort(key=lambda x: x[1])
# Return if edit distance is small enough
if distances and distances[0][1] <= max(3, len(target) // 3):
return distances[0][0]
return None
def _validate_type(
param_name: str, value: Any, expected_type: str
) -> Tuple[bool, Any]:
"""Validate and optionally coerce a parameter value to the expected type.
Returns (is_valid, coerced_value). coerced_value is value itself if no
coercion was needed.
"""
type_map = {
"string": str,
"integer": int,
"number": (int, float),
"boolean": bool,
"array": list,
"object": dict,
}
expected = type_map.get(expected_type)
if expected is None:
return True, value # Unknown type, skip validation
# Direct type check
if isinstance(value, expected):
return True, value
# Coercion attempts
if expected_type == "string":
return True, str(value)
if expected_type == "integer":
if isinstance(value, str) and value.isdigit():
return True, int(value)
if isinstance(value, float) and value == int(value):
return True, int(value)
return False, value
if expected_type == "number":
if isinstance(value, str):
try:
return True, float(value)
except ValueError:
return False, value
return False, value
if expected_type == "boolean":
if isinstance(value, str):
lower = value.lower()
if lower in ("true", "1", "yes"):
return True, True
if lower in ("false", "0", "no"):
return True, False
return False, value
return False, value
def _truncate(s: str, max_len: int) -> str:
"""Truncate a string for display."""
if len(s) <= max_len:
return s
return s[:max_len - 3] + "..."