fix(gateway): cancel active runs during shutdown

Track adapter background message-processing tasks, cancel them during gateway shutdown, and interrupt running agents before disconnecting adapters. This prevents old gateway instances from continuing in-flight work after stop/replace, which was contributing to the restart-time task continuation/flicker behavior reported in #1414. Adds regression coverage for adapter task cancellation and shutdown interrupts.
This commit is contained in:
teknium1
2026-03-15 04:21:50 -07:00
parent 621fd80b1e
commit 21c20aeaa5
3 changed files with 149 additions and 2 deletions

View File

@@ -356,6 +356,10 @@ class BasePlatformAdapter(ABC):
# Key: session_key (e.g., chat_id), Value: (event, asyncio.Event for interrupt)
self._active_sessions: Dict[str, asyncio.Event] = {}
self._pending_messages: Dict[str, MessageEvent] = {}
# Background message-processing tasks spawned by handle_message().
# Gateway shutdown cancels these so an old gateway instance doesn't keep
# working on a task after --replace or manual restarts.
self._background_tasks: set[asyncio.Task] = set()
# Chats where auto-TTS on voice input is disabled (set by /voice off)
self._auto_tts_disabled_chats: set = set()
@@ -778,7 +782,15 @@ class BasePlatformAdapter(ABC):
return # Don't process now - will be handled after current task finishes
# Spawn background task to process this message
asyncio.create_task(self._process_message_background(event, session_key))
task = asyncio.create_task(self._process_message_background(event, session_key))
try:
self._background_tasks.add(task)
except TypeError:
# Some tests stub create_task() with lightweight sentinels that are not
# hashable and do not support lifecycle callbacks.
return
if hasattr(task, "add_done_callback"):
task.add_done_callback(self._background_tasks.discard)
@staticmethod
def _get_human_delay() -> float:
@@ -988,6 +1000,21 @@ class BasePlatformAdapter(ABC):
if session_key in self._active_sessions:
del self._active_sessions[session_key]
async def cancel_background_tasks(self) -> None:
"""Cancel any in-flight background message-processing tasks.
Used during gateway shutdown/replacement so active sessions from the old
process do not keep running after adapters are being torn down.
"""
tasks = [task for task in self._background_tasks if not task.done()]
for task in tasks:
task.cancel()
if tasks:
await asyncio.gather(*tasks, return_exceptions=True)
self._background_tasks.clear()
self._pending_messages.clear()
self._active_sessions.clear()
def has_pending_interrupt(self, session_key: str) -> bool:
"""Check if there's a pending interrupt for a session."""
return session_key in self._active_sessions and self._active_sessions[session_key].is_set()

View File

@@ -900,8 +900,19 @@ class GatewayRunner:
"""Stop the gateway and disconnect all adapters."""
logger.info("Stopping gateway...")
self._running = False
for session_key, agent in list(self._running_agents.items()):
try:
agent.interrupt("Gateway shutting down")
logger.debug("Interrupted running agent for session %s during shutdown", session_key[:20])
except Exception as e:
logger.debug("Failed interrupting agent during shutdown: %s", e)
for platform, adapter in list(self.adapters.items()):
try:
await adapter.cancel_background_tasks()
except Exception as e:
logger.debug("%s background-task cancel error: %s", platform.value, e)
try:
await adapter.disconnect()
logger.info("%s disconnected", platform.value)
@@ -909,6 +920,9 @@ class GatewayRunner:
logger.error("%s disconnect error: %s", platform.value, e)
self.adapters.clear()
self._running_agents.clear()
self._pending_messages.clear()
self._pending_approvals.clear()
self._shutdown_all_gateway_honcho()
self._shutdown_event.set()

View File

@@ -0,0 +1,106 @@
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from gateway.config import GatewayConfig, Platform, PlatformConfig
from gateway.platforms.base import BasePlatformAdapter, MessageEvent, SendResult
from gateway.run import GatewayRunner
from gateway.session import SessionSource, build_session_key
class StubAdapter(BasePlatformAdapter):
def __init__(self):
super().__init__(PlatformConfig(enabled=True, token="***"), Platform.TELEGRAM)
async def connect(self):
return True
async def disconnect(self):
return None
async def send(self, chat_id, content, reply_to=None, metadata=None):
return SendResult(success=True, message_id="1")
async def send_typing(self, chat_id, metadata=None):
return None
async def get_chat_info(self, chat_id):
return {"id": chat_id}
def _source(chat_id="123456", chat_type="dm"):
return SessionSource(
platform=Platform.TELEGRAM,
chat_id=chat_id,
chat_type=chat_type,
)
@pytest.mark.asyncio
async def test_cancel_background_tasks_cancels_inflight_message_processing():
adapter = StubAdapter()
release = asyncio.Event()
async def block_forever(_event):
await release.wait()
return None
adapter.set_message_handler(block_forever)
event = MessageEvent(text="work", source=_source(), message_id="1")
await adapter.handle_message(event)
await asyncio.sleep(0)
session_key = build_session_key(event.source)
assert session_key in adapter._active_sessions
assert adapter._background_tasks
await adapter.cancel_background_tasks()
assert adapter._background_tasks == set()
assert adapter._active_sessions == {}
assert adapter._pending_messages == {}
@pytest.mark.asyncio
async def test_gateway_stop_interrupts_running_agents_and_cancels_adapter_tasks():
runner = object.__new__(GatewayRunner)
runner.config = GatewayConfig(platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="***")})
runner._running = True
runner._shutdown_event = asyncio.Event()
runner._exit_reason = None
runner._pending_messages = {"session": "pending text"}
runner._pending_approvals = {"session": {"command": "rm -rf /tmp/x"}}
runner._shutdown_all_gateway_honcho = lambda: None
adapter = StubAdapter()
release = asyncio.Event()
async def block_forever(_event):
await release.wait()
return None
adapter.set_message_handler(block_forever)
event = MessageEvent(text="work", source=_source(), message_id="1")
await adapter.handle_message(event)
await asyncio.sleep(0)
disconnect_mock = AsyncMock()
adapter.disconnect = disconnect_mock
session_key = build_session_key(event.source)
running_agent = MagicMock()
runner._running_agents = {session_key: running_agent}
runner.adapters = {Platform.TELEGRAM: adapter}
with patch("gateway.status.remove_pid_file"), patch("gateway.status.write_runtime_status"):
await runner.stop()
running_agent.interrupt.assert_called_once_with("Gateway shutting down")
disconnect_mock.assert_awaited_once()
assert runner.adapters == {}
assert runner._running_agents == {}
assert runner._pending_messages == {}
assert runner._pending_approvals == {}
assert runner._shutdown_event.is_set() is True