Some checks failed
Docker Build and Publish / build-and-push (pull_request) Has been skipped
Contributor Attribution Check / check-attribution (pull_request) Successful in 35s
Supply Chain Audit / Scan PR for supply chain risks (pull_request) Successful in 37s
Tests / e2e (pull_request) Successful in 1m48s
Tests / test (pull_request) Failing after 36m13s
Centralized safety classification for tool call batches: tools/batch_executor.py (new): - classify_tool_calls() — classifies batch into parallel_safe, path_scoped, sequential, never_parallel tiers - BatchExecutionPlan — structured plan with parallel and sequential batches - Path conflict detection — write_file + patch on same file go sequential - Destructive command detection — rm, mv, sed -i, redirects - execute_parallel_batch() — ThreadPoolExecutor for concurrent execution tools/registry.py (enhanced): - ToolEntry.parallel_safe field — tools can declare parallel safety - registry.register() accepts parallel_safe=True parameter - registry.get_parallel_safe_tools() — query registry-declared safe tools Safety tiers: - parallel_safe: read_file, web_search, search_files, etc. - path_scoped: write_file, patch (concurrent when paths don't overlap) - sequential: terminal, delegate_task, unknown tools - never_parallel: clarify (requires user interaction) 19 tests passing.
151 lines
6.0 KiB
Python
151 lines
6.0 KiB
Python
"""Tests for batch tool execution safety classification."""
|
|
import json
|
|
import pytest
|
|
from unittest.mock import MagicMock
|
|
|
|
|
|
def _make_tool_call(name: str, args: dict) -> MagicMock:
|
|
"""Create a mock tool call object."""
|
|
tc = MagicMock()
|
|
tc.function.name = name
|
|
tc.function.arguments = json.dumps(args)
|
|
tc.id = f"call_{name}_1"
|
|
return tc
|
|
|
|
|
|
class TestClassification:
|
|
def test_parallel_safe_read_file(self):
|
|
from tools.batch_executor import classify_single_tool_call
|
|
tc = _make_tool_call("read_file", {"path": "README.md"})
|
|
result = classify_single_tool_call(tc)
|
|
assert result.tier == "parallel_safe"
|
|
|
|
def test_parallel_safe_web_search(self):
|
|
from tools.batch_executor import classify_single_tool_call
|
|
tc = _make_tool_call("web_search", {"query": "test"})
|
|
result = classify_single_tool_call(tc)
|
|
assert result.tier == "parallel_safe"
|
|
|
|
def test_parallel_safe_search_files(self):
|
|
from tools.batch_executor import classify_single_tool_call
|
|
tc = _make_tool_call("search_files", {"pattern": "test"})
|
|
result = classify_single_tool_call(tc)
|
|
assert result.tier == "parallel_safe"
|
|
|
|
def test_never_parallel_clarify(self):
|
|
from tools.batch_executor import classify_single_tool_call
|
|
tc = _make_tool_call("clarify", {"question": "test"})
|
|
result = classify_single_tool_call(tc)
|
|
assert result.tier == "never_parallel"
|
|
|
|
def test_terminal_is_sequential(self):
|
|
from tools.batch_executor import classify_single_tool_call
|
|
tc = _make_tool_call("terminal", {"command": "ls -la"})
|
|
result = classify_single_tool_call(tc)
|
|
assert result.tier == "sequential"
|
|
|
|
def test_terminal_destructive_rm(self):
|
|
from tools.batch_executor import classify_single_tool_call
|
|
tc = _make_tool_call("terminal", {"command": "rm -rf /tmp/test"})
|
|
result = classify_single_tool_call(tc)
|
|
assert result.tier == "sequential"
|
|
assert "Destructive" in result.reason
|
|
|
|
def test_write_file_is_path_scoped(self):
|
|
from tools.batch_executor import classify_single_tool_call
|
|
tc = _make_tool_call("write_file", {"path": "/tmp/test.txt", "content": "hello"})
|
|
result = classify_single_tool_call(tc)
|
|
assert result.tier == "path_scoped"
|
|
|
|
def test_delegate_is_sequential(self):
|
|
from tools.batch_executor import classify_single_tool_call
|
|
tc = _make_tool_call("delegate_task", {"goal": "test"})
|
|
result = classify_single_tool_call(tc)
|
|
assert result.tier == "sequential"
|
|
|
|
def test_unknown_tool_is_sequential(self):
|
|
from tools.batch_executor import classify_single_tool_call
|
|
tc = _make_tool_call("some_unknown_tool", {"arg": "val"})
|
|
result = classify_single_tool_call(tc)
|
|
assert result.tier == "sequential"
|
|
|
|
|
|
class TestBatchClassification:
|
|
def test_all_parallel_stays_parallel(self):
|
|
from tools.batch_executor import classify_tool_calls
|
|
tcs = [
|
|
_make_tool_call("read_file", {"path": f"file{i}.txt"})
|
|
for i in range(5)
|
|
]
|
|
plan = classify_tool_calls(tcs)
|
|
assert plan.can_parallelize
|
|
assert len(plan.parallel_batch) == 5
|
|
assert len(plan.sequential_batch) == 0
|
|
|
|
def test_mixed_batch(self):
|
|
from tools.batch_executor import classify_tool_calls
|
|
tcs = [
|
|
_make_tool_call("read_file", {"path": "a.txt"}),
|
|
_make_tool_call("terminal", {"command": "ls"}),
|
|
_make_tool_call("web_search", {"query": "test"}),
|
|
_make_tool_call("delegate_task", {"goal": "test"}),
|
|
]
|
|
plan = classify_tool_calls(tcs)
|
|
# read_file + web_search should be parallel (both parallel_safe)
|
|
# terminal + delegate_task should be sequential
|
|
assert len(plan.parallel_batch) >= 2
|
|
assert len(plan.sequential_batch) >= 2
|
|
|
|
def test_clarify_blocks_all(self):
|
|
from tools.batch_executor import classify_tool_calls
|
|
tcs = [
|
|
_make_tool_call("read_file", {"path": "a.txt"}),
|
|
_make_tool_call("clarify", {"question": "which one?"}),
|
|
_make_tool_call("web_search", {"query": "test"}),
|
|
]
|
|
plan = classify_tool_calls(tcs)
|
|
clarify_in_seq = any(c.tool_name == "clarify" for c in plan.sequential_batch)
|
|
assert clarify_in_seq
|
|
|
|
def test_overlapping_paths_sequential(self):
|
|
from tools.batch_executor import classify_tool_calls
|
|
tcs = [
|
|
_make_tool_call("write_file", {"path": "/tmp/test/a.txt", "content": "hello"}),
|
|
_make_tool_call("patch", {"path": "/tmp/test/a.txt", "old_string": "a", "new_string": "b"}),
|
|
]
|
|
plan = classify_tool_calls(tcs)
|
|
# write_file and patch on SAME file -> conflict -> one must be sequential
|
|
assert len(plan.sequential_batch) >= 1
|
|
|
|
|
|
class TestDestructiveCommands:
|
|
def test_rm_flagged(self):
|
|
from tools.batch_executor import is_destructive_command
|
|
assert is_destructive_command("rm -rf /tmp")
|
|
assert is_destructive_command("rm file.txt")
|
|
|
|
def test_mv_flagged(self):
|
|
from tools.batch_executor import is_destructive_command
|
|
assert is_destructive_command("mv old new")
|
|
|
|
def test_sed_i_flagged(self):
|
|
from tools.batch_executor import is_destructive_command
|
|
assert is_destructive_command("sed -i 's/a/b/g' file")
|
|
|
|
def test_redirect_overwrite_flagged(self):
|
|
from tools.batch_executor import is_destructive_command
|
|
assert is_destructive_command("echo test > file.txt")
|
|
|
|
def test_safe_commands_not_flagged(self):
|
|
from tools.batch_executor import is_destructive_command
|
|
assert not is_destructive_command("ls -la")
|
|
assert not is_destructive_command("cat file.txt")
|
|
assert not is_destructive_command("echo test >> file.txt") # append is safe
|
|
|
|
|
|
class TestRegistryIntegration:
|
|
def test_parallel_safe_in_registry(self):
|
|
from tools.registry import registry
|
|
safe = registry.get_parallel_safe_tools()
|
|
assert isinstance(safe, set)
|