""" poka_yoke.py — Validation firewall for tool calls. Poka-yoke (mistake-proofing): validates tool calls against the registry before execution. Catches hallucinated tool names, malformed parameters, missing required arguments, and type mismatches. Usage: from tools.poka_yoke import validate_tool_call is_valid, corrected_name, corrected_params, messages = validate_tool_call( "read_file", {"path": "test.txt"} ) """ import json import logging import re from typing import Any, Dict, List, Optional, Tuple logger = logging.getLogger(__name__) def validate_tool_call( function_name: str, function_args: Dict[str, Any], ) -> Tuple[bool, Optional[str], Optional[Dict[str, Any]], List[str]]: """Validate a tool call against the registry before execution. Args: function_name: The tool name from the LLM's function_call. function_args: The arguments dict from the LLM's function_call. Returns: Tuple of (is_valid, corrected_name, corrected_params, messages): - is_valid: False if the call should be blocked entirely. - corrected_name: Suggested name if a close match was found (None if OK). - corrected_params: Corrected params if type coercion fixed issues (None if OK). - messages: List of error/warning/info messages. """ from tools.registry import registry messages: List[str] = [] corrected_name: Optional[str] = None corrected_params: Optional[Dict[str, Any]] = None # ── 1. Check if tool exists ─────────────────────────────────────────── entry = registry.get_entry(function_name) if entry is None: # Tool not found — suggest closest match all_names = registry.get_all_tool_names() suggestion = _find_closest_name(function_name, all_names) if suggestion: messages.append( f"Unknown tool '{function_name}'. Did you mean '{suggestion}'?" ) corrected_name = suggestion # Re-validate with corrected name entry = registry.get_entry(suggestion) if entry is None: return False, corrected_name, None, messages else: available = ", ".join(sorted(all_names)[:20]) messages.append( f"Unknown tool '{function_name}'. " f"Available tools: {available}{'...' if len(all_names) > 20 else ''}" ) return False, None, None, messages # ── 2. Validate parameters against schema ───────────────────────────── schema = entry.schema params_schema = schema.get("parameters", {}) properties = params_schema.get("properties", {}) required = set(params_schema.get("required", [])) # Check for missing required parameters for param_name in sorted(required): if param_name not in function_args: param_info = properties.get(param_name, {}) param_type = param_info.get("type", "any") messages.append( f"Missing required parameter '{param_name}' " f"(expected type: {param_type}). " f"Tool: {function_name}" ) # If required params are missing, we still return the error # (the agent might be able to self-correct) if any("Missing required" in m for m in messages): # Don't block — return the error as a tool result so the agent can retry # But mark as invalid so caller knows return False, corrected_name, corrected_params, messages # ── 3. Check for unknown parameters ─────────────────────────────────── if properties: known_params = set(properties.keys()) # Allow extra params that start with _ (internal convention) unknown = [ p for p in function_args if p not in known_params and not p.startswith("_") ] if unknown: known_str = ", ".join(sorted(known_params)) unknown_str = ", ".join(sorted(unknown)) messages.append( f"Unknown parameter(s) for '{function_name}': {unknown_str}. " f"Known parameters: {known_str}" ) # Remove unknown params (don't block, just clean) corrected_params = { k: v for k, v in function_args.items() if k in known_params or k.startswith("_") } # ── 4. Type validation ──────────────────────────────────────────────── type_errors = [] coerced = dict(corrected_params or function_args) for param_name, param_value in coerced.items(): if param_name.startswith("_"): continue param_schema = properties.get(param_name) if not param_schema: continue expected_type = param_schema.get("type") if not expected_type: continue is_valid_type, coerced_value = _validate_type( param_name, param_value, expected_type ) if not is_valid_type: type_errors.append( f"Parameter '{param_name}': expected {expected_type}, " f"got {type(param_value).__name__} ({_truncate(str(param_value), 50)})" ) elif coerced_value is not param_value: coerced[param_name] = coerced_value messages.append( f"Parameter '{param_name}': coerced from " f"{type(param_value).__name__} to {expected_type}" ) if type_errors: messages.extend(type_errors) return False, corrected_name, corrected_params, messages if coerced != (corrected_params or function_args): corrected_params = coerced # ── 5. Enum validation ──────────────────────────────────────────────── for param_name, param_value in (corrected_params or function_args).items(): param_schema = properties.get(param_name, {}) enum_values = param_schema.get("enum") if enum_values and param_value not in enum_values: messages.append( f"Parameter '{param_name}': value '{param_value}' not in " f"allowed values: {enum_values}" ) return False, corrected_name, corrected_params, messages # ── 6. Pattern validation ───────────────────────────────────────────── for param_name, param_value in (corrected_params or function_args).items(): if not isinstance(param_value, str): continue param_schema = properties.get(param_name, {}) pattern = param_schema.get("pattern") if pattern and not re.match(pattern, param_value): messages.append( f"Parameter '{param_name}': value '{_truncate(param_value, 50)}' " f"does not match pattern '{pattern}'" ) # ── Done ────────────────────────────────────────────────────────────── is_valid = not any("Missing required" in m for m in messages) if is_valid and not messages: return True, None, None, [] return is_valid, corrected_name, corrected_params, messages def _find_closest_name(target: str, candidates: List[str]) -> Optional[str]: """Find the closest tool name using simple edit distance heuristics.""" if not candidates: return None target_lower = target.lower() # Exact prefix match for name in candidates: if name.lower().startswith(target_lower[:4]) and len(target_lower) > 3: return name # Substring match for name in candidates: if target_lower in name.lower() or name.lower() in target_lower: return name # Levenshtein distance (simple, for short strings) def _levenshtein(a: str, b: str) -> int: if len(a) < len(b): return _levenshtein(b, a) if len(b) == 0: return len(a) prev = range(len(b) + 1) for i, ca in enumerate(a): curr = [i + 1] for j, cb in enumerate(b): curr.append(min( prev[j + 1] + 1, curr[j] + 1, prev[j] + (0 if ca == cb else 1), )) prev = curr return prev[-1] distances = [(name, _levenshtein(target_lower, name.lower())) for name in candidates] distances.sort(key=lambda x: x[1]) # Return if edit distance is small enough if distances and distances[0][1] <= max(3, len(target) // 3): return distances[0][0] return None def _validate_type( param_name: str, value: Any, expected_type: str ) -> Tuple[bool, Any]: """Validate and optionally coerce a parameter value to the expected type. Returns (is_valid, coerced_value). coerced_value is value itself if no coercion was needed. """ type_map = { "string": str, "integer": int, "number": (int, float), "boolean": bool, "array": list, "object": dict, } expected = type_map.get(expected_type) if expected is None: return True, value # Unknown type, skip validation # Direct type check if isinstance(value, expected): return True, value # Coercion attempts if expected_type == "string": return True, str(value) if expected_type == "integer": if isinstance(value, str) and value.isdigit(): return True, int(value) if isinstance(value, float) and value == int(value): return True, int(value) return False, value if expected_type == "number": if isinstance(value, str): try: return True, float(value) except ValueError: return False, value return False, value if expected_type == "boolean": if isinstance(value, str): lower = value.lower() if lower in ("true", "1", "yes"): return True, True if lower in ("false", "0", "no"): return True, False return False, value return False, value def _truncate(s: str, max_len: int) -> str: """Truncate a string for display.""" if len(s) <= max_len: return s return s[:max_len - 3] + "..."