test(#798): Strengthen parallel tool calling tests + fix flaky concurrent tests
All checks were successful
Lint / lint (pull_request) Successful in 10s

- Add TestAIAgentConcurrentExecution with 8 integration tests exercising
  _execute_tool_calls_concurrent through AIAgent for 2/3/4-tool batches,
  pass-rate reporting, and Gemma 4-style read patterns.
- Fix test_malformed_json_args_forces_sequential: use JSON array '[1,2,3]'
  instead of unrepairable garbage now that repair_and_load_json handles
  most malformed input.
- Fix test_concurrent_handles_tool_error: replace racy call_count list
  with deterministic failure based on tool_call_id to eliminate flaky
  failures under ThreadPoolExecutor.

Closes #798
This commit is contained in:
Alexander Whitestone
2026-04-22 01:34:24 -04:00
parent 16eab5d503
commit ed250b1ca8
2 changed files with 220 additions and 5 deletions

View File

@@ -416,3 +416,219 @@ class TestEdgeCases:
"""Verify max workers constant exists and is reasonable."""
from run_agent import _MAX_TOOL_WORKERS
assert 1 <= _MAX_TOOL_WORKERS <= 32
# ── Integration Tests: AIAgent Concurrent Execution ───────────────────────────
class TestAIAgentConcurrentExecution:
"""Exercise _execute_tool_calls_concurrent through an AIAgent instance."""
@pytest.fixture
def agent(self):
"""Minimal AIAgent with mocked OpenAI client and tool loading."""
from types import SimpleNamespace
from unittest.mock import patch
from run_agent import AIAgent
def _make_tool_defs(*names):
return [
{
"type": "function",
"function": {
"name": n,
"description": f"{n} tool",
"parameters": {"type": "object", "properties": {}},
},
}
for n in names
]
with (
patch("run_agent.get_tool_definitions", return_value=_make_tool_defs("web_search", "read_file")),
patch("run_agent.check_toolset_requirements", return_value={}),
patch("run_agent.OpenAI"),
):
a = AIAgent(
api_key="test-key-1234567890",
quiet_mode=True,
skip_context_files=True,
skip_memory=True,
)
a.client = MagicMock()
return a
def _mock_assistant_msg(self, tool_calls=None):
from types import SimpleNamespace
return SimpleNamespace(content="", tool_calls=tool_calls)
def _mock_tool_call(self, name, arguments, call_id):
from types import SimpleNamespace
return SimpleNamespace(
id=call_id,
type="function",
function=SimpleNamespace(name=name, arguments=json.dumps(arguments)),
)
def test_two_tool_batch_executes_concurrently(self, agent):
"""2-tool parallel batch: all execute, results ordered, 100% pass."""
tc1 = self._mock_tool_call("read_file", {"path": "a.txt"}, "c1")
tc2 = self._mock_tool_call("read_file", {"path": "b.txt"}, "c2")
mock_msg = self._mock_assistant_msg(tool_calls=[tc1, tc2])
messages = []
def fake_handle(name, args, task_id, **kwargs):
return json.dumps({"file": args.get("path", ""), "content": f"content_of_{args.get('path', '')}"})
with patch("run_agent.handle_function_call", side_effect=fake_handle):
agent._execute_tool_calls_concurrent(mock_msg, messages, "task-1")
assert len(messages) == 2
assert messages[0]["tool_call_id"] == "c1"
assert messages[1]["tool_call_id"] == "c2"
assert "a.txt" in messages[0]["content"]
assert "b.txt" in messages[1]["content"]
def test_three_tool_batch_executes_concurrently(self, agent):
"""3-tool parallel batch: all execute, results ordered, 100% pass."""
tcs = [
self._mock_tool_call("web_search", {"query": f"q{i}"}, f"c{i}")
for i in range(3)
]
mock_msg = self._mock_assistant_msg(tool_calls=tcs)
messages = []
def fake_handle(name, args, task_id, **kwargs):
return json.dumps({"query": args.get("query", ""), "results": [f"result_{args.get('query', '')}"]})
with patch("run_agent.handle_function_call", side_effect=fake_handle):
agent._execute_tool_calls_concurrent(mock_msg, messages, "task-1")
assert len(messages) == 3
for i, tc in enumerate(tcs):
assert messages[i]["tool_call_id"] == tc.id
assert f"q{i}" in messages[i]["content"]
def test_four_tool_batch_executes_concurrently(self, agent):
"""4-tool parallel batch: all execute, results ordered, 100% pass."""
tcs = [
self._mock_tool_call("read_file", {"path": f"file{i}.txt"}, f"c{i}")
for i in range(4)
]
mock_msg = self._mock_assistant_msg(tool_calls=tcs)
messages = []
def fake_handle(name, args, task_id, **kwargs):
return json.dumps({"path": args.get("path", ""), "size": 100})
with patch("run_agent.handle_function_call", side_effect=fake_handle):
agent._execute_tool_calls_concurrent(mock_msg, messages, "task-1")
assert len(messages) == 4
for i, tc in enumerate(tcs):
assert messages[i]["tool_call_id"] == tc.id
assert f"file{i}.txt" in messages[i]["content"]
def test_mixed_read_and_search_batch(self, agent):
"""read_file + search_files: safe parallel, different scopes."""
tc1 = self._mock_tool_call("read_file", {"path": "config.yaml"}, "c1")
tc2 = self._mock_tool_call("web_search", {"query": "provider"}, "c2")
mock_msg = self._mock_assistant_msg(tool_calls=[tc1, tc2])
messages = []
def fake_handle(name, args, task_id, **kwargs):
return json.dumps({"tool": name, "args": args})
with patch("run_agent.handle_function_call", side_effect=fake_handle):
agent._execute_tool_calls_concurrent(mock_msg, messages, "task-1")
assert len(messages) == 2
assert messages[0]["tool_call_id"] == "c1"
assert messages[1]["tool_call_id"] == "c2"
assert "config.yaml" in messages[0]["content"]
assert "provider" in messages[1]["content"]
def test_concurrent_pass_rate_report(self, agent):
"""Simulate 2/3/4-tool batches and report pass rate."""
batch_sizes = [2, 3, 4]
pass_rates = {}
for size in batch_sizes:
tcs = [
self._mock_tool_call("web_search", {"query": f"q{i}"}, f"c{i}")
for i in range(size)
]
mock_msg = self._mock_assistant_msg(tool_calls=tcs)
messages = []
def fake_handle(name, args, task_id, **kwargs):
return json.dumps({"ok": True, "query": args.get("query", "")})
with patch("run_agent.handle_function_call", side_effect=fake_handle):
agent._execute_tool_calls_concurrent(mock_msg, messages, "task-1")
passed = sum(1 for m in messages if "ok" in m.get("content", ""))
pass_rates[size] = passed / size if size > 0 else 0.0
for size, rate in pass_rates.items():
assert rate == 1.0, f"Expected 100% pass rate for {size}-tool batch, got {rate:.0%}"
def test_gemma4_style_two_read_files(self, agent):
"""Gemma 4 may issue two reads simultaneously — verify both returned."""
tc1 = self._mock_tool_call("read_file", {"path": "src/main.py"}, "c1")
tc2 = self._mock_tool_call("read_file", {"path": "src/utils.py"}, "c2")
mock_msg = self._mock_assistant_msg(tool_calls=[tc1, tc2])
messages = []
def fake_handle(name, args, task_id, **kwargs):
return json.dumps({"content": f"# {args['path']}\nprint('hello')"})
with patch("run_agent.handle_function_call", side_effect=fake_handle):
agent._execute_tool_calls_concurrent(mock_msg, messages, "task-1")
assert len(messages) == 2
assert "main.py" in messages[0]["content"]
assert "utils.py" in messages[1]["content"]
def test_gemma4_style_three_reads(self, agent):
"""Gemma 4 may issue 3 reads for different files — all returned."""
tcs = [
self._mock_tool_call("read_file", {"path": f"mod{i}.py"}, f"c{i}")
for i in range(3)
]
mock_msg = self._mock_assistant_msg(tool_calls=tcs)
messages = []
def fake_handle(name, args, task_id, **kwargs):
return json.dumps({"content": f"# {args['path']}"})
with patch("run_agent.handle_function_call", side_effect=fake_handle):
agent._execute_tool_calls_concurrent(mock_msg, messages, "task-1")
assert len(messages) == 3
for i in range(3):
assert f"mod{i}.py" in messages[i]["content"]
def test_mixed_safe_and_write_tools_parallel(self, agent):
"""Mix of read (safe) and write (path-scoped) on different paths — parallel."""
tc1 = self._mock_tool_call("read_file", {"path": "input.txt"}, "c1")
tc2 = self._mock_tool_call("write_file", {"path": "output.txt", "content": "x"}, "c2")
tc3 = self._mock_tool_call("read_file", {"path": "config.txt"}, "c3")
mock_msg = self._mock_assistant_msg(tool_calls=[tc1, tc2, tc3])
messages = []
call_order = []
def fake_handle(name, args, task_id, **kwargs):
call_order.append(name)
return json.dumps({"tool": name, "path": args.get("path", "")})
with patch("run_agent.handle_function_call", side_effect=fake_handle):
agent._execute_tool_calls_concurrent(mock_msg, messages, "task-1")
assert len(messages) == 3
# Results ordered by tool call ID, not completion order
assert messages[0]["tool_call_id"] == "c1"
assert messages[1]["tool_call_id"] == "c2"
assert messages[2]["tool_call_id"] == "c3"
# All three should have executed
assert len(call_order) == 3