Some checks failed
Contributor Attribution Check / check-attribution (pull_request) Successful in 24s
Docker Build and Publish / build-and-push (pull_request) Has been skipped
Supply Chain Audit / Scan PR for supply chain risks (pull_request) Successful in 22s
Tests / e2e (pull_request) Successful in 3m6s
Tests / test (pull_request) Failing after 41m24s
183 lines
6.2 KiB
Python
183 lines
6.2 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Tests for tool_pokayoke.py — Tool Hallucination Prevention
|
|
"""
|
|
|
|
import json
|
|
import pytest
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
from tools.tool_pokayoke import (
|
|
levenshtein_distance,
|
|
find_similar_names,
|
|
auto_correct_parameter,
|
|
ToolCallValidator,
|
|
validate_tool_call,
|
|
reset_circuit_breaker,
|
|
get_hallucination_stats,
|
|
)
|
|
|
|
|
|
class TestLevenshteinDistance:
|
|
"""Test Levenshtein distance calculation."""
|
|
|
|
def test_identical_strings(self):
|
|
assert levenshtein_distance("hello", "hello") == 0
|
|
|
|
def test_single_insertion(self):
|
|
assert levenshtein_distance("hello", "hell") == 1
|
|
assert levenshtein_distance("hell", "hello") == 1
|
|
|
|
def test_single_substitution(self):
|
|
assert levenshtein_distance("hello", "hallo") == 1
|
|
|
|
def test_multiple_edits(self):
|
|
assert levenshtein_distance("kitten", "sitting") == 3
|
|
|
|
def test_empty_strings(self):
|
|
assert levenshtein_distance("", "hello") == 5
|
|
assert levenshtein_distance("hello", "") == 5
|
|
assert levenshtein_distance("", "") == 0
|
|
|
|
|
|
class TestFindSimilarNames:
|
|
"""Test finding similar tool names."""
|
|
|
|
def test_exact_match_excluded(self):
|
|
names = ["browser_type", "browser_click", "browser_navigate"]
|
|
result = find_similar_names("browser_type", names, max_distance=2)
|
|
# Exact match should not be included (distance 0)
|
|
assert all(name != "browser_type" for name, _ in result)
|
|
|
|
def test_close_matches_found(self):
|
|
names = ["browser_type", "browser_click", "terminal"]
|
|
result = find_similar_names("browser_typo", names, max_distance=1)
|
|
assert len(result) == 1
|
|
assert result[0][0] == "browser_type"
|
|
assert result[0][1] == 1
|
|
|
|
def test_no_matches_beyond_distance(self):
|
|
names = ["browser_type", "terminal"]
|
|
result = find_similar_names("xyz", names, max_distance=1)
|
|
assert len(result) == 0
|
|
|
|
|
|
class TestAutoCorrectParameter:
|
|
"""Test parameter auto-correction."""
|
|
|
|
def test_exact_correction(self):
|
|
valid = ["path", "content", "mode"]
|
|
assert auto_correct_parameter("path", valid) is None # Exact match, no correction needed
|
|
|
|
def test_single_edit_correction(self):
|
|
valid = ["path", "content", "mode"]
|
|
assert auto_correct_parameter("file_path", valid) is None # Distance > 1
|
|
assert auto_correct_parameter("pathe", valid) == "path" # Distance 1
|
|
|
|
def test_no_correction_for_far_match(self):
|
|
valid = ["path", "content"]
|
|
assert auto_correct_parameter("xyz", valid) is None
|
|
|
|
|
|
class TestToolCallValidator:
|
|
"""Test the stateful validator."""
|
|
|
|
@pytest.fixture
|
|
def validator(self):
|
|
v = ToolCallValidator(failure_threshold=3)
|
|
# Mock tool schemas
|
|
v.tool_schemas = {
|
|
"browser_type": {
|
|
"parameters": {
|
|
"properties": {
|
|
"ref": {"type": "string"},
|
|
"text": {"type": "string"},
|
|
}
|
|
}
|
|
},
|
|
"terminal": {
|
|
"parameters": {
|
|
"properties": {
|
|
"command": {"type": "string"},
|
|
"timeout": {"type": "integer"},
|
|
}
|
|
}
|
|
},
|
|
}
|
|
v._initialized = True
|
|
return v
|
|
|
|
def test_valid_tool_passes(self, validator):
|
|
is_valid, corrected, params, msgs = validator.validate("browser_type", {"ref": "@e1"})
|
|
assert is_valid is True
|
|
assert corrected is None
|
|
assert len(msgs) == 0
|
|
|
|
def test_invalid_tool_suggests(self, validator):
|
|
is_valid, corrected, params, msgs = validator.validate("browser_typo", {"ref": "@e1"})
|
|
assert is_valid is False
|
|
assert "browser_type" in str(msgs)
|
|
|
|
def test_auto_correct_tool_name(self, validator):
|
|
is_valid, corrected, params, msgs = validator.validate("browser_tipe", {"ref": "@e1"})
|
|
assert is_valid is True
|
|
assert corrected == "browser_type"
|
|
assert any("Auto-corrected" in m for m in msgs)
|
|
|
|
def test_parameter_correction(self, validator):
|
|
is_valid, corrected, params, msgs = validator.validate("browser_type", {"reff": "@e1"})
|
|
assert is_valid is True
|
|
assert "ref" in params
|
|
assert any("reff" in m and "ref" in m for m in msgs)
|
|
|
|
def test_circuit_breaker(self, validator):
|
|
# Fail 3 times
|
|
for _ in range(3):
|
|
validator.validate("nonexistent_tool", {})
|
|
|
|
# 4th attempt should trigger circuit breaker
|
|
is_valid, corrected, params, msgs = validator.validate("nonexistent_tool", {})
|
|
assert is_valid is False
|
|
assert any("CIRCUIT BREAKER" in m for m in msgs)
|
|
|
|
def test_success_resets_circuit_breaker(self, validator):
|
|
# Fail twice
|
|
validator.validate("nonexistent_tool", {})
|
|
validator.validate("nonexistent_tool", {})
|
|
|
|
# Succeed with valid tool
|
|
validator.validate("browser_type", {"ref": "@e1"})
|
|
|
|
# Failure counter should be reset
|
|
assert "nonexistent_tool" not in validator.consecutive_failures
|
|
|
|
|
|
class TestValidateToolCall:
|
|
"""Test the global validate_tool_call function."""
|
|
|
|
def test_integration(self):
|
|
# This test depends on the actual registry being available
|
|
# We'll mock it for unit testing
|
|
with patch("tools.tool_pokayoke._validator") as mock_validator:
|
|
mock_validator.validate.return_value = (True, None, {}, [])
|
|
is_valid, corrected, params, msgs = validate_tool_call("test_tool", {})
|
|
assert is_valid is True
|
|
|
|
|
|
class TestCircuitBreakerReset:
|
|
"""Test circuit breaker reset functionality."""
|
|
|
|
def test_reset_specific_tool(self):
|
|
reset_circuit_breaker("test_tool")
|
|
stats = get_hallucination_stats()
|
|
assert "test_tool" not in stats["consecutive_failures"]
|
|
|
|
def test_reset_all(self):
|
|
reset_circuit_breaker()
|
|
stats = get_hallucination_stats()
|
|
assert len(stats["consecutive_failures"]) == 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__, "-v"])
|