Files
hermes-agent/tests/test_batch_executor.py

151 lines
6.0 KiB
Python
Raw Normal View History

"""Tests for batch tool execution safety classification."""
import json
import pytest
from unittest.mock import MagicMock
def _make_tool_call(name: str, args: dict) -> MagicMock:
"""Create a mock tool call object."""
tc = MagicMock()
tc.function.name = name
tc.function.arguments = json.dumps(args)
tc.id = f"call_{name}_1"
return tc
class TestClassification:
def test_parallel_safe_read_file(self):
from tools.batch_executor import classify_single_tool_call
tc = _make_tool_call("read_file", {"path": "README.md"})
result = classify_single_tool_call(tc)
assert result.tier == "parallel_safe"
def test_parallel_safe_web_search(self):
from tools.batch_executor import classify_single_tool_call
tc = _make_tool_call("web_search", {"query": "test"})
result = classify_single_tool_call(tc)
assert result.tier == "parallel_safe"
def test_parallel_safe_search_files(self):
from tools.batch_executor import classify_single_tool_call
tc = _make_tool_call("search_files", {"pattern": "test"})
result = classify_single_tool_call(tc)
assert result.tier == "parallel_safe"
def test_never_parallel_clarify(self):
from tools.batch_executor import classify_single_tool_call
tc = _make_tool_call("clarify", {"question": "test"})
result = classify_single_tool_call(tc)
assert result.tier == "never_parallel"
def test_terminal_is_sequential(self):
from tools.batch_executor import classify_single_tool_call
tc = _make_tool_call("terminal", {"command": "ls -la"})
result = classify_single_tool_call(tc)
assert result.tier == "sequential"
def test_terminal_destructive_rm(self):
from tools.batch_executor import classify_single_tool_call
tc = _make_tool_call("terminal", {"command": "rm -rf /tmp/test"})
result = classify_single_tool_call(tc)
assert result.tier == "sequential"
assert "Destructive" in result.reason
def test_write_file_is_path_scoped(self):
from tools.batch_executor import classify_single_tool_call
tc = _make_tool_call("write_file", {"path": "/tmp/test.txt", "content": "hello"})
result = classify_single_tool_call(tc)
assert result.tier == "path_scoped"
def test_delegate_is_sequential(self):
from tools.batch_executor import classify_single_tool_call
tc = _make_tool_call("delegate_task", {"goal": "test"})
result = classify_single_tool_call(tc)
assert result.tier == "sequential"
def test_unknown_tool_is_sequential(self):
from tools.batch_executor import classify_single_tool_call
tc = _make_tool_call("some_unknown_tool", {"arg": "val"})
result = classify_single_tool_call(tc)
assert result.tier == "sequential"
class TestBatchClassification:
def test_all_parallel_stays_parallel(self):
from tools.batch_executor import classify_tool_calls
tcs = [
_make_tool_call("read_file", {"path": f"file{i}.txt"})
for i in range(5)
]
plan = classify_tool_calls(tcs)
assert plan.can_parallelize
assert len(plan.parallel_batch) == 5
assert len(plan.sequential_batch) == 0
def test_mixed_batch(self):
from tools.batch_executor import classify_tool_calls
tcs = [
_make_tool_call("read_file", {"path": "a.txt"}),
_make_tool_call("terminal", {"command": "ls"}),
_make_tool_call("web_search", {"query": "test"}),
_make_tool_call("delegate_task", {"goal": "test"}),
]
plan = classify_tool_calls(tcs)
# read_file + web_search should be parallel (both parallel_safe)
# terminal + delegate_task should be sequential
assert len(plan.parallel_batch) >= 2
assert len(plan.sequential_batch) >= 2
def test_clarify_blocks_all(self):
from tools.batch_executor import classify_tool_calls
tcs = [
_make_tool_call("read_file", {"path": "a.txt"}),
_make_tool_call("clarify", {"question": "which one?"}),
_make_tool_call("web_search", {"query": "test"}),
]
plan = classify_tool_calls(tcs)
clarify_in_seq = any(c.tool_name == "clarify" for c in plan.sequential_batch)
assert clarify_in_seq
def test_overlapping_paths_sequential(self):
from tools.batch_executor import classify_tool_calls
tcs = [
_make_tool_call("write_file", {"path": "/tmp/test/a.txt", "content": "hello"}),
_make_tool_call("patch", {"path": "/tmp/test/a.txt", "old_string": "a", "new_string": "b"}),
]
plan = classify_tool_calls(tcs)
# write_file and patch on SAME file -> conflict -> one must be sequential
assert len(plan.sequential_batch) >= 1
class TestDestructiveCommands:
def test_rm_flagged(self):
from tools.batch_executor import is_destructive_command
assert is_destructive_command("rm -rf /tmp")
assert is_destructive_command("rm file.txt")
def test_mv_flagged(self):
from tools.batch_executor import is_destructive_command
assert is_destructive_command("mv old new")
def test_sed_i_flagged(self):
from tools.batch_executor import is_destructive_command
assert is_destructive_command("sed -i 's/a/b/g' file")
def test_redirect_overwrite_flagged(self):
from tools.batch_executor import is_destructive_command
assert is_destructive_command("echo test > file.txt")
def test_safe_commands_not_flagged(self):
from tools.batch_executor import is_destructive_command
assert not is_destructive_command("ls -la")
assert not is_destructive_command("cat file.txt")
assert not is_destructive_command("echo test >> file.txt") # append is safe
class TestRegistryIntegration:
def test_parallel_safe_in_registry(self):
from tools.registry import registry
safe = registry.get_parallel_safe_tools()
assert isinstance(safe, set)