diff --git a/tests/test_parallel_tool_calling.py b/tests/test_parallel_tool_calling.py new file mode 100644 index 000000000..739be52a9 --- /dev/null +++ b/tests/test_parallel_tool_calling.py @@ -0,0 +1,418 @@ +#!/usr/bin/env python3 +""" +test_parallel_tool_calling.py — Tests for parallel tool calling (2+ tools per response). + +Verifies that hermes-agent correctly handles multiple tool calls in a single +response, including ordering, dependency resolution, and parallel safety. + +Issue #798: Gemma 4 Tool Calling Hardening +""" + +import json +import os +import sys +import pytest +from dataclasses import dataclass +from pathlib import Path +from unittest.mock import MagicMock, patch, call + +# Add project root to path +sys.path.insert(0, os.path.dirname(os.path.dirname(__file__))) + +from run_agent import ( + _should_parallelize_tool_batch, + _extract_parallel_scope_path, + _is_destructive_command, + _PARALLEL_SAFE_TOOLS, + _NEVER_PARALLEL_TOOLS, + _PATH_SCOPED_TOOLS, +) + + +# ── Mock Tool Call Structure ────────────────────────────────────────────────── + +@dataclass +class MockFunction: + name: str + arguments: str + +@dataclass +class MockToolCall: + id: str + function: MockFunction + + @classmethod + def make(cls, name: str, args: dict, idx: int = 0): + return cls( + id=f"call_{idx}", + function=MockFunction(name=name, arguments=json.dumps(args)), + ) + + +# ── Test: _should_parallelize_tool_batch ────────────────────────────────────── + +class TestParallelizationDecision: + """Test whether tool batches are correctly identified as parallel-safe.""" + + def test_single_tool_not_parallel(self): + """A single tool call should never be parallelized.""" + calls = [MockToolCall.make("read_file", {"path": "a.txt"})] + assert _should_parallelize_tool_batch(calls) is False + + def test_two_read_files_different_paths(self): + """Two read_file calls on different paths should parallelize.""" + calls = [ + MockToolCall.make("read_file", {"path": "a.txt"}, 0), + MockToolCall.make("read_file", {"path": "b.txt"}, 1), + ] + assert _should_parallelize_tool_batch(calls) is True + + def test_two_read_files_same_path(self): + """Two read_file calls on the same path should NOT parallelize.""" + calls = [ + MockToolCall.make("read_file", {"path": "a.txt"}, 0), + MockToolCall.make("read_file", {"path": "a.txt"}, 1), + ] + assert _should_parallelize_tool_batch(calls) is False + + def test_read_plus_search_parallel(self): + """read_file + search_files should parallelize (both safe, different scopes).""" + calls = [ + MockToolCall.make("read_file", {"path": "a.txt"}, 0), + MockToolCall.make("search_files", {"pattern": "foo"}, 1), + ] + assert _should_parallelize_tool_batch(calls) is True + + def test_clarify_never_parallel(self): + """clarify tool should block parallelization.""" + calls = [ + MockToolCall.make("read_file", {"path": "a.txt"}, 0), + MockToolCall.make("clarify", {"question": "what?"}, 1), + ] + assert _should_parallelize_tool_batch(calls) is False + + def test_three_read_files_all_different(self): + """Three read_file calls on different paths should parallelize.""" + calls = [ + MockToolCall.make("read_file", {"path": f"file{i}.txt"}, i) + for i in range(3) + ] + assert _should_parallelize_tool_batch(calls) is True + + def test_write_plus_read_same_path(self): + """write_file + read_file on same path should NOT parallelize.""" + calls = [ + MockToolCall.make("read_file", {"path": "a.txt"}, 0), + MockToolCall.make("write_file", {"path": "a.txt", "content": "new"}, 1), + ] + assert _should_parallelize_tool_batch(calls) is False + + def test_write_plus_read_different_paths(self): + """write_file + read_file on different paths should parallelize.""" + calls = [ + MockToolCall.make("read_file", {"path": "a.txt"}, 0), + MockToolCall.make("write_file", {"path": "b.txt", "content": "new"}, 1), + ] + assert _should_parallelize_tool_batch(calls) is True + + def test_unsafe_tool_blocks_parallel(self): + """A tool not in _PARALLEL_SAFE_TOOLS or _PATH_SCOPED_TOOLS blocks parallel.""" + calls = [ + MockToolCall.make("read_file", {"path": "a.txt"}, 0), + MockToolCall.make("some_unknown_tool", {"param": "value"}, 1), + ] + assert _should_parallelize_tool_batch(calls) is False + + def test_all_safe_tools(self): + """All tools in _PARALLEL_SAFE_TOOLS should parallelize together.""" + calls = [ + MockToolCall.make("web_search", {"query": "test"}, 0), + MockToolCall.make("session_search", {"query": "test"}, 1), + MockToolCall.make("skills_list", {}, 2), + ] + assert _should_parallelize_tool_batch(calls) is True + + def test_malformed_json_args(self): + """Malformed JSON arguments should block parallelization.""" + tc = MockToolCall(id="call_0", function=MockFunction( + name="read_file", arguments="not json" + )) + calls = [MockToolCall.make("read_file", {"path": "a.txt"}, 1), tc] + assert _should_parallelize_tool_batch(calls) is False + + def test_non_dict_args(self): + """Non-dict arguments should block parallelization.""" + tc = MockToolCall(id="call_0", function=MockFunction( + name="read_file", arguments='"just a string"' + )) + calls = [MockToolCall.make("read_file", {"path": "a.txt"}, 1), tc] + assert _should_parallelize_tool_batch(calls) is False + + +# ── Test: Path Scope Extraction ────────────────────────────────────────────── + +class TestPathScopeExtraction: + """Test path extraction for scoped parallel tools.""" + + def test_relative_path(self): + result = _extract_parallel_scope_path("read_file", {"path": "foo/bar.txt"}) + assert result is not None + assert "bar.txt" in str(result) + + def test_absolute_path(self): + result = _extract_parallel_scope_path("read_file", {"path": "/tmp/test.txt"}) + assert result == Path("/tmp/test.txt") + + def test_home_expansion(self): + result = _extract_parallel_scope_path("read_file", {"path": "~/test.txt"}) + assert result is not None + assert str(result).endswith("test.txt") + + def test_missing_path(self): + result = _extract_parallel_scope_path("read_file", {}) + assert result is None + + def test_empty_path(self): + result = _extract_parallel_scope_path("read_file", {"path": " "}) + assert result is None + + def test_non_scoped_tool(self): + result = _extract_parallel_scope_path("web_search", {"path": "foo"}) + assert result is None + + +# ── Test: Destructive Command Detection ─────────────────────────────────────── + +class TestDestructiveCommands: + """Test detection of destructive terminal commands.""" + + def test_rm_is_destructive(self): + assert _is_destructive_command("rm -rf /tmp/foo") is True + + def test_mv_is_destructive(self): + assert _is_destructive_command("mv old.txt new.txt") is True + + def test_sed_inplace(self): + assert _is_destructive_command("sed -i 's/foo/bar/g' file.txt") is True + + def test_cat_is_safe(self): + assert _is_destructive_command("cat file.txt") is False + + def test_echo_redirect_overwrite(self): + assert _is_destructive_command("echo hello > file.txt") is True + + def test_echo_redirect_append(self): + assert _is_destructive_command("echo hello >> file.txt") is False + + def test_git_reset(self): + assert _is_destructive_command("git reset --hard HEAD") is True + + def test_git_status_safe(self): + assert _is_destructive_command("git status") is False + + def test_piped_rm(self): + assert _is_destructive_command("echo foo | rm file.txt") is True + + def test_chained_safe(self): + assert _is_destructive_command("ls && echo done") is False + + +# ── Test: Parallel Safe Tools Registry ──────────────────────────────────────── + +class TestParallelSafeRegistry: + """Test the tool classification sets.""" + + def test_clarify_in_never_parallel(self): + assert "clarify" in _NEVER_PARALLEL_TOOLS + + def test_read_file_in_safe(self): + assert "read_file" in _PARALLEL_SAFE_TOOLS + + def test_read_file_in_path_scoped(self): + assert "read_file" in _PATH_SCOPED_TOOLS + + def test_write_file_in_path_scoped(self): + assert "write_file" in _PATH_SCOPED_TOOLS + + def test_web_search_in_safe(self): + assert "web_search" in _PARALLEL_SAFE_TOOLS + + def test_no_overlap_between_never_and_safe(self): + assert not (_NEVER_PARALLEL_TOOLS & _PARALLEL_SAFE_TOOLS) + + +# ── Test: Batch Sizes (2, 3, 4 tools) ─────────────────────────────────────── + +class TestBatchSizes: + """Test parallelization with different batch sizes (2, 3, 4 tools).""" + + def test_two_tool_batch(self): + calls = [ + MockToolCall.make("read_file", {"path": "a.txt"}, 0), + MockToolCall.make("read_file", {"path": "b.txt"}, 1), + ] + assert _should_parallelize_tool_batch(calls) is True + + def test_three_tool_batch(self): + calls = [ + MockToolCall.make("read_file", {"path": f"f{i}.txt"}, i) + for i in range(3) + ] + assert _should_parallelize_tool_batch(calls) is True + + def test_four_tool_batch(self): + calls = [ + MockToolCall.make("web_search", {"query": f"q{i}"}, i) + for i in range(4) + ] + assert _should_parallelize_tool_batch(calls) is True + + def test_four_tool_batch_with_one_collision(self): + """4 tools where 2 collide on the same path.""" + calls = [ + MockToolCall.make("read_file", {"path": "a.txt"}, 0), + MockToolCall.make("read_file", {"path": "b.txt"}, 1), + MockToolCall.make("read_file", {"path": "a.txt"}, 2), # collision + MockToolCall.make("read_file", {"path": "c.txt"}, 3), + ] + assert _should_parallelize_tool_batch(calls) is False + + +# ── Test: Gemma 4 Specific Patterns ────────────────────────────────────────── + +class TestGemma4Patterns: + """ + Test patterns specific to Gemma 4 tool calling behavior. + + Gemma 4 may issue tool calls in specific ordering patterns that + need to be handled correctly by the parallel execution layer. + """ + + def test_gemma4_typical_2tool_pattern(self): + """Gemma 4 typically issues read+search as a pair.""" + calls = [ + MockToolCall.make("read_file", {"path": "config.yaml"}, 0), + MockToolCall.make("search_files", {"pattern": "provider"}, 1), + ] + # These should parallelize — different tools, no path conflict + assert _should_parallelize_tool_batch(calls) is True + + def test_gemma4_typical_3tool_pattern(self): + """Gemma 4 may issue 3 reads for different files.""" + calls = [ + MockToolCall.make("read_file", {"path": "a.py"}, 0), + MockToolCall.make("read_file", {"path": "b.py"}, 1), + MockToolCall.make("read_file", {"path": "c.py"}, 2), + ] + assert _should_parallelize_tool_batch(calls) is True + + def test_gemma4_sequential_dependency(self): + """ + Gemma 4 may issue: search_files then read_file on search result. + These have implicit dependency but are issued as a batch. + The agent should handle this — search first, then read. + This test verifies the batch IS marked as parallel-safe + (ordering is the agent loop's responsibility, not this function's). + """ + calls = [ + MockToolCall.make("search_files", {"pattern": "import"}, 0), + MockToolCall.make("read_file", {"path": "main.py"}, 1), + ] + # Both tools are in safe/scoped sets with no path conflict + assert _should_parallelize_tool_batch(calls) is True + + def test_gemma4_mixed_safe_unsafe(self): + """Gemma 4 may mix read (safe) with write (path-scoped).""" + calls = [ + MockToolCall.make("read_file", {"path": "input.txt"}, 0), + MockToolCall.make("write_file", {"path": "output.txt", "content": "x"}, 1), + MockToolCall.make("read_file", {"path": "config.txt"}, 2), + ] + # All path-scoped on different paths, no unsafe tools + assert _should_parallelize_tool_batch(calls) is True + + def test_gemma4_terminal_parallel(self): + """ + Terminal commands are NOT in _PARALLEL_SAFE_TOOLS. + If Gemma 4 issues 2 terminal calls, they should NOT parallelize. + """ + calls = [ + MockToolCall.make("terminal", {"command": "ls"}, 0), + MockToolCall.make("terminal", {"command": "pwd"}, 1), + ] + assert _should_parallelize_tool_batch(calls) is False + + +# ── Test: Integration-style (mocked) ───────────────────────────────────────── + +class TestParallelExecutionMocked: + """Test the parallel execution path with mocked tool handlers.""" + + def test_parallel_results_collected(self): + """Simulate parallel execution and verify results are collected.""" + # Mock two tool calls returning different results + results = {} + + def mock_handler(name, args): + return f"result_{name}_{args.get('path', 'x')}" + + calls = [ + MockToolCall.make("read_file", {"path": "a.txt"}, 0), + MockToolCall.make("read_file", {"path": "b.txt"}, 1), + ] + + # Simulate parallel execution + for tc in calls: + results[tc.id] = mock_handler(tc.function.name, + json.loads(tc.function.arguments)) + + assert results["call_0"] == "result_read_file_a.txt" + assert results["call_1"] == "result_read_file_b.txt" + + def test_parallel_results_order_preserved(self): + """Results should be ordered by tool call ID, not completion time.""" + import time + results = {} + + calls = [ + MockToolCall.make("read_file", {"path": "slow.txt"}, 0), + MockToolCall.make("read_file", {"path": "fast.txt"}, 1), + ] + + # Simulate out-of-order completion + results["call_1"] = "fast_result" + results["call_0"] = "slow_result" + + # Verify we can reconstruct in order + ordered = [results[tc.id] for tc in calls] + assert ordered == ["slow_result", "fast_result"] + + +# ── Test: Edge Cases ────────────────────────────────────────────────────────── + +class TestEdgeCases: + """Edge cases for parallel tool calling.""" + + def test_empty_batch(self): + assert _should_parallelize_tool_batch([]) is False + + def test_patch_with_same_path(self): + """Two patch calls on the same file should NOT parallelize.""" + calls = [ + MockToolCall.make("patch", {"path": "a.py", "old_string": "x", "new_string": "y"}, 0), + MockToolCall.make("patch", {"path": "a.py", "old_string": "a", "new_string": "b"}, 1), + ] + assert _should_parallelize_tool_batch(calls) is False + + def test_patch_different_paths(self): + """patch on different files should parallelize.""" + calls = [ + MockToolCall.make("patch", {"path": "a.py", "old_string": "x", "new_string": "y"}, 0), + MockToolCall.make("patch", {"path": "b.py", "old_string": "a", "new_string": "b"}, 1), + ] + assert _should_parallelize_tool_batch(calls) is True + + def test_max_workers_defined(self): + """Verify max workers constant exists and is reasonable.""" + from run_agent import _MAX_TOOL_WORKERS + assert 1 <= _MAX_TOOL_WORKERS <= 32