From be322efdf2a00cb66d6fff91fe8f6b4f5f978ccb Mon Sep 17 00:00:00 2001 From: Teknium <127238744+teknium1@users.noreply.github.com> Date: Sat, 28 Mar 2026 12:13:35 -0700 Subject: [PATCH] fix(matrix): harden e2ee access-token handling (#3562) * fix(matrix): harden e2ee access-token handling * fix: patch nio mock in e2ee maintenance sync loop test The sync_loop now imports nio for SyncError checking (from PR #3280), so the test needs to inject a fake nio module via sys.modules. --------- Co-authored-by: Cortana --- gateway/platforms/matrix.py | 143 +++++++++++++++++++++---- tests/gateway/test_matrix.py | 197 +++++++++++++++++++++++++++++++++++ 2 files changed, 320 insertions(+), 20 deletions(-) diff --git a/gateway/platforms/matrix.py b/gateway/platforms/matrix.py index 79ac82396..b12d3b97c 100644 --- a/gateway/platforms/matrix.py +++ b/gateway/platforms/matrix.py @@ -161,22 +161,49 @@ class MatrixAdapter(BasePlatformAdapter): # Authenticate. if self._access_token: client.access_token = self._access_token - # Resolve user_id if not set. - if not self._user_id: - resp = await client.whoami() - if isinstance(resp, nio.WhoamiResponse): - self._user_id = resp.user_id - client.user_id = resp.user_id - logger.info("Matrix: authenticated as %s", self._user_id) - else: - logger.error( - "Matrix: whoami failed — check MATRIX_ACCESS_TOKEN and MATRIX_HOMESERVER" + + # With access-token auth, always resolve whoami so we validate the + # token and learn the device_id. The device_id matters for E2EE: + # without it, matrix-nio can send plain messages but may fail to + # decrypt inbound encrypted events or encrypt outbound room sends. + resp = await client.whoami() + if isinstance(resp, nio.WhoamiResponse): + resolved_user_id = getattr(resp, "user_id", "") or self._user_id + resolved_device_id = getattr(resp, "device_id", "") + if resolved_user_id: + self._user_id = resolved_user_id + + # restore_login() is the matrix-nio path that binds the access + # token to a specific device and loads the crypto store. + if resolved_device_id and hasattr(client, "restore_login"): + client.restore_login( + self._user_id or resolved_user_id, + resolved_device_id, + self._access_token, ) - await client.close() - return False + else: + if self._user_id: + client.user_id = self._user_id + if resolved_device_id: + client.device_id = resolved_device_id + client.access_token = self._access_token + if self._encryption: + logger.warning( + "Matrix: access-token login did not restore E2EE state; " + "encrypted rooms may fail until a device_id is available" + ) + + logger.info( + "Matrix: using access token for %s%s", + self._user_id or "(unknown user)", + f" (device {resolved_device_id})" if resolved_device_id else "", + ) else: - client.user_id = self._user_id - logger.info("Matrix: using access token for %s", self._user_id) + logger.error( + "Matrix: whoami failed — check MATRIX_ACCESS_TOKEN and MATRIX_HOMESERVER" + ) + await client.close() + return False elif self._password and self._user_id: resp = await client.login( self._password, @@ -194,13 +221,18 @@ class MatrixAdapter(BasePlatformAdapter): return False # If E2EE is enabled, load the crypto store. - if self._encryption and hasattr(client, "olm"): + if self._encryption and getattr(client, "olm", None): try: if client.should_upload_keys: await client.keys_upload() logger.info("Matrix: E2EE crypto initialized") except Exception as exc: logger.warning("Matrix: crypto init issue: %s", exc) + elif self._encryption: + logger.warning( + "Matrix: E2EE requested but crypto store is not loaded; " + "encrypted rooms may fail" + ) # Register event callbacks. client.add_event_callback(self._on_room_message, nio.RoomMessageText) @@ -230,6 +262,7 @@ class MatrixAdapter(BasePlatformAdapter): ) # Build DM room cache from m.direct account data. await self._refresh_dm_cache() + await self._run_e2ee_maintenance() else: logger.warning("Matrix: initial sync returned %s", type(resp).__name__) @@ -301,13 +334,48 @@ class MatrixAdapter(BasePlatformAdapter): relates_to["m.in_reply_to"] = {"event_id": reply_to} msg_content["m.relates_to"] = relates_to - resp = await self._client.room_send( - chat_id, - "m.room.message", - msg_content, - ) + async def _room_send_once(*, ignore_unverified_devices: bool = False): + return await asyncio.wait_for( + self._client.room_send( + chat_id, + "m.room.message", + msg_content, + ignore_unverified_devices=ignore_unverified_devices, + ), + timeout=45, + ) + + try: + resp = await _room_send_once(ignore_unverified_devices=False) + except Exception as exc: + retryable = isinstance(exc, asyncio.TimeoutError) + olm_unverified = getattr(nio, "OlmUnverifiedDeviceError", None) + send_retry = getattr(nio, "SendRetryError", None) + if isinstance(olm_unverified, type) and isinstance(exc, olm_unverified): + retryable = True + if isinstance(send_retry, type) and isinstance(exc, send_retry): + retryable = True + + if not retryable: + logger.error("Matrix: failed to send to %s: %s", chat_id, exc) + return SendResult(success=False, error=str(exc)) + + logger.warning( + "Matrix: initial encrypted send to %s failed (%s); " + "retrying after E2EE maintenance with ignored unverified devices", + chat_id, + exc, + ) + await self._run_e2ee_maintenance() + try: + resp = await _room_send_once(ignore_unverified_devices=True) + except Exception as retry_exc: + logger.error("Matrix: failed to send to %s after retry: %s", chat_id, retry_exc) + return SendResult(success=False, error=str(retry_exc)) + if isinstance(resp, nio.RoomSendResponse): last_event_id = resp.event_id + logger.info("Matrix: sent event %s to %s", last_event_id, chat_id) else: err = getattr(resp, "message", str(resp)) logger.error("Matrix: failed to send to %s: %s", chat_id, err) @@ -565,6 +633,9 @@ class MatrixAdapter(BasePlatformAdapter): getattr(resp, "message", resp), ) await asyncio.sleep(5) + continue + + await self._run_e2ee_maintenance() except asyncio.CancelledError: return except Exception as exc: @@ -573,6 +644,38 @@ class MatrixAdapter(BasePlatformAdapter): logger.warning("Matrix: sync error: %s — retrying in 5s", exc) await asyncio.sleep(5) + async def _run_e2ee_maintenance(self) -> None: + """Run matrix-nio E2EE housekeeping between syncs. + + Hermes uses a custom sync loop instead of matrix-nio's sync_forever(), + so we need to explicitly drive the key management work that sync_forever() + normally handles for encrypted rooms. + """ + client = self._client + if not client or not self._encryption or not getattr(client, "olm", None): + return + + tasks = [asyncio.create_task(client.send_to_device_messages())] + + if client.should_upload_keys: + tasks.append(asyncio.create_task(client.keys_upload())) + + if client.should_query_keys: + tasks.append(asyncio.create_task(client.keys_query())) + + if client.should_claim_keys: + users = client.get_users_for_key_claiming() + if users: + tasks.append(asyncio.create_task(client.keys_claim(users))) + + for task in asyncio.as_completed(tasks): + try: + await task + except asyncio.CancelledError: + raise + except Exception as exc: + logger.warning("Matrix: E2EE maintenance task failed: %s", exc) + # ------------------------------------------------------------------ # Event callbacks # ------------------------------------------------------------------ diff --git a/tests/gateway/test_matrix.py b/tests/gateway/test_matrix.py index 31e59caeb..5a9879f60 100644 --- a/tests/gateway/test_matrix.py +++ b/tests/gateway/test_matrix.py @@ -1,4 +1,5 @@ """Tests for Matrix platform adapter.""" +import asyncio import json import re import pytest @@ -446,3 +447,199 @@ class TestMatrixRequirements: monkeypatch.delenv("MATRIX_HOMESERVER", raising=False) from gateway.platforms.matrix import check_matrix_requirements assert check_matrix_requirements() is False + + +# --------------------------------------------------------------------------- +# Access-token auth / E2EE bootstrap +# --------------------------------------------------------------------------- + +class TestMatrixAccessTokenAuth: + @pytest.mark.asyncio + async def test_connect_fetches_device_id_from_whoami_for_access_token(self): + from gateway.platforms.matrix import MatrixAdapter + + config = PlatformConfig( + enabled=True, + token="syt_test_access_token", + extra={ + "homeserver": "https://matrix.example.org", + "user_id": "@bot:example.org", + "encryption": True, + }, + ) + adapter = MatrixAdapter(config) + + class FakeWhoamiResponse: + def __init__(self, user_id, device_id): + self.user_id = user_id + self.device_id = device_id + + class FakeSyncResponse: + def __init__(self): + self.rooms = MagicMock(join={}) + + fake_client = MagicMock() + fake_client.whoami = AsyncMock(return_value=FakeWhoamiResponse("@bot:example.org", "DEV123")) + fake_client.sync = AsyncMock(return_value=FakeSyncResponse()) + fake_client.keys_upload = AsyncMock() + fake_client.keys_query = AsyncMock() + fake_client.keys_claim = AsyncMock() + fake_client.send_to_device_messages = AsyncMock(return_value=[]) + fake_client.get_users_for_key_claiming = MagicMock(return_value={}) + fake_client.close = AsyncMock() + fake_client.add_event_callback = MagicMock() + fake_client.rooms = {} + fake_client.account_data = {} + fake_client.olm = object() + fake_client.should_upload_keys = False + fake_client.should_query_keys = False + fake_client.should_claim_keys = False + + def _restore_login(user_id, device_id, access_token): + fake_client.user_id = user_id + fake_client.device_id = device_id + fake_client.access_token = access_token + fake_client.olm = object() + + fake_client.restore_login = MagicMock(side_effect=_restore_login) + + fake_nio = MagicMock() + fake_nio.AsyncClient = MagicMock(return_value=fake_client) + fake_nio.WhoamiResponse = FakeWhoamiResponse + fake_nio.SyncResponse = FakeSyncResponse + fake_nio.LoginResponse = type("LoginResponse", (), {}) + fake_nio.RoomMessageText = type("RoomMessageText", (), {}) + fake_nio.RoomMessageImage = type("RoomMessageImage", (), {}) + fake_nio.RoomMessageAudio = type("RoomMessageAudio", (), {}) + fake_nio.RoomMessageVideo = type("RoomMessageVideo", (), {}) + fake_nio.RoomMessageFile = type("RoomMessageFile", (), {}) + fake_nio.InviteMemberEvent = type("InviteMemberEvent", (), {}) + fake_nio.MegolmEvent = type("MegolmEvent", (), {}) + + with patch.dict("sys.modules", {"nio": fake_nio}): + with patch.object(adapter, "_refresh_dm_cache", AsyncMock()): + with patch.object(adapter, "_sync_loop", AsyncMock(return_value=None)): + assert await adapter.connect() is True + + fake_client.restore_login.assert_called_once_with( + "@bot:example.org", "DEV123", "syt_test_access_token" + ) + assert fake_client.access_token == "syt_test_access_token" + assert fake_client.user_id == "@bot:example.org" + assert fake_client.device_id == "DEV123" + fake_client.whoami.assert_awaited_once() + + await adapter.disconnect() + + +class TestMatrixE2EEMaintenance: + @pytest.mark.asyncio + async def test_sync_loop_runs_e2ee_maintenance_requests(self): + adapter = _make_adapter() + adapter._encryption = True + adapter._closing = False + + class FakeSyncError: + pass + + async def _sync_once(timeout=30000): + adapter._closing = True + return MagicMock() + + fake_client = MagicMock() + fake_client.sync = AsyncMock(side_effect=_sync_once) + fake_client.send_to_device_messages = AsyncMock(return_value=[]) + fake_client.keys_upload = AsyncMock() + fake_client.keys_query = AsyncMock() + fake_client.get_users_for_key_claiming = MagicMock( + return_value={"@alice:example.org": ["DEVICE1"]} + ) + fake_client.keys_claim = AsyncMock() + fake_client.olm = object() + fake_client.should_upload_keys = True + fake_client.should_query_keys = True + fake_client.should_claim_keys = True + + adapter._client = fake_client + + fake_nio = MagicMock() + fake_nio.SyncError = FakeSyncError + + with patch.dict("sys.modules", {"nio": fake_nio}): + await adapter._sync_loop() + + fake_client.sync.assert_awaited_once_with(timeout=30000) + fake_client.send_to_device_messages.assert_awaited_once() + fake_client.keys_upload.assert_awaited_once() + fake_client.keys_query.assert_awaited_once() + fake_client.keys_claim.assert_awaited_once_with( + {"@alice:example.org": ["DEVICE1"]} + ) + + +class TestMatrixEncryptedSendFallback: + @pytest.mark.asyncio + async def test_send_retries_with_ignored_unverified_devices(self): + adapter = _make_adapter() + adapter._encryption = True + + class FakeRoomSendResponse: + def __init__(self, event_id): + self.event_id = event_id + + class FakeOlmUnverifiedDeviceError(Exception): + pass + + fake_client = MagicMock() + fake_client.room_send = AsyncMock(side_effect=[ + FakeOlmUnverifiedDeviceError("unverified"), + FakeRoomSendResponse("$event123"), + ]) + adapter._client = fake_client + adapter._run_e2ee_maintenance = AsyncMock() + + fake_nio = MagicMock() + fake_nio.RoomSendResponse = FakeRoomSendResponse + fake_nio.OlmUnverifiedDeviceError = FakeOlmUnverifiedDeviceError + + with patch.dict("sys.modules", {"nio": fake_nio}): + result = await adapter.send("!room:example.org", "hello") + + assert result.success is True + assert result.message_id == "$event123" + adapter._run_e2ee_maintenance.assert_awaited_once() + assert fake_client.room_send.await_count == 2 + first_call = fake_client.room_send.await_args_list[0] + second_call = fake_client.room_send.await_args_list[1] + assert first_call.kwargs.get("ignore_unverified_devices") is False + assert second_call.kwargs.get("ignore_unverified_devices") is True + + @pytest.mark.asyncio + async def test_send_retries_after_timeout_in_encrypted_room(self): + adapter = _make_adapter() + adapter._encryption = True + + class FakeRoomSendResponse: + def __init__(self, event_id): + self.event_id = event_id + + fake_client = MagicMock() + fake_client.room_send = AsyncMock(side_effect=[ + asyncio.TimeoutError(), + FakeRoomSendResponse("$event456"), + ]) + adapter._client = fake_client + adapter._run_e2ee_maintenance = AsyncMock() + + fake_nio = MagicMock() + fake_nio.RoomSendResponse = FakeRoomSendResponse + + with patch.dict("sys.modules", {"nio": fake_nio}): + result = await adapter.send("!room:example.org", "hello") + + assert result.success is True + assert result.message_id == "$event456" + adapter._run_e2ee_maintenance.assert_awaited_once() + assert fake_client.room_send.await_count == 2 + second_call = fake_client.room_send.await_args_list[1] + assert second_call.kwargs.get("ignore_unverified_devices") is True