275 lines
9.2 KiB
Python
275 lines
9.2 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
test_poka_yoke.py — Tests for the tool call validation firewall.
|
|
|
|
Covers: unknown tool, bad param type, missing required arg,
|
|
extra unknown param, enum validation, closest-name suggestion.
|
|
"""
|
|
|
|
import json
|
|
import os
|
|
import sys
|
|
import pytest
|
|
from unittest.mock import patch, MagicMock
|
|
|
|
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
|
|
|
|
from tools.poka_yoke import (
|
|
validate_tool_call,
|
|
_find_closest_name,
|
|
_validate_type,
|
|
_truncate,
|
|
)
|
|
|
|
|
|
# ── Mock Registry ─────────────────────────────────────────────────────────────
|
|
|
|
class MockEntry:
|
|
def __init__(self, name, schema):
|
|
self.name = name
|
|
self.schema = schema
|
|
self.toolset = "test"
|
|
|
|
|
|
MOCK_TOOLS = {
|
|
"read_file": MockEntry("read_file", {
|
|
"name": "read_file",
|
|
"description": "Read a file",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"path": {"type": "string", "description": "File path"},
|
|
"offset": {"type": "integer", "description": "Start line"},
|
|
"limit": {"type": "integer", "description": "Max lines"},
|
|
},
|
|
"required": ["path"],
|
|
},
|
|
}),
|
|
"web_search": MockEntry("web_search", {
|
|
"name": "web_search",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"query": {"type": "string"},
|
|
"max_results": {"type": "integer"},
|
|
},
|
|
"required": ["query"],
|
|
},
|
|
}),
|
|
"write_file": MockEntry("write_file", {
|
|
"name": "write_file",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"path": {"type": "string"},
|
|
"content": {"type": "string"},
|
|
},
|
|
"required": ["path", "content"],
|
|
},
|
|
}),
|
|
"terminal": MockEntry("terminal", {
|
|
"name": "terminal",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"command": {"type": "string"},
|
|
"timeout": {"type": "integer"},
|
|
"background": {"type": "boolean"},
|
|
},
|
|
"required": ["command"],
|
|
},
|
|
}),
|
|
}
|
|
|
|
|
|
def _mock_registry():
|
|
"""Create a mock registry."""
|
|
mock_reg = MagicMock()
|
|
mock_reg.get_entry = lambda name: MOCK_TOOLS.get(name)
|
|
mock_reg.get_all_tool_names = lambda: list(MOCK_TOOLS.keys())
|
|
return mock_reg
|
|
|
|
|
|
# ── Test: Unknown Tool ────────────────────────────────────────────────────────
|
|
|
|
class TestUnknownTool:
|
|
@patch("tools.poka_yoke.registry")
|
|
def test_unknown_tool_rejected(self, mock_reg):
|
|
mock_reg.get_entry.return_value = None
|
|
mock_reg.get_all_tool_names.return_value = list(MOCK_TOOLS.keys())
|
|
|
|
is_valid, name, params, msgs = validate_tool_call("nonexistent_tool", {})
|
|
|
|
assert is_valid is False
|
|
assert len(msgs) > 0
|
|
assert "nonexistent_tool" in msgs[0]
|
|
assert "Unknown tool" in msgs[0]
|
|
|
|
@patch("tools.poka_yoke.registry")
|
|
def test_unknown_tool_lists_available(self, mock_reg):
|
|
mock_reg.get_entry.return_value = None
|
|
mock_reg.get_all_tool_names.return_value = list(MOCK_TOOLS.keys())
|
|
|
|
is_valid, name, params, msgs = validate_tool_call("foo", {})
|
|
|
|
assert is_valid is False
|
|
assert "read_file" in msgs[0]
|
|
|
|
@patch("tools.poka_yoke.registry")
|
|
def test_close_name_suggests_correction(self, mock_reg):
|
|
mock_reg.get_entry.return_value = None
|
|
mock_reg.get_all_tool_names.return_value = list(MOCK_TOOLS.keys())
|
|
|
|
is_valid, name, params, msgs = validate_tool_call("readfile", {})
|
|
|
|
assert "read_file" in msgs[0]
|
|
assert name == "read_file"
|
|
|
|
|
|
# ── Test: Missing Required Args ───────────────────────────────────────────────
|
|
|
|
class TestMissingRequired:
|
|
@patch("tools.poka_yoke.registry")
|
|
def test_missing_required_rejected(self, mock_reg):
|
|
mock_reg.get_entry.return_value = MOCK_TOOLS["read_file"]
|
|
|
|
is_valid, name, params, msgs = validate_tool_call("read_file", {})
|
|
|
|
assert is_valid is False
|
|
assert any("Missing required" in m for m in msgs)
|
|
assert any("'path'" in m for m in msgs)
|
|
|
|
@patch("tools.poka_yoke.registry")
|
|
def test_multiple_missing_required(self, mock_reg):
|
|
mock_reg.get_entry.return_value = MOCK_TOOLS["write_file"]
|
|
|
|
is_valid, name, params, msgs = validate_tool_call("write_file", {})
|
|
|
|
assert is_valid is False
|
|
assert any("'path'" in m for m in msgs)
|
|
assert any("'content'" in m for m in msgs)
|
|
|
|
@patch("tools.poka_yoke.registry")
|
|
def test_required_present_passes(self, mock_reg):
|
|
mock_reg.get_entry.return_value = MOCK_TOOLS["read_file"]
|
|
|
|
is_valid, name, params, msgs = validate_tool_call(
|
|
"read_file", {"path": "test.txt"}
|
|
)
|
|
|
|
assert is_valid is True
|
|
|
|
|
|
# ── Test: Type Validation ─────────────────────────────────────────────────────
|
|
|
|
class TestTypeValidation:
|
|
@patch("tools.poka_yoke.registry")
|
|
def test_wrong_type_rejected(self, mock_reg):
|
|
mock_reg.get_entry.return_value = MOCK_TOOLS["read_file"]
|
|
|
|
is_valid, name, params, msgs = validate_tool_call(
|
|
"read_file", {"path": "test.txt", "offset": "not_a_number"}
|
|
)
|
|
|
|
assert is_valid is False
|
|
assert any("offset" in m and "integer" in m for m in msgs)
|
|
|
|
@patch("tools.poka_yoke.registry")
|
|
def test_string_to_int_coercion(self, mock_reg):
|
|
mock_reg.get_entry.return_value = MOCK_TOOLS["read_file"]
|
|
|
|
is_valid, name, params, msgs = validate_tool_call(
|
|
"read_file", {"path": "test.txt", "offset": "42"}
|
|
)
|
|
|
|
assert is_valid is True
|
|
assert params is not None
|
|
assert params["offset"] == 42
|
|
|
|
@patch("tools.poka_yoke.registry")
|
|
def test_boolean_coercion(self, mock_reg):
|
|
mock_reg.get_entry.return_value = MOCK_TOOLS["terminal"]
|
|
|
|
is_valid, name, params, msgs = validate_tool_call(
|
|
"terminal", {"command": "ls", "background": "true"}
|
|
)
|
|
|
|
assert is_valid is True
|
|
assert params is not None
|
|
assert params["background"] is True
|
|
|
|
|
|
# ── Test: Unknown Parameters ──────────────────────────────────────────────────
|
|
|
|
class TestUnknownParams:
|
|
@patch("tools.poka_yoke.registry")
|
|
def test_unknown_param_removed(self, mock_reg):
|
|
mock_reg.get_entry.return_value = MOCK_TOOLS["read_file"]
|
|
|
|
is_valid, name, params, msgs = validate_tool_call(
|
|
"read_file", {"path": "test.txt", "bogus_param": "value"}
|
|
)
|
|
|
|
assert is_valid is True
|
|
assert params is not None
|
|
assert "bogus_param" not in params
|
|
assert "path" in params
|
|
assert any("Unknown parameter" in m for m in msgs)
|
|
|
|
|
|
# ── Test: Valid Calls Pass Through ────────────────────────────────────────────
|
|
|
|
class TestValidCalls:
|
|
@patch("tools.poka_yoke.registry")
|
|
def test_valid_read_file(self, mock_reg):
|
|
mock_reg.get_entry.return_value = MOCK_TOOLS["read_file"]
|
|
|
|
is_valid, name, params, msgs = validate_tool_call(
|
|
"read_file", {"path": "test.txt", "offset": 1, "limit": 100}
|
|
)
|
|
|
|
assert is_valid is True
|
|
assert name is None
|
|
assert params is None
|
|
assert msgs == []
|
|
|
|
@patch("tools.poka_yoke.registry")
|
|
def test_valid_write_file(self, mock_reg):
|
|
mock_reg.get_entry.return_value = MOCK_TOOLS["write_file"]
|
|
|
|
is_valid, name, params, msgs = validate_tool_call(
|
|
"write_file", {"path": "out.txt", "content": "hello"}
|
|
)
|
|
|
|
assert is_valid is True
|
|
|
|
|
|
# ── Test: Helper Functions ────────────────────────────────────────────────────
|
|
|
|
class TestHelpers:
|
|
def test_find_closest_exact_prefix(self):
|
|
assert _find_closest_name("readfil", ["read_file", "write_file"]) == "read_file"
|
|
|
|
def test_find_closest_substring(self):
|
|
assert _find_closest_name("file", ["read_file", "web_search"]) == "read_file"
|
|
|
|
def test_find_closest_no_match(self):
|
|
assert _find_closest_name("xyzzy", ["read_file", "write_file"]) is None
|
|
|
|
def test_validate_type_string(self):
|
|
ok, val = _validate_type("x", "hello", "string")
|
|
assert ok is True
|
|
|
|
def test_validate_type_int_coercion(self):
|
|
ok, val = _validate_type("x", "42", "integer")
|
|
assert ok is True
|
|
assert val == 42
|
|
|
|
def test_validate_type_int_bad(self):
|
|
ok, val = _validate_type("x", "not_int", "integer")
|
|
assert ok is False
|
|
|
|
def test_truncate(self):
|
|
assert _truncate("hello", 10) == "hello"
|
|
assert _truncate("hello world", 8) == "hello..."
|