feat: concurrent tool execution with ThreadPoolExecutor
When the model returns multiple tool calls in a single response, they are now executed concurrently using a thread pool instead of sequentially. This significantly reduces wall-clock time when multiple independent tools are batched (e.g. parallel web_search, read_file, terminal calls). Architecture: - _execute_tool_calls() dispatches to sequential or concurrent path - Single tool calls and batches containing 'clarify' use sequential path - Multiple non-interactive tools use ThreadPoolExecutor (max 8 workers) - Results are collected and appended to messages in original order - _invoke_tool() extracted as shared tool invocation helper Safety: - Pre-flight interrupt check skips all tools if interrupted - Per-tool exception handling: one failure doesn't crash the batch - Result truncation (100k char limit) applied per tool - Budget pressure injection after all tools complete - Checkpoints taken before file-mutating tools - CLI spinner shows batch progress, then per-tool completion messages Tests: 10 new tests covering dispatch logic, ordering, error handling, interrupt behavior, truncation, and _invoke_tool routing.
This commit is contained in:
264
run_agent.py
264
run_agent.py
@@ -21,6 +21,7 @@ Usage:
|
||||
"""
|
||||
|
||||
import atexit
|
||||
import concurrent.futures
|
||||
import copy
|
||||
import hashlib
|
||||
import json
|
||||
@@ -193,6 +194,14 @@ class IterationBudget:
|
||||
return max(0, self.max_total - self._used)
|
||||
|
||||
|
||||
# Tools that must never run concurrently (interactive / user-facing).
|
||||
# When any of these appear in a batch, we fall back to sequential execution.
|
||||
_NEVER_PARALLEL_TOOLS = frozenset({"clarify"})
|
||||
|
||||
# Maximum number of concurrent worker threads for parallel tool execution.
|
||||
_MAX_TOOL_WORKERS = 8
|
||||
|
||||
|
||||
class AIAgent:
|
||||
"""
|
||||
AI Agent with tool calling capabilities.
|
||||
@@ -3119,7 +3128,260 @@ class AIAgent:
|
||||
return compressed, new_system_prompt
|
||||
|
||||
def _execute_tool_calls(self, assistant_message, messages: list, effective_task_id: str, api_call_count: int = 0) -> None:
|
||||
"""Execute tool calls from the assistant message and append results to messages."""
|
||||
"""Execute tool calls from the assistant message and append results to messages.
|
||||
|
||||
Dispatches to concurrent execution when multiple independent tool calls
|
||||
are present, falling back to sequential execution for single calls or
|
||||
when interactive tools (e.g. clarify) are in the batch.
|
||||
"""
|
||||
tool_calls = assistant_message.tool_calls
|
||||
|
||||
# Single tool call or interactive tool present → sequential
|
||||
if (len(tool_calls) <= 1
|
||||
or any(tc.function.name in _NEVER_PARALLEL_TOOLS for tc in tool_calls)):
|
||||
return self._execute_tool_calls_sequential(
|
||||
assistant_message, messages, effective_task_id, api_call_count
|
||||
)
|
||||
|
||||
# Multiple non-interactive tools → concurrent
|
||||
return self._execute_tool_calls_concurrent(
|
||||
assistant_message, messages, effective_task_id, api_call_count
|
||||
)
|
||||
|
||||
def _invoke_tool(self, function_name: str, function_args: dict, effective_task_id: str) -> str:
|
||||
"""Invoke a single tool and return the result string. No display logic.
|
||||
|
||||
Handles both agent-level tools (todo, memory, etc.) and registry-dispatched
|
||||
tools. Used by the concurrent execution path; the sequential path retains
|
||||
its own inline invocation for backward-compatible display handling.
|
||||
"""
|
||||
if function_name == "todo":
|
||||
from tools.todo_tool import todo_tool as _todo_tool
|
||||
return _todo_tool(
|
||||
todos=function_args.get("todos"),
|
||||
merge=function_args.get("merge", False),
|
||||
store=self._todo_store,
|
||||
)
|
||||
elif function_name == "session_search":
|
||||
if not self._session_db:
|
||||
return json.dumps({"success": False, "error": "Session database not available."})
|
||||
from tools.session_search_tool import session_search as _session_search
|
||||
return _session_search(
|
||||
query=function_args.get("query", ""),
|
||||
role_filter=function_args.get("role_filter"),
|
||||
limit=function_args.get("limit", 3),
|
||||
db=self._session_db,
|
||||
current_session_id=self.session_id,
|
||||
)
|
||||
elif function_name == "memory":
|
||||
target = function_args.get("target", "memory")
|
||||
from tools.memory_tool import memory_tool as _memory_tool
|
||||
result = _memory_tool(
|
||||
action=function_args.get("action"),
|
||||
target=target,
|
||||
content=function_args.get("content"),
|
||||
old_text=function_args.get("old_text"),
|
||||
store=self._memory_store,
|
||||
)
|
||||
# Also send user observations to Honcho when active
|
||||
if self._honcho and target == "user" and function_args.get("action") == "add":
|
||||
self._honcho_save_user_observation(function_args.get("content", ""))
|
||||
return result
|
||||
elif function_name == "clarify":
|
||||
from tools.clarify_tool import clarify_tool as _clarify_tool
|
||||
return _clarify_tool(
|
||||
question=function_args.get("question", ""),
|
||||
choices=function_args.get("choices"),
|
||||
callback=self.clarify_callback,
|
||||
)
|
||||
elif function_name == "delegate_task":
|
||||
from tools.delegate_tool import delegate_task as _delegate_task
|
||||
return _delegate_task(
|
||||
goal=function_args.get("goal"),
|
||||
context=function_args.get("context"),
|
||||
toolsets=function_args.get("toolsets"),
|
||||
tasks=function_args.get("tasks"),
|
||||
max_iterations=function_args.get("max_iterations"),
|
||||
parent_agent=self,
|
||||
)
|
||||
else:
|
||||
return handle_function_call(
|
||||
function_name, function_args, effective_task_id,
|
||||
enabled_tools=list(self.valid_tool_names) if self.valid_tool_names else None,
|
||||
)
|
||||
|
||||
def _execute_tool_calls_concurrent(self, assistant_message, messages: list, effective_task_id: str, api_call_count: int = 0) -> None:
|
||||
"""Execute multiple tool calls concurrently using a thread pool.
|
||||
|
||||
Results are collected in the original tool-call order and appended to
|
||||
messages so the API sees them in the expected sequence.
|
||||
"""
|
||||
tool_calls = assistant_message.tool_calls
|
||||
num_tools = len(tool_calls)
|
||||
|
||||
# ── Pre-flight: interrupt check ──────────────────────────────────
|
||||
if self._interrupt_requested:
|
||||
print(f"{self.log_prefix}⚡ Interrupt: skipping {num_tools} tool call(s)")
|
||||
for tc in tool_calls:
|
||||
messages.append({
|
||||
"role": "tool",
|
||||
"content": f"[Tool execution cancelled — {tc.function.name} was skipped due to user interrupt]",
|
||||
"tool_call_id": tc.id,
|
||||
})
|
||||
return
|
||||
|
||||
# ── Parse args + pre-execution bookkeeping ───────────────────────
|
||||
parsed_calls = [] # list of (tool_call, function_name, function_args)
|
||||
for tool_call in tool_calls:
|
||||
function_name = tool_call.function.name
|
||||
|
||||
# Reset nudge counters
|
||||
if function_name == "memory":
|
||||
self._turns_since_memory = 0
|
||||
elif function_name == "skill_manage":
|
||||
self._iters_since_skill = 0
|
||||
|
||||
try:
|
||||
function_args = json.loads(tool_call.function.arguments)
|
||||
except json.JSONDecodeError:
|
||||
function_args = {}
|
||||
if not isinstance(function_args, dict):
|
||||
function_args = {}
|
||||
|
||||
# Checkpoint for file-mutating tools
|
||||
if function_name in ("write_file", "patch") and self._checkpoint_mgr.enabled:
|
||||
try:
|
||||
file_path = function_args.get("path", "")
|
||||
if file_path:
|
||||
work_dir = self._checkpoint_mgr.get_working_dir_for_path(file_path)
|
||||
self._checkpoint_mgr.ensure_checkpoint(work_dir, f"before {function_name}")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
parsed_calls.append((tool_call, function_name, function_args))
|
||||
|
||||
# ── Logging / callbacks ──────────────────────────────────────────
|
||||
tool_names_str = ", ".join(name for _, name, _ in parsed_calls)
|
||||
if not self.quiet_mode:
|
||||
print(f" ⚡ Concurrent: {num_tools} tool calls — {tool_names_str}")
|
||||
for i, (tc, name, args) in enumerate(parsed_calls, 1):
|
||||
args_str = json.dumps(args, ensure_ascii=False)
|
||||
args_preview = args_str[:self.log_prefix_chars] + "..." if len(args_str) > self.log_prefix_chars else args_str
|
||||
print(f" 📞 Tool {i}: {name}({list(args.keys())}) - {args_preview}")
|
||||
|
||||
for _, name, args in parsed_calls:
|
||||
if self.tool_progress_callback:
|
||||
try:
|
||||
preview = _build_tool_preview(name, args)
|
||||
self.tool_progress_callback(name, preview, args)
|
||||
except Exception as cb_err:
|
||||
logging.debug(f"Tool progress callback error: {cb_err}")
|
||||
|
||||
# ── Concurrent execution ─────────────────────────────────────────
|
||||
# Each slot holds (function_name, function_args, function_result, duration, error_flag)
|
||||
results = [None] * num_tools
|
||||
|
||||
def _run_tool(index, tool_call, function_name, function_args):
|
||||
"""Worker function executed in a thread."""
|
||||
start = time.time()
|
||||
try:
|
||||
result = self._invoke_tool(function_name, function_args, effective_task_id)
|
||||
except Exception as tool_error:
|
||||
result = f"Error executing tool '{function_name}': {tool_error}"
|
||||
logger.error("_invoke_tool raised for %s: %s", function_name, tool_error, exc_info=True)
|
||||
duration = time.time() - start
|
||||
is_error, _ = _detect_tool_failure(function_name, result)
|
||||
results[index] = (function_name, function_args, result, duration, is_error)
|
||||
|
||||
# Start spinner for CLI mode
|
||||
spinner = None
|
||||
if self.quiet_mode:
|
||||
face = random.choice(KawaiiSpinner.KAWAII_WAITING)
|
||||
spinner = KawaiiSpinner(f"{face} ⚡ running {num_tools} tools concurrently", spinner_type='dots')
|
||||
spinner.start()
|
||||
|
||||
try:
|
||||
max_workers = min(num_tools, _MAX_TOOL_WORKERS)
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
futures = []
|
||||
for i, (tc, name, args) in enumerate(parsed_calls):
|
||||
f = executor.submit(_run_tool, i, tc, name, args)
|
||||
futures.append(f)
|
||||
|
||||
# Wait for all to complete (exceptions are captured inside _run_tool)
|
||||
concurrent.futures.wait(futures)
|
||||
finally:
|
||||
if spinner:
|
||||
# Build a summary message for the spinner stop
|
||||
completed = sum(1 for r in results if r is not None)
|
||||
total_dur = sum(r[3] for r in results if r is not None)
|
||||
spinner.stop(f"⚡ {completed}/{num_tools} tools completed in {total_dur:.1f}s total")
|
||||
|
||||
# ── Post-execution: display per-tool results ─────────────────────
|
||||
for i, (tc, name, args) in enumerate(parsed_calls):
|
||||
r = results[i]
|
||||
if r is None:
|
||||
# Shouldn't happen, but safety fallback
|
||||
function_result = f"Error executing tool '{name}': thread did not return a result"
|
||||
tool_duration = 0.0
|
||||
else:
|
||||
function_name, function_args, function_result, tool_duration, is_error = r
|
||||
|
||||
if is_error:
|
||||
result_preview = function_result[:200] if len(function_result) > 200 else function_result
|
||||
logger.warning("Tool %s returned error (%.2fs): %s", function_name, tool_duration, result_preview)
|
||||
|
||||
if self.verbose_logging:
|
||||
result_preview = function_result[:200] if len(function_result) > 200 else function_result
|
||||
logging.debug(f"Tool {function_name} completed in {tool_duration:.2f}s")
|
||||
logging.debug(f"Tool result preview: {result_preview}...")
|
||||
|
||||
# Print cute message per tool
|
||||
if self.quiet_mode:
|
||||
cute_msg = _get_cute_tool_message_impl(name, args, tool_duration, result=function_result)
|
||||
print(f" {cute_msg}")
|
||||
elif not self.quiet_mode:
|
||||
response_preview = function_result[:self.log_prefix_chars] + "..." if len(function_result) > self.log_prefix_chars else function_result
|
||||
print(f" ✅ Tool {i+1} completed in {tool_duration:.2f}s - {response_preview}")
|
||||
|
||||
# Truncate oversized results
|
||||
MAX_TOOL_RESULT_CHARS = 100_000
|
||||
if len(function_result) > MAX_TOOL_RESULT_CHARS:
|
||||
original_len = len(function_result)
|
||||
function_result = (
|
||||
function_result[:MAX_TOOL_RESULT_CHARS]
|
||||
+ f"\n\n[Truncated: tool response was {original_len:,} chars, "
|
||||
f"exceeding the {MAX_TOOL_RESULT_CHARS:,} char limit]"
|
||||
)
|
||||
|
||||
# Append tool result message in order
|
||||
tool_msg = {
|
||||
"role": "tool",
|
||||
"content": function_result,
|
||||
"tool_call_id": tc.id,
|
||||
}
|
||||
messages.append(tool_msg)
|
||||
|
||||
# ── Budget pressure injection ────────────────────────────────────
|
||||
budget_warning = self._get_budget_warning(api_call_count)
|
||||
if budget_warning and messages and messages[-1].get("role") == "tool":
|
||||
last_content = messages[-1]["content"]
|
||||
try:
|
||||
parsed = json.loads(last_content)
|
||||
if isinstance(parsed, dict):
|
||||
parsed["_budget_warning"] = budget_warning
|
||||
messages[-1]["content"] = json.dumps(parsed, ensure_ascii=False)
|
||||
else:
|
||||
messages[-1]["content"] = last_content + f"\n\n{budget_warning}"
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
messages[-1]["content"] = last_content + f"\n\n{budget_warning}"
|
||||
if not self.quiet_mode:
|
||||
remaining = self.max_iterations - api_call_count
|
||||
tier = "⚠️ WARNING" if remaining <= self.max_iterations * 0.1 else "💡 CAUTION"
|
||||
print(f"{self.log_prefix}{tier}: {remaining} iterations remaining")
|
||||
|
||||
def _execute_tool_calls_sequential(self, assistant_message, messages: list, effective_task_id: str, api_call_count: int = 0) -> None:
|
||||
"""Execute tool calls sequentially (original behavior). Used for single calls or interactive tools."""
|
||||
for i, tool_call in enumerate(assistant_message.tool_calls, 1):
|
||||
# SAFETY: check interrupt BEFORE starting each tool.
|
||||
# If the user sent "stop" during a previous tool's execution,
|
||||
|
||||
@@ -668,6 +668,168 @@ class TestExecuteToolCalls:
|
||||
assert "Truncated" in messages[0]["content"]
|
||||
|
||||
|
||||
class TestConcurrentToolExecution:
|
||||
"""Tests for _execute_tool_calls_concurrent and dispatch logic."""
|
||||
|
||||
def test_single_tool_uses_sequential_path(self, agent):
|
||||
"""Single tool call should use sequential path, not concurrent."""
|
||||
tc = _mock_tool_call(name="web_search", arguments='{"q":"test"}', call_id="c1")
|
||||
mock_msg = _mock_assistant_msg(content="", tool_calls=[tc])
|
||||
messages = []
|
||||
with patch.object(agent, "_execute_tool_calls_sequential") as mock_seq:
|
||||
with patch.object(agent, "_execute_tool_calls_concurrent") as mock_con:
|
||||
agent._execute_tool_calls(mock_msg, messages, "task-1")
|
||||
mock_seq.assert_called_once()
|
||||
mock_con.assert_not_called()
|
||||
|
||||
def test_clarify_forces_sequential(self, agent):
|
||||
"""Batch containing clarify should use sequential path."""
|
||||
tc1 = _mock_tool_call(name="web_search", arguments='{}', call_id="c1")
|
||||
tc2 = _mock_tool_call(name="clarify", arguments='{"question":"ok?"}', call_id="c2")
|
||||
mock_msg = _mock_assistant_msg(content="", tool_calls=[tc1, tc2])
|
||||
messages = []
|
||||
with patch.object(agent, "_execute_tool_calls_sequential") as mock_seq:
|
||||
with patch.object(agent, "_execute_tool_calls_concurrent") as mock_con:
|
||||
agent._execute_tool_calls(mock_msg, messages, "task-1")
|
||||
mock_seq.assert_called_once()
|
||||
mock_con.assert_not_called()
|
||||
|
||||
def test_multiple_tools_uses_concurrent_path(self, agent):
|
||||
"""Multiple non-interactive tools should use concurrent path."""
|
||||
tc1 = _mock_tool_call(name="web_search", arguments='{}', call_id="c1")
|
||||
tc2 = _mock_tool_call(name="read_file", arguments='{"path":"x.py"}', call_id="c2")
|
||||
mock_msg = _mock_assistant_msg(content="", tool_calls=[tc1, tc2])
|
||||
messages = []
|
||||
with patch.object(agent, "_execute_tool_calls_sequential") as mock_seq:
|
||||
with patch.object(agent, "_execute_tool_calls_concurrent") as mock_con:
|
||||
agent._execute_tool_calls(mock_msg, messages, "task-1")
|
||||
mock_con.assert_called_once()
|
||||
mock_seq.assert_not_called()
|
||||
|
||||
def test_concurrent_executes_all_tools(self, agent):
|
||||
"""Concurrent path should execute all tools and append results in order."""
|
||||
tc1 = _mock_tool_call(name="web_search", arguments='{"q":"alpha"}', call_id="c1")
|
||||
tc2 = _mock_tool_call(name="web_search", arguments='{"q":"beta"}', call_id="c2")
|
||||
tc3 = _mock_tool_call(name="web_search", arguments='{"q":"gamma"}', call_id="c3")
|
||||
mock_msg = _mock_assistant_msg(content="", tool_calls=[tc1, tc2, tc3])
|
||||
messages = []
|
||||
|
||||
call_log = []
|
||||
|
||||
def fake_handle(name, args, task_id, **kwargs):
|
||||
call_log.append(name)
|
||||
return json.dumps({"result": args.get("q", "")})
|
||||
|
||||
with patch("run_agent.handle_function_call", side_effect=fake_handle):
|
||||
agent._execute_tool_calls_concurrent(mock_msg, messages, "task-1")
|
||||
|
||||
assert len(messages) == 3
|
||||
# Results must be in original order
|
||||
assert messages[0]["tool_call_id"] == "c1"
|
||||
assert messages[1]["tool_call_id"] == "c2"
|
||||
assert messages[2]["tool_call_id"] == "c3"
|
||||
# All should be tool messages
|
||||
assert all(m["role"] == "tool" for m in messages)
|
||||
# Content should contain the query results
|
||||
assert "alpha" in messages[0]["content"]
|
||||
assert "beta" in messages[1]["content"]
|
||||
assert "gamma" in messages[2]["content"]
|
||||
|
||||
def test_concurrent_preserves_order_despite_timing(self, agent):
|
||||
"""Even if tools finish in different order, messages should be in original order."""
|
||||
import time as _time
|
||||
|
||||
tc1 = _mock_tool_call(name="web_search", arguments='{"q":"slow"}', call_id="c1")
|
||||
tc2 = _mock_tool_call(name="web_search", arguments='{"q":"fast"}', call_id="c2")
|
||||
mock_msg = _mock_assistant_msg(content="", tool_calls=[tc1, tc2])
|
||||
messages = []
|
||||
|
||||
def fake_handle(name, args, task_id, **kwargs):
|
||||
q = args.get("q", "")
|
||||
if q == "slow":
|
||||
_time.sleep(0.1) # Slow tool
|
||||
return f"result_{q}"
|
||||
|
||||
with patch("run_agent.handle_function_call", side_effect=fake_handle):
|
||||
agent._execute_tool_calls_concurrent(mock_msg, messages, "task-1")
|
||||
|
||||
assert messages[0]["tool_call_id"] == "c1"
|
||||
assert "result_slow" in messages[0]["content"]
|
||||
assert messages[1]["tool_call_id"] == "c2"
|
||||
assert "result_fast" in messages[1]["content"]
|
||||
|
||||
def test_concurrent_handles_tool_error(self, agent):
|
||||
"""If one tool raises, others should still complete."""
|
||||
tc1 = _mock_tool_call(name="web_search", arguments='{}', call_id="c1")
|
||||
tc2 = _mock_tool_call(name="web_search", arguments='{}', call_id="c2")
|
||||
mock_msg = _mock_assistant_msg(content="", tool_calls=[tc1, tc2])
|
||||
messages = []
|
||||
|
||||
call_count = [0]
|
||||
def fake_handle(name, args, task_id, **kwargs):
|
||||
call_count[0] += 1
|
||||
if call_count[0] == 1:
|
||||
raise RuntimeError("boom")
|
||||
return "success"
|
||||
|
||||
with patch("run_agent.handle_function_call", side_effect=fake_handle):
|
||||
agent._execute_tool_calls_concurrent(mock_msg, messages, "task-1")
|
||||
|
||||
assert len(messages) == 2
|
||||
# First tool should have error
|
||||
assert "Error" in messages[0]["content"] or "boom" in messages[0]["content"]
|
||||
# Second tool should succeed
|
||||
assert "success" in messages[1]["content"]
|
||||
|
||||
def test_concurrent_interrupt_before_start(self, agent):
|
||||
"""If interrupt is requested before concurrent execution, all tools are skipped."""
|
||||
tc1 = _mock_tool_call(name="web_search", arguments='{}', call_id="c1")
|
||||
tc2 = _mock_tool_call(name="read_file", arguments='{}', call_id="c2")
|
||||
mock_msg = _mock_assistant_msg(content="", tool_calls=[tc1, tc2])
|
||||
messages = []
|
||||
|
||||
with patch("run_agent._set_interrupt"):
|
||||
agent.interrupt()
|
||||
|
||||
agent._execute_tool_calls_concurrent(mock_msg, messages, "task-1")
|
||||
assert len(messages) == 2
|
||||
assert "cancelled" in messages[0]["content"].lower() or "skipped" in messages[0]["content"].lower()
|
||||
assert "cancelled" in messages[1]["content"].lower() or "skipped" in messages[1]["content"].lower()
|
||||
|
||||
def test_concurrent_truncates_large_results(self, agent):
|
||||
"""Concurrent path should truncate results over 100k chars."""
|
||||
tc1 = _mock_tool_call(name="web_search", arguments='{}', call_id="c1")
|
||||
tc2 = _mock_tool_call(name="web_search", arguments='{}', call_id="c2")
|
||||
mock_msg = _mock_assistant_msg(content="", tool_calls=[tc1, tc2])
|
||||
messages = []
|
||||
big_result = "x" * 150_000
|
||||
|
||||
with patch("run_agent.handle_function_call", return_value=big_result):
|
||||
agent._execute_tool_calls_concurrent(mock_msg, messages, "task-1")
|
||||
|
||||
assert len(messages) == 2
|
||||
for m in messages:
|
||||
assert len(m["content"]) < 150_000
|
||||
assert "Truncated" in m["content"]
|
||||
|
||||
def test_invoke_tool_dispatches_to_handle_function_call(self, agent):
|
||||
"""_invoke_tool should route regular tools through handle_function_call."""
|
||||
with patch("run_agent.handle_function_call", return_value="result") as mock_hfc:
|
||||
result = agent._invoke_tool("web_search", {"q": "test"}, "task-1")
|
||||
mock_hfc.assert_called_once_with(
|
||||
"web_search", {"q": "test"}, "task-1",
|
||||
enabled_tools=list(agent.valid_tool_names),
|
||||
)
|
||||
assert result == "result"
|
||||
|
||||
def test_invoke_tool_handles_agent_level_tools(self, agent):
|
||||
"""_invoke_tool should handle todo tool directly."""
|
||||
with patch("tools.todo_tool.todo_tool", return_value='{"ok":true}') as mock_todo:
|
||||
result = agent._invoke_tool("todo", {"todos": []}, "task-1")
|
||||
mock_todo.assert_called_once()
|
||||
assert "ok" in result
|
||||
|
||||
|
||||
class TestHandleMaxIterations:
|
||||
def test_returns_summary(self, agent):
|
||||
resp = _mock_response(content="Here is a summary of what I did.")
|
||||
|
||||
@@ -91,8 +91,11 @@ class TestPreToolCheck:
|
||||
agent._persist_session = MagicMock()
|
||||
|
||||
# Import and call the method
|
||||
import types
|
||||
from run_agent import AIAgent
|
||||
# Bind the real method to our mock
|
||||
# Bind the real methods to our mock so dispatch works correctly
|
||||
agent._execute_tool_calls_sequential = types.MethodType(AIAgent._execute_tool_calls_sequential, agent)
|
||||
agent._execute_tool_calls_concurrent = types.MethodType(AIAgent._execute_tool_calls_concurrent, agent)
|
||||
AIAgent._execute_tool_calls(agent, assistant_msg, messages, "default")
|
||||
|
||||
# All 3 should be skipped
|
||||
|
||||
Reference in New Issue
Block a user