299 lines
11 KiB
Python
299 lines
11 KiB
Python
"""
|
|
poka_yoke.py — Validation firewall for tool calls.
|
|
|
|
Poka-yoke (mistake-proofing): validates tool calls against the registry
|
|
before execution. Catches hallucinated tool names, malformed parameters,
|
|
missing required arguments, and type mismatches.
|
|
|
|
Usage:
|
|
from tools.poka_yoke import validate_tool_call
|
|
|
|
is_valid, corrected_name, corrected_params, messages = validate_tool_call(
|
|
"read_file", {"path": "test.txt"}
|
|
)
|
|
"""
|
|
|
|
import json
|
|
import logging
|
|
import re
|
|
from typing import Any, Dict, List, Optional, Tuple
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def validate_tool_call(
|
|
function_name: str,
|
|
function_args: Dict[str, Any],
|
|
) -> Tuple[bool, Optional[str], Optional[Dict[str, Any]], List[str]]:
|
|
"""Validate a tool call against the registry before execution.
|
|
|
|
Args:
|
|
function_name: The tool name from the LLM's function_call.
|
|
function_args: The arguments dict from the LLM's function_call.
|
|
|
|
Returns:
|
|
Tuple of (is_valid, corrected_name, corrected_params, messages):
|
|
- is_valid: False if the call should be blocked entirely.
|
|
- corrected_name: Suggested name if a close match was found (None if OK).
|
|
- corrected_params: Corrected params if type coercion fixed issues (None if OK).
|
|
- messages: List of error/warning/info messages.
|
|
"""
|
|
from tools.registry import registry
|
|
|
|
messages: List[str] = []
|
|
corrected_name: Optional[str] = None
|
|
corrected_params: Optional[Dict[str, Any]] = None
|
|
|
|
# ── 1. Check if tool exists ───────────────────────────────────────────
|
|
|
|
entry = registry.get_entry(function_name)
|
|
|
|
if entry is None:
|
|
# Tool not found — suggest closest match
|
|
all_names = registry.get_all_tool_names()
|
|
suggestion = _find_closest_name(function_name, all_names)
|
|
|
|
if suggestion:
|
|
messages.append(
|
|
f"Unknown tool '{function_name}'. Did you mean '{suggestion}'?"
|
|
)
|
|
corrected_name = suggestion
|
|
# Re-validate with corrected name
|
|
entry = registry.get_entry(suggestion)
|
|
if entry is None:
|
|
return False, corrected_name, None, messages
|
|
else:
|
|
available = ", ".join(sorted(all_names)[:20])
|
|
messages.append(
|
|
f"Unknown tool '{function_name}'. "
|
|
f"Available tools: {available}{'...' if len(all_names) > 20 else ''}"
|
|
)
|
|
return False, None, None, messages
|
|
|
|
# ── 2. Validate parameters against schema ─────────────────────────────
|
|
|
|
schema = entry.schema
|
|
params_schema = schema.get("parameters", {})
|
|
properties = params_schema.get("properties", {})
|
|
required = set(params_schema.get("required", []))
|
|
|
|
# Check for missing required parameters
|
|
for param_name in sorted(required):
|
|
if param_name not in function_args:
|
|
param_info = properties.get(param_name, {})
|
|
param_type = param_info.get("type", "any")
|
|
messages.append(
|
|
f"Missing required parameter '{param_name}' "
|
|
f"(expected type: {param_type}). "
|
|
f"Tool: {function_name}"
|
|
)
|
|
|
|
# If required params are missing, we still return the error
|
|
# (the agent might be able to self-correct)
|
|
if any("Missing required" in m for m in messages):
|
|
# Don't block — return the error as a tool result so the agent can retry
|
|
# But mark as invalid so caller knows
|
|
return False, corrected_name, corrected_params, messages
|
|
|
|
# ── 3. Check for unknown parameters ───────────────────────────────────
|
|
|
|
if properties:
|
|
known_params = set(properties.keys())
|
|
# Allow extra params that start with _ (internal convention)
|
|
unknown = [
|
|
p for p in function_args
|
|
if p not in known_params and not p.startswith("_")
|
|
]
|
|
if unknown:
|
|
known_str = ", ".join(sorted(known_params))
|
|
unknown_str = ", ".join(sorted(unknown))
|
|
messages.append(
|
|
f"Unknown parameter(s) for '{function_name}': {unknown_str}. "
|
|
f"Known parameters: {known_str}"
|
|
)
|
|
# Remove unknown params (don't block, just clean)
|
|
corrected_params = {
|
|
k: v for k, v in function_args.items()
|
|
if k in known_params or k.startswith("_")
|
|
}
|
|
|
|
# ── 4. Type validation ────────────────────────────────────────────────
|
|
|
|
type_errors = []
|
|
coerced = dict(corrected_params or function_args)
|
|
|
|
for param_name, param_value in coerced.items():
|
|
if param_name.startswith("_"):
|
|
continue
|
|
param_schema = properties.get(param_name)
|
|
if not param_schema:
|
|
continue
|
|
|
|
expected_type = param_schema.get("type")
|
|
if not expected_type:
|
|
continue
|
|
|
|
is_valid_type, coerced_value = _validate_type(
|
|
param_name, param_value, expected_type
|
|
)
|
|
if not is_valid_type:
|
|
type_errors.append(
|
|
f"Parameter '{param_name}': expected {expected_type}, "
|
|
f"got {type(param_value).__name__} ({_truncate(str(param_value), 50)})"
|
|
)
|
|
elif coerced_value is not param_value:
|
|
coerced[param_name] = coerced_value
|
|
messages.append(
|
|
f"Parameter '{param_name}': coerced from "
|
|
f"{type(param_value).__name__} to {expected_type}"
|
|
)
|
|
|
|
if type_errors:
|
|
messages.extend(type_errors)
|
|
return False, corrected_name, corrected_params, messages
|
|
|
|
if coerced != (corrected_params or function_args):
|
|
corrected_params = coerced
|
|
|
|
# ── 5. Enum validation ────────────────────────────────────────────────
|
|
|
|
for param_name, param_value in (corrected_params or function_args).items():
|
|
param_schema = properties.get(param_name, {})
|
|
enum_values = param_schema.get("enum")
|
|
if enum_values and param_value not in enum_values:
|
|
messages.append(
|
|
f"Parameter '{param_name}': value '{param_value}' not in "
|
|
f"allowed values: {enum_values}"
|
|
)
|
|
return False, corrected_name, corrected_params, messages
|
|
|
|
# ── 6. Pattern validation ─────────────────────────────────────────────
|
|
|
|
for param_name, param_value in (corrected_params or function_args).items():
|
|
if not isinstance(param_value, str):
|
|
continue
|
|
param_schema = properties.get(param_name, {})
|
|
pattern = param_schema.get("pattern")
|
|
if pattern and not re.match(pattern, param_value):
|
|
messages.append(
|
|
f"Parameter '{param_name}': value '{_truncate(param_value, 50)}' "
|
|
f"does not match pattern '{pattern}'"
|
|
)
|
|
|
|
# ── Done ──────────────────────────────────────────────────────────────
|
|
|
|
is_valid = not any("Missing required" in m for m in messages)
|
|
|
|
if is_valid and not messages:
|
|
return True, None, None, []
|
|
|
|
return is_valid, corrected_name, corrected_params, messages
|
|
|
|
|
|
def _find_closest_name(target: str, candidates: List[str]) -> Optional[str]:
|
|
"""Find the closest tool name using simple edit distance heuristics."""
|
|
if not candidates:
|
|
return None
|
|
|
|
target_lower = target.lower()
|
|
|
|
# Exact prefix match
|
|
for name in candidates:
|
|
if name.lower().startswith(target_lower[:4]) and len(target_lower) > 3:
|
|
return name
|
|
|
|
# Substring match
|
|
for name in candidates:
|
|
if target_lower in name.lower() or name.lower() in target_lower:
|
|
return name
|
|
|
|
# Levenshtein distance (simple, for short strings)
|
|
def _levenshtein(a: str, b: str) -> int:
|
|
if len(a) < len(b):
|
|
return _levenshtein(b, a)
|
|
if len(b) == 0:
|
|
return len(a)
|
|
prev = range(len(b) + 1)
|
|
for i, ca in enumerate(a):
|
|
curr = [i + 1]
|
|
for j, cb in enumerate(b):
|
|
curr.append(min(
|
|
prev[j + 1] + 1,
|
|
curr[j] + 1,
|
|
prev[j] + (0 if ca == cb else 1),
|
|
))
|
|
prev = curr
|
|
return prev[-1]
|
|
|
|
distances = [(name, _levenshtein(target_lower, name.lower())) for name in candidates]
|
|
distances.sort(key=lambda x: x[1])
|
|
|
|
# Return if edit distance is small enough
|
|
if distances and distances[0][1] <= max(3, len(target) // 3):
|
|
return distances[0][0]
|
|
|
|
return None
|
|
|
|
|
|
def _validate_type(
|
|
param_name: str, value: Any, expected_type: str
|
|
) -> Tuple[bool, Any]:
|
|
"""Validate and optionally coerce a parameter value to the expected type.
|
|
|
|
Returns (is_valid, coerced_value). coerced_value is value itself if no
|
|
coercion was needed.
|
|
"""
|
|
type_map = {
|
|
"string": str,
|
|
"integer": int,
|
|
"number": (int, float),
|
|
"boolean": bool,
|
|
"array": list,
|
|
"object": dict,
|
|
}
|
|
|
|
expected = type_map.get(expected_type)
|
|
if expected is None:
|
|
return True, value # Unknown type, skip validation
|
|
|
|
# Direct type check
|
|
if isinstance(value, expected):
|
|
return True, value
|
|
|
|
# Coercion attempts
|
|
if expected_type == "string":
|
|
return True, str(value)
|
|
|
|
if expected_type == "integer":
|
|
if isinstance(value, str) and value.isdigit():
|
|
return True, int(value)
|
|
if isinstance(value, float) and value == int(value):
|
|
return True, int(value)
|
|
return False, value
|
|
|
|
if expected_type == "number":
|
|
if isinstance(value, str):
|
|
try:
|
|
return True, float(value)
|
|
except ValueError:
|
|
return False, value
|
|
return False, value
|
|
|
|
if expected_type == "boolean":
|
|
if isinstance(value, str):
|
|
lower = value.lower()
|
|
if lower in ("true", "1", "yes"):
|
|
return True, True
|
|
if lower in ("false", "0", "no"):
|
|
return True, False
|
|
return False, value
|
|
|
|
return False, value
|
|
|
|
|
|
def _truncate(s: str, max_len: int) -> str:
|
|
"""Truncate a string for display."""
|
|
if len(s) <= max_len:
|
|
return s
|
|
return s[:max_len - 3] + "..."
|