"""Tests for batch tool execution safety classification.""" import json import pytest from unittest.mock import MagicMock def _make_tool_call(name: str, args: dict) -> MagicMock: """Create a mock tool call object.""" tc = MagicMock() tc.function.name = name tc.function.arguments = json.dumps(args) tc.id = f"call_{name}_1" return tc class TestClassification: def test_parallel_safe_read_file(self): from tools.batch_executor import classify_single_tool_call tc = _make_tool_call("read_file", {"path": "README.md"}) result = classify_single_tool_call(tc) assert result.tier == "parallel_safe" def test_parallel_safe_web_search(self): from tools.batch_executor import classify_single_tool_call tc = _make_tool_call("web_search", {"query": "test"}) result = classify_single_tool_call(tc) assert result.tier == "parallel_safe" def test_parallel_safe_search_files(self): from tools.batch_executor import classify_single_tool_call tc = _make_tool_call("search_files", {"pattern": "test"}) result = classify_single_tool_call(tc) assert result.tier == "parallel_safe" def test_never_parallel_clarify(self): from tools.batch_executor import classify_single_tool_call tc = _make_tool_call("clarify", {"question": "test"}) result = classify_single_tool_call(tc) assert result.tier == "never_parallel" def test_terminal_is_sequential(self): from tools.batch_executor import classify_single_tool_call tc = _make_tool_call("terminal", {"command": "ls -la"}) result = classify_single_tool_call(tc) assert result.tier == "sequential" def test_terminal_destructive_rm(self): from tools.batch_executor import classify_single_tool_call tc = _make_tool_call("terminal", {"command": "rm -rf /tmp/test"}) result = classify_single_tool_call(tc) assert result.tier == "sequential" assert "Destructive" in result.reason def test_write_file_is_path_scoped(self): from tools.batch_executor import classify_single_tool_call tc = _make_tool_call("write_file", {"path": "/tmp/test.txt", "content": "hello"}) result = classify_single_tool_call(tc) assert result.tier == "path_scoped" def test_delegate_is_sequential(self): from tools.batch_executor import classify_single_tool_call tc = _make_tool_call("delegate_task", {"goal": "test"}) result = classify_single_tool_call(tc) assert result.tier == "sequential" def test_unknown_tool_is_sequential(self): from tools.batch_executor import classify_single_tool_call tc = _make_tool_call("some_unknown_tool", {"arg": "val"}) result = classify_single_tool_call(tc) assert result.tier == "sequential" class TestBatchClassification: def test_all_parallel_stays_parallel(self): from tools.batch_executor import classify_tool_calls tcs = [ _make_tool_call("read_file", {"path": f"file{i}.txt"}) for i in range(5) ] plan = classify_tool_calls(tcs) assert plan.can_parallelize assert len(plan.parallel_batch) == 5 assert len(plan.sequential_batch) == 0 def test_mixed_batch(self): from tools.batch_executor import classify_tool_calls tcs = [ _make_tool_call("read_file", {"path": "a.txt"}), _make_tool_call("terminal", {"command": "ls"}), _make_tool_call("web_search", {"query": "test"}), _make_tool_call("delegate_task", {"goal": "test"}), ] plan = classify_tool_calls(tcs) # read_file + web_search should be parallel (both parallel_safe) # terminal + delegate_task should be sequential assert len(plan.parallel_batch) >= 2 assert len(plan.sequential_batch) >= 2 def test_clarify_blocks_all(self): from tools.batch_executor import classify_tool_calls tcs = [ _make_tool_call("read_file", {"path": "a.txt"}), _make_tool_call("clarify", {"question": "which one?"}), _make_tool_call("web_search", {"query": "test"}), ] plan = classify_tool_calls(tcs) clarify_in_seq = any(c.tool_name == "clarify" for c in plan.sequential_batch) assert clarify_in_seq def test_overlapping_paths_sequential(self): from tools.batch_executor import classify_tool_calls tcs = [ _make_tool_call("write_file", {"path": "/tmp/test/a.txt", "content": "hello"}), _make_tool_call("patch", {"path": "/tmp/test/a.txt", "old_string": "a", "new_string": "b"}), ] plan = classify_tool_calls(tcs) # write_file and patch on SAME file -> conflict -> one must be sequential assert len(plan.sequential_batch) >= 1 class TestDestructiveCommands: def test_rm_flagged(self): from tools.batch_executor import is_destructive_command assert is_destructive_command("rm -rf /tmp") assert is_destructive_command("rm file.txt") def test_mv_flagged(self): from tools.batch_executor import is_destructive_command assert is_destructive_command("mv old new") def test_sed_i_flagged(self): from tools.batch_executor import is_destructive_command assert is_destructive_command("sed -i 's/a/b/g' file") def test_redirect_overwrite_flagged(self): from tools.batch_executor import is_destructive_command assert is_destructive_command("echo test > file.txt") def test_safe_commands_not_flagged(self): from tools.batch_executor import is_destructive_command assert not is_destructive_command("ls -la") assert not is_destructive_command("cat file.txt") assert not is_destructive_command("echo test >> file.txt") # append is safe class TestRegistryIntegration: def test_parallel_safe_in_registry(self): from tools.registry import registry safe = registry.get_parallel_safe_tools() assert isinstance(safe, set)