Compare commits
1 Commits
claude/iss
...
burn/798-1
| Author | SHA1 | Date | |
|---|---|---|---|
| 7c71b7e73a |
418
tests/test_parallel_tool_calling.py
Normal file
418
tests/test_parallel_tool_calling.py
Normal file
@@ -0,0 +1,418 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
test_parallel_tool_calling.py — Tests for parallel tool calling (2+ tools per response).
|
||||
|
||||
Verifies that hermes-agent correctly handles multiple tool calls in a single
|
||||
response, including ordering, dependency resolution, and parallel safety.
|
||||
|
||||
Issue #798: Gemma 4 Tool Calling Hardening
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import pytest
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch, call
|
||||
|
||||
# Add project root to path
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
|
||||
|
||||
from run_agent import (
|
||||
_should_parallelize_tool_batch,
|
||||
_extract_parallel_scope_path,
|
||||
_is_destructive_command,
|
||||
_PARALLEL_SAFE_TOOLS,
|
||||
_NEVER_PARALLEL_TOOLS,
|
||||
_PATH_SCOPED_TOOLS,
|
||||
)
|
||||
|
||||
|
||||
# ── Mock Tool Call Structure ──────────────────────────────────────────────────
|
||||
|
||||
@dataclass
|
||||
class MockFunction:
|
||||
name: str
|
||||
arguments: str
|
||||
|
||||
@dataclass
|
||||
class MockToolCall:
|
||||
id: str
|
||||
function: MockFunction
|
||||
|
||||
@classmethod
|
||||
def make(cls, name: str, args: dict, idx: int = 0):
|
||||
return cls(
|
||||
id=f"call_{idx}",
|
||||
function=MockFunction(name=name, arguments=json.dumps(args)),
|
||||
)
|
||||
|
||||
|
||||
# ── Test: _should_parallelize_tool_batch ──────────────────────────────────────
|
||||
|
||||
class TestParallelizationDecision:
|
||||
"""Test whether tool batches are correctly identified as parallel-safe."""
|
||||
|
||||
def test_single_tool_not_parallel(self):
|
||||
"""A single tool call should never be parallelized."""
|
||||
calls = [MockToolCall.make("read_file", {"path": "a.txt"})]
|
||||
assert _should_parallelize_tool_batch(calls) is False
|
||||
|
||||
def test_two_read_files_different_paths(self):
|
||||
"""Two read_file calls on different paths should parallelize."""
|
||||
calls = [
|
||||
MockToolCall.make("read_file", {"path": "a.txt"}, 0),
|
||||
MockToolCall.make("read_file", {"path": "b.txt"}, 1),
|
||||
]
|
||||
assert _should_parallelize_tool_batch(calls) is True
|
||||
|
||||
def test_two_read_files_same_path(self):
|
||||
"""Two read_file calls on the same path should NOT parallelize."""
|
||||
calls = [
|
||||
MockToolCall.make("read_file", {"path": "a.txt"}, 0),
|
||||
MockToolCall.make("read_file", {"path": "a.txt"}, 1),
|
||||
]
|
||||
assert _should_parallelize_tool_batch(calls) is False
|
||||
|
||||
def test_read_plus_search_parallel(self):
|
||||
"""read_file + search_files should parallelize (both safe, different scopes)."""
|
||||
calls = [
|
||||
MockToolCall.make("read_file", {"path": "a.txt"}, 0),
|
||||
MockToolCall.make("search_files", {"pattern": "foo"}, 1),
|
||||
]
|
||||
assert _should_parallelize_tool_batch(calls) is True
|
||||
|
||||
def test_clarify_never_parallel(self):
|
||||
"""clarify tool should block parallelization."""
|
||||
calls = [
|
||||
MockToolCall.make("read_file", {"path": "a.txt"}, 0),
|
||||
MockToolCall.make("clarify", {"question": "what?"}, 1),
|
||||
]
|
||||
assert _should_parallelize_tool_batch(calls) is False
|
||||
|
||||
def test_three_read_files_all_different(self):
|
||||
"""Three read_file calls on different paths should parallelize."""
|
||||
calls = [
|
||||
MockToolCall.make("read_file", {"path": f"file{i}.txt"}, i)
|
||||
for i in range(3)
|
||||
]
|
||||
assert _should_parallelize_tool_batch(calls) is True
|
||||
|
||||
def test_write_plus_read_same_path(self):
|
||||
"""write_file + read_file on same path should NOT parallelize."""
|
||||
calls = [
|
||||
MockToolCall.make("read_file", {"path": "a.txt"}, 0),
|
||||
MockToolCall.make("write_file", {"path": "a.txt", "content": "new"}, 1),
|
||||
]
|
||||
assert _should_parallelize_tool_batch(calls) is False
|
||||
|
||||
def test_write_plus_read_different_paths(self):
|
||||
"""write_file + read_file on different paths should parallelize."""
|
||||
calls = [
|
||||
MockToolCall.make("read_file", {"path": "a.txt"}, 0),
|
||||
MockToolCall.make("write_file", {"path": "b.txt", "content": "new"}, 1),
|
||||
]
|
||||
assert _should_parallelize_tool_batch(calls) is True
|
||||
|
||||
def test_unsafe_tool_blocks_parallel(self):
|
||||
"""A tool not in _PARALLEL_SAFE_TOOLS or _PATH_SCOPED_TOOLS blocks parallel."""
|
||||
calls = [
|
||||
MockToolCall.make("read_file", {"path": "a.txt"}, 0),
|
||||
MockToolCall.make("some_unknown_tool", {"param": "value"}, 1),
|
||||
]
|
||||
assert _should_parallelize_tool_batch(calls) is False
|
||||
|
||||
def test_all_safe_tools(self):
|
||||
"""All tools in _PARALLEL_SAFE_TOOLS should parallelize together."""
|
||||
calls = [
|
||||
MockToolCall.make("web_search", {"query": "test"}, 0),
|
||||
MockToolCall.make("session_search", {"query": "test"}, 1),
|
||||
MockToolCall.make("skills_list", {}, 2),
|
||||
]
|
||||
assert _should_parallelize_tool_batch(calls) is True
|
||||
|
||||
def test_malformed_json_args(self):
|
||||
"""Malformed JSON arguments should block parallelization."""
|
||||
tc = MockToolCall(id="call_0", function=MockFunction(
|
||||
name="read_file", arguments="not json"
|
||||
))
|
||||
calls = [MockToolCall.make("read_file", {"path": "a.txt"}, 1), tc]
|
||||
assert _should_parallelize_tool_batch(calls) is False
|
||||
|
||||
def test_non_dict_args(self):
|
||||
"""Non-dict arguments should block parallelization."""
|
||||
tc = MockToolCall(id="call_0", function=MockFunction(
|
||||
name="read_file", arguments='"just a string"'
|
||||
))
|
||||
calls = [MockToolCall.make("read_file", {"path": "a.txt"}, 1), tc]
|
||||
assert _should_parallelize_tool_batch(calls) is False
|
||||
|
||||
|
||||
# ── Test: Path Scope Extraction ──────────────────────────────────────────────
|
||||
|
||||
class TestPathScopeExtraction:
|
||||
"""Test path extraction for scoped parallel tools."""
|
||||
|
||||
def test_relative_path(self):
|
||||
result = _extract_parallel_scope_path("read_file", {"path": "foo/bar.txt"})
|
||||
assert result is not None
|
||||
assert "bar.txt" in str(result)
|
||||
|
||||
def test_absolute_path(self):
|
||||
result = _extract_parallel_scope_path("read_file", {"path": "/tmp/test.txt"})
|
||||
assert result == Path("/tmp/test.txt")
|
||||
|
||||
def test_home_expansion(self):
|
||||
result = _extract_parallel_scope_path("read_file", {"path": "~/test.txt"})
|
||||
assert result is not None
|
||||
assert str(result).endswith("test.txt")
|
||||
|
||||
def test_missing_path(self):
|
||||
result = _extract_parallel_scope_path("read_file", {})
|
||||
assert result is None
|
||||
|
||||
def test_empty_path(self):
|
||||
result = _extract_parallel_scope_path("read_file", {"path": " "})
|
||||
assert result is None
|
||||
|
||||
def test_non_scoped_tool(self):
|
||||
result = _extract_parallel_scope_path("web_search", {"path": "foo"})
|
||||
assert result is None
|
||||
|
||||
|
||||
# ── Test: Destructive Command Detection ───────────────────────────────────────
|
||||
|
||||
class TestDestructiveCommands:
|
||||
"""Test detection of destructive terminal commands."""
|
||||
|
||||
def test_rm_is_destructive(self):
|
||||
assert _is_destructive_command("rm -rf /tmp/foo") is True
|
||||
|
||||
def test_mv_is_destructive(self):
|
||||
assert _is_destructive_command("mv old.txt new.txt") is True
|
||||
|
||||
def test_sed_inplace(self):
|
||||
assert _is_destructive_command("sed -i 's/foo/bar/g' file.txt") is True
|
||||
|
||||
def test_cat_is_safe(self):
|
||||
assert _is_destructive_command("cat file.txt") is False
|
||||
|
||||
def test_echo_redirect_overwrite(self):
|
||||
assert _is_destructive_command("echo hello > file.txt") is True
|
||||
|
||||
def test_echo_redirect_append(self):
|
||||
assert _is_destructive_command("echo hello >> file.txt") is False
|
||||
|
||||
def test_git_reset(self):
|
||||
assert _is_destructive_command("git reset --hard HEAD") is True
|
||||
|
||||
def test_git_status_safe(self):
|
||||
assert _is_destructive_command("git status") is False
|
||||
|
||||
def test_piped_rm(self):
|
||||
assert _is_destructive_command("echo foo | rm file.txt") is True
|
||||
|
||||
def test_chained_safe(self):
|
||||
assert _is_destructive_command("ls && echo done") is False
|
||||
|
||||
|
||||
# ── Test: Parallel Safe Tools Registry ────────────────────────────────────────
|
||||
|
||||
class TestParallelSafeRegistry:
|
||||
"""Test the tool classification sets."""
|
||||
|
||||
def test_clarify_in_never_parallel(self):
|
||||
assert "clarify" in _NEVER_PARALLEL_TOOLS
|
||||
|
||||
def test_read_file_in_safe(self):
|
||||
assert "read_file" in _PARALLEL_SAFE_TOOLS
|
||||
|
||||
def test_read_file_in_path_scoped(self):
|
||||
assert "read_file" in _PATH_SCOPED_TOOLS
|
||||
|
||||
def test_write_file_in_path_scoped(self):
|
||||
assert "write_file" in _PATH_SCOPED_TOOLS
|
||||
|
||||
def test_web_search_in_safe(self):
|
||||
assert "web_search" in _PARALLEL_SAFE_TOOLS
|
||||
|
||||
def test_no_overlap_between_never_and_safe(self):
|
||||
assert not (_NEVER_PARALLEL_TOOLS & _PARALLEL_SAFE_TOOLS)
|
||||
|
||||
|
||||
# ── Test: Batch Sizes (2, 3, 4 tools) ───────────────────────────────────────
|
||||
|
||||
class TestBatchSizes:
|
||||
"""Test parallelization with different batch sizes (2, 3, 4 tools)."""
|
||||
|
||||
def test_two_tool_batch(self):
|
||||
calls = [
|
||||
MockToolCall.make("read_file", {"path": "a.txt"}, 0),
|
||||
MockToolCall.make("read_file", {"path": "b.txt"}, 1),
|
||||
]
|
||||
assert _should_parallelize_tool_batch(calls) is True
|
||||
|
||||
def test_three_tool_batch(self):
|
||||
calls = [
|
||||
MockToolCall.make("read_file", {"path": f"f{i}.txt"}, i)
|
||||
for i in range(3)
|
||||
]
|
||||
assert _should_parallelize_tool_batch(calls) is True
|
||||
|
||||
def test_four_tool_batch(self):
|
||||
calls = [
|
||||
MockToolCall.make("web_search", {"query": f"q{i}"}, i)
|
||||
for i in range(4)
|
||||
]
|
||||
assert _should_parallelize_tool_batch(calls) is True
|
||||
|
||||
def test_four_tool_batch_with_one_collision(self):
|
||||
"""4 tools where 2 collide on the same path."""
|
||||
calls = [
|
||||
MockToolCall.make("read_file", {"path": "a.txt"}, 0),
|
||||
MockToolCall.make("read_file", {"path": "b.txt"}, 1),
|
||||
MockToolCall.make("read_file", {"path": "a.txt"}, 2), # collision
|
||||
MockToolCall.make("read_file", {"path": "c.txt"}, 3),
|
||||
]
|
||||
assert _should_parallelize_tool_batch(calls) is False
|
||||
|
||||
|
||||
# ── Test: Gemma 4 Specific Patterns ──────────────────────────────────────────
|
||||
|
||||
class TestGemma4Patterns:
|
||||
"""
|
||||
Test patterns specific to Gemma 4 tool calling behavior.
|
||||
|
||||
Gemma 4 may issue tool calls in specific ordering patterns that
|
||||
need to be handled correctly by the parallel execution layer.
|
||||
"""
|
||||
|
||||
def test_gemma4_typical_2tool_pattern(self):
|
||||
"""Gemma 4 typically issues read+search as a pair."""
|
||||
calls = [
|
||||
MockToolCall.make("read_file", {"path": "config.yaml"}, 0),
|
||||
MockToolCall.make("search_files", {"pattern": "provider"}, 1),
|
||||
]
|
||||
# These should parallelize — different tools, no path conflict
|
||||
assert _should_parallelize_tool_batch(calls) is True
|
||||
|
||||
def test_gemma4_typical_3tool_pattern(self):
|
||||
"""Gemma 4 may issue 3 reads for different files."""
|
||||
calls = [
|
||||
MockToolCall.make("read_file", {"path": "a.py"}, 0),
|
||||
MockToolCall.make("read_file", {"path": "b.py"}, 1),
|
||||
MockToolCall.make("read_file", {"path": "c.py"}, 2),
|
||||
]
|
||||
assert _should_parallelize_tool_batch(calls) is True
|
||||
|
||||
def test_gemma4_sequential_dependency(self):
|
||||
"""
|
||||
Gemma 4 may issue: search_files then read_file on search result.
|
||||
These have implicit dependency but are issued as a batch.
|
||||
The agent should handle this — search first, then read.
|
||||
This test verifies the batch IS marked as parallel-safe
|
||||
(ordering is the agent loop's responsibility, not this function's).
|
||||
"""
|
||||
calls = [
|
||||
MockToolCall.make("search_files", {"pattern": "import"}, 0),
|
||||
MockToolCall.make("read_file", {"path": "main.py"}, 1),
|
||||
]
|
||||
# Both tools are in safe/scoped sets with no path conflict
|
||||
assert _should_parallelize_tool_batch(calls) is True
|
||||
|
||||
def test_gemma4_mixed_safe_unsafe(self):
|
||||
"""Gemma 4 may mix read (safe) with write (path-scoped)."""
|
||||
calls = [
|
||||
MockToolCall.make("read_file", {"path": "input.txt"}, 0),
|
||||
MockToolCall.make("write_file", {"path": "output.txt", "content": "x"}, 1),
|
||||
MockToolCall.make("read_file", {"path": "config.txt"}, 2),
|
||||
]
|
||||
# All path-scoped on different paths, no unsafe tools
|
||||
assert _should_parallelize_tool_batch(calls) is True
|
||||
|
||||
def test_gemma4_terminal_parallel(self):
|
||||
"""
|
||||
Terminal commands are NOT in _PARALLEL_SAFE_TOOLS.
|
||||
If Gemma 4 issues 2 terminal calls, they should NOT parallelize.
|
||||
"""
|
||||
calls = [
|
||||
MockToolCall.make("terminal", {"command": "ls"}, 0),
|
||||
MockToolCall.make("terminal", {"command": "pwd"}, 1),
|
||||
]
|
||||
assert _should_parallelize_tool_batch(calls) is False
|
||||
|
||||
|
||||
# ── Test: Integration-style (mocked) ─────────────────────────────────────────
|
||||
|
||||
class TestParallelExecutionMocked:
|
||||
"""Test the parallel execution path with mocked tool handlers."""
|
||||
|
||||
def test_parallel_results_collected(self):
|
||||
"""Simulate parallel execution and verify results are collected."""
|
||||
# Mock two tool calls returning different results
|
||||
results = {}
|
||||
|
||||
def mock_handler(name, args):
|
||||
return f"result_{name}_{args.get('path', 'x')}"
|
||||
|
||||
calls = [
|
||||
MockToolCall.make("read_file", {"path": "a.txt"}, 0),
|
||||
MockToolCall.make("read_file", {"path": "b.txt"}, 1),
|
||||
]
|
||||
|
||||
# Simulate parallel execution
|
||||
for tc in calls:
|
||||
results[tc.id] = mock_handler(tc.function.name,
|
||||
json.loads(tc.function.arguments))
|
||||
|
||||
assert results["call_0"] == "result_read_file_a.txt"
|
||||
assert results["call_1"] == "result_read_file_b.txt"
|
||||
|
||||
def test_parallel_results_order_preserved(self):
|
||||
"""Results should be ordered by tool call ID, not completion time."""
|
||||
import time
|
||||
results = {}
|
||||
|
||||
calls = [
|
||||
MockToolCall.make("read_file", {"path": "slow.txt"}, 0),
|
||||
MockToolCall.make("read_file", {"path": "fast.txt"}, 1),
|
||||
]
|
||||
|
||||
# Simulate out-of-order completion
|
||||
results["call_1"] = "fast_result"
|
||||
results["call_0"] = "slow_result"
|
||||
|
||||
# Verify we can reconstruct in order
|
||||
ordered = [results[tc.id] for tc in calls]
|
||||
assert ordered == ["slow_result", "fast_result"]
|
||||
|
||||
|
||||
# ── Test: Edge Cases ──────────────────────────────────────────────────────────
|
||||
|
||||
class TestEdgeCases:
|
||||
"""Edge cases for parallel tool calling."""
|
||||
|
||||
def test_empty_batch(self):
|
||||
assert _should_parallelize_tool_batch([]) is False
|
||||
|
||||
def test_patch_with_same_path(self):
|
||||
"""Two patch calls on the same file should NOT parallelize."""
|
||||
calls = [
|
||||
MockToolCall.make("patch", {"path": "a.py", "old_string": "x", "new_string": "y"}, 0),
|
||||
MockToolCall.make("patch", {"path": "a.py", "old_string": "a", "new_string": "b"}, 1),
|
||||
]
|
||||
assert _should_parallelize_tool_batch(calls) is False
|
||||
|
||||
def test_patch_different_paths(self):
|
||||
"""patch on different files should parallelize."""
|
||||
calls = [
|
||||
MockToolCall.make("patch", {"path": "a.py", "old_string": "x", "new_string": "y"}, 0),
|
||||
MockToolCall.make("patch", {"path": "b.py", "old_string": "a", "new_string": "b"}, 1),
|
||||
]
|
||||
assert _should_parallelize_tool_batch(calls) is True
|
||||
|
||||
def test_max_workers_defined(self):
|
||||
"""Verify max workers constant exists and is reasonable."""
|
||||
from run_agent import _MAX_TOOL_WORKERS
|
||||
assert 1 <= _MAX_TOOL_WORKERS <= 32
|
||||
Reference in New Issue
Block a user