Files
hermes-agent/tests/test_batch_executor.py
Alexander Whitestone 9f0c410481
Some checks failed
Docker Build and Publish / build-and-push (pull_request) Has been skipped
Contributor Attribution Check / check-attribution (pull_request) Successful in 35s
Supply Chain Audit / Scan PR for supply chain risks (pull_request) Successful in 37s
Tests / e2e (pull_request) Successful in 1m48s
Tests / test (pull_request) Failing after 36m13s
feat: batch tool execution with parallel safety checks (#749)
Centralized safety classification for tool call batches:

tools/batch_executor.py (new):
- classify_tool_calls() — classifies batch into parallel_safe,
  path_scoped, sequential, never_parallel tiers
- BatchExecutionPlan — structured plan with parallel and sequential batches
- Path conflict detection — write_file + patch on same file go sequential
- Destructive command detection — rm, mv, sed -i, redirects
- execute_parallel_batch() — ThreadPoolExecutor for concurrent execution

tools/registry.py (enhanced):
- ToolEntry.parallel_safe field — tools can declare parallel safety
- registry.register() accepts parallel_safe=True parameter
- registry.get_parallel_safe_tools() — query registry-declared safe tools

Safety tiers:
- parallel_safe: read_file, web_search, search_files, etc.
- path_scoped: write_file, patch (concurrent when paths don't overlap)
- sequential: terminal, delegate_task, unknown tools
- never_parallel: clarify (requires user interaction)

19 tests passing.
2026-04-15 22:17:16 -04:00

151 lines
6.0 KiB
Python

"""Tests for batch tool execution safety classification."""
import json
import pytest
from unittest.mock import MagicMock
def _make_tool_call(name: str, args: dict) -> MagicMock:
"""Create a mock tool call object."""
tc = MagicMock()
tc.function.name = name
tc.function.arguments = json.dumps(args)
tc.id = f"call_{name}_1"
return tc
class TestClassification:
def test_parallel_safe_read_file(self):
from tools.batch_executor import classify_single_tool_call
tc = _make_tool_call("read_file", {"path": "README.md"})
result = classify_single_tool_call(tc)
assert result.tier == "parallel_safe"
def test_parallel_safe_web_search(self):
from tools.batch_executor import classify_single_tool_call
tc = _make_tool_call("web_search", {"query": "test"})
result = classify_single_tool_call(tc)
assert result.tier == "parallel_safe"
def test_parallel_safe_search_files(self):
from tools.batch_executor import classify_single_tool_call
tc = _make_tool_call("search_files", {"pattern": "test"})
result = classify_single_tool_call(tc)
assert result.tier == "parallel_safe"
def test_never_parallel_clarify(self):
from tools.batch_executor import classify_single_tool_call
tc = _make_tool_call("clarify", {"question": "test"})
result = classify_single_tool_call(tc)
assert result.tier == "never_parallel"
def test_terminal_is_sequential(self):
from tools.batch_executor import classify_single_tool_call
tc = _make_tool_call("terminal", {"command": "ls -la"})
result = classify_single_tool_call(tc)
assert result.tier == "sequential"
def test_terminal_destructive_rm(self):
from tools.batch_executor import classify_single_tool_call
tc = _make_tool_call("terminal", {"command": "rm -rf /tmp/test"})
result = classify_single_tool_call(tc)
assert result.tier == "sequential"
assert "Destructive" in result.reason
def test_write_file_is_path_scoped(self):
from tools.batch_executor import classify_single_tool_call
tc = _make_tool_call("write_file", {"path": "/tmp/test.txt", "content": "hello"})
result = classify_single_tool_call(tc)
assert result.tier == "path_scoped"
def test_delegate_is_sequential(self):
from tools.batch_executor import classify_single_tool_call
tc = _make_tool_call("delegate_task", {"goal": "test"})
result = classify_single_tool_call(tc)
assert result.tier == "sequential"
def test_unknown_tool_is_sequential(self):
from tools.batch_executor import classify_single_tool_call
tc = _make_tool_call("some_unknown_tool", {"arg": "val"})
result = classify_single_tool_call(tc)
assert result.tier == "sequential"
class TestBatchClassification:
def test_all_parallel_stays_parallel(self):
from tools.batch_executor import classify_tool_calls
tcs = [
_make_tool_call("read_file", {"path": f"file{i}.txt"})
for i in range(5)
]
plan = classify_tool_calls(tcs)
assert plan.can_parallelize
assert len(plan.parallel_batch) == 5
assert len(plan.sequential_batch) == 0
def test_mixed_batch(self):
from tools.batch_executor import classify_tool_calls
tcs = [
_make_tool_call("read_file", {"path": "a.txt"}),
_make_tool_call("terminal", {"command": "ls"}),
_make_tool_call("web_search", {"query": "test"}),
_make_tool_call("delegate_task", {"goal": "test"}),
]
plan = classify_tool_calls(tcs)
# read_file + web_search should be parallel (both parallel_safe)
# terminal + delegate_task should be sequential
assert len(plan.parallel_batch) >= 2
assert len(plan.sequential_batch) >= 2
def test_clarify_blocks_all(self):
from tools.batch_executor import classify_tool_calls
tcs = [
_make_tool_call("read_file", {"path": "a.txt"}),
_make_tool_call("clarify", {"question": "which one?"}),
_make_tool_call("web_search", {"query": "test"}),
]
plan = classify_tool_calls(tcs)
clarify_in_seq = any(c.tool_name == "clarify" for c in plan.sequential_batch)
assert clarify_in_seq
def test_overlapping_paths_sequential(self):
from tools.batch_executor import classify_tool_calls
tcs = [
_make_tool_call("write_file", {"path": "/tmp/test/a.txt", "content": "hello"}),
_make_tool_call("patch", {"path": "/tmp/test/a.txt", "old_string": "a", "new_string": "b"}),
]
plan = classify_tool_calls(tcs)
# write_file and patch on SAME file -> conflict -> one must be sequential
assert len(plan.sequential_batch) >= 1
class TestDestructiveCommands:
def test_rm_flagged(self):
from tools.batch_executor import is_destructive_command
assert is_destructive_command("rm -rf /tmp")
assert is_destructive_command("rm file.txt")
def test_mv_flagged(self):
from tools.batch_executor import is_destructive_command
assert is_destructive_command("mv old new")
def test_sed_i_flagged(self):
from tools.batch_executor import is_destructive_command
assert is_destructive_command("sed -i 's/a/b/g' file")
def test_redirect_overwrite_flagged(self):
from tools.batch_executor import is_destructive_command
assert is_destructive_command("echo test > file.txt")
def test_safe_commands_not_flagged(self):
from tools.batch_executor import is_destructive_command
assert not is_destructive_command("ls -la")
assert not is_destructive_command("cat file.txt")
assert not is_destructive_command("echo test >> file.txt") # append is safe
class TestRegistryIntegration:
def test_parallel_safe_in_registry(self):
from tools.registry import registry
safe = registry.get_parallel_safe_tools()
assert isinstance(safe, set)