diff --git a/tests/run_agent/test_run_agent.py b/tests/run_agent/test_run_agent.py index d71e6a625..b3f3ac827 100644 --- a/tests/run_agent/test_run_agent.py +++ b/tests/run_agent/test_run_agent.py @@ -1302,9 +1302,9 @@ class TestConcurrentToolExecution: mock_con.assert_not_called() def test_malformed_json_args_forces_sequential(self, agent): - """Unparseable tool arguments should fall back to sequential.""" + """Non-dict tool arguments (e.g. JSON array) should fall back to sequential.""" tc1 = _mock_tool_call(name="web_search", arguments='{}', call_id="c1") - tc2 = _mock_tool_call(name="web_search", arguments="NOT JSON {{{", call_id="c2") + tc2 = _mock_tool_call(name="web_search", arguments='[1, 2, 3]', call_id="c2") mock_msg = _mock_assistant_msg(content="", tool_calls=[tc1, tc2]) messages = [] with patch.object(agent, "_execute_tool_calls_sequential") as mock_seq: @@ -1384,10 +1384,9 @@ class TestConcurrentToolExecution: mock_msg = _mock_assistant_msg(content="", tool_calls=[tc1, tc2]) messages = [] - call_count = [0] def fake_handle(name, args, task_id, **kwargs): - call_count[0] += 1 - if call_count[0] == 1: + # Deterministic failure based on tool_call_id to avoid race conditions + if kwargs.get("tool_call_id") == "c1": raise RuntimeError("boom") return "success" diff --git a/tests/test_parallel_tool_calling.py b/tests/test_parallel_tool_calling.py index 739be52a9..58bb1a03e 100644 --- a/tests/test_parallel_tool_calling.py +++ b/tests/test_parallel_tool_calling.py @@ -416,3 +416,219 @@ class TestEdgeCases: """Verify max workers constant exists and is reasonable.""" from run_agent import _MAX_TOOL_WORKERS assert 1 <= _MAX_TOOL_WORKERS <= 32 + + +# ── Integration Tests: AIAgent Concurrent Execution ─────────────────────────── + +class TestAIAgentConcurrentExecution: + """Exercise _execute_tool_calls_concurrent through an AIAgent instance.""" + + @pytest.fixture + def agent(self): + """Minimal AIAgent with mocked OpenAI client and tool loading.""" + from types import SimpleNamespace + from unittest.mock import patch + from run_agent import AIAgent + + def _make_tool_defs(*names): + return [ + { + "type": "function", + "function": { + "name": n, + "description": f"{n} tool", + "parameters": {"type": "object", "properties": {}}, + }, + } + for n in names + ] + + with ( + patch("run_agent.get_tool_definitions", return_value=_make_tool_defs("web_search", "read_file")), + patch("run_agent.check_toolset_requirements", return_value={}), + patch("run_agent.OpenAI"), + ): + a = AIAgent( + api_key="test-key-1234567890", + quiet_mode=True, + skip_context_files=True, + skip_memory=True, + ) + a.client = MagicMock() + return a + + def _mock_assistant_msg(self, tool_calls=None): + from types import SimpleNamespace + return SimpleNamespace(content="", tool_calls=tool_calls) + + def _mock_tool_call(self, name, arguments, call_id): + from types import SimpleNamespace + return SimpleNamespace( + id=call_id, + type="function", + function=SimpleNamespace(name=name, arguments=json.dumps(arguments)), + ) + + def test_two_tool_batch_executes_concurrently(self, agent): + """2-tool parallel batch: all execute, results ordered, 100% pass.""" + tc1 = self._mock_tool_call("read_file", {"path": "a.txt"}, "c1") + tc2 = self._mock_tool_call("read_file", {"path": "b.txt"}, "c2") + mock_msg = self._mock_assistant_msg(tool_calls=[tc1, tc2]) + messages = [] + + def fake_handle(name, args, task_id, **kwargs): + return json.dumps({"file": args.get("path", ""), "content": f"content_of_{args.get('path', '')}"}) + + with patch("run_agent.handle_function_call", side_effect=fake_handle): + agent._execute_tool_calls_concurrent(mock_msg, messages, "task-1") + + assert len(messages) == 2 + assert messages[0]["tool_call_id"] == "c1" + assert messages[1]["tool_call_id"] == "c2" + assert "a.txt" in messages[0]["content"] + assert "b.txt" in messages[1]["content"] + + def test_three_tool_batch_executes_concurrently(self, agent): + """3-tool parallel batch: all execute, results ordered, 100% pass.""" + tcs = [ + self._mock_tool_call("web_search", {"query": f"q{i}"}, f"c{i}") + for i in range(3) + ] + mock_msg = self._mock_assistant_msg(tool_calls=tcs) + messages = [] + + def fake_handle(name, args, task_id, **kwargs): + return json.dumps({"query": args.get("query", ""), "results": [f"result_{args.get('query', '')}"]}) + + with patch("run_agent.handle_function_call", side_effect=fake_handle): + agent._execute_tool_calls_concurrent(mock_msg, messages, "task-1") + + assert len(messages) == 3 + for i, tc in enumerate(tcs): + assert messages[i]["tool_call_id"] == tc.id + assert f"q{i}" in messages[i]["content"] + + def test_four_tool_batch_executes_concurrently(self, agent): + """4-tool parallel batch: all execute, results ordered, 100% pass.""" + tcs = [ + self._mock_tool_call("read_file", {"path": f"file{i}.txt"}, f"c{i}") + for i in range(4) + ] + mock_msg = self._mock_assistant_msg(tool_calls=tcs) + messages = [] + + def fake_handle(name, args, task_id, **kwargs): + return json.dumps({"path": args.get("path", ""), "size": 100}) + + with patch("run_agent.handle_function_call", side_effect=fake_handle): + agent._execute_tool_calls_concurrent(mock_msg, messages, "task-1") + + assert len(messages) == 4 + for i, tc in enumerate(tcs): + assert messages[i]["tool_call_id"] == tc.id + assert f"file{i}.txt" in messages[i]["content"] + + def test_mixed_read_and_search_batch(self, agent): + """read_file + search_files: safe parallel, different scopes.""" + tc1 = self._mock_tool_call("read_file", {"path": "config.yaml"}, "c1") + tc2 = self._mock_tool_call("web_search", {"query": "provider"}, "c2") + mock_msg = self._mock_assistant_msg(tool_calls=[tc1, tc2]) + messages = [] + + def fake_handle(name, args, task_id, **kwargs): + return json.dumps({"tool": name, "args": args}) + + with patch("run_agent.handle_function_call", side_effect=fake_handle): + agent._execute_tool_calls_concurrent(mock_msg, messages, "task-1") + + assert len(messages) == 2 + assert messages[0]["tool_call_id"] == "c1" + assert messages[1]["tool_call_id"] == "c2" + assert "config.yaml" in messages[0]["content"] + assert "provider" in messages[1]["content"] + + def test_concurrent_pass_rate_report(self, agent): + """Simulate 2/3/4-tool batches and report pass rate.""" + batch_sizes = [2, 3, 4] + pass_rates = {} + + for size in batch_sizes: + tcs = [ + self._mock_tool_call("web_search", {"query": f"q{i}"}, f"c{i}") + for i in range(size) + ] + mock_msg = self._mock_assistant_msg(tool_calls=tcs) + messages = [] + + def fake_handle(name, args, task_id, **kwargs): + return json.dumps({"ok": True, "query": args.get("query", "")}) + + with patch("run_agent.handle_function_call", side_effect=fake_handle): + agent._execute_tool_calls_concurrent(mock_msg, messages, "task-1") + + passed = sum(1 for m in messages if "ok" in m.get("content", "")) + pass_rates[size] = passed / size if size > 0 else 0.0 + + for size, rate in pass_rates.items(): + assert rate == 1.0, f"Expected 100% pass rate for {size}-tool batch, got {rate:.0%}" + + def test_gemma4_style_two_read_files(self, agent): + """Gemma 4 may issue two reads simultaneously — verify both returned.""" + tc1 = self._mock_tool_call("read_file", {"path": "src/main.py"}, "c1") + tc2 = self._mock_tool_call("read_file", {"path": "src/utils.py"}, "c2") + mock_msg = self._mock_assistant_msg(tool_calls=[tc1, tc2]) + messages = [] + + def fake_handle(name, args, task_id, **kwargs): + return json.dumps({"content": f"# {args['path']}\nprint('hello')"}) + + with patch("run_agent.handle_function_call", side_effect=fake_handle): + agent._execute_tool_calls_concurrent(mock_msg, messages, "task-1") + + assert len(messages) == 2 + assert "main.py" in messages[0]["content"] + assert "utils.py" in messages[1]["content"] + + def test_gemma4_style_three_reads(self, agent): + """Gemma 4 may issue 3 reads for different files — all returned.""" + tcs = [ + self._mock_tool_call("read_file", {"path": f"mod{i}.py"}, f"c{i}") + for i in range(3) + ] + mock_msg = self._mock_assistant_msg(tool_calls=tcs) + messages = [] + + def fake_handle(name, args, task_id, **kwargs): + return json.dumps({"content": f"# {args['path']}"}) + + with patch("run_agent.handle_function_call", side_effect=fake_handle): + agent._execute_tool_calls_concurrent(mock_msg, messages, "task-1") + + assert len(messages) == 3 + for i in range(3): + assert f"mod{i}.py" in messages[i]["content"] + + def test_mixed_safe_and_write_tools_parallel(self, agent): + """Mix of read (safe) and write (path-scoped) on different paths — parallel.""" + tc1 = self._mock_tool_call("read_file", {"path": "input.txt"}, "c1") + tc2 = self._mock_tool_call("write_file", {"path": "output.txt", "content": "x"}, "c2") + tc3 = self._mock_tool_call("read_file", {"path": "config.txt"}, "c3") + mock_msg = self._mock_assistant_msg(tool_calls=[tc1, tc2, tc3]) + messages = [] + + call_order = [] + + def fake_handle(name, args, task_id, **kwargs): + call_order.append(name) + return json.dumps({"tool": name, "path": args.get("path", "")}) + + with patch("run_agent.handle_function_call", side_effect=fake_handle): + agent._execute_tool_calls_concurrent(mock_msg, messages, "task-1") + + assert len(messages) == 3 + # Results ordered by tool call ID, not completion order + assert messages[0]["tool_call_id"] == "c1" + assert messages[1]["tool_call_id"] == "c2" + assert messages[2]["tool_call_id"] == "c3" + # All three should have executed + assert len(call_order) == 3