fix(matrix): E2EE decryption — request keys, auto-trust devices, retry buffered events (#4083)

When the Matrix adapter receives encrypted events it can't decrypt
(MegolmEvent), it now:

1. Requests the missing room key from other devices via
   client.request_room_key(event) instead of silently dropping the message

2. Buffers undecrypted events (bounded to 100, 5 min TTL) and retries
   decryption after each E2EE maintenance cycle when new keys arrive

3. Auto-trusts/verifies all devices after key queries so other clients
   share session keys with the bot proactively

4. Exports Megolm keys on disconnect and imports them on connect, so
   session keys survive gateway restarts

This addresses the 'could not decrypt event' warnings that caused the
bot to miss messages in encrypted rooms.
This commit is contained in:
Teknium
2026-03-30 17:16:09 -07:00
committed by GitHub
parent 7e0c2c3ce3
commit 07746dca0c
2 changed files with 514 additions and 4 deletions

View File

@@ -49,6 +49,14 @@ _STORE_DIR = _get_hermes_dir("platforms/matrix/store", "matrix/store")
# Grace period: ignore messages older than this many seconds before startup.
_STARTUP_GRACE_SECONDS = 5
# E2EE key export file for persistence across restarts.
_KEY_EXPORT_FILE = _STORE_DIR / "exported_keys.txt"
_KEY_EXPORT_PASSPHRASE = "hermes-matrix-e2ee-keys"
# Pending undecrypted events: cap and TTL for retry buffer.
_MAX_PENDING_EVENTS = 100
_PENDING_EVENT_TTL = 300 # seconds — stop retrying after 5 min
def check_matrix_requirements() -> bool:
"""Return True if the Matrix adapter can be used."""
@@ -111,6 +119,10 @@ class MatrixAdapter(BasePlatformAdapter):
self._processed_events: deque = deque(maxlen=1000)
self._processed_events_set: set = set()
# Buffer for undecrypted events pending key receipt.
# Each entry: (room, event, timestamp)
self._pending_megolm: list = []
def _is_duplicate_event(self, event_id) -> bool:
"""Return True if this event was already processed. Tracks the ID otherwise."""
if not event_id:
@@ -232,6 +244,16 @@ class MatrixAdapter(BasePlatformAdapter):
logger.info("Matrix: E2EE crypto initialized")
except Exception as exc:
logger.warning("Matrix: crypto init issue: %s", exc)
# Import previously exported Megolm keys (survives restarts).
if _KEY_EXPORT_FILE.exists():
try:
await client.import_keys(
str(_KEY_EXPORT_FILE), _KEY_EXPORT_PASSPHRASE,
)
logger.info("Matrix: imported Megolm keys from backup")
except Exception as exc:
logger.debug("Matrix: could not import keys: %s", exc)
elif self._encryption:
logger.warning(
"Matrix: E2EE requested but crypto store is not loaded; "
@@ -286,6 +308,18 @@ class MatrixAdapter(BasePlatformAdapter):
except (asyncio.CancelledError, Exception):
pass
# Export Megolm keys before closing so the next restart can decrypt
# events that used sessions from this run.
if self._client and self._encryption and getattr(self._client, "olm", None):
try:
_STORE_DIR.mkdir(parents=True, exist_ok=True)
await self._client.export_keys(
str(_KEY_EXPORT_FILE), _KEY_EXPORT_PASSPHRASE,
)
logger.info("Matrix: exported Megolm keys for next restart")
except Exception as exc:
logger.debug("Matrix: could not export keys on disconnect: %s", exc)
if self._client:
await self._client.close()
self._client = None
@@ -665,17 +699,22 @@ class MatrixAdapter(BasePlatformAdapter):
Hermes uses a custom sync loop instead of matrix-nio's sync_forever(),
so we need to explicitly drive the key management work that sync_forever()
normally handles for encrypted rooms.
Also auto-trusts all devices (so senders share session keys with us)
and retries decryption for any buffered MegolmEvents.
"""
client = self._client
if not client or not self._encryption or not getattr(client, "olm", None):
return
did_query_keys = client.should_query_keys
tasks = [asyncio.create_task(client.send_to_device_messages())]
if client.should_upload_keys:
tasks.append(asyncio.create_task(client.keys_upload()))
if client.should_query_keys:
if did_query_keys:
tasks.append(asyncio.create_task(client.keys_query()))
if client.should_claim_keys:
@@ -691,6 +730,111 @@ class MatrixAdapter(BasePlatformAdapter):
except Exception as exc:
logger.warning("Matrix: E2EE maintenance task failed: %s", exc)
# After key queries, auto-trust all devices so senders share keys with
# us. For a bot this is the right default — we want to decrypt
# everything, not enforce manual verification.
if did_query_keys:
self._auto_trust_devices()
# Retry any buffered undecrypted events now that new keys may have
# arrived (from key requests, key queries, or to-device forwarding).
if self._pending_megolm:
await self._retry_pending_decryptions()
def _auto_trust_devices(self) -> None:
"""Trust/verify all unverified devices we know about.
When other clients see our device as verified, they proactively share
Megolm session keys with us. Without this, many clients will refuse
to include an unverified device in key distributions.
"""
client = self._client
if not client:
return
device_store = getattr(client, "device_store", None)
if not device_store:
return
own_device = getattr(client, "device_id", None)
trusted_count = 0
try:
# DeviceStore.__iter__ yields OlmDevice objects directly.
for device in device_store:
if getattr(device, "device_id", None) == own_device:
continue
if not getattr(device, "verified", False):
client.verify_device(device)
trusted_count += 1
except Exception as exc:
logger.debug("Matrix: auto-trust error: %s", exc)
if trusted_count:
logger.info("Matrix: auto-trusted %d new device(s)", trusted_count)
async def _retry_pending_decryptions(self) -> None:
"""Retry decrypting buffered MegolmEvents after new keys arrive."""
import nio
client = self._client
if not client or not self._pending_megolm:
return
now = time.time()
still_pending: list = []
for room, event, ts in self._pending_megolm:
# Drop events that have aged past the TTL.
if now - ts > _PENDING_EVENT_TTL:
logger.debug(
"Matrix: dropping expired pending event %s (age %.0fs)",
getattr(event, "event_id", "?"), now - ts,
)
continue
try:
decrypted = client.decrypt_event(event)
except Exception:
# Still missing the key — keep in buffer.
still_pending.append((room, event, ts))
continue
if isinstance(decrypted, nio.MegolmEvent):
# decrypt_event returned the same undecryptable event.
still_pending.append((room, event, ts))
continue
logger.info(
"Matrix: decrypted buffered event %s (%s)",
getattr(event, "event_id", "?"),
type(decrypted).__name__,
)
# Route to the appropriate handler based on decrypted type.
try:
if isinstance(decrypted, nio.RoomMessageText):
await self._on_room_message(room, decrypted)
elif isinstance(
decrypted,
(nio.RoomMessageImage, nio.RoomMessageAudio,
nio.RoomMessageVideo, nio.RoomMessageFile),
):
await self._on_room_message_media(room, decrypted)
else:
logger.debug(
"Matrix: decrypted event %s has unhandled type %s",
getattr(event, "event_id", "?"),
type(decrypted).__name__,
)
except Exception as exc:
logger.warning(
"Matrix: error processing decrypted event %s: %s",
getattr(event, "event_id", "?"), exc,
)
self._pending_megolm = still_pending
# ------------------------------------------------------------------
# Event callbacks
# ------------------------------------------------------------------
@@ -712,13 +856,29 @@ class MatrixAdapter(BasePlatformAdapter):
if event_ts and event_ts < self._startup_ts - _STARTUP_GRACE_SECONDS:
return
# Handle decrypted MegolmEvents — extract the inner event.
# Handle undecryptable MegolmEvents: request the missing session key
# and buffer the event for retry once the key arrives.
if isinstance(event, nio.MegolmEvent):
# Failed to decrypt.
logger.warning(
"Matrix: could not decrypt event %s in %s",
"Matrix: could not decrypt event %s in %s — requesting key",
event.event_id, room.room_id,
)
# Ask other devices in the room to forward the session key.
try:
resp = await self._client.request_room_key(event)
if hasattr(resp, "event_id") or not isinstance(resp, Exception):
logger.debug(
"Matrix: room key request sent for session %s",
getattr(event, "session_id", "?"),
)
except Exception as exc:
logger.debug("Matrix: room key request failed: %s", exc)
# Buffer for retry on next maintenance cycle.
self._pending_megolm.append((room, event, time.time()))
if len(self._pending_megolm) > _MAX_PENDING_EVENTS:
self._pending_megolm = self._pending_megolm[-_MAX_PENDING_EVENTS:]
return
# Skip edits (m.replace relation).

View File

@@ -643,3 +643,353 @@ class TestMatrixEncryptedSendFallback:
assert fake_client.room_send.await_count == 2
second_call = fake_client.room_send.await_args_list[1]
assert second_call.kwargs.get("ignore_unverified_devices") is True
# ---------------------------------------------------------------------------
# E2EE: Auto-trust devices
# ---------------------------------------------------------------------------
class TestMatrixAutoTrustDevices:
def test_auto_trust_verifies_unverified_devices(self):
adapter = _make_adapter()
# DeviceStore.__iter__ yields OlmDevice objects directly.
device_a = MagicMock()
device_a.device_id = "DEVICE_A"
device_a.verified = False
device_b = MagicMock()
device_b.device_id = "DEVICE_B"
device_b.verified = True # already trusted
device_c = MagicMock()
device_c.device_id = "DEVICE_C"
device_c.verified = False
fake_client = MagicMock()
fake_client.device_id = "OWN_DEVICE"
fake_client.verify_device = MagicMock()
# Simulate DeviceStore iteration (yields OlmDevice objects)
fake_client.device_store = MagicMock()
fake_client.device_store.__iter__ = MagicMock(
return_value=iter([device_a, device_b, device_c])
)
adapter._client = fake_client
adapter._auto_trust_devices()
# Should have verified device_a and device_c (not device_b, already verified)
assert fake_client.verify_device.call_count == 2
verified_devices = [call.args[0] for call in fake_client.verify_device.call_args_list]
assert device_a in verified_devices
assert device_c in verified_devices
assert device_b not in verified_devices
def test_auto_trust_skips_own_device(self):
adapter = _make_adapter()
own_device = MagicMock()
own_device.device_id = "MY_DEVICE"
own_device.verified = False
fake_client = MagicMock()
fake_client.device_id = "MY_DEVICE"
fake_client.verify_device = MagicMock()
fake_client.device_store = MagicMock()
fake_client.device_store.__iter__ = MagicMock(
return_value=iter([own_device])
)
adapter._client = fake_client
adapter._auto_trust_devices()
fake_client.verify_device.assert_not_called()
def test_auto_trust_handles_missing_device_store(self):
adapter = _make_adapter()
fake_client = MagicMock(spec=[]) # empty spec — no attributes
adapter._client = fake_client
# Should not raise
adapter._auto_trust_devices()
# ---------------------------------------------------------------------------
# E2EE: MegolmEvent key request + buffering
# ---------------------------------------------------------------------------
class TestMatrixMegolmEventHandling:
@pytest.mark.asyncio
async def test_megolm_event_requests_room_key_and_buffers(self):
adapter = _make_adapter()
adapter._user_id = "@bot:example.org"
adapter._startup_ts = 0.0
adapter._dm_rooms = {}
fake_megolm = MagicMock()
fake_megolm.sender = "@alice:example.org"
fake_megolm.event_id = "$encrypted_event"
fake_megolm.server_timestamp = 9999999999000 # future
fake_megolm.session_id = "SESSION123"
fake_room = MagicMock()
fake_room.room_id = "!room:example.org"
fake_client = MagicMock()
fake_client.request_room_key = AsyncMock(return_value=MagicMock())
adapter._client = fake_client
# Create a MegolmEvent class for isinstance check
fake_nio = MagicMock()
FakeMegolmEvent = type("MegolmEvent", (), {})
fake_megolm.__class__ = FakeMegolmEvent
fake_nio.MegolmEvent = FakeMegolmEvent
with patch.dict("sys.modules", {"nio": fake_nio}):
await adapter._on_room_message(fake_room, fake_megolm)
# Should have requested the room key
fake_client.request_room_key.assert_awaited_once_with(fake_megolm)
# Should have buffered the event
assert len(adapter._pending_megolm) == 1
room, event, ts = adapter._pending_megolm[0]
assert room is fake_room
assert event is fake_megolm
@pytest.mark.asyncio
async def test_megolm_buffer_capped(self):
adapter = _make_adapter()
adapter._user_id = "@bot:example.org"
adapter._startup_ts = 0.0
adapter._dm_rooms = {}
fake_client = MagicMock()
fake_client.request_room_key = AsyncMock(return_value=MagicMock())
adapter._client = fake_client
FakeMegolmEvent = type("MegolmEvent", (), {})
fake_nio = MagicMock()
fake_nio.MegolmEvent = FakeMegolmEvent
# Fill the buffer past max
from gateway.platforms.matrix import _MAX_PENDING_EVENTS
with patch.dict("sys.modules", {"nio": fake_nio}):
for i in range(_MAX_PENDING_EVENTS + 10):
evt = MagicMock()
evt.__class__ = FakeMegolmEvent
evt.sender = "@alice:example.org"
evt.event_id = f"$event_{i}"
evt.server_timestamp = 9999999999000
evt.session_id = f"SESSION_{i}"
room = MagicMock()
room.room_id = "!room:example.org"
await adapter._on_room_message(room, evt)
assert len(adapter._pending_megolm) == _MAX_PENDING_EVENTS
# ---------------------------------------------------------------------------
# E2EE: Retry pending decryptions
# ---------------------------------------------------------------------------
class TestMatrixRetryPendingDecryptions:
@pytest.mark.asyncio
async def test_successful_decryption_routes_to_text_handler(self):
import time as _time
adapter = _make_adapter()
adapter._user_id = "@bot:example.org"
adapter._startup_ts = 0.0
adapter._dm_rooms = {}
# Create types
FakeMegolmEvent = type("MegolmEvent", (), {})
FakeRoomMessageText = type("RoomMessageText", (), {})
decrypted_event = MagicMock()
decrypted_event.__class__ = FakeRoomMessageText
fake_megolm = MagicMock()
fake_megolm.__class__ = FakeMegolmEvent
fake_megolm.event_id = "$encrypted"
fake_room = MagicMock()
now = _time.time()
adapter._pending_megolm = [(fake_room, fake_megolm, now)]
fake_client = MagicMock()
fake_client.decrypt_event = MagicMock(return_value=decrypted_event)
adapter._client = fake_client
fake_nio = MagicMock()
fake_nio.MegolmEvent = FakeMegolmEvent
fake_nio.RoomMessageText = FakeRoomMessageText
fake_nio.RoomMessageImage = type("RoomMessageImage", (), {})
fake_nio.RoomMessageAudio = type("RoomMessageAudio", (), {})
fake_nio.RoomMessageVideo = type("RoomMessageVideo", (), {})
fake_nio.RoomMessageFile = type("RoomMessageFile", (), {})
with patch.dict("sys.modules", {"nio": fake_nio}):
with patch.object(adapter, "_on_room_message", AsyncMock()) as mock_handler:
await adapter._retry_pending_decryptions()
mock_handler.assert_awaited_once_with(fake_room, decrypted_event)
# Buffer should be empty now
assert len(adapter._pending_megolm) == 0
@pytest.mark.asyncio
async def test_still_undecryptable_stays_in_buffer(self):
import time as _time
adapter = _make_adapter()
FakeMegolmEvent = type("MegolmEvent", (), {})
fake_megolm = MagicMock()
fake_megolm.__class__ = FakeMegolmEvent
fake_megolm.event_id = "$still_encrypted"
now = _time.time()
adapter._pending_megolm = [(MagicMock(), fake_megolm, now)]
fake_client = MagicMock()
# decrypt_event raises when key is still missing
fake_client.decrypt_event = MagicMock(side_effect=Exception("missing key"))
adapter._client = fake_client
fake_nio = MagicMock()
fake_nio.MegolmEvent = FakeMegolmEvent
with patch.dict("sys.modules", {"nio": fake_nio}):
await adapter._retry_pending_decryptions()
assert len(adapter._pending_megolm) == 1
@pytest.mark.asyncio
async def test_expired_events_dropped(self):
import time as _time
adapter = _make_adapter()
from gateway.platforms.matrix import _PENDING_EVENT_TTL
fake_megolm = MagicMock()
fake_megolm.event_id = "$old_event"
fake_megolm.__class__ = type("MegolmEvent", (), {})
# Timestamp well past TTL
old_ts = _time.time() - _PENDING_EVENT_TTL - 60
adapter._pending_megolm = [(MagicMock(), fake_megolm, old_ts)]
fake_client = MagicMock()
adapter._client = fake_client
fake_nio = MagicMock()
fake_nio.MegolmEvent = type("MegolmEvent", (), {})
with patch.dict("sys.modules", {"nio": fake_nio}):
await adapter._retry_pending_decryptions()
# Should have been dropped
assert len(adapter._pending_megolm) == 0
# Should NOT have tried to decrypt
fake_client.decrypt_event.assert_not_called()
@pytest.mark.asyncio
async def test_media_event_routes_to_media_handler(self):
import time as _time
adapter = _make_adapter()
adapter._user_id = "@bot:example.org"
adapter._startup_ts = 0.0
FakeMegolmEvent = type("MegolmEvent", (), {})
FakeRoomMessageImage = type("RoomMessageImage", (), {})
decrypted_image = MagicMock()
decrypted_image.__class__ = FakeRoomMessageImage
fake_megolm = MagicMock()
fake_megolm.__class__ = FakeMegolmEvent
fake_megolm.event_id = "$encrypted_image"
fake_room = MagicMock()
now = _time.time()
adapter._pending_megolm = [(fake_room, fake_megolm, now)]
fake_client = MagicMock()
fake_client.decrypt_event = MagicMock(return_value=decrypted_image)
adapter._client = fake_client
fake_nio = MagicMock()
fake_nio.MegolmEvent = FakeMegolmEvent
fake_nio.RoomMessageText = type("RoomMessageText", (), {})
fake_nio.RoomMessageImage = FakeRoomMessageImage
fake_nio.RoomMessageAudio = type("RoomMessageAudio", (), {})
fake_nio.RoomMessageVideo = type("RoomMessageVideo", (), {})
fake_nio.RoomMessageFile = type("RoomMessageFile", (), {})
with patch.dict("sys.modules", {"nio": fake_nio}):
with patch.object(adapter, "_on_room_message_media", AsyncMock()) as mock_media:
await adapter._retry_pending_decryptions()
mock_media.assert_awaited_once_with(fake_room, decrypted_image)
assert len(adapter._pending_megolm) == 0
# ---------------------------------------------------------------------------
# E2EE: Key export / import
# ---------------------------------------------------------------------------
class TestMatrixKeyExportImport:
@pytest.mark.asyncio
async def test_disconnect_exports_keys(self):
adapter = _make_adapter()
adapter._encryption = True
adapter._sync_task = None
fake_client = MagicMock()
fake_client.olm = object()
fake_client.export_keys = AsyncMock()
fake_client.close = AsyncMock()
adapter._client = fake_client
from gateway.platforms.matrix import _KEY_EXPORT_FILE, _KEY_EXPORT_PASSPHRASE
await adapter.disconnect()
fake_client.export_keys.assert_awaited_once_with(
str(_KEY_EXPORT_FILE), _KEY_EXPORT_PASSPHRASE,
)
@pytest.mark.asyncio
async def test_disconnect_handles_export_failure(self):
adapter = _make_adapter()
adapter._encryption = True
adapter._sync_task = None
fake_client = MagicMock()
fake_client.olm = object()
fake_client.export_keys = AsyncMock(side_effect=Exception("export failed"))
fake_client.close = AsyncMock()
adapter._client = fake_client
# Should not raise
await adapter.disconnect()
assert adapter._client is None # still cleaned up
@pytest.mark.asyncio
async def test_disconnect_skips_export_when_no_encryption(self):
adapter = _make_adapter()
adapter._encryption = False
adapter._sync_task = None
fake_client = MagicMock()
fake_client.close = AsyncMock()
adapter._client = fake_client
await adapter.disconnect()
# Should not have tried to export
assert not hasattr(fake_client, "export_keys") or \
not fake_client.export_keys.called