Some checks failed
Contributor Attribution Check / check-attribution (pull_request) Successful in 45s
Docker Build and Publish / build-and-push (pull_request) Has been skipped
Supply Chain Audit / Scan PR for supply chain risks (pull_request) Successful in 1m16s
Tests / e2e (pull_request) Successful in 3m17s
Tests / test (pull_request) Failing after 1h30m54s
419 lines
16 KiB
Python
419 lines
16 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
test_parallel_tool_calling.py — Tests for parallel tool calling (2+ tools per response).
|
|
|
|
Verifies that hermes-agent correctly handles multiple tool calls in a single
|
|
response, including ordering, dependency resolution, and parallel safety.
|
|
|
|
Issue #798: Gemma 4 Tool Calling Hardening
|
|
"""
|
|
|
|
import json
|
|
import os
|
|
import sys
|
|
import pytest
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from unittest.mock import MagicMock, patch, call
|
|
|
|
# Add project root to path
|
|
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
|
|
|
|
from run_agent import (
|
|
_should_parallelize_tool_batch,
|
|
_extract_parallel_scope_path,
|
|
_is_destructive_command,
|
|
_PARALLEL_SAFE_TOOLS,
|
|
_NEVER_PARALLEL_TOOLS,
|
|
_PATH_SCOPED_TOOLS,
|
|
)
|
|
|
|
|
|
# ── Mock Tool Call Structure ──────────────────────────────────────────────────
|
|
|
|
@dataclass
|
|
class MockFunction:
|
|
name: str
|
|
arguments: str
|
|
|
|
@dataclass
|
|
class MockToolCall:
|
|
id: str
|
|
function: MockFunction
|
|
|
|
@classmethod
|
|
def make(cls, name: str, args: dict, idx: int = 0):
|
|
return cls(
|
|
id=f"call_{idx}",
|
|
function=MockFunction(name=name, arguments=json.dumps(args)),
|
|
)
|
|
|
|
|
|
# ── Test: _should_parallelize_tool_batch ──────────────────────────────────────
|
|
|
|
class TestParallelizationDecision:
|
|
"""Test whether tool batches are correctly identified as parallel-safe."""
|
|
|
|
def test_single_tool_not_parallel(self):
|
|
"""A single tool call should never be parallelized."""
|
|
calls = [MockToolCall.make("read_file", {"path": "a.txt"})]
|
|
assert _should_parallelize_tool_batch(calls) is False
|
|
|
|
def test_two_read_files_different_paths(self):
|
|
"""Two read_file calls on different paths should parallelize."""
|
|
calls = [
|
|
MockToolCall.make("read_file", {"path": "a.txt"}, 0),
|
|
MockToolCall.make("read_file", {"path": "b.txt"}, 1),
|
|
]
|
|
assert _should_parallelize_tool_batch(calls) is True
|
|
|
|
def test_two_read_files_same_path(self):
|
|
"""Two read_file calls on the same path should NOT parallelize."""
|
|
calls = [
|
|
MockToolCall.make("read_file", {"path": "a.txt"}, 0),
|
|
MockToolCall.make("read_file", {"path": "a.txt"}, 1),
|
|
]
|
|
assert _should_parallelize_tool_batch(calls) is False
|
|
|
|
def test_read_plus_search_parallel(self):
|
|
"""read_file + search_files should parallelize (both safe, different scopes)."""
|
|
calls = [
|
|
MockToolCall.make("read_file", {"path": "a.txt"}, 0),
|
|
MockToolCall.make("search_files", {"pattern": "foo"}, 1),
|
|
]
|
|
assert _should_parallelize_tool_batch(calls) is True
|
|
|
|
def test_clarify_never_parallel(self):
|
|
"""clarify tool should block parallelization."""
|
|
calls = [
|
|
MockToolCall.make("read_file", {"path": "a.txt"}, 0),
|
|
MockToolCall.make("clarify", {"question": "what?"}, 1),
|
|
]
|
|
assert _should_parallelize_tool_batch(calls) is False
|
|
|
|
def test_three_read_files_all_different(self):
|
|
"""Three read_file calls on different paths should parallelize."""
|
|
calls = [
|
|
MockToolCall.make("read_file", {"path": f"file{i}.txt"}, i)
|
|
for i in range(3)
|
|
]
|
|
assert _should_parallelize_tool_batch(calls) is True
|
|
|
|
def test_write_plus_read_same_path(self):
|
|
"""write_file + read_file on same path should NOT parallelize."""
|
|
calls = [
|
|
MockToolCall.make("read_file", {"path": "a.txt"}, 0),
|
|
MockToolCall.make("write_file", {"path": "a.txt", "content": "new"}, 1),
|
|
]
|
|
assert _should_parallelize_tool_batch(calls) is False
|
|
|
|
def test_write_plus_read_different_paths(self):
|
|
"""write_file + read_file on different paths should parallelize."""
|
|
calls = [
|
|
MockToolCall.make("read_file", {"path": "a.txt"}, 0),
|
|
MockToolCall.make("write_file", {"path": "b.txt", "content": "new"}, 1),
|
|
]
|
|
assert _should_parallelize_tool_batch(calls) is True
|
|
|
|
def test_unsafe_tool_blocks_parallel(self):
|
|
"""A tool not in _PARALLEL_SAFE_TOOLS or _PATH_SCOPED_TOOLS blocks parallel."""
|
|
calls = [
|
|
MockToolCall.make("read_file", {"path": "a.txt"}, 0),
|
|
MockToolCall.make("some_unknown_tool", {"param": "value"}, 1),
|
|
]
|
|
assert _should_parallelize_tool_batch(calls) is False
|
|
|
|
def test_all_safe_tools(self):
|
|
"""All tools in _PARALLEL_SAFE_TOOLS should parallelize together."""
|
|
calls = [
|
|
MockToolCall.make("web_search", {"query": "test"}, 0),
|
|
MockToolCall.make("session_search", {"query": "test"}, 1),
|
|
MockToolCall.make("skills_list", {}, 2),
|
|
]
|
|
assert _should_parallelize_tool_batch(calls) is True
|
|
|
|
def test_malformed_json_args(self):
|
|
"""Malformed JSON arguments should block parallelization."""
|
|
tc = MockToolCall(id="call_0", function=MockFunction(
|
|
name="read_file", arguments="not json"
|
|
))
|
|
calls = [MockToolCall.make("read_file", {"path": "a.txt"}, 1), tc]
|
|
assert _should_parallelize_tool_batch(calls) is False
|
|
|
|
def test_non_dict_args(self):
|
|
"""Non-dict arguments should block parallelization."""
|
|
tc = MockToolCall(id="call_0", function=MockFunction(
|
|
name="read_file", arguments='"just a string"'
|
|
))
|
|
calls = [MockToolCall.make("read_file", {"path": "a.txt"}, 1), tc]
|
|
assert _should_parallelize_tool_batch(calls) is False
|
|
|
|
|
|
# ── Test: Path Scope Extraction ──────────────────────────────────────────────
|
|
|
|
class TestPathScopeExtraction:
|
|
"""Test path extraction for scoped parallel tools."""
|
|
|
|
def test_relative_path(self):
|
|
result = _extract_parallel_scope_path("read_file", {"path": "foo/bar.txt"})
|
|
assert result is not None
|
|
assert "bar.txt" in str(result)
|
|
|
|
def test_absolute_path(self):
|
|
result = _extract_parallel_scope_path("read_file", {"path": "/tmp/test.txt"})
|
|
assert result == Path("/tmp/test.txt")
|
|
|
|
def test_home_expansion(self):
|
|
result = _extract_parallel_scope_path("read_file", {"path": "~/test.txt"})
|
|
assert result is not None
|
|
assert str(result).endswith("test.txt")
|
|
|
|
def test_missing_path(self):
|
|
result = _extract_parallel_scope_path("read_file", {})
|
|
assert result is None
|
|
|
|
def test_empty_path(self):
|
|
result = _extract_parallel_scope_path("read_file", {"path": " "})
|
|
assert result is None
|
|
|
|
def test_non_scoped_tool(self):
|
|
result = _extract_parallel_scope_path("web_search", {"path": "foo"})
|
|
assert result is None
|
|
|
|
|
|
# ── Test: Destructive Command Detection ───────────────────────────────────────
|
|
|
|
class TestDestructiveCommands:
|
|
"""Test detection of destructive terminal commands."""
|
|
|
|
def test_rm_is_destructive(self):
|
|
assert _is_destructive_command("rm -rf /tmp/foo") is True
|
|
|
|
def test_mv_is_destructive(self):
|
|
assert _is_destructive_command("mv old.txt new.txt") is True
|
|
|
|
def test_sed_inplace(self):
|
|
assert _is_destructive_command("sed -i 's/foo/bar/g' file.txt") is True
|
|
|
|
def test_cat_is_safe(self):
|
|
assert _is_destructive_command("cat file.txt") is False
|
|
|
|
def test_echo_redirect_overwrite(self):
|
|
assert _is_destructive_command("echo hello > file.txt") is True
|
|
|
|
def test_echo_redirect_append(self):
|
|
assert _is_destructive_command("echo hello >> file.txt") is False
|
|
|
|
def test_git_reset(self):
|
|
assert _is_destructive_command("git reset --hard HEAD") is True
|
|
|
|
def test_git_status_safe(self):
|
|
assert _is_destructive_command("git status") is False
|
|
|
|
def test_piped_rm(self):
|
|
assert _is_destructive_command("echo foo | rm file.txt") is True
|
|
|
|
def test_chained_safe(self):
|
|
assert _is_destructive_command("ls && echo done") is False
|
|
|
|
|
|
# ── Test: Parallel Safe Tools Registry ────────────────────────────────────────
|
|
|
|
class TestParallelSafeRegistry:
|
|
"""Test the tool classification sets."""
|
|
|
|
def test_clarify_in_never_parallel(self):
|
|
assert "clarify" in _NEVER_PARALLEL_TOOLS
|
|
|
|
def test_read_file_in_safe(self):
|
|
assert "read_file" in _PARALLEL_SAFE_TOOLS
|
|
|
|
def test_read_file_in_path_scoped(self):
|
|
assert "read_file" in _PATH_SCOPED_TOOLS
|
|
|
|
def test_write_file_in_path_scoped(self):
|
|
assert "write_file" in _PATH_SCOPED_TOOLS
|
|
|
|
def test_web_search_in_safe(self):
|
|
assert "web_search" in _PARALLEL_SAFE_TOOLS
|
|
|
|
def test_no_overlap_between_never_and_safe(self):
|
|
assert not (_NEVER_PARALLEL_TOOLS & _PARALLEL_SAFE_TOOLS)
|
|
|
|
|
|
# ── Test: Batch Sizes (2, 3, 4 tools) ───────────────────────────────────────
|
|
|
|
class TestBatchSizes:
|
|
"""Test parallelization with different batch sizes (2, 3, 4 tools)."""
|
|
|
|
def test_two_tool_batch(self):
|
|
calls = [
|
|
MockToolCall.make("read_file", {"path": "a.txt"}, 0),
|
|
MockToolCall.make("read_file", {"path": "b.txt"}, 1),
|
|
]
|
|
assert _should_parallelize_tool_batch(calls) is True
|
|
|
|
def test_three_tool_batch(self):
|
|
calls = [
|
|
MockToolCall.make("read_file", {"path": f"f{i}.txt"}, i)
|
|
for i in range(3)
|
|
]
|
|
assert _should_parallelize_tool_batch(calls) is True
|
|
|
|
def test_four_tool_batch(self):
|
|
calls = [
|
|
MockToolCall.make("web_search", {"query": f"q{i}"}, i)
|
|
for i in range(4)
|
|
]
|
|
assert _should_parallelize_tool_batch(calls) is True
|
|
|
|
def test_four_tool_batch_with_one_collision(self):
|
|
"""4 tools where 2 collide on the same path."""
|
|
calls = [
|
|
MockToolCall.make("read_file", {"path": "a.txt"}, 0),
|
|
MockToolCall.make("read_file", {"path": "b.txt"}, 1),
|
|
MockToolCall.make("read_file", {"path": "a.txt"}, 2), # collision
|
|
MockToolCall.make("read_file", {"path": "c.txt"}, 3),
|
|
]
|
|
assert _should_parallelize_tool_batch(calls) is False
|
|
|
|
|
|
# ── Test: Gemma 4 Specific Patterns ──────────────────────────────────────────
|
|
|
|
class TestGemma4Patterns:
|
|
"""
|
|
Test patterns specific to Gemma 4 tool calling behavior.
|
|
|
|
Gemma 4 may issue tool calls in specific ordering patterns that
|
|
need to be handled correctly by the parallel execution layer.
|
|
"""
|
|
|
|
def test_gemma4_typical_2tool_pattern(self):
|
|
"""Gemma 4 typically issues read+search as a pair."""
|
|
calls = [
|
|
MockToolCall.make("read_file", {"path": "config.yaml"}, 0),
|
|
MockToolCall.make("search_files", {"pattern": "provider"}, 1),
|
|
]
|
|
# These should parallelize — different tools, no path conflict
|
|
assert _should_parallelize_tool_batch(calls) is True
|
|
|
|
def test_gemma4_typical_3tool_pattern(self):
|
|
"""Gemma 4 may issue 3 reads for different files."""
|
|
calls = [
|
|
MockToolCall.make("read_file", {"path": "a.py"}, 0),
|
|
MockToolCall.make("read_file", {"path": "b.py"}, 1),
|
|
MockToolCall.make("read_file", {"path": "c.py"}, 2),
|
|
]
|
|
assert _should_parallelize_tool_batch(calls) is True
|
|
|
|
def test_gemma4_sequential_dependency(self):
|
|
"""
|
|
Gemma 4 may issue: search_files then read_file on search result.
|
|
These have implicit dependency but are issued as a batch.
|
|
The agent should handle this — search first, then read.
|
|
This test verifies the batch IS marked as parallel-safe
|
|
(ordering is the agent loop's responsibility, not this function's).
|
|
"""
|
|
calls = [
|
|
MockToolCall.make("search_files", {"pattern": "import"}, 0),
|
|
MockToolCall.make("read_file", {"path": "main.py"}, 1),
|
|
]
|
|
# Both tools are in safe/scoped sets with no path conflict
|
|
assert _should_parallelize_tool_batch(calls) is True
|
|
|
|
def test_gemma4_mixed_safe_unsafe(self):
|
|
"""Gemma 4 may mix read (safe) with write (path-scoped)."""
|
|
calls = [
|
|
MockToolCall.make("read_file", {"path": "input.txt"}, 0),
|
|
MockToolCall.make("write_file", {"path": "output.txt", "content": "x"}, 1),
|
|
MockToolCall.make("read_file", {"path": "config.txt"}, 2),
|
|
]
|
|
# All path-scoped on different paths, no unsafe tools
|
|
assert _should_parallelize_tool_batch(calls) is True
|
|
|
|
def test_gemma4_terminal_parallel(self):
|
|
"""
|
|
Terminal commands are NOT in _PARALLEL_SAFE_TOOLS.
|
|
If Gemma 4 issues 2 terminal calls, they should NOT parallelize.
|
|
"""
|
|
calls = [
|
|
MockToolCall.make("terminal", {"command": "ls"}, 0),
|
|
MockToolCall.make("terminal", {"command": "pwd"}, 1),
|
|
]
|
|
assert _should_parallelize_tool_batch(calls) is False
|
|
|
|
|
|
# ── Test: Integration-style (mocked) ─────────────────────────────────────────
|
|
|
|
class TestParallelExecutionMocked:
|
|
"""Test the parallel execution path with mocked tool handlers."""
|
|
|
|
def test_parallel_results_collected(self):
|
|
"""Simulate parallel execution and verify results are collected."""
|
|
# Mock two tool calls returning different results
|
|
results = {}
|
|
|
|
def mock_handler(name, args):
|
|
return f"result_{name}_{args.get('path', 'x')}"
|
|
|
|
calls = [
|
|
MockToolCall.make("read_file", {"path": "a.txt"}, 0),
|
|
MockToolCall.make("read_file", {"path": "b.txt"}, 1),
|
|
]
|
|
|
|
# Simulate parallel execution
|
|
for tc in calls:
|
|
results[tc.id] = mock_handler(tc.function.name,
|
|
json.loads(tc.function.arguments))
|
|
|
|
assert results["call_0"] == "result_read_file_a.txt"
|
|
assert results["call_1"] == "result_read_file_b.txt"
|
|
|
|
def test_parallel_results_order_preserved(self):
|
|
"""Results should be ordered by tool call ID, not completion time."""
|
|
import time
|
|
results = {}
|
|
|
|
calls = [
|
|
MockToolCall.make("read_file", {"path": "slow.txt"}, 0),
|
|
MockToolCall.make("read_file", {"path": "fast.txt"}, 1),
|
|
]
|
|
|
|
# Simulate out-of-order completion
|
|
results["call_1"] = "fast_result"
|
|
results["call_0"] = "slow_result"
|
|
|
|
# Verify we can reconstruct in order
|
|
ordered = [results[tc.id] for tc in calls]
|
|
assert ordered == ["slow_result", "fast_result"]
|
|
|
|
|
|
# ── Test: Edge Cases ──────────────────────────────────────────────────────────
|
|
|
|
class TestEdgeCases:
|
|
"""Edge cases for parallel tool calling."""
|
|
|
|
def test_empty_batch(self):
|
|
assert _should_parallelize_tool_batch([]) is False
|
|
|
|
def test_patch_with_same_path(self):
|
|
"""Two patch calls on the same file should NOT parallelize."""
|
|
calls = [
|
|
MockToolCall.make("patch", {"path": "a.py", "old_string": "x", "new_string": "y"}, 0),
|
|
MockToolCall.make("patch", {"path": "a.py", "old_string": "a", "new_string": "b"}, 1),
|
|
]
|
|
assert _should_parallelize_tool_batch(calls) is False
|
|
|
|
def test_patch_different_paths(self):
|
|
"""patch on different files should parallelize."""
|
|
calls = [
|
|
MockToolCall.make("patch", {"path": "a.py", "old_string": "x", "new_string": "y"}, 0),
|
|
MockToolCall.make("patch", {"path": "b.py", "old_string": "a", "new_string": "b"}, 1),
|
|
]
|
|
assert _should_parallelize_tool_batch(calls) is True
|
|
|
|
def test_max_workers_defined(self):
|
|
"""Verify max workers constant exists and is reasonable."""
|
|
from run_agent import _MAX_TOOL_WORKERS
|
|
assert 1 <= _MAX_TOOL_WORKERS <= 32
|