diff --git a/run_agent.py b/run_agent.py index e8fb7bc0f..e17fcc7af 100644 --- a/run_agent.py +++ b/run_agent.py @@ -21,6 +21,7 @@ Usage: """ import atexit +import concurrent.futures import copy import hashlib import json @@ -193,6 +194,14 @@ class IterationBudget: return max(0, self.max_total - self._used) +# Tools that must never run concurrently (interactive / user-facing). +# When any of these appear in a batch, we fall back to sequential execution. +_NEVER_PARALLEL_TOOLS = frozenset({"clarify"}) + +# Maximum number of concurrent worker threads for parallel tool execution. +_MAX_TOOL_WORKERS = 8 + + class AIAgent: """ AI Agent with tool calling capabilities. @@ -3119,7 +3128,260 @@ class AIAgent: return compressed, new_system_prompt def _execute_tool_calls(self, assistant_message, messages: list, effective_task_id: str, api_call_count: int = 0) -> None: - """Execute tool calls from the assistant message and append results to messages.""" + """Execute tool calls from the assistant message and append results to messages. + + Dispatches to concurrent execution when multiple independent tool calls + are present, falling back to sequential execution for single calls or + when interactive tools (e.g. clarify) are in the batch. + """ + tool_calls = assistant_message.tool_calls + + # Single tool call or interactive tool present → sequential + if (len(tool_calls) <= 1 + or any(tc.function.name in _NEVER_PARALLEL_TOOLS for tc in tool_calls)): + return self._execute_tool_calls_sequential( + assistant_message, messages, effective_task_id, api_call_count + ) + + # Multiple non-interactive tools → concurrent + return self._execute_tool_calls_concurrent( + assistant_message, messages, effective_task_id, api_call_count + ) + + def _invoke_tool(self, function_name: str, function_args: dict, effective_task_id: str) -> str: + """Invoke a single tool and return the result string. No display logic. + + Handles both agent-level tools (todo, memory, etc.) and registry-dispatched + tools. Used by the concurrent execution path; the sequential path retains + its own inline invocation for backward-compatible display handling. + """ + if function_name == "todo": + from tools.todo_tool import todo_tool as _todo_tool + return _todo_tool( + todos=function_args.get("todos"), + merge=function_args.get("merge", False), + store=self._todo_store, + ) + elif function_name == "session_search": + if not self._session_db: + return json.dumps({"success": False, "error": "Session database not available."}) + from tools.session_search_tool import session_search as _session_search + return _session_search( + query=function_args.get("query", ""), + role_filter=function_args.get("role_filter"), + limit=function_args.get("limit", 3), + db=self._session_db, + current_session_id=self.session_id, + ) + elif function_name == "memory": + target = function_args.get("target", "memory") + from tools.memory_tool import memory_tool as _memory_tool + result = _memory_tool( + action=function_args.get("action"), + target=target, + content=function_args.get("content"), + old_text=function_args.get("old_text"), + store=self._memory_store, + ) + # Also send user observations to Honcho when active + if self._honcho and target == "user" and function_args.get("action") == "add": + self._honcho_save_user_observation(function_args.get("content", "")) + return result + elif function_name == "clarify": + from tools.clarify_tool import clarify_tool as _clarify_tool + return _clarify_tool( + question=function_args.get("question", ""), + choices=function_args.get("choices"), + callback=self.clarify_callback, + ) + elif function_name == "delegate_task": + from tools.delegate_tool import delegate_task as _delegate_task + return _delegate_task( + goal=function_args.get("goal"), + context=function_args.get("context"), + toolsets=function_args.get("toolsets"), + tasks=function_args.get("tasks"), + max_iterations=function_args.get("max_iterations"), + parent_agent=self, + ) + else: + return handle_function_call( + function_name, function_args, effective_task_id, + enabled_tools=list(self.valid_tool_names) if self.valid_tool_names else None, + ) + + def _execute_tool_calls_concurrent(self, assistant_message, messages: list, effective_task_id: str, api_call_count: int = 0) -> None: + """Execute multiple tool calls concurrently using a thread pool. + + Results are collected in the original tool-call order and appended to + messages so the API sees them in the expected sequence. + """ + tool_calls = assistant_message.tool_calls + num_tools = len(tool_calls) + + # ── Pre-flight: interrupt check ────────────────────────────────── + if self._interrupt_requested: + print(f"{self.log_prefix}⚡ Interrupt: skipping {num_tools} tool call(s)") + for tc in tool_calls: + messages.append({ + "role": "tool", + "content": f"[Tool execution cancelled — {tc.function.name} was skipped due to user interrupt]", + "tool_call_id": tc.id, + }) + return + + # ── Parse args + pre-execution bookkeeping ─────────────────────── + parsed_calls = [] # list of (tool_call, function_name, function_args) + for tool_call in tool_calls: + function_name = tool_call.function.name + + # Reset nudge counters + if function_name == "memory": + self._turns_since_memory = 0 + elif function_name == "skill_manage": + self._iters_since_skill = 0 + + try: + function_args = json.loads(tool_call.function.arguments) + except json.JSONDecodeError: + function_args = {} + if not isinstance(function_args, dict): + function_args = {} + + # Checkpoint for file-mutating tools + if function_name in ("write_file", "patch") and self._checkpoint_mgr.enabled: + try: + file_path = function_args.get("path", "") + if file_path: + work_dir = self._checkpoint_mgr.get_working_dir_for_path(file_path) + self._checkpoint_mgr.ensure_checkpoint(work_dir, f"before {function_name}") + except Exception: + pass + + parsed_calls.append((tool_call, function_name, function_args)) + + # ── Logging / callbacks ────────────────────────────────────────── + tool_names_str = ", ".join(name for _, name, _ in parsed_calls) + if not self.quiet_mode: + print(f" ⚡ Concurrent: {num_tools} tool calls — {tool_names_str}") + for i, (tc, name, args) in enumerate(parsed_calls, 1): + args_str = json.dumps(args, ensure_ascii=False) + args_preview = args_str[:self.log_prefix_chars] + "..." if len(args_str) > self.log_prefix_chars else args_str + print(f" 📞 Tool {i}: {name}({list(args.keys())}) - {args_preview}") + + for _, name, args in parsed_calls: + if self.tool_progress_callback: + try: + preview = _build_tool_preview(name, args) + self.tool_progress_callback(name, preview, args) + except Exception as cb_err: + logging.debug(f"Tool progress callback error: {cb_err}") + + # ── Concurrent execution ───────────────────────────────────────── + # Each slot holds (function_name, function_args, function_result, duration, error_flag) + results = [None] * num_tools + + def _run_tool(index, tool_call, function_name, function_args): + """Worker function executed in a thread.""" + start = time.time() + try: + result = self._invoke_tool(function_name, function_args, effective_task_id) + except Exception as tool_error: + result = f"Error executing tool '{function_name}': {tool_error}" + logger.error("_invoke_tool raised for %s: %s", function_name, tool_error, exc_info=True) + duration = time.time() - start + is_error, _ = _detect_tool_failure(function_name, result) + results[index] = (function_name, function_args, result, duration, is_error) + + # Start spinner for CLI mode + spinner = None + if self.quiet_mode: + face = random.choice(KawaiiSpinner.KAWAII_WAITING) + spinner = KawaiiSpinner(f"{face} ⚡ running {num_tools} tools concurrently", spinner_type='dots') + spinner.start() + + try: + max_workers = min(num_tools, _MAX_TOOL_WORKERS) + with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = [] + for i, (tc, name, args) in enumerate(parsed_calls): + f = executor.submit(_run_tool, i, tc, name, args) + futures.append(f) + + # Wait for all to complete (exceptions are captured inside _run_tool) + concurrent.futures.wait(futures) + finally: + if spinner: + # Build a summary message for the spinner stop + completed = sum(1 for r in results if r is not None) + total_dur = sum(r[3] for r in results if r is not None) + spinner.stop(f"⚡ {completed}/{num_tools} tools completed in {total_dur:.1f}s total") + + # ── Post-execution: display per-tool results ───────────────────── + for i, (tc, name, args) in enumerate(parsed_calls): + r = results[i] + if r is None: + # Shouldn't happen, but safety fallback + function_result = f"Error executing tool '{name}': thread did not return a result" + tool_duration = 0.0 + else: + function_name, function_args, function_result, tool_duration, is_error = r + + if is_error: + result_preview = function_result[:200] if len(function_result) > 200 else function_result + logger.warning("Tool %s returned error (%.2fs): %s", function_name, tool_duration, result_preview) + + if self.verbose_logging: + result_preview = function_result[:200] if len(function_result) > 200 else function_result + logging.debug(f"Tool {function_name} completed in {tool_duration:.2f}s") + logging.debug(f"Tool result preview: {result_preview}...") + + # Print cute message per tool + if self.quiet_mode: + cute_msg = _get_cute_tool_message_impl(name, args, tool_duration, result=function_result) + print(f" {cute_msg}") + elif not self.quiet_mode: + response_preview = function_result[:self.log_prefix_chars] + "..." if len(function_result) > self.log_prefix_chars else function_result + print(f" ✅ Tool {i+1} completed in {tool_duration:.2f}s - {response_preview}") + + # Truncate oversized results + MAX_TOOL_RESULT_CHARS = 100_000 + if len(function_result) > MAX_TOOL_RESULT_CHARS: + original_len = len(function_result) + function_result = ( + function_result[:MAX_TOOL_RESULT_CHARS] + + f"\n\n[Truncated: tool response was {original_len:,} chars, " + f"exceeding the {MAX_TOOL_RESULT_CHARS:,} char limit]" + ) + + # Append tool result message in order + tool_msg = { + "role": "tool", + "content": function_result, + "tool_call_id": tc.id, + } + messages.append(tool_msg) + + # ── Budget pressure injection ──────────────────────────────────── + budget_warning = self._get_budget_warning(api_call_count) + if budget_warning and messages and messages[-1].get("role") == "tool": + last_content = messages[-1]["content"] + try: + parsed = json.loads(last_content) + if isinstance(parsed, dict): + parsed["_budget_warning"] = budget_warning + messages[-1]["content"] = json.dumps(parsed, ensure_ascii=False) + else: + messages[-1]["content"] = last_content + f"\n\n{budget_warning}" + except (json.JSONDecodeError, TypeError): + messages[-1]["content"] = last_content + f"\n\n{budget_warning}" + if not self.quiet_mode: + remaining = self.max_iterations - api_call_count + tier = "⚠️ WARNING" if remaining <= self.max_iterations * 0.1 else "💡 CAUTION" + print(f"{self.log_prefix}{tier}: {remaining} iterations remaining") + + def _execute_tool_calls_sequential(self, assistant_message, messages: list, effective_task_id: str, api_call_count: int = 0) -> None: + """Execute tool calls sequentially (original behavior). Used for single calls or interactive tools.""" for i, tool_call in enumerate(assistant_message.tool_calls, 1): # SAFETY: check interrupt BEFORE starting each tool. # If the user sent "stop" during a previous tool's execution, diff --git a/tests/test_run_agent.py b/tests/test_run_agent.py index 45680d976..61a24f98b 100644 --- a/tests/test_run_agent.py +++ b/tests/test_run_agent.py @@ -668,6 +668,168 @@ class TestExecuteToolCalls: assert "Truncated" in messages[0]["content"] +class TestConcurrentToolExecution: + """Tests for _execute_tool_calls_concurrent and dispatch logic.""" + + def test_single_tool_uses_sequential_path(self, agent): + """Single tool call should use sequential path, not concurrent.""" + tc = _mock_tool_call(name="web_search", arguments='{"q":"test"}', call_id="c1") + mock_msg = _mock_assistant_msg(content="", tool_calls=[tc]) + messages = [] + with patch.object(agent, "_execute_tool_calls_sequential") as mock_seq: + with patch.object(agent, "_execute_tool_calls_concurrent") as mock_con: + agent._execute_tool_calls(mock_msg, messages, "task-1") + mock_seq.assert_called_once() + mock_con.assert_not_called() + + def test_clarify_forces_sequential(self, agent): + """Batch containing clarify should use sequential path.""" + tc1 = _mock_tool_call(name="web_search", arguments='{}', call_id="c1") + tc2 = _mock_tool_call(name="clarify", arguments='{"question":"ok?"}', call_id="c2") + mock_msg = _mock_assistant_msg(content="", tool_calls=[tc1, tc2]) + messages = [] + with patch.object(agent, "_execute_tool_calls_sequential") as mock_seq: + with patch.object(agent, "_execute_tool_calls_concurrent") as mock_con: + agent._execute_tool_calls(mock_msg, messages, "task-1") + mock_seq.assert_called_once() + mock_con.assert_not_called() + + def test_multiple_tools_uses_concurrent_path(self, agent): + """Multiple non-interactive tools should use concurrent path.""" + tc1 = _mock_tool_call(name="web_search", arguments='{}', call_id="c1") + tc2 = _mock_tool_call(name="read_file", arguments='{"path":"x.py"}', call_id="c2") + mock_msg = _mock_assistant_msg(content="", tool_calls=[tc1, tc2]) + messages = [] + with patch.object(agent, "_execute_tool_calls_sequential") as mock_seq: + with patch.object(agent, "_execute_tool_calls_concurrent") as mock_con: + agent._execute_tool_calls(mock_msg, messages, "task-1") + mock_con.assert_called_once() + mock_seq.assert_not_called() + + def test_concurrent_executes_all_tools(self, agent): + """Concurrent path should execute all tools and append results in order.""" + tc1 = _mock_tool_call(name="web_search", arguments='{"q":"alpha"}', call_id="c1") + tc2 = _mock_tool_call(name="web_search", arguments='{"q":"beta"}', call_id="c2") + tc3 = _mock_tool_call(name="web_search", arguments='{"q":"gamma"}', call_id="c3") + mock_msg = _mock_assistant_msg(content="", tool_calls=[tc1, tc2, tc3]) + messages = [] + + call_log = [] + + def fake_handle(name, args, task_id, **kwargs): + call_log.append(name) + return json.dumps({"result": args.get("q", "")}) + + with patch("run_agent.handle_function_call", side_effect=fake_handle): + agent._execute_tool_calls_concurrent(mock_msg, messages, "task-1") + + assert len(messages) == 3 + # Results must be in original order + assert messages[0]["tool_call_id"] == "c1" + assert messages[1]["tool_call_id"] == "c2" + assert messages[2]["tool_call_id"] == "c3" + # All should be tool messages + assert all(m["role"] == "tool" for m in messages) + # Content should contain the query results + assert "alpha" in messages[0]["content"] + assert "beta" in messages[1]["content"] + assert "gamma" in messages[2]["content"] + + def test_concurrent_preserves_order_despite_timing(self, agent): + """Even if tools finish in different order, messages should be in original order.""" + import time as _time + + tc1 = _mock_tool_call(name="web_search", arguments='{"q":"slow"}', call_id="c1") + tc2 = _mock_tool_call(name="web_search", arguments='{"q":"fast"}', call_id="c2") + mock_msg = _mock_assistant_msg(content="", tool_calls=[tc1, tc2]) + messages = [] + + def fake_handle(name, args, task_id, **kwargs): + q = args.get("q", "") + if q == "slow": + _time.sleep(0.1) # Slow tool + return f"result_{q}" + + with patch("run_agent.handle_function_call", side_effect=fake_handle): + agent._execute_tool_calls_concurrent(mock_msg, messages, "task-1") + + assert messages[0]["tool_call_id"] == "c1" + assert "result_slow" in messages[0]["content"] + assert messages[1]["tool_call_id"] == "c2" + assert "result_fast" in messages[1]["content"] + + def test_concurrent_handles_tool_error(self, agent): + """If one tool raises, others should still complete.""" + tc1 = _mock_tool_call(name="web_search", arguments='{}', call_id="c1") + tc2 = _mock_tool_call(name="web_search", arguments='{}', call_id="c2") + mock_msg = _mock_assistant_msg(content="", tool_calls=[tc1, tc2]) + messages = [] + + call_count = [0] + def fake_handle(name, args, task_id, **kwargs): + call_count[0] += 1 + if call_count[0] == 1: + raise RuntimeError("boom") + return "success" + + with patch("run_agent.handle_function_call", side_effect=fake_handle): + agent._execute_tool_calls_concurrent(mock_msg, messages, "task-1") + + assert len(messages) == 2 + # First tool should have error + assert "Error" in messages[0]["content"] or "boom" in messages[0]["content"] + # Second tool should succeed + assert "success" in messages[1]["content"] + + def test_concurrent_interrupt_before_start(self, agent): + """If interrupt is requested before concurrent execution, all tools are skipped.""" + tc1 = _mock_tool_call(name="web_search", arguments='{}', call_id="c1") + tc2 = _mock_tool_call(name="read_file", arguments='{}', call_id="c2") + mock_msg = _mock_assistant_msg(content="", tool_calls=[tc1, tc2]) + messages = [] + + with patch("run_agent._set_interrupt"): + agent.interrupt() + + agent._execute_tool_calls_concurrent(mock_msg, messages, "task-1") + assert len(messages) == 2 + assert "cancelled" in messages[0]["content"].lower() or "skipped" in messages[0]["content"].lower() + assert "cancelled" in messages[1]["content"].lower() or "skipped" in messages[1]["content"].lower() + + def test_concurrent_truncates_large_results(self, agent): + """Concurrent path should truncate results over 100k chars.""" + tc1 = _mock_tool_call(name="web_search", arguments='{}', call_id="c1") + tc2 = _mock_tool_call(name="web_search", arguments='{}', call_id="c2") + mock_msg = _mock_assistant_msg(content="", tool_calls=[tc1, tc2]) + messages = [] + big_result = "x" * 150_000 + + with patch("run_agent.handle_function_call", return_value=big_result): + agent._execute_tool_calls_concurrent(mock_msg, messages, "task-1") + + assert len(messages) == 2 + for m in messages: + assert len(m["content"]) < 150_000 + assert "Truncated" in m["content"] + + def test_invoke_tool_dispatches_to_handle_function_call(self, agent): + """_invoke_tool should route regular tools through handle_function_call.""" + with patch("run_agent.handle_function_call", return_value="result") as mock_hfc: + result = agent._invoke_tool("web_search", {"q": "test"}, "task-1") + mock_hfc.assert_called_once_with( + "web_search", {"q": "test"}, "task-1", + enabled_tools=list(agent.valid_tool_names), + ) + assert result == "result" + + def test_invoke_tool_handles_agent_level_tools(self, agent): + """_invoke_tool should handle todo tool directly.""" + with patch("tools.todo_tool.todo_tool", return_value='{"ok":true}') as mock_todo: + result = agent._invoke_tool("todo", {"todos": []}, "task-1") + mock_todo.assert_called_once() + assert "ok" in result + + class TestHandleMaxIterations: def test_returns_summary(self, agent): resp = _mock_response(content="Here is a summary of what I did.") diff --git a/tests/tools/test_interrupt.py b/tests/tools/test_interrupt.py index 6165deaaf..dc0ab4599 100644 --- a/tests/tools/test_interrupt.py +++ b/tests/tools/test_interrupt.py @@ -91,8 +91,11 @@ class TestPreToolCheck: agent._persist_session = MagicMock() # Import and call the method + import types from run_agent import AIAgent - # Bind the real method to our mock + # Bind the real methods to our mock so dispatch works correctly + agent._execute_tool_calls_sequential = types.MethodType(AIAgent._execute_tool_calls_sequential, agent) + agent._execute_tool_calls_concurrent = types.MethodType(AIAgent._execute_tool_calls_concurrent, agent) AIAgent._execute_tool_calls(agent, assistant_msg, messages, "default") # All 3 should be skipped