diff --git a/gateway/platforms/matrix.py b/gateway/platforms/matrix.py index 309baeee7..c9bcd945a 100644 --- a/gateway/platforms/matrix.py +++ b/gateway/platforms/matrix.py @@ -49,6 +49,14 @@ _STORE_DIR = _get_hermes_dir("platforms/matrix/store", "matrix/store") # Grace period: ignore messages older than this many seconds before startup. _STARTUP_GRACE_SECONDS = 5 +# E2EE key export file for persistence across restarts. +_KEY_EXPORT_FILE = _STORE_DIR / "exported_keys.txt" +_KEY_EXPORT_PASSPHRASE = "hermes-matrix-e2ee-keys" + +# Pending undecrypted events: cap and TTL for retry buffer. +_MAX_PENDING_EVENTS = 100 +_PENDING_EVENT_TTL = 300 # seconds — stop retrying after 5 min + def check_matrix_requirements() -> bool: """Return True if the Matrix adapter can be used.""" @@ -111,6 +119,10 @@ class MatrixAdapter(BasePlatformAdapter): self._processed_events: deque = deque(maxlen=1000) self._processed_events_set: set = set() + # Buffer for undecrypted events pending key receipt. + # Each entry: (room, event, timestamp) + self._pending_megolm: list = [] + def _is_duplicate_event(self, event_id) -> bool: """Return True if this event was already processed. Tracks the ID otherwise.""" if not event_id: @@ -232,6 +244,16 @@ class MatrixAdapter(BasePlatformAdapter): logger.info("Matrix: E2EE crypto initialized") except Exception as exc: logger.warning("Matrix: crypto init issue: %s", exc) + + # Import previously exported Megolm keys (survives restarts). + if _KEY_EXPORT_FILE.exists(): + try: + await client.import_keys( + str(_KEY_EXPORT_FILE), _KEY_EXPORT_PASSPHRASE, + ) + logger.info("Matrix: imported Megolm keys from backup") + except Exception as exc: + logger.debug("Matrix: could not import keys: %s", exc) elif self._encryption: logger.warning( "Matrix: E2EE requested but crypto store is not loaded; " @@ -286,6 +308,18 @@ class MatrixAdapter(BasePlatformAdapter): except (asyncio.CancelledError, Exception): pass + # Export Megolm keys before closing so the next restart can decrypt + # events that used sessions from this run. + if self._client and self._encryption and getattr(self._client, "olm", None): + try: + _STORE_DIR.mkdir(parents=True, exist_ok=True) + await self._client.export_keys( + str(_KEY_EXPORT_FILE), _KEY_EXPORT_PASSPHRASE, + ) + logger.info("Matrix: exported Megolm keys for next restart") + except Exception as exc: + logger.debug("Matrix: could not export keys on disconnect: %s", exc) + if self._client: await self._client.close() self._client = None @@ -665,17 +699,22 @@ class MatrixAdapter(BasePlatformAdapter): Hermes uses a custom sync loop instead of matrix-nio's sync_forever(), so we need to explicitly drive the key management work that sync_forever() normally handles for encrypted rooms. + + Also auto-trusts all devices (so senders share session keys with us) + and retries decryption for any buffered MegolmEvents. """ client = self._client if not client or not self._encryption or not getattr(client, "olm", None): return + did_query_keys = client.should_query_keys + tasks = [asyncio.create_task(client.send_to_device_messages())] if client.should_upload_keys: tasks.append(asyncio.create_task(client.keys_upload())) - if client.should_query_keys: + if did_query_keys: tasks.append(asyncio.create_task(client.keys_query())) if client.should_claim_keys: @@ -691,6 +730,111 @@ class MatrixAdapter(BasePlatformAdapter): except Exception as exc: logger.warning("Matrix: E2EE maintenance task failed: %s", exc) + # After key queries, auto-trust all devices so senders share keys with + # us. For a bot this is the right default — we want to decrypt + # everything, not enforce manual verification. + if did_query_keys: + self._auto_trust_devices() + + # Retry any buffered undecrypted events now that new keys may have + # arrived (from key requests, key queries, or to-device forwarding). + if self._pending_megolm: + await self._retry_pending_decryptions() + + def _auto_trust_devices(self) -> None: + """Trust/verify all unverified devices we know about. + + When other clients see our device as verified, they proactively share + Megolm session keys with us. Without this, many clients will refuse + to include an unverified device in key distributions. + """ + client = self._client + if not client: + return + + device_store = getattr(client, "device_store", None) + if not device_store: + return + + own_device = getattr(client, "device_id", None) + trusted_count = 0 + + try: + # DeviceStore.__iter__ yields OlmDevice objects directly. + for device in device_store: + if getattr(device, "device_id", None) == own_device: + continue + if not getattr(device, "verified", False): + client.verify_device(device) + trusted_count += 1 + except Exception as exc: + logger.debug("Matrix: auto-trust error: %s", exc) + + if trusted_count: + logger.info("Matrix: auto-trusted %d new device(s)", trusted_count) + + async def _retry_pending_decryptions(self) -> None: + """Retry decrypting buffered MegolmEvents after new keys arrive.""" + import nio + + client = self._client + if not client or not self._pending_megolm: + return + + now = time.time() + still_pending: list = [] + + for room, event, ts in self._pending_megolm: + # Drop events that have aged past the TTL. + if now - ts > _PENDING_EVENT_TTL: + logger.debug( + "Matrix: dropping expired pending event %s (age %.0fs)", + getattr(event, "event_id", "?"), now - ts, + ) + continue + + try: + decrypted = client.decrypt_event(event) + except Exception: + # Still missing the key — keep in buffer. + still_pending.append((room, event, ts)) + continue + + if isinstance(decrypted, nio.MegolmEvent): + # decrypt_event returned the same undecryptable event. + still_pending.append((room, event, ts)) + continue + + logger.info( + "Matrix: decrypted buffered event %s (%s)", + getattr(event, "event_id", "?"), + type(decrypted).__name__, + ) + + # Route to the appropriate handler based on decrypted type. + try: + if isinstance(decrypted, nio.RoomMessageText): + await self._on_room_message(room, decrypted) + elif isinstance( + decrypted, + (nio.RoomMessageImage, nio.RoomMessageAudio, + nio.RoomMessageVideo, nio.RoomMessageFile), + ): + await self._on_room_message_media(room, decrypted) + else: + logger.debug( + "Matrix: decrypted event %s has unhandled type %s", + getattr(event, "event_id", "?"), + type(decrypted).__name__, + ) + except Exception as exc: + logger.warning( + "Matrix: error processing decrypted event %s: %s", + getattr(event, "event_id", "?"), exc, + ) + + self._pending_megolm = still_pending + # ------------------------------------------------------------------ # Event callbacks # ------------------------------------------------------------------ @@ -712,13 +856,29 @@ class MatrixAdapter(BasePlatformAdapter): if event_ts and event_ts < self._startup_ts - _STARTUP_GRACE_SECONDS: return - # Handle decrypted MegolmEvents — extract the inner event. + # Handle undecryptable MegolmEvents: request the missing session key + # and buffer the event for retry once the key arrives. if isinstance(event, nio.MegolmEvent): - # Failed to decrypt. logger.warning( - "Matrix: could not decrypt event %s in %s", + "Matrix: could not decrypt event %s in %s — requesting key", event.event_id, room.room_id, ) + + # Ask other devices in the room to forward the session key. + try: + resp = await self._client.request_room_key(event) + if hasattr(resp, "event_id") or not isinstance(resp, Exception): + logger.debug( + "Matrix: room key request sent for session %s", + getattr(event, "session_id", "?"), + ) + except Exception as exc: + logger.debug("Matrix: room key request failed: %s", exc) + + # Buffer for retry on next maintenance cycle. + self._pending_megolm.append((room, event, time.time())) + if len(self._pending_megolm) > _MAX_PENDING_EVENTS: + self._pending_megolm = self._pending_megolm[-_MAX_PENDING_EVENTS:] return # Skip edits (m.replace relation). diff --git a/tests/gateway/test_matrix.py b/tests/gateway/test_matrix.py index 5a9879f60..9912eef00 100644 --- a/tests/gateway/test_matrix.py +++ b/tests/gateway/test_matrix.py @@ -643,3 +643,353 @@ class TestMatrixEncryptedSendFallback: assert fake_client.room_send.await_count == 2 second_call = fake_client.room_send.await_args_list[1] assert second_call.kwargs.get("ignore_unverified_devices") is True + + +# --------------------------------------------------------------------------- +# E2EE: Auto-trust devices +# --------------------------------------------------------------------------- + +class TestMatrixAutoTrustDevices: + def test_auto_trust_verifies_unverified_devices(self): + adapter = _make_adapter() + + # DeviceStore.__iter__ yields OlmDevice objects directly. + device_a = MagicMock() + device_a.device_id = "DEVICE_A" + device_a.verified = False + device_b = MagicMock() + device_b.device_id = "DEVICE_B" + device_b.verified = True # already trusted + device_c = MagicMock() + device_c.device_id = "DEVICE_C" + device_c.verified = False + + fake_client = MagicMock() + fake_client.device_id = "OWN_DEVICE" + fake_client.verify_device = MagicMock() + + # Simulate DeviceStore iteration (yields OlmDevice objects) + fake_client.device_store = MagicMock() + fake_client.device_store.__iter__ = MagicMock( + return_value=iter([device_a, device_b, device_c]) + ) + + adapter._client = fake_client + adapter._auto_trust_devices() + + # Should have verified device_a and device_c (not device_b, already verified) + assert fake_client.verify_device.call_count == 2 + verified_devices = [call.args[0] for call in fake_client.verify_device.call_args_list] + assert device_a in verified_devices + assert device_c in verified_devices + assert device_b not in verified_devices + + def test_auto_trust_skips_own_device(self): + adapter = _make_adapter() + + own_device = MagicMock() + own_device.device_id = "MY_DEVICE" + own_device.verified = False + + fake_client = MagicMock() + fake_client.device_id = "MY_DEVICE" + fake_client.verify_device = MagicMock() + + fake_client.device_store = MagicMock() + fake_client.device_store.__iter__ = MagicMock( + return_value=iter([own_device]) + ) + + adapter._client = fake_client + adapter._auto_trust_devices() + + fake_client.verify_device.assert_not_called() + + def test_auto_trust_handles_missing_device_store(self): + adapter = _make_adapter() + fake_client = MagicMock(spec=[]) # empty spec — no attributes + adapter._client = fake_client + # Should not raise + adapter._auto_trust_devices() + + +# --------------------------------------------------------------------------- +# E2EE: MegolmEvent key request + buffering +# --------------------------------------------------------------------------- + +class TestMatrixMegolmEventHandling: + @pytest.mark.asyncio + async def test_megolm_event_requests_room_key_and_buffers(self): + adapter = _make_adapter() + adapter._user_id = "@bot:example.org" + adapter._startup_ts = 0.0 + adapter._dm_rooms = {} + + fake_megolm = MagicMock() + fake_megolm.sender = "@alice:example.org" + fake_megolm.event_id = "$encrypted_event" + fake_megolm.server_timestamp = 9999999999000 # future + fake_megolm.session_id = "SESSION123" + + fake_room = MagicMock() + fake_room.room_id = "!room:example.org" + + fake_client = MagicMock() + fake_client.request_room_key = AsyncMock(return_value=MagicMock()) + adapter._client = fake_client + + # Create a MegolmEvent class for isinstance check + fake_nio = MagicMock() + FakeMegolmEvent = type("MegolmEvent", (), {}) + fake_megolm.__class__ = FakeMegolmEvent + fake_nio.MegolmEvent = FakeMegolmEvent + + with patch.dict("sys.modules", {"nio": fake_nio}): + await adapter._on_room_message(fake_room, fake_megolm) + + # Should have requested the room key + fake_client.request_room_key.assert_awaited_once_with(fake_megolm) + + # Should have buffered the event + assert len(adapter._pending_megolm) == 1 + room, event, ts = adapter._pending_megolm[0] + assert room is fake_room + assert event is fake_megolm + + @pytest.mark.asyncio + async def test_megolm_buffer_capped(self): + adapter = _make_adapter() + adapter._user_id = "@bot:example.org" + adapter._startup_ts = 0.0 + adapter._dm_rooms = {} + + fake_client = MagicMock() + fake_client.request_room_key = AsyncMock(return_value=MagicMock()) + adapter._client = fake_client + + FakeMegolmEvent = type("MegolmEvent", (), {}) + fake_nio = MagicMock() + fake_nio.MegolmEvent = FakeMegolmEvent + + # Fill the buffer past max + from gateway.platforms.matrix import _MAX_PENDING_EVENTS + with patch.dict("sys.modules", {"nio": fake_nio}): + for i in range(_MAX_PENDING_EVENTS + 10): + evt = MagicMock() + evt.__class__ = FakeMegolmEvent + evt.sender = "@alice:example.org" + evt.event_id = f"$event_{i}" + evt.server_timestamp = 9999999999000 + evt.session_id = f"SESSION_{i}" + room = MagicMock() + room.room_id = "!room:example.org" + await adapter._on_room_message(room, evt) + + assert len(adapter._pending_megolm) == _MAX_PENDING_EVENTS + + +# --------------------------------------------------------------------------- +# E2EE: Retry pending decryptions +# --------------------------------------------------------------------------- + +class TestMatrixRetryPendingDecryptions: + @pytest.mark.asyncio + async def test_successful_decryption_routes_to_text_handler(self): + import time as _time + + adapter = _make_adapter() + adapter._user_id = "@bot:example.org" + adapter._startup_ts = 0.0 + adapter._dm_rooms = {} + + # Create types + FakeMegolmEvent = type("MegolmEvent", (), {}) + FakeRoomMessageText = type("RoomMessageText", (), {}) + + decrypted_event = MagicMock() + decrypted_event.__class__ = FakeRoomMessageText + + fake_megolm = MagicMock() + fake_megolm.__class__ = FakeMegolmEvent + fake_megolm.event_id = "$encrypted" + + fake_room = MagicMock() + now = _time.time() + + adapter._pending_megolm = [(fake_room, fake_megolm, now)] + + fake_client = MagicMock() + fake_client.decrypt_event = MagicMock(return_value=decrypted_event) + adapter._client = fake_client + + fake_nio = MagicMock() + fake_nio.MegolmEvent = FakeMegolmEvent + fake_nio.RoomMessageText = FakeRoomMessageText + fake_nio.RoomMessageImage = type("RoomMessageImage", (), {}) + fake_nio.RoomMessageAudio = type("RoomMessageAudio", (), {}) + fake_nio.RoomMessageVideo = type("RoomMessageVideo", (), {}) + fake_nio.RoomMessageFile = type("RoomMessageFile", (), {}) + + with patch.dict("sys.modules", {"nio": fake_nio}): + with patch.object(adapter, "_on_room_message", AsyncMock()) as mock_handler: + await adapter._retry_pending_decryptions() + mock_handler.assert_awaited_once_with(fake_room, decrypted_event) + + # Buffer should be empty now + assert len(adapter._pending_megolm) == 0 + + @pytest.mark.asyncio + async def test_still_undecryptable_stays_in_buffer(self): + import time as _time + + adapter = _make_adapter() + + FakeMegolmEvent = type("MegolmEvent", (), {}) + + fake_megolm = MagicMock() + fake_megolm.__class__ = FakeMegolmEvent + fake_megolm.event_id = "$still_encrypted" + + now = _time.time() + adapter._pending_megolm = [(MagicMock(), fake_megolm, now)] + + fake_client = MagicMock() + # decrypt_event raises when key is still missing + fake_client.decrypt_event = MagicMock(side_effect=Exception("missing key")) + adapter._client = fake_client + + fake_nio = MagicMock() + fake_nio.MegolmEvent = FakeMegolmEvent + + with patch.dict("sys.modules", {"nio": fake_nio}): + await adapter._retry_pending_decryptions() + + assert len(adapter._pending_megolm) == 1 + + @pytest.mark.asyncio + async def test_expired_events_dropped(self): + import time as _time + + adapter = _make_adapter() + + from gateway.platforms.matrix import _PENDING_EVENT_TTL + + fake_megolm = MagicMock() + fake_megolm.event_id = "$old_event" + fake_megolm.__class__ = type("MegolmEvent", (), {}) + + # Timestamp well past TTL + old_ts = _time.time() - _PENDING_EVENT_TTL - 60 + adapter._pending_megolm = [(MagicMock(), fake_megolm, old_ts)] + + fake_client = MagicMock() + adapter._client = fake_client + + fake_nio = MagicMock() + fake_nio.MegolmEvent = type("MegolmEvent", (), {}) + + with patch.dict("sys.modules", {"nio": fake_nio}): + await adapter._retry_pending_decryptions() + + # Should have been dropped + assert len(adapter._pending_megolm) == 0 + # Should NOT have tried to decrypt + fake_client.decrypt_event.assert_not_called() + + @pytest.mark.asyncio + async def test_media_event_routes_to_media_handler(self): + import time as _time + + adapter = _make_adapter() + adapter._user_id = "@bot:example.org" + adapter._startup_ts = 0.0 + + FakeMegolmEvent = type("MegolmEvent", (), {}) + FakeRoomMessageImage = type("RoomMessageImage", (), {}) + + decrypted_image = MagicMock() + decrypted_image.__class__ = FakeRoomMessageImage + + fake_megolm = MagicMock() + fake_megolm.__class__ = FakeMegolmEvent + fake_megolm.event_id = "$encrypted_image" + + fake_room = MagicMock() + now = _time.time() + adapter._pending_megolm = [(fake_room, fake_megolm, now)] + + fake_client = MagicMock() + fake_client.decrypt_event = MagicMock(return_value=decrypted_image) + adapter._client = fake_client + + fake_nio = MagicMock() + fake_nio.MegolmEvent = FakeMegolmEvent + fake_nio.RoomMessageText = type("RoomMessageText", (), {}) + fake_nio.RoomMessageImage = FakeRoomMessageImage + fake_nio.RoomMessageAudio = type("RoomMessageAudio", (), {}) + fake_nio.RoomMessageVideo = type("RoomMessageVideo", (), {}) + fake_nio.RoomMessageFile = type("RoomMessageFile", (), {}) + + with patch.dict("sys.modules", {"nio": fake_nio}): + with patch.object(adapter, "_on_room_message_media", AsyncMock()) as mock_media: + await adapter._retry_pending_decryptions() + mock_media.assert_awaited_once_with(fake_room, decrypted_image) + + assert len(adapter._pending_megolm) == 0 + + +# --------------------------------------------------------------------------- +# E2EE: Key export / import +# --------------------------------------------------------------------------- + +class TestMatrixKeyExportImport: + @pytest.mark.asyncio + async def test_disconnect_exports_keys(self): + adapter = _make_adapter() + adapter._encryption = True + adapter._sync_task = None + + fake_client = MagicMock() + fake_client.olm = object() + fake_client.export_keys = AsyncMock() + fake_client.close = AsyncMock() + adapter._client = fake_client + + from gateway.platforms.matrix import _KEY_EXPORT_FILE, _KEY_EXPORT_PASSPHRASE + + await adapter.disconnect() + + fake_client.export_keys.assert_awaited_once_with( + str(_KEY_EXPORT_FILE), _KEY_EXPORT_PASSPHRASE, + ) + + @pytest.mark.asyncio + async def test_disconnect_handles_export_failure(self): + adapter = _make_adapter() + adapter._encryption = True + adapter._sync_task = None + + fake_client = MagicMock() + fake_client.olm = object() + fake_client.export_keys = AsyncMock(side_effect=Exception("export failed")) + fake_client.close = AsyncMock() + adapter._client = fake_client + + # Should not raise + await adapter.disconnect() + assert adapter._client is None # still cleaned up + + @pytest.mark.asyncio + async def test_disconnect_skips_export_when_no_encryption(self): + adapter = _make_adapter() + adapter._encryption = False + adapter._sync_task = None + + fake_client = MagicMock() + fake_client.close = AsyncMock() + adapter._client = fake_client + + await adapter.disconnect() + # Should not have tried to export + assert not hasattr(fake_client, "export_keys") or \ + not fake_client.export_keys.called