diff --git a/gateway/platforms/whatsapp.py b/gateway/platforms/whatsapp.py index d9ac0077c..02448a6dd 100644 --- a/gateway/platforms/whatsapp.py +++ b/gateway/platforms/whatsapp.py @@ -142,6 +142,7 @@ class WhatsAppAdapter(BasePlatformAdapter): self._bridge_log_fh = None self._bridge_log: Optional[Path] = None self._poll_task: Optional[asyncio.Task] = None + self._http_session: Optional["aiohttp.ClientSession"] = None self._session_lock_identity: Optional[str] = None async def connect(self) -> bool: @@ -224,6 +225,7 @@ class WhatsAppAdapter(BasePlatformAdapter): print(f"[{self.name}] Using existing bridge (status: {bridge_status})") self._mark_connected() self._bridge_process = None # Not managed by us + self._http_session = aiohttp.ClientSession() self._poll_task = asyncio.create_task(self._poll_messages()) return True else: @@ -329,6 +331,9 @@ class WhatsAppAdapter(BasePlatformAdapter): print(f"[{self.name}] Bridge log: {self._bridge_log}") print(f"[{self.name}] If session expired, re-pair: hermes whatsapp") + # Create a persistent HTTP session for all bridge communication + self._http_session = aiohttp.ClientSession() + # Start message polling task self._poll_task = asyncio.create_task(self._poll_messages()) @@ -400,7 +405,21 @@ class WhatsAppAdapter(BasePlatformAdapter): else: # Bridge was not started by us, don't kill it print(f"[{self.name}] Disconnecting (external bridge left running)") - + + # Cancel the poll task explicitly + if self._poll_task and not self._poll_task.done(): + self._poll_task.cancel() + try: + await self._poll_task + except (asyncio.CancelledError, Exception): + pass + self._poll_task = None + + # Close the persistent HTTP session + if self._http_session and not self._http_session.closed: + await self._http_session.close() + self._http_session = None + if self._session_lock_identity: try: from gateway.status import release_scoped_lock @@ -422,7 +441,7 @@ class WhatsAppAdapter(BasePlatformAdapter): metadata: Optional[Dict[str, Any]] = None ) -> SendResult: """Send a message via the WhatsApp bridge.""" - if not self._running: + if not self._running or not self._http_session: return SendResult(success=False, error="Not connected") bridge_exit = await self._check_managed_bridge_exit() if bridge_exit: @@ -430,36 +449,29 @@ class WhatsAppAdapter(BasePlatformAdapter): try: import aiohttp + + payload = { + "chatId": chat_id, + "message": content, + } + if reply_to: + payload["replyTo"] = reply_to - async with aiohttp.ClientSession() as session: - payload = { - "chatId": chat_id, - "message": content, - } - if reply_to: - payload["replyTo"] = reply_to - - async with session.post( - f"http://127.0.0.1:{self._bridge_port}/send", - json=payload, - timeout=aiohttp.ClientTimeout(total=30) - ) as resp: - if resp.status == 200: - data = await resp.json() - return SendResult( - success=True, - message_id=data.get("messageId"), - raw_response=data - ) - else: - error = await resp.text() - return SendResult(success=False, error=error) - - except ImportError: - return SendResult( - success=False, - error="aiohttp not installed. Run: pip install aiohttp" - ) + async with self._http_session.post( + f"http://127.0.0.1:{self._bridge_port}/send", + json=payload, + timeout=aiohttp.ClientTimeout(total=30) + ) as resp: + if resp.status == 200: + data = await resp.json() + return SendResult( + success=True, + message_id=data.get("messageId"), + raw_response=data + ) + else: + error = await resp.text() + return SendResult(success=False, error=error) except Exception as e: return SendResult(success=False, error=str(e)) @@ -470,28 +482,27 @@ class WhatsAppAdapter(BasePlatformAdapter): content: str, ) -> SendResult: """Edit a previously sent message via the WhatsApp bridge.""" - if not self._running: + if not self._running or not self._http_session: return SendResult(success=False, error="Not connected") bridge_exit = await self._check_managed_bridge_exit() if bridge_exit: return SendResult(success=False, error=bridge_exit) try: import aiohttp - async with aiohttp.ClientSession() as session: - async with session.post( - f"http://127.0.0.1:{self._bridge_port}/edit", - json={ - "chatId": chat_id, - "messageId": message_id, - "message": content, - }, - timeout=aiohttp.ClientTimeout(total=15) - ) as resp: - if resp.status == 200: - return SendResult(success=True, message_id=message_id) - else: - error = await resp.text() - return SendResult(success=False, error=error) + async with self._http_session.post( + f"http://127.0.0.1:{self._bridge_port}/edit", + json={ + "chatId": chat_id, + "messageId": message_id, + "message": content, + }, + timeout=aiohttp.ClientTimeout(total=15) + ) as resp: + if resp.status == 200: + return SendResult(success=True, message_id=message_id) + else: + error = await resp.text() + return SendResult(success=False, error=error) except Exception as e: return SendResult(success=False, error=str(e)) @@ -504,7 +515,7 @@ class WhatsAppAdapter(BasePlatformAdapter): file_name: Optional[str] = None, ) -> SendResult: """Send any media file via bridge /send-media endpoint.""" - if not self._running: + if not self._running or not self._http_session: return SendResult(success=False, error="Not connected") bridge_exit = await self._check_managed_bridge_exit() if bridge_exit: @@ -525,22 +536,21 @@ class WhatsAppAdapter(BasePlatformAdapter): if file_name: payload["fileName"] = file_name - async with aiohttp.ClientSession() as session: - async with session.post( - f"http://127.0.0.1:{self._bridge_port}/send-media", - json=payload, - timeout=aiohttp.ClientTimeout(total=120), - ) as resp: - if resp.status == 200: - data = await resp.json() - return SendResult( - success=True, - message_id=data.get("messageId"), - raw_response=data, - ) - else: - error = await resp.text() - return SendResult(success=False, error=error) + async with self._http_session.post( + f"http://127.0.0.1:{self._bridge_port}/send-media", + json=payload, + timeout=aiohttp.ClientTimeout(total=120), + ) as resp: + if resp.status == 200: + data = await resp.json() + return SendResult( + success=True, + message_id=data.get("messageId"), + raw_response=data, + ) + else: + error = await resp.text() + return SendResult(success=False, error=error) except Exception as e: return SendResult(success=False, error=str(e)) @@ -598,45 +608,43 @@ class WhatsAppAdapter(BasePlatformAdapter): async def send_typing(self, chat_id: str, metadata=None) -> None: """Send typing indicator via bridge.""" - if not self._running: + if not self._running or not self._http_session: return if await self._check_managed_bridge_exit(): return try: import aiohttp - - async with aiohttp.ClientSession() as session: - await session.post( - f"http://127.0.0.1:{self._bridge_port}/typing", - json={"chatId": chat_id}, - timeout=aiohttp.ClientTimeout(total=5) - ) + + await self._http_session.post( + f"http://127.0.0.1:{self._bridge_port}/typing", + json={"chatId": chat_id}, + timeout=aiohttp.ClientTimeout(total=5) + ) except Exception: pass # Ignore typing indicator failures async def get_chat_info(self, chat_id: str) -> Dict[str, Any]: """Get information about a WhatsApp chat.""" - if not self._running: + if not self._running or not self._http_session: return {"name": "Unknown", "type": "dm"} if await self._check_managed_bridge_exit(): return {"name": chat_id, "type": "dm"} try: import aiohttp - - async with aiohttp.ClientSession() as session: - async with session.get( - f"http://127.0.0.1:{self._bridge_port}/chat/{chat_id}", - timeout=aiohttp.ClientTimeout(total=10) - ) as resp: - if resp.status == 200: - data = await resp.json() - return { - "name": data.get("name", chat_id), - "type": "group" if data.get("isGroup") else "dm", - "participants": data.get("participants", []), - } + + async with self._http_session.get( + f"http://127.0.0.1:{self._bridge_port}/chat/{chat_id}", + timeout=aiohttp.ClientTimeout(total=10) + ) as resp: + if resp.status == 200: + data = await resp.json() + return { + "name": data.get("name", chat_id), + "type": "group" if data.get("isGroup") else "dm", + "participants": data.get("participants", []), + } except Exception as e: logger.debug("Could not get WhatsApp chat info for %s: %s", chat_id, e) @@ -644,29 +652,26 @@ class WhatsAppAdapter(BasePlatformAdapter): async def _poll_messages(self) -> None: """Poll the bridge for incoming messages.""" - try: - import aiohttp - except ImportError: - print(f"[{self.name}] aiohttp not installed, message polling disabled") - return - + import aiohttp + while self._running: + if not self._http_session: + break bridge_exit = await self._check_managed_bridge_exit() if bridge_exit: print(f"[{self.name}] {bridge_exit}") break try: - async with aiohttp.ClientSession() as session: - async with session.get( - f"http://127.0.0.1:{self._bridge_port}/messages", - timeout=aiohttp.ClientTimeout(total=30) - ) as resp: - if resp.status == 200: - messages = await resp.json() - for msg_data in messages: - event = await self._build_message_event(msg_data) - if event: - await self.handle_message(event) + async with self._http_session.get( + f"http://127.0.0.1:{self._bridge_port}/messages", + timeout=aiohttp.ClientTimeout(total=30) + ) as resp: + if resp.status == 200: + messages = await resp.json() + for msg_data in messages: + event = await self._build_message_event(msg_data) + if event: + await self.handle_message(event) except asyncio.CancelledError: break except Exception as e: diff --git a/tests/gateway/test_whatsapp_connect.py b/tests/gateway/test_whatsapp_connect.py index 7a2126bb8..61ff8f361 100644 --- a/tests/gateway/test_whatsapp_connect.py +++ b/tests/gateway/test_whatsapp_connect.py @@ -63,6 +63,7 @@ def _make_adapter(): adapter._background_tasks = set() adapter._auto_tts_disabled_chats = set() adapter._message_queue = asyncio.Queue() + adapter._http_session = None return adapter @@ -219,6 +220,7 @@ class TestBridgeRuntimeFailure: fatal_handler = AsyncMock() adapter.set_fatal_error_handler(fatal_handler) adapter._running = True + adapter._http_session = MagicMock() # Persistent session active mock_fh = MagicMock() adapter._bridge_log_fh = mock_fh @@ -242,6 +244,7 @@ class TestBridgeRuntimeFailure: fatal_handler = AsyncMock() adapter.set_fatal_error_handler(fatal_handler) adapter._running = True + adapter._http_session = MagicMock() # Persistent session active mock_fh = MagicMock() adapter._bridge_log_fh = mock_fh @@ -417,3 +420,83 @@ class TestKillPortProcess: with patch("gateway.platforms.whatsapp._IS_WINDOWS", True), \ patch("gateway.platforms.whatsapp.subprocess.run", side_effect=OSError("no netstat")): _kill_port_process(3000) # must not raise + + +# --------------------------------------------------------------------------- +# Persistent HTTP session lifecycle +# --------------------------------------------------------------------------- + +class TestHttpSessionLifecycle: + """Verify persistent aiohttp.ClientSession is created and cleaned up.""" + + @pytest.mark.asyncio + async def test_session_closed_on_disconnect(self): + """disconnect() should close self._http_session.""" + adapter = _make_adapter() + mock_session = AsyncMock() + mock_session.closed = False + adapter._http_session = mock_session + adapter._poll_task = None + adapter._bridge_process = None + adapter._running = True + adapter._session_lock_identity = None + + await adapter.disconnect() + + mock_session.close.assert_called_once() + assert adapter._http_session is None + + @pytest.mark.asyncio + async def test_session_not_closed_when_already_closed(self): + """disconnect() should skip close() when session is already closed.""" + adapter = _make_adapter() + mock_session = AsyncMock() + mock_session.closed = True + adapter._http_session = mock_session + adapter._poll_task = None + adapter._bridge_process = None + adapter._running = True + adapter._session_lock_identity = None + + await adapter.disconnect() + + mock_session.close.assert_not_called() + assert adapter._http_session is None + + @pytest.mark.asyncio + async def test_poll_task_cancelled_on_disconnect(self): + """disconnect() should cancel the poll task.""" + adapter = _make_adapter() + mock_task = MagicMock() + mock_task.done.return_value = False + mock_task.cancel = MagicMock() + mock_future = asyncio.Future() + mock_future.set_exception(asyncio.CancelledError()) + mock_task.__await__ = mock_future.__await__ + adapter._poll_task = mock_task + adapter._http_session = None + adapter._bridge_process = None + adapter._running = True + adapter._session_lock_identity = None + + await adapter.disconnect() + + mock_task.cancel.assert_called_once() + assert adapter._poll_task is None + + @pytest.mark.asyncio + async def test_disconnect_skips_done_poll_task(self): + """disconnect() should not cancel an already-done poll task.""" + adapter = _make_adapter() + mock_task = MagicMock() + mock_task.done.return_value = True + adapter._poll_task = mock_task + adapter._http_session = None + adapter._bridge_process = None + adapter._running = True + adapter._session_lock_identity = None + + await adapter.disconnect() + + mock_task.cancel.assert_not_called() + assert adapter._poll_task is None