diff --git a/cli.py b/cli.py index 80e2e7846..1e4181770 100755 --- a/cli.py +++ b/cli.py @@ -3608,6 +3608,19 @@ class HermesCLI: continue print(f"\n⚡ New message detected, interrupting...") self.agent.interrupt(interrupt_msg) + # Debug: log to file (stdout may be devnull from redirect_stdout) + try: + import pathlib as _pl + _dbg = _pl.Path.home() / ".hermes" / "interrupt_debug.log" + with open(_dbg, "a") as _f: + import time as _t + _f.write(f"{_t.strftime('%H:%M:%S')} interrupt fired: msg={str(interrupt_msg)[:60]!r}, " + f"children={len(self.agent._active_children)}, " + f"parent._interrupt={self.agent._interrupt_requested}\n") + for _ci, _ch in enumerate(self.agent._active_children): + _f.write(f" child[{_ci}]._interrupt={_ch._interrupt_requested}\n") + except Exception: + pass break except queue.Empty: pass # Queue empty or timeout, continue waiting @@ -3877,6 +3890,16 @@ class HermesCLI: payload = (text, images) if images else text if self._agent_running and not (text and text.startswith("/")): self._interrupt_queue.put(payload) + # Debug: log to file when message enters interrupt queue + try: + import pathlib as _pl + _dbg = _pl.Path.home() / ".hermes" / "interrupt_debug.log" + with open(_dbg, "a") as _f: + import time as _t + _f.write(f"{_t.strftime('%H:%M:%S')} ENTER: queued interrupt msg={str(payload)[:60]!r}, " + f"agent_running={self._agent_running}\n") + except Exception: + pass else: self._pending_input.put(payload) event.app.current_buffer.reset(append_to_history=True) diff --git a/gateway/run.py b/gateway/run.py index aae5c6342..8c0685591 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -3418,17 +3418,19 @@ class GatewayRunner: # Monitor for interrupts from the adapter (new messages arriving) async def monitor_for_interrupt(): adapter = self.adapters.get(source.platform) - if not adapter: + if not adapter or not session_key: return - chat_id = source.chat_id while True: await asyncio.sleep(0.2) # Check every 200ms - # Check if adapter has a pending interrupt for this session - if hasattr(adapter, 'has_pending_interrupt') and adapter.has_pending_interrupt(chat_id): + # Check if adapter has a pending interrupt for this session. + # Must use session_key (build_session_key output) — NOT + # source.chat_id — because the adapter stores interrupt events + # under the full session key. + if hasattr(adapter, 'has_pending_interrupt') and adapter.has_pending_interrupt(session_key): agent = agent_holder[0] if agent: - pending_event = adapter.get_pending_message(chat_id) + pending_event = adapter.get_pending_message(session_key) pending_text = pending_event.text if pending_event else None logger.debug("Interrupt detected from adapter, signaling agent...") agent.interrupt(pending_text) @@ -3445,10 +3447,11 @@ class GatewayRunner: result = result_holder[0] adapter = self.adapters.get(source.platform) - # Get pending message from adapter if interrupted + # Get pending message from adapter if interrupted. + # Use session_key (not source.chat_id) to match adapter's storage keys. pending = None if result and result.get("interrupted") and adapter: - pending_event = adapter.get_pending_message(source.chat_id) + pending_event = adapter.get_pending_message(session_key) if session_key else None if pending_event: pending = pending_event.text elif result.get("interrupt_message"): @@ -3460,8 +3463,8 @@ class GatewayRunner: # Clear the adapter's interrupt event so the next _run_agent call # doesn't immediately re-trigger the interrupt before the new agent # even makes its first API call (this was causing an infinite loop). - if adapter and hasattr(adapter, '_active_sessions') and source.chat_id in adapter._active_sessions: - adapter._active_sessions[source.chat_id].clear() + if adapter and hasattr(adapter, '_active_sessions') and session_key and session_key in adapter._active_sessions: + adapter._active_sessions[session_key].clear() # Don't send the interrupted response to the user — it's just noise # like "Operation interrupted." They already know they sent a new diff --git a/tests/gateway/test_interrupt_key_match.py b/tests/gateway/test_interrupt_key_match.py new file mode 100644 index 000000000..f129977d4 --- /dev/null +++ b/tests/gateway/test_interrupt_key_match.py @@ -0,0 +1,124 @@ +"""Tests verifying interrupt key consistency between adapter and gateway. + +Regression test for a bug where monitor_for_interrupt() in _run_agent used +source.chat_id to query the adapter, but the adapter stores interrupts under +the full session key (build_session_key output). This mismatch meant +interrupts were never detected, causing subagents to ignore new messages. +""" + +import asyncio + +import pytest + +from gateway.config import Platform, PlatformConfig +from gateway.platforms.base import BasePlatformAdapter, MessageEvent, SendResult +from gateway.session import SessionSource, build_session_key + + +class StubAdapter(BasePlatformAdapter): + """Minimal adapter for interrupt tests.""" + + def __init__(self): + super().__init__(PlatformConfig(enabled=True, token="test"), Platform.TELEGRAM) + + async def connect(self): + return True + + async def disconnect(self): + pass + + async def send(self, chat_id, content, reply_to=None, metadata=None): + return SendResult(success=True, message_id="1") + + async def send_typing(self, chat_id, metadata=None): + pass + + async def get_chat_info(self, chat_id): + return {"id": chat_id} + + +def _source(chat_id="123456", chat_type="dm", thread_id=None): + return SessionSource( + platform=Platform.TELEGRAM, + chat_id=chat_id, + chat_type=chat_type, + thread_id=thread_id, + ) + + +class TestInterruptKeyConsistency: + """Ensure adapter interrupt methods are queried with session_key, not chat_id.""" + + def test_session_key_differs_from_chat_id_for_dm(self): + """Session key for a DM is NOT the same as chat_id.""" + source = _source("123456", "dm") + session_key = build_session_key(source) + assert session_key != source.chat_id + assert session_key == "agent:main:telegram:dm" + + def test_session_key_differs_from_chat_id_for_group(self): + """Session key for a group chat includes prefix, unlike raw chat_id.""" + source = _source("-1001234", "group") + session_key = build_session_key(source) + assert session_key != source.chat_id + assert "agent:main:" in session_key + assert source.chat_id in session_key + + @pytest.mark.asyncio + async def test_has_pending_interrupt_requires_session_key(self): + """has_pending_interrupt returns True only when queried with session_key.""" + adapter = StubAdapter() + source = _source("123456", "dm") + session_key = build_session_key(source) + + # Simulate adapter storing interrupt under session_key + interrupt_event = asyncio.Event() + adapter._active_sessions[session_key] = interrupt_event + interrupt_event.set() + + # Using session_key → found + assert adapter.has_pending_interrupt(session_key) is True + + # Using chat_id → NOT found (this was the bug) + assert adapter.has_pending_interrupt(source.chat_id) is False + + @pytest.mark.asyncio + async def test_get_pending_message_requires_session_key(self): + """get_pending_message returns the event only with session_key.""" + adapter = StubAdapter() + source = _source("123456", "dm") + session_key = build_session_key(source) + + event = MessageEvent(text="hello", source=source, message_id="42") + adapter._pending_messages[session_key] = event + + # Using chat_id → None (the bug) + assert adapter.get_pending_message(source.chat_id) is None + + # Using session_key → found + result = adapter.get_pending_message(session_key) + assert result is event + + @pytest.mark.asyncio + async def test_handle_message_stores_under_session_key(self): + """handle_message stores pending messages under session_key, not chat_id.""" + adapter = StubAdapter() + adapter.set_message_handler(lambda event: asyncio.sleep(0, result=None)) + + source = _source("-1001234", "group") + session_key = build_session_key(source) + + # Mark session as active + adapter._active_sessions[session_key] = asyncio.Event() + + # Send a second message while session is active + event = MessageEvent(text="interrupt!", source=source, message_id="2") + await adapter.handle_message(event) + + # Stored under session_key + assert session_key in adapter._pending_messages + # NOT stored under chat_id + assert source.chat_id not in adapter._pending_messages + + # Interrupt event was set + assert adapter._active_sessions[session_key].is_set() diff --git a/tests/run_interrupt_test.py b/tests/run_interrupt_test.py new file mode 100644 index 000000000..19ff3009f --- /dev/null +++ b/tests/run_interrupt_test.py @@ -0,0 +1,141 @@ +#!/usr/bin/env python3 +"""Run a real interrupt test with actual AIAgent + delegate child. + +Not a pytest test — runs directly as a script for live testing. +""" + +import threading +import time +import sys +import os + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from unittest.mock import MagicMock, patch +from run_agent import AIAgent, IterationBudget +from tools.delegate_tool import _run_single_child +from tools.interrupt import set_interrupt, is_interrupted + +set_interrupt(False) + +# Create parent agent (minimal) +parent = AIAgent.__new__(AIAgent) +parent._interrupt_requested = False +parent._interrupt_message = None +parent._active_children = [] +parent.quiet_mode = True +parent.model = "test/model" +parent.base_url = "http://localhost:1" +parent.api_key = "test" +parent.provider = "test" +parent.api_mode = "chat_completions" +parent.platform = "cli" +parent.enabled_toolsets = ["terminal", "file"] +parent.providers_allowed = None +parent.providers_ignored = None +parent.providers_order = None +parent.provider_sort = None +parent.max_tokens = None +parent.reasoning_config = None +parent.prefill_messages = None +parent._session_db = None +parent._delegate_depth = 0 +parent._delegate_spinner = None +parent.tool_progress_callback = None +parent.iteration_budget = IterationBudget(max_total=100) +parent._client_kwargs = {"api_key": "test", "base_url": "http://localhost:1"} + +child_started = threading.Event() +result_holder = [None] + + +def run_delegate(): + with patch("run_agent.OpenAI") as MockOpenAI: + mock_client = MagicMock() + + def slow_create(**kwargs): + time.sleep(3) + resp = MagicMock() + resp.choices = [MagicMock()] + resp.choices[0].message.content = "Done" + resp.choices[0].message.tool_calls = None + resp.choices[0].message.refusal = None + resp.choices[0].finish_reason = "stop" + resp.usage.prompt_tokens = 100 + resp.usage.completion_tokens = 10 + resp.usage.total_tokens = 110 + resp.usage.prompt_tokens_details = None + return resp + + mock_client.chat.completions.create = slow_create + mock_client.close = MagicMock() + MockOpenAI.return_value = mock_client + + original_init = AIAgent.__init__ + + def patched_init(self_agent, *a, **kw): + original_init(self_agent, *a, **kw) + child_started.set() + + with patch.object(AIAgent, "__init__", patched_init): + try: + result = _run_single_child( + task_index=0, + goal="Test slow task", + context=None, + toolsets=["terminal"], + model="test/model", + max_iterations=5, + parent_agent=parent, + task_count=1, + override_provider="test", + override_base_url="http://localhost:1", + override_api_key="test", + override_api_mode="chat_completions", + ) + result_holder[0] = result + except Exception as e: + print(f"ERROR in delegate: {e}") + import traceback + traceback.print_exc() + + +print("Starting agent thread...") +agent_thread = threading.Thread(target=run_delegate, daemon=True) +agent_thread.start() + +started = child_started.wait(timeout=10) +if not started: + print("ERROR: Child never started") + sys.exit(1) + +time.sleep(0.5) + +print(f"Active children: {len(parent._active_children)}") +for i, c in enumerate(parent._active_children): + print(f" Child {i}: _interrupt_requested={c._interrupt_requested}") + +t0 = time.monotonic() +parent.interrupt("User typed a new message") +print(f"Called parent.interrupt()") + +for i, c in enumerate(parent._active_children): + print(f" Child {i} after interrupt: _interrupt_requested={c._interrupt_requested}") +print(f"Global is_interrupted: {is_interrupted()}") + +agent_thread.join(timeout=10) +elapsed = time.monotonic() - t0 +print(f"Agent thread finished in {elapsed:.2f}s") + +result = result_holder[0] +if result: + print(f"Status: {result['status']}") + print(f"Duration: {result['duration_seconds']}s") + if elapsed < 2.0: + print("✅ PASS: Interrupt detected quickly!") + else: + print(f"❌ FAIL: Took {elapsed:.2f}s — interrupt was too slow or not detected") +else: + print("❌ FAIL: No result!") + +set_interrupt(False) diff --git a/tests/test_cli_interrupt_subagent.py b/tests/test_cli_interrupt_subagent.py new file mode 100644 index 000000000..b91a7b654 --- /dev/null +++ b/tests/test_cli_interrupt_subagent.py @@ -0,0 +1,171 @@ +"""End-to-end test simulating CLI interrupt during subagent execution. + +Reproduces the exact scenario: +1. Parent agent calls delegate_task +2. Child agent is running (simulated with a slow tool) +3. User "types a message" (simulated by calling parent.interrupt from another thread) +4. Child should detect the interrupt and stop + +This tests the COMPLETE path including _run_single_child, _active_children +registration, interrupt propagation, and child detection. +""" + +import json +import os +import queue +import threading +import time +import unittest +from unittest.mock import MagicMock, patch, PropertyMock + +from tools.interrupt import set_interrupt, is_interrupted + + +class TestCLISubagentInterrupt(unittest.TestCase): + """Simulate exact CLI scenario.""" + + def setUp(self): + set_interrupt(False) + + def tearDown(self): + set_interrupt(False) + + def test_full_delegate_interrupt_flow(self): + """Full integration: parent runs delegate_task, main thread interrupts.""" + from run_agent import AIAgent + + interrupt_detected = threading.Event() + child_started = threading.Event() + child_api_call_count = 0 + + # Create a real-enough parent agent + parent = AIAgent.__new__(AIAgent) + parent._interrupt_requested = False + parent._interrupt_message = None + parent._active_children = [] + parent.quiet_mode = True + parent.model = "test/model" + parent.base_url = "http://localhost:1" + parent.api_key = "test" + parent.provider = "test" + parent.api_mode = "chat_completions" + parent.platform = "cli" + parent.enabled_toolsets = ["terminal", "file"] + parent.providers_allowed = None + parent.providers_ignored = None + parent.providers_order = None + parent.provider_sort = None + parent.max_tokens = None + parent.reasoning_config = None + parent.prefill_messages = None + parent._session_db = None + parent._delegate_depth = 0 + parent._delegate_spinner = None + parent.tool_progress_callback = None + + # We'll track what happens with _active_children + original_children = parent._active_children + + # Mock the child's run_conversation to simulate a slow operation + # that checks _interrupt_requested like the real one does + def mock_child_run_conversation(user_message, **kwargs): + child_started.set() + # Find the child in parent._active_children + child = parent._active_children[-1] if parent._active_children else None + + # Simulate the agent loop: poll _interrupt_requested like run_conversation does + for i in range(100): # Up to 10 seconds (100 * 0.1s) + if child and child._interrupt_requested: + interrupt_detected.set() + return { + "final_response": "Interrupted!", + "messages": [], + "api_calls": 1, + "completed": False, + "interrupted": True, + "interrupt_message": child._interrupt_message, + } + time.sleep(0.1) + + return { + "final_response": "Finished without interrupt", + "messages": [], + "api_calls": 5, + "completed": True, + "interrupted": False, + } + + # Patch AIAgent to use our mock + from tools.delegate_tool import _run_single_child + from run_agent import IterationBudget + + parent.iteration_budget = IterationBudget(max_total=100) + + # Run delegate in a thread (simulates agent_thread) + delegate_result = [None] + delegate_error = [None] + + def run_delegate(): + try: + with patch('run_agent.AIAgent') as MockAgent: + mock_instance = MagicMock() + mock_instance._interrupt_requested = False + mock_instance._interrupt_message = None + mock_instance._active_children = [] + mock_instance.quiet_mode = True + mock_instance.run_conversation = mock_child_run_conversation + mock_instance.interrupt = lambda msg=None: setattr(mock_instance, '_interrupt_requested', True) or setattr(mock_instance, '_interrupt_message', msg) + mock_instance.tools = [] + MockAgent.return_value = mock_instance + + result = _run_single_child( + task_index=0, + goal="Do something slow", + context=None, + toolsets=["terminal"], + model=None, + max_iterations=50, + parent_agent=parent, + task_count=1, + ) + delegate_result[0] = result + except Exception as e: + delegate_error[0] = e + + agent_thread = threading.Thread(target=run_delegate, daemon=True) + agent_thread.start() + + # Wait for child to start + assert child_started.wait(timeout=5), "Child never started!" + + # Now simulate user interrupt (from main/process thread) + time.sleep(0.2) # Give child a moment to be in its loop + + print(f"Parent has {len(parent._active_children)} active children") + assert len(parent._active_children) >= 1, f"Expected child in _active_children, got {len(parent._active_children)}" + + # This is what the CLI does: + parent.interrupt("Hey stop that") + + print(f"Parent._interrupt_requested: {parent._interrupt_requested}") + for i, child in enumerate(parent._active_children): + print(f"Child {i}._interrupt_requested: {child._interrupt_requested}") + + # Wait for child to detect interrupt + detected = interrupt_detected.wait(timeout=3.0) + + # Wait for delegate to finish + agent_thread.join(timeout=5) + + if delegate_error[0]: + raise delegate_error[0] + + assert detected, "Child never detected the interrupt!" + result = delegate_result[0] + assert result is not None, "Delegate returned no result" + assert result["status"] == "interrupted", f"Expected 'interrupted', got '{result['status']}'" + print(f"✓ Interrupt detected! Result: {result}") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_interactive_interrupt.py b/tests/test_interactive_interrupt.py new file mode 100644 index 000000000..bb90c7452 --- /dev/null +++ b/tests/test_interactive_interrupt.py @@ -0,0 +1,189 @@ +#!/usr/bin/env python3 +"""Interactive interrupt test that mimics the exact CLI flow. + +Starts an agent in a thread with a mock delegate_task that takes a while, +then simulates the user typing a message via _interrupt_queue. + +Logs every step to stderr (which isn't affected by redirect_stdout) +so we can see exactly where the interrupt gets lost. +""" + +import contextlib +import io +import json +import logging +import queue +import sys +import threading +import time +import os + +# Force stderr logging so redirect_stdout doesn't swallow it +logging.basicConfig(level=logging.DEBUG, stream=sys.stderr, + format="%(asctime)s [%(threadName)s] %(message)s") +log = logging.getLogger("interrupt_test") + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from unittest.mock import MagicMock, patch +from run_agent import AIAgent, IterationBudget +from tools.interrupt import set_interrupt, is_interrupted + +set_interrupt(False) + +# ─── Create parent agent ─── +parent = AIAgent.__new__(AIAgent) +parent._interrupt_requested = False +parent._interrupt_message = None +parent._active_children = [] +parent.quiet_mode = True +parent.model = "test/model" +parent.base_url = "http://localhost:1" +parent.api_key = "test" +parent.provider = "test" +parent.api_mode = "chat_completions" +parent.platform = "cli" +parent.enabled_toolsets = ["terminal", "file"] +parent.providers_allowed = None +parent.providers_ignored = None +parent.providers_order = None +parent.provider_sort = None +parent.max_tokens = None +parent.reasoning_config = None +parent.prefill_messages = None +parent._session_db = None +parent._delegate_depth = 0 +parent._delegate_spinner = None +parent.tool_progress_callback = None +parent.iteration_budget = IterationBudget(max_total=100) +parent._client_kwargs = {"api_key": "test", "base_url": "http://localhost:1"} + +# Monkey-patch parent.interrupt to log +_original_interrupt = AIAgent.interrupt +def logged_interrupt(self, message=None): + log.info(f"🔴 parent.interrupt() called with: {message!r}") + log.info(f" _active_children count: {len(self._active_children)}") + _original_interrupt(self, message) + log.info(f" After interrupt: _interrupt_requested={self._interrupt_requested}") + for i, c in enumerate(self._active_children): + log.info(f" Child {i}._interrupt_requested={c._interrupt_requested}") +parent.interrupt = lambda msg=None: logged_interrupt(parent, msg) + +# ─── Simulate the exact CLI flow ─── +interrupt_queue = queue.Queue() +child_running = threading.Event() +agent_result = [None] + +def make_slow_response(delay=2.0): + """API response that takes a while.""" + def create(**kwargs): + log.info(f" 🌐 Mock API call starting (will take {delay}s)...") + time.sleep(delay) + log.info(f" 🌐 Mock API call completed") + resp = MagicMock() + resp.choices = [MagicMock()] + resp.choices[0].message.content = "Done with the task" + resp.choices[0].message.tool_calls = None + resp.choices[0].message.refusal = None + resp.choices[0].finish_reason = "stop" + resp.usage.prompt_tokens = 100 + resp.usage.completion_tokens = 10 + resp.usage.total_tokens = 110 + resp.usage.prompt_tokens_details = None + return resp + return create + + +def agent_thread_func(): + """Simulates the agent_thread in cli.py's chat() method.""" + log.info("🟢 agent_thread starting") + + with patch("run_agent.OpenAI") as MockOpenAI: + mock_client = MagicMock() + mock_client.chat.completions.create = make_slow_response(delay=3.0) + mock_client.close = MagicMock() + MockOpenAI.return_value = mock_client + + from tools.delegate_tool import _run_single_child + + # Signal that child is about to start + original_init = AIAgent.__init__ + def patched_init(self_agent, *a, **kw): + log.info("🟡 Child AIAgent.__init__ called") + original_init(self_agent, *a, **kw) + child_running.set() + log.info(f"🟡 Child started, parent._active_children = {len(parent._active_children)}") + + with patch.object(AIAgent, "__init__", patched_init): + result = _run_single_child( + task_index=0, + goal="Do a slow thing", + context=None, + toolsets=["terminal"], + model="test/model", + max_iterations=3, + parent_agent=parent, + task_count=1, + override_provider="test", + override_base_url="http://localhost:1", + override_api_key="test", + override_api_mode="chat_completions", + ) + agent_result[0] = result + log.info(f"🟢 agent_thread finished. Result status: {result.get('status')}") + + +# ─── Start agent thread (like chat() does) ─── +agent_thread = threading.Thread(target=agent_thread_func, name="agent_thread", daemon=True) +agent_thread.start() + +# ─── Wait for child to start ─── +if not child_running.wait(timeout=10): + print("FAIL: Child never started", file=sys.stderr) + sys.exit(1) + +# Give child time to enter its main loop and start API call +time.sleep(1.0) + +# ─── Simulate user typing a message (like handle_enter does) ─── +log.info("📝 Simulating user typing 'Hey stop that'") +interrupt_queue.put("Hey stop that") + +# ─── Simulate chat() polling loop (like the real chat() method) ─── +log.info("📡 Starting interrupt queue polling (like chat())") +interrupt_msg = None +poll_count = 0 +while agent_thread.is_alive(): + try: + interrupt_msg = interrupt_queue.get(timeout=0.1) + if interrupt_msg: + log.info(f"📨 Got interrupt message from queue: {interrupt_msg!r}") + log.info(f" Calling parent.interrupt()...") + parent.interrupt(interrupt_msg) + log.info(f" parent.interrupt() returned. Breaking poll loop.") + break + except queue.Empty: + poll_count += 1 + if poll_count % 20 == 0: # Log every 2s + log.info(f" Still polling ({poll_count} iterations)...") + +# ─── Wait for agent to finish ─── +log.info("⏳ Waiting for agent_thread to join...") +t0 = time.monotonic() +agent_thread.join(timeout=10) +elapsed = time.monotonic() - t0 +log.info(f"✅ agent_thread joined after {elapsed:.2f}s") + +# ─── Check results ─── +result = agent_result[0] +if result: + log.info(f"Result status: {result['status']}") + log.info(f"Result duration: {result['duration_seconds']}s") + if result["status"] == "interrupted" and elapsed < 2.0: + print("✅ PASS: Interrupt worked correctly!", file=sys.stderr) + else: + print(f"❌ FAIL: status={result['status']}, elapsed={elapsed:.2f}s", file=sys.stderr) +else: + print("❌ FAIL: No result returned", file=sys.stderr) + +set_interrupt(False) diff --git a/tests/test_interrupt_propagation.py b/tests/test_interrupt_propagation.py new file mode 100644 index 000000000..ff1cafdc8 --- /dev/null +++ b/tests/test_interrupt_propagation.py @@ -0,0 +1,155 @@ +"""Test interrupt propagation from parent to child agents. + +Reproduces the CLI scenario: user sends a message while delegate_task is +running, main thread calls parent.interrupt(), child should stop. +""" + +import json +import threading +import time +import unittest +from unittest.mock import MagicMock, patch, PropertyMock + +from tools.interrupt import set_interrupt, is_interrupted, _interrupt_event + + +class TestInterruptPropagationToChild(unittest.TestCase): + """Verify interrupt propagates from parent to child agent.""" + + def setUp(self): + set_interrupt(False) + + def tearDown(self): + set_interrupt(False) + + def test_parent_interrupt_sets_child_flag(self): + """When parent.interrupt() is called, child._interrupt_requested should be set.""" + from run_agent import AIAgent + + parent = AIAgent.__new__(AIAgent) + parent._interrupt_requested = False + parent._interrupt_message = None + parent._active_children = [] + parent.quiet_mode = True + + child = AIAgent.__new__(AIAgent) + child._interrupt_requested = False + child._interrupt_message = None + child._active_children = [] + child.quiet_mode = True + + parent._active_children.append(child) + + parent.interrupt("new user message") + + assert parent._interrupt_requested is True + assert child._interrupt_requested is True + assert child._interrupt_message == "new user message" + assert is_interrupted() is True + + def test_child_clear_interrupt_at_start_clears_global(self): + """child.clear_interrupt() at start of run_conversation clears the GLOBAL event. + + This is the intended behavior at startup, but verify it doesn't + accidentally clear an interrupt intended for a running child. + """ + from run_agent import AIAgent + + child = AIAgent.__new__(AIAgent) + child._interrupt_requested = True + child._interrupt_message = "msg" + child.quiet_mode = True + child._active_children = [] + + # Global is set + set_interrupt(True) + assert is_interrupted() is True + + # child.clear_interrupt() clears both + child.clear_interrupt() + assert child._interrupt_requested is False + assert is_interrupted() is False + + def test_interrupt_during_child_api_call_detected(self): + """Interrupt set during _interruptible_api_call is detected within 0.5s.""" + from run_agent import AIAgent + + child = AIAgent.__new__(AIAgent) + child._interrupt_requested = False + child._interrupt_message = None + child._active_children = [] + child.quiet_mode = True + child.api_mode = "chat_completions" + child.log_prefix = "" + child._client_kwargs = {"api_key": "test", "base_url": "http://localhost:1234"} + + # Mock a slow API call + mock_client = MagicMock() + def slow_api_call(**kwargs): + time.sleep(5) # Would take 5s normally + return MagicMock() + mock_client.chat.completions.create = slow_api_call + mock_client.close = MagicMock() + child.client = mock_client + + # Set interrupt after 0.2s from another thread + def set_interrupt_later(): + time.sleep(0.2) + child.interrupt("stop!") + t = threading.Thread(target=set_interrupt_later, daemon=True) + t.start() + + start = time.monotonic() + try: + child._interruptible_api_call({"model": "test", "messages": []}) + self.fail("Should have raised InterruptedError") + except InterruptedError: + elapsed = time.monotonic() - start + # Should detect within ~0.5s (0.2s delay + 0.3s poll interval) + assert elapsed < 1.0, f"Took {elapsed:.2f}s to detect interrupt (expected < 1.0s)" + finally: + t.join(timeout=2) + set_interrupt(False) + + def test_concurrent_interrupt_propagation(self): + """Simulates exact CLI flow: parent runs delegate in thread, main thread interrupts.""" + from run_agent import AIAgent + + parent = AIAgent.__new__(AIAgent) + parent._interrupt_requested = False + parent._interrupt_message = None + parent._active_children = [] + parent.quiet_mode = True + + child = AIAgent.__new__(AIAgent) + child._interrupt_requested = False + child._interrupt_message = None + child._active_children = [] + child.quiet_mode = True + + # Register child (simulating what _run_single_child does) + parent._active_children.append(child) + + # Simulate child running (checking flag in a loop) + child_detected = threading.Event() + def simulate_child_loop(): + while not child._interrupt_requested: + time.sleep(0.05) + child_detected.set() + + child_thread = threading.Thread(target=simulate_child_loop, daemon=True) + child_thread.start() + + # Small delay, then interrupt from "main thread" + time.sleep(0.1) + parent.interrupt("user typed something new") + + # Child should detect within 200ms + detected = child_detected.wait(timeout=1.0) + assert detected, "Child never detected the interrupt!" + child_thread.join(timeout=1) + set_interrupt(False) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_real_interrupt_subagent.py b/tests/test_real_interrupt_subagent.py new file mode 100644 index 000000000..f665a006b --- /dev/null +++ b/tests/test_real_interrupt_subagent.py @@ -0,0 +1,176 @@ +"""Test real interrupt propagation through delegate_task with actual AIAgent. + +This uses a real AIAgent with mocked HTTP responses to test the complete +interrupt flow through _run_single_child → child.run_conversation(). +""" + +import json +import os +import threading +import time +import unittest +from unittest.mock import MagicMock, patch, PropertyMock + +from tools.interrupt import set_interrupt, is_interrupted + + +def _make_slow_api_response(delay=5.0): + """Create a mock that simulates a slow API response (like a real LLM call).""" + def slow_create(**kwargs): + # Simulate a slow API call + time.sleep(delay) + # Return a simple text response (no tool calls) + resp = MagicMock() + resp.choices = [MagicMock()] + resp.choices[0].message = MagicMock() + resp.choices[0].message.content = "Done" + resp.choices[0].message.tool_calls = None + resp.choices[0].message.refusal = None + resp.choices[0].finish_reason = "stop" + resp.usage = MagicMock() + resp.usage.prompt_tokens = 100 + resp.usage.completion_tokens = 10 + resp.usage.total_tokens = 110 + resp.usage.prompt_tokens_details = None + return resp + return slow_create + + +class TestRealSubagentInterrupt(unittest.TestCase): + """Test interrupt with real AIAgent child through delegate_tool.""" + + def setUp(self): + set_interrupt(False) + os.environ.setdefault("OPENAI_API_KEY", "test-key") + + def tearDown(self): + set_interrupt(False) + + def test_interrupt_child_during_api_call(self): + """Real AIAgent child interrupted while making API call.""" + from run_agent import AIAgent, IterationBudget + + # Create a real parent agent (just enough to be a parent) + parent = AIAgent.__new__(AIAgent) + parent._interrupt_requested = False + parent._interrupt_message = None + parent._active_children = [] + parent.quiet_mode = True + parent.model = "test/model" + parent.base_url = "http://localhost:1" + parent.api_key = "test" + parent.provider = "test" + parent.api_mode = "chat_completions" + parent.platform = "cli" + parent.enabled_toolsets = ["terminal", "file"] + parent.providers_allowed = None + parent.providers_ignored = None + parent.providers_order = None + parent.provider_sort = None + parent.max_tokens = None + parent.reasoning_config = None + parent.prefill_messages = None + parent._session_db = None + parent._delegate_depth = 0 + parent._delegate_spinner = None + parent.tool_progress_callback = None + parent.iteration_budget = IterationBudget(max_total=100) + parent._client_kwargs = {"api_key": "test", "base_url": "http://localhost:1"} + + from tools.delegate_tool import _run_single_child + + child_started = threading.Event() + result_holder = [None] + error_holder = [None] + + def run_delegate(): + try: + # Patch the OpenAI client creation inside AIAgent.__init__ + with patch('run_agent.OpenAI') as MockOpenAI: + mock_client = MagicMock() + # API call takes 5 seconds — should be interrupted before that + mock_client.chat.completions.create = _make_slow_api_response(delay=5.0) + mock_client.close = MagicMock() + MockOpenAI.return_value = mock_client + + # Also need to patch the system prompt builder + with patch('run_agent.build_system_prompt', return_value="You are a test agent"): + # Signal when child starts + original_run = AIAgent.run_conversation + + def patched_run(self_agent, *args, **kwargs): + child_started.set() + return original_run(self_agent, *args, **kwargs) + + with patch.object(AIAgent, 'run_conversation', patched_run): + result = _run_single_child( + task_index=0, + goal="Test task", + context=None, + toolsets=["terminal"], + model="test/model", + max_iterations=5, + parent_agent=parent, + task_count=1, + override_provider="test", + override_base_url="http://localhost:1", + override_api_key="test", + override_api_mode="chat_completions", + ) + result_holder[0] = result + except Exception as e: + import traceback + traceback.print_exc() + error_holder[0] = e + + agent_thread = threading.Thread(target=run_delegate, daemon=True) + agent_thread.start() + + # Wait for child to start run_conversation + started = child_started.wait(timeout=10) + if not started: + agent_thread.join(timeout=1) + if error_holder[0]: + raise error_holder[0] + self.fail("Child never started run_conversation") + + # Give child time to enter main loop and start API call + time.sleep(0.5) + + # Verify child is registered + print(f"Active children: {len(parent._active_children)}") + self.assertGreaterEqual(len(parent._active_children), 1, + "Child not registered in _active_children") + + # Interrupt! (simulating what CLI does) + start = time.monotonic() + parent.interrupt("User typed a new message") + + # Check propagation + child = parent._active_children[0] if parent._active_children else None + if child: + print(f"Child._interrupt_requested after parent.interrupt(): {child._interrupt_requested}") + self.assertTrue(child._interrupt_requested, + "Interrupt did not propagate to child!") + + # Wait for delegate to finish (should be fast since interrupted) + agent_thread.join(timeout=5) + elapsed = time.monotonic() - start + + if error_holder[0]: + raise error_holder[0] + + result = result_holder[0] + self.assertIsNotNone(result, "Delegate returned no result") + print(f"Result status: {result['status']}, elapsed: {elapsed:.2f}s") + print(f"Full result: {result}") + + # The child should have been interrupted, not completed the full 5s API call + self.assertLess(elapsed, 3.0, + f"Took {elapsed:.2f}s — interrupt was not detected quickly enough") + self.assertEqual(result["status"], "interrupted", + f"Expected 'interrupted', got '{result['status']}'") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_redirect_stdout_issue.py b/tests/test_redirect_stdout_issue.py new file mode 100644 index 000000000..8501add63 --- /dev/null +++ b/tests/test_redirect_stdout_issue.py @@ -0,0 +1,54 @@ +"""Verify that redirect_stdout in _run_single_child is process-wide. + +This demonstrates that contextlib.redirect_stdout changes sys.stdout +for ALL threads, not just the current one. This means during subagent +execution, all output from other threads (including the CLI's process_thread) +is swallowed. +""" + +import contextlib +import io +import sys +import threading +import time +import unittest + + +class TestRedirectStdoutIsProcessWide(unittest.TestCase): + + def test_redirect_stdout_affects_other_threads(self): + """contextlib.redirect_stdout changes sys.stdout for ALL threads.""" + captured_from_other_thread = [] + real_stdout = sys.stdout + other_thread_saw_devnull = threading.Event() + + def other_thread_work(): + """Runs in a different thread, tries to use sys.stdout.""" + time.sleep(0.2) # Let redirect_stdout take effect + # Check what sys.stdout is + if sys.stdout is not real_stdout: + other_thread_saw_devnull.set() + # Try to print — this should go to devnull + captured_from_other_thread.append(sys.stdout) + + t = threading.Thread(target=other_thread_work, daemon=True) + t.start() + + # redirect_stdout in main thread + devnull = io.StringIO() + with contextlib.redirect_stdout(devnull): + time.sleep(0.5) # Let the other thread check during redirect + + t.join(timeout=2) + + # The other thread should have seen devnull, NOT the real stdout + self.assertTrue( + other_thread_saw_devnull.is_set(), + "redirect_stdout was NOT process-wide — other thread still saw real stdout. " + "This test's premise is wrong." + ) + print("Confirmed: redirect_stdout IS process-wide — affects all threads") + + +if __name__ == "__main__": + unittest.main()