diff --git a/gateway/platforms/base.py b/gateway/platforms/base.py index 91ac5d30c..f103fb8b9 100644 --- a/gateway/platforms/base.py +++ b/gateway/platforms/base.py @@ -356,6 +356,10 @@ class BasePlatformAdapter(ABC): # Key: session_key (e.g., chat_id), Value: (event, asyncio.Event for interrupt) self._active_sessions: Dict[str, asyncio.Event] = {} self._pending_messages: Dict[str, MessageEvent] = {} + # Background message-processing tasks spawned by handle_message(). + # Gateway shutdown cancels these so an old gateway instance doesn't keep + # working on a task after --replace or manual restarts. + self._background_tasks: set[asyncio.Task] = set() # Chats where auto-TTS on voice input is disabled (set by /voice off) self._auto_tts_disabled_chats: set = set() @@ -778,7 +782,15 @@ class BasePlatformAdapter(ABC): return # Don't process now - will be handled after current task finishes # Spawn background task to process this message - asyncio.create_task(self._process_message_background(event, session_key)) + task = asyncio.create_task(self._process_message_background(event, session_key)) + try: + self._background_tasks.add(task) + except TypeError: + # Some tests stub create_task() with lightweight sentinels that are not + # hashable and do not support lifecycle callbacks. + return + if hasattr(task, "add_done_callback"): + task.add_done_callback(self._background_tasks.discard) @staticmethod def _get_human_delay() -> float: @@ -988,6 +1000,21 @@ class BasePlatformAdapter(ABC): if session_key in self._active_sessions: del self._active_sessions[session_key] + async def cancel_background_tasks(self) -> None: + """Cancel any in-flight background message-processing tasks. + + Used during gateway shutdown/replacement so active sessions from the old + process do not keep running after adapters are being torn down. + """ + tasks = [task for task in self._background_tasks if not task.done()] + for task in tasks: + task.cancel() + if tasks: + await asyncio.gather(*tasks, return_exceptions=True) + self._background_tasks.clear() + self._pending_messages.clear() + self._active_sessions.clear() + def has_pending_interrupt(self, session_key: str) -> bool: """Check if there's a pending interrupt for a session.""" return session_key in self._active_sessions and self._active_sessions[session_key].is_set() diff --git a/gateway/run.py b/gateway/run.py index 8508e0f8a..716e981f2 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -900,8 +900,19 @@ class GatewayRunner: """Stop the gateway and disconnect all adapters.""" logger.info("Stopping gateway...") self._running = False - + + for session_key, agent in list(self._running_agents.items()): + try: + agent.interrupt("Gateway shutting down") + logger.debug("Interrupted running agent for session %s during shutdown", session_key[:20]) + except Exception as e: + logger.debug("Failed interrupting agent during shutdown: %s", e) + for platform, adapter in list(self.adapters.items()): + try: + await adapter.cancel_background_tasks() + except Exception as e: + logger.debug("✗ %s background-task cancel error: %s", platform.value, e) try: await adapter.disconnect() logger.info("✓ %s disconnected", platform.value) @@ -909,6 +920,9 @@ class GatewayRunner: logger.error("✗ %s disconnect error: %s", platform.value, e) self.adapters.clear() + self._running_agents.clear() + self._pending_messages.clear() + self._pending_approvals.clear() self._shutdown_all_gateway_honcho() self._shutdown_event.set() diff --git a/tests/gateway/test_gateway_shutdown.py b/tests/gateway/test_gateway_shutdown.py new file mode 100644 index 000000000..15e2e6634 --- /dev/null +++ b/tests/gateway/test_gateway_shutdown.py @@ -0,0 +1,106 @@ +import asyncio +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from gateway.config import GatewayConfig, Platform, PlatformConfig +from gateway.platforms.base import BasePlatformAdapter, MessageEvent, SendResult +from gateway.run import GatewayRunner +from gateway.session import SessionSource, build_session_key + + +class StubAdapter(BasePlatformAdapter): + def __init__(self): + super().__init__(PlatformConfig(enabled=True, token="***"), Platform.TELEGRAM) + + async def connect(self): + return True + + async def disconnect(self): + return None + + async def send(self, chat_id, content, reply_to=None, metadata=None): + return SendResult(success=True, message_id="1") + + async def send_typing(self, chat_id, metadata=None): + return None + + async def get_chat_info(self, chat_id): + return {"id": chat_id} + + +def _source(chat_id="123456", chat_type="dm"): + return SessionSource( + platform=Platform.TELEGRAM, + chat_id=chat_id, + chat_type=chat_type, + ) + + +@pytest.mark.asyncio +async def test_cancel_background_tasks_cancels_inflight_message_processing(): + adapter = StubAdapter() + release = asyncio.Event() + + async def block_forever(_event): + await release.wait() + return None + + adapter.set_message_handler(block_forever) + event = MessageEvent(text="work", source=_source(), message_id="1") + + await adapter.handle_message(event) + await asyncio.sleep(0) + + session_key = build_session_key(event.source) + assert session_key in adapter._active_sessions + assert adapter._background_tasks + + await adapter.cancel_background_tasks() + + assert adapter._background_tasks == set() + assert adapter._active_sessions == {} + assert adapter._pending_messages == {} + + +@pytest.mark.asyncio +async def test_gateway_stop_interrupts_running_agents_and_cancels_adapter_tasks(): + runner = object.__new__(GatewayRunner) + runner.config = GatewayConfig(platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="***")}) + runner._running = True + runner._shutdown_event = asyncio.Event() + runner._exit_reason = None + runner._pending_messages = {"session": "pending text"} + runner._pending_approvals = {"session": {"command": "rm -rf /tmp/x"}} + runner._shutdown_all_gateway_honcho = lambda: None + + adapter = StubAdapter() + release = asyncio.Event() + + async def block_forever(_event): + await release.wait() + return None + + adapter.set_message_handler(block_forever) + event = MessageEvent(text="work", source=_source(), message_id="1") + await adapter.handle_message(event) + await asyncio.sleep(0) + + disconnect_mock = AsyncMock() + adapter.disconnect = disconnect_mock + + session_key = build_session_key(event.source) + running_agent = MagicMock() + runner._running_agents = {session_key: running_agent} + runner.adapters = {Platform.TELEGRAM: adapter} + + with patch("gateway.status.remove_pid_file"), patch("gateway.status.write_runtime_status"): + await runner.stop() + + running_agent.interrupt.assert_called_once_with("Gateway shutting down") + disconnect_mock.assert_awaited_once() + assert runner.adapters == {} + assert runner._running_agents == {} + assert runner._pending_messages == {} + assert runner._pending_approvals == {} + assert runner._shutdown_event.is_set() is True