diff --git a/run_agent.py b/run_agent.py index f0e8f25db..fee04f396 100644 --- a/run_agent.py +++ b/run_agent.py @@ -203,6 +203,27 @@ class IterationBudget: # When any of these appear in a batch, we fall back to sequential execution. _NEVER_PARALLEL_TOOLS = frozenset({"clarify"}) +# Read-only tools with no shared mutable session state. +_PARALLEL_SAFE_TOOLS = frozenset({ + "ha_get_state", + "ha_list_entities", + "ha_list_services", + "honcho_context", + "honcho_profile", + "honcho_search", + "read_file", + "search_files", + "session_search", + "skill_view", + "skills_list", + "vision_analyze", + "web_extract", + "web_search", +}) + +# File tools can run concurrently when they target independent paths. +_PATH_SCOPED_TOOLS = frozenset({"read_file", "write_file", "patch"}) + # Maximum number of concurrent worker threads for parallel tool execution. _MAX_TOOL_WORKERS = 8 @@ -234,6 +255,74 @@ def _is_destructive_command(cmd: str) -> bool: return False +def _should_parallelize_tool_batch(tool_calls) -> bool: + """Return True when a tool-call batch is safe to run concurrently.""" + if len(tool_calls) <= 1: + return False + + tool_names = [tc.function.name for tc in tool_calls] + if any(name in _NEVER_PARALLEL_TOOLS for name in tool_names): + return False + + reserved_paths: list[Path] = [] + for tool_call in tool_calls: + tool_name = tool_call.function.name + try: + function_args = json.loads(tool_call.function.arguments) + except Exception: + logging.debug( + "Could not parse args for %s — defaulting to sequential; raw=%s", + tool_name, + tool_call.function.arguments[:200], + ) + return False + if not isinstance(function_args, dict): + logging.debug( + "Non-dict args for %s (%s) — defaulting to sequential", + tool_name, + type(function_args).__name__, + ) + return False + + if tool_name in _PATH_SCOPED_TOOLS: + scoped_path = _extract_parallel_scope_path(tool_name, function_args) + if scoped_path is None: + return False + if any(_paths_overlap(scoped_path, existing) for existing in reserved_paths): + return False + reserved_paths.append(scoped_path) + continue + + if tool_name not in _PARALLEL_SAFE_TOOLS: + return False + + return True + + +def _extract_parallel_scope_path(tool_name: str, function_args: dict) -> Path | None: + """Return the normalized file target for path-scoped tools.""" + if tool_name not in _PATH_SCOPED_TOOLS: + return None + + raw_path = function_args.get("path") + if not isinstance(raw_path, str) or not raw_path.strip(): + return None + + # Avoid resolve(); the file may not exist yet. + return Path(raw_path).expanduser() + + +def _paths_overlap(left: Path, right: Path) -> bool: + """Return True when two paths may refer to the same subtree.""" + left_parts = left.parts + right_parts = right.parts + if not left_parts or not right_parts: + # Empty paths shouldn't reach here (guarded upstream), but be safe. + return bool(left_parts) == bool(right_parts) and bool(left_parts) + common_len = min(len(left_parts), len(right_parts)) + return left_parts[:common_len] == right_parts[:common_len] + + def _inject_honcho_turn_context(content, turn_context: str): """Append Honcho recall to the current-turn user message without mutating history. @@ -4078,20 +4167,17 @@ class AIAgent: def _execute_tool_calls(self, assistant_message, messages: list, effective_task_id: str, api_call_count: int = 0) -> None: """Execute tool calls from the assistant message and append results to messages. - Dispatches to concurrent execution when multiple independent tool calls - are present, falling back to sequential execution for single calls or - when interactive tools (e.g. clarify) are in the batch. + Dispatches to concurrent execution only for batches that look + independent: read-only tools may always share the parallel path, while + file reads/writes may do so only when their target paths do not overlap. """ tool_calls = assistant_message.tool_calls - # Single tool call or interactive tool present → sequential - if (len(tool_calls) <= 1 - or any(tc.function.name in _NEVER_PARALLEL_TOOLS for tc in tool_calls)): + if not _should_parallelize_tool_batch(tool_calls): return self._execute_tool_calls_sequential( assistant_message, messages, effective_task_id, api_call_count ) - # Multiple non-interactive tools → concurrent return self._execute_tool_calls_concurrent( assistant_message, messages, effective_task_id, api_call_count ) diff --git a/tests/test_run_agent.py b/tests/test_run_agent.py index ec9b26f3a..50b3a5092 100644 --- a/tests/test_run_agent.py +++ b/tests/test_run_agent.py @@ -806,7 +806,7 @@ class TestConcurrentToolExecution: mock_con.assert_not_called() def test_multiple_tools_uses_concurrent_path(self, agent): - """Multiple non-interactive tools should use concurrent path.""" + """Multiple read-only tools should use concurrent path.""" tc1 = _mock_tool_call(name="web_search", arguments='{}', call_id="c1") tc2 = _mock_tool_call(name="read_file", arguments='{"path":"x.py"}', call_id="c2") mock_msg = _mock_assistant_msg(content="", tool_calls=[tc1, tc2]) @@ -817,6 +817,94 @@ class TestConcurrentToolExecution: mock_con.assert_called_once() mock_seq.assert_not_called() + def test_terminal_batch_forces_sequential(self, agent): + """Stateful tools should not share the concurrent execution path.""" + tc1 = _mock_tool_call(name="web_search", arguments='{}', call_id="c1") + tc2 = _mock_tool_call(name="terminal", arguments='{"command":"pwd"}', call_id="c2") + mock_msg = _mock_assistant_msg(content="", tool_calls=[tc1, tc2]) + messages = [] + with patch.object(agent, "_execute_tool_calls_sequential") as mock_seq: + with patch.object(agent, "_execute_tool_calls_concurrent") as mock_con: + agent._execute_tool_calls(mock_msg, messages, "task-1") + mock_seq.assert_called_once() + mock_con.assert_not_called() + + def test_write_batch_forces_sequential(self, agent): + """File mutations should stay ordered within a turn.""" + tc1 = _mock_tool_call(name="read_file", arguments='{"path":"x.py"}', call_id="c1") + tc2 = _mock_tool_call(name="write_file", arguments='{"path":"x.py","content":"print(1)"}', call_id="c2") + mock_msg = _mock_assistant_msg(content="", tool_calls=[tc1, tc2]) + messages = [] + with patch.object(agent, "_execute_tool_calls_sequential") as mock_seq: + with patch.object(agent, "_execute_tool_calls_concurrent") as mock_con: + agent._execute_tool_calls(mock_msg, messages, "task-1") + mock_seq.assert_called_once() + mock_con.assert_not_called() + + def test_disjoint_write_batch_uses_concurrent_path(self, agent): + """Independent file writes should still run concurrently.""" + tc1 = _mock_tool_call( + name="write_file", + arguments='{"path":"src/a.py","content":"print(1)"}', + call_id="c1", + ) + tc2 = _mock_tool_call( + name="write_file", + arguments='{"path":"src/b.py","content":"print(2)"}', + call_id="c2", + ) + mock_msg = _mock_assistant_msg(content="", tool_calls=[tc1, tc2]) + messages = [] + with patch.object(agent, "_execute_tool_calls_sequential") as mock_seq: + with patch.object(agent, "_execute_tool_calls_concurrent") as mock_con: + agent._execute_tool_calls(mock_msg, messages, "task-1") + mock_con.assert_called_once() + mock_seq.assert_not_called() + + def test_overlapping_write_batch_forces_sequential(self, agent): + """Writes to the same file must stay ordered.""" + tc1 = _mock_tool_call( + name="write_file", + arguments='{"path":"src/a.py","content":"print(1)"}', + call_id="c1", + ) + tc2 = _mock_tool_call( + name="patch", + arguments='{"path":"src/a.py","old_string":"1","new_string":"2"}', + call_id="c2", + ) + mock_msg = _mock_assistant_msg(content="", tool_calls=[tc1, tc2]) + messages = [] + with patch.object(agent, "_execute_tool_calls_sequential") as mock_seq: + with patch.object(agent, "_execute_tool_calls_concurrent") as mock_con: + agent._execute_tool_calls(mock_msg, messages, "task-1") + mock_seq.assert_called_once() + mock_con.assert_not_called() + + def test_malformed_json_args_forces_sequential(self, agent): + """Unparseable tool arguments should fall back to sequential.""" + tc1 = _mock_tool_call(name="web_search", arguments='{}', call_id="c1") + tc2 = _mock_tool_call(name="web_search", arguments="NOT JSON {{{", call_id="c2") + mock_msg = _mock_assistant_msg(content="", tool_calls=[tc1, tc2]) + messages = [] + with patch.object(agent, "_execute_tool_calls_sequential") as mock_seq: + with patch.object(agent, "_execute_tool_calls_concurrent") as mock_con: + agent._execute_tool_calls(mock_msg, messages, "task-1") + mock_seq.assert_called_once() + mock_con.assert_not_called() + + def test_non_dict_args_forces_sequential(self, agent): + """Tool arguments that parse to a non-dict type should fall back to sequential.""" + tc1 = _mock_tool_call(name="web_search", arguments='{}', call_id="c1") + tc2 = _mock_tool_call(name="web_search", arguments='"just a string"', call_id="c2") + mock_msg = _mock_assistant_msg(content="", tool_calls=[tc1, tc2]) + messages = [] + with patch.object(agent, "_execute_tool_calls_sequential") as mock_seq: + with patch.object(agent, "_execute_tool_calls_concurrent") as mock_con: + agent._execute_tool_calls(mock_msg, messages, "task-1") + mock_seq.assert_called_once() + mock_con.assert_not_called() + def test_concurrent_executes_all_tools(self, agent): """Concurrent path should execute all tools and append results in order.""" tc1 = _mock_tool_call(name="web_search", arguments='{"q":"alpha"}', call_id="c1") @@ -943,6 +1031,39 @@ class TestConcurrentToolExecution: assert "ok" in result +class TestPathsOverlap: + """Unit tests for the _paths_overlap helper.""" + + def test_same_path_overlaps(self): + from run_agent import _paths_overlap + assert _paths_overlap(Path("src/a.py"), Path("src/a.py")) + + def test_siblings_do_not_overlap(self): + from run_agent import _paths_overlap + assert not _paths_overlap(Path("src/a.py"), Path("src/b.py")) + + def test_parent_child_overlap(self): + from run_agent import _paths_overlap + assert _paths_overlap(Path("src"), Path("src/sub/a.py")) + + def test_different_roots_do_not_overlap(self): + from run_agent import _paths_overlap + assert not _paths_overlap(Path("src/a.py"), Path("other/a.py")) + + def test_nested_vs_flat_do_not_overlap(self): + from run_agent import _paths_overlap + assert not _paths_overlap(Path("src/sub/a.py"), Path("src/a.py")) + + def test_empty_paths_do_not_overlap(self): + from run_agent import _paths_overlap + assert not _paths_overlap(Path(""), Path("")) + + def test_one_empty_path_does_not_overlap(self): + from run_agent import _paths_overlap + assert not _paths_overlap(Path(""), Path("src/a.py")) + assert not _paths_overlap(Path("src/a.py"), Path("")) + + class TestHandleMaxIterations: def test_returns_summary(self, agent): resp = _mock_response(content="Here is a summary of what I did.")