fix(matrix): harden e2ee access-token handling (#3562)
* fix(matrix): harden e2ee access-token handling * fix: patch nio mock in e2ee maintenance sync loop test The sync_loop now imports nio for SyncError checking (from PR #3280), so the test needs to inject a fake nio module via sys.modules. --------- Co-authored-by: Cortana <andrew+cortana@chalkley.org>
This commit is contained in:
@@ -161,22 +161,49 @@ class MatrixAdapter(BasePlatformAdapter):
|
||||
# Authenticate.
|
||||
if self._access_token:
|
||||
client.access_token = self._access_token
|
||||
# Resolve user_id if not set.
|
||||
if not self._user_id:
|
||||
resp = await client.whoami()
|
||||
if isinstance(resp, nio.WhoamiResponse):
|
||||
self._user_id = resp.user_id
|
||||
client.user_id = resp.user_id
|
||||
logger.info("Matrix: authenticated as %s", self._user_id)
|
||||
else:
|
||||
logger.error(
|
||||
"Matrix: whoami failed — check MATRIX_ACCESS_TOKEN and MATRIX_HOMESERVER"
|
||||
|
||||
# With access-token auth, always resolve whoami so we validate the
|
||||
# token and learn the device_id. The device_id matters for E2EE:
|
||||
# without it, matrix-nio can send plain messages but may fail to
|
||||
# decrypt inbound encrypted events or encrypt outbound room sends.
|
||||
resp = await client.whoami()
|
||||
if isinstance(resp, nio.WhoamiResponse):
|
||||
resolved_user_id = getattr(resp, "user_id", "") or self._user_id
|
||||
resolved_device_id = getattr(resp, "device_id", "")
|
||||
if resolved_user_id:
|
||||
self._user_id = resolved_user_id
|
||||
|
||||
# restore_login() is the matrix-nio path that binds the access
|
||||
# token to a specific device and loads the crypto store.
|
||||
if resolved_device_id and hasattr(client, "restore_login"):
|
||||
client.restore_login(
|
||||
self._user_id or resolved_user_id,
|
||||
resolved_device_id,
|
||||
self._access_token,
|
||||
)
|
||||
await client.close()
|
||||
return False
|
||||
else:
|
||||
if self._user_id:
|
||||
client.user_id = self._user_id
|
||||
if resolved_device_id:
|
||||
client.device_id = resolved_device_id
|
||||
client.access_token = self._access_token
|
||||
if self._encryption:
|
||||
logger.warning(
|
||||
"Matrix: access-token login did not restore E2EE state; "
|
||||
"encrypted rooms may fail until a device_id is available"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Matrix: using access token for %s%s",
|
||||
self._user_id or "(unknown user)",
|
||||
f" (device {resolved_device_id})" if resolved_device_id else "",
|
||||
)
|
||||
else:
|
||||
client.user_id = self._user_id
|
||||
logger.info("Matrix: using access token for %s", self._user_id)
|
||||
logger.error(
|
||||
"Matrix: whoami failed — check MATRIX_ACCESS_TOKEN and MATRIX_HOMESERVER"
|
||||
)
|
||||
await client.close()
|
||||
return False
|
||||
elif self._password and self._user_id:
|
||||
resp = await client.login(
|
||||
self._password,
|
||||
@@ -194,13 +221,18 @@ class MatrixAdapter(BasePlatformAdapter):
|
||||
return False
|
||||
|
||||
# If E2EE is enabled, load the crypto store.
|
||||
if self._encryption and hasattr(client, "olm"):
|
||||
if self._encryption and getattr(client, "olm", None):
|
||||
try:
|
||||
if client.should_upload_keys:
|
||||
await client.keys_upload()
|
||||
logger.info("Matrix: E2EE crypto initialized")
|
||||
except Exception as exc:
|
||||
logger.warning("Matrix: crypto init issue: %s", exc)
|
||||
elif self._encryption:
|
||||
logger.warning(
|
||||
"Matrix: E2EE requested but crypto store is not loaded; "
|
||||
"encrypted rooms may fail"
|
||||
)
|
||||
|
||||
# Register event callbacks.
|
||||
client.add_event_callback(self._on_room_message, nio.RoomMessageText)
|
||||
@@ -230,6 +262,7 @@ class MatrixAdapter(BasePlatformAdapter):
|
||||
)
|
||||
# Build DM room cache from m.direct account data.
|
||||
await self._refresh_dm_cache()
|
||||
await self._run_e2ee_maintenance()
|
||||
else:
|
||||
logger.warning("Matrix: initial sync returned %s", type(resp).__name__)
|
||||
|
||||
@@ -301,13 +334,48 @@ class MatrixAdapter(BasePlatformAdapter):
|
||||
relates_to["m.in_reply_to"] = {"event_id": reply_to}
|
||||
msg_content["m.relates_to"] = relates_to
|
||||
|
||||
resp = await self._client.room_send(
|
||||
chat_id,
|
||||
"m.room.message",
|
||||
msg_content,
|
||||
)
|
||||
async def _room_send_once(*, ignore_unverified_devices: bool = False):
|
||||
return await asyncio.wait_for(
|
||||
self._client.room_send(
|
||||
chat_id,
|
||||
"m.room.message",
|
||||
msg_content,
|
||||
ignore_unverified_devices=ignore_unverified_devices,
|
||||
),
|
||||
timeout=45,
|
||||
)
|
||||
|
||||
try:
|
||||
resp = await _room_send_once(ignore_unverified_devices=False)
|
||||
except Exception as exc:
|
||||
retryable = isinstance(exc, asyncio.TimeoutError)
|
||||
olm_unverified = getattr(nio, "OlmUnverifiedDeviceError", None)
|
||||
send_retry = getattr(nio, "SendRetryError", None)
|
||||
if isinstance(olm_unverified, type) and isinstance(exc, olm_unverified):
|
||||
retryable = True
|
||||
if isinstance(send_retry, type) and isinstance(exc, send_retry):
|
||||
retryable = True
|
||||
|
||||
if not retryable:
|
||||
logger.error("Matrix: failed to send to %s: %s", chat_id, exc)
|
||||
return SendResult(success=False, error=str(exc))
|
||||
|
||||
logger.warning(
|
||||
"Matrix: initial encrypted send to %s failed (%s); "
|
||||
"retrying after E2EE maintenance with ignored unverified devices",
|
||||
chat_id,
|
||||
exc,
|
||||
)
|
||||
await self._run_e2ee_maintenance()
|
||||
try:
|
||||
resp = await _room_send_once(ignore_unverified_devices=True)
|
||||
except Exception as retry_exc:
|
||||
logger.error("Matrix: failed to send to %s after retry: %s", chat_id, retry_exc)
|
||||
return SendResult(success=False, error=str(retry_exc))
|
||||
|
||||
if isinstance(resp, nio.RoomSendResponse):
|
||||
last_event_id = resp.event_id
|
||||
logger.info("Matrix: sent event %s to %s", last_event_id, chat_id)
|
||||
else:
|
||||
err = getattr(resp, "message", str(resp))
|
||||
logger.error("Matrix: failed to send to %s: %s", chat_id, err)
|
||||
@@ -565,6 +633,9 @@ class MatrixAdapter(BasePlatformAdapter):
|
||||
getattr(resp, "message", resp),
|
||||
)
|
||||
await asyncio.sleep(5)
|
||||
continue
|
||||
|
||||
await self._run_e2ee_maintenance()
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
except Exception as exc:
|
||||
@@ -573,6 +644,38 @@ class MatrixAdapter(BasePlatformAdapter):
|
||||
logger.warning("Matrix: sync error: %s — retrying in 5s", exc)
|
||||
await asyncio.sleep(5)
|
||||
|
||||
async def _run_e2ee_maintenance(self) -> None:
|
||||
"""Run matrix-nio E2EE housekeeping between syncs.
|
||||
|
||||
Hermes uses a custom sync loop instead of matrix-nio's sync_forever(),
|
||||
so we need to explicitly drive the key management work that sync_forever()
|
||||
normally handles for encrypted rooms.
|
||||
"""
|
||||
client = self._client
|
||||
if not client or not self._encryption or not getattr(client, "olm", None):
|
||||
return
|
||||
|
||||
tasks = [asyncio.create_task(client.send_to_device_messages())]
|
||||
|
||||
if client.should_upload_keys:
|
||||
tasks.append(asyncio.create_task(client.keys_upload()))
|
||||
|
||||
if client.should_query_keys:
|
||||
tasks.append(asyncio.create_task(client.keys_query()))
|
||||
|
||||
if client.should_claim_keys:
|
||||
users = client.get_users_for_key_claiming()
|
||||
if users:
|
||||
tasks.append(asyncio.create_task(client.keys_claim(users)))
|
||||
|
||||
for task in asyncio.as_completed(tasks):
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.warning("Matrix: E2EE maintenance task failed: %s", exc)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Event callbacks
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Tests for Matrix platform adapter."""
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
import pytest
|
||||
@@ -446,3 +447,199 @@ class TestMatrixRequirements:
|
||||
monkeypatch.delenv("MATRIX_HOMESERVER", raising=False)
|
||||
from gateway.platforms.matrix import check_matrix_requirements
|
||||
assert check_matrix_requirements() is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Access-token auth / E2EE bootstrap
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestMatrixAccessTokenAuth:
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_fetches_device_id_from_whoami_for_access_token(self):
|
||||
from gateway.platforms.matrix import MatrixAdapter
|
||||
|
||||
config = PlatformConfig(
|
||||
enabled=True,
|
||||
token="syt_test_access_token",
|
||||
extra={
|
||||
"homeserver": "https://matrix.example.org",
|
||||
"user_id": "@bot:example.org",
|
||||
"encryption": True,
|
||||
},
|
||||
)
|
||||
adapter = MatrixAdapter(config)
|
||||
|
||||
class FakeWhoamiResponse:
|
||||
def __init__(self, user_id, device_id):
|
||||
self.user_id = user_id
|
||||
self.device_id = device_id
|
||||
|
||||
class FakeSyncResponse:
|
||||
def __init__(self):
|
||||
self.rooms = MagicMock(join={})
|
||||
|
||||
fake_client = MagicMock()
|
||||
fake_client.whoami = AsyncMock(return_value=FakeWhoamiResponse("@bot:example.org", "DEV123"))
|
||||
fake_client.sync = AsyncMock(return_value=FakeSyncResponse())
|
||||
fake_client.keys_upload = AsyncMock()
|
||||
fake_client.keys_query = AsyncMock()
|
||||
fake_client.keys_claim = AsyncMock()
|
||||
fake_client.send_to_device_messages = AsyncMock(return_value=[])
|
||||
fake_client.get_users_for_key_claiming = MagicMock(return_value={})
|
||||
fake_client.close = AsyncMock()
|
||||
fake_client.add_event_callback = MagicMock()
|
||||
fake_client.rooms = {}
|
||||
fake_client.account_data = {}
|
||||
fake_client.olm = object()
|
||||
fake_client.should_upload_keys = False
|
||||
fake_client.should_query_keys = False
|
||||
fake_client.should_claim_keys = False
|
||||
|
||||
def _restore_login(user_id, device_id, access_token):
|
||||
fake_client.user_id = user_id
|
||||
fake_client.device_id = device_id
|
||||
fake_client.access_token = access_token
|
||||
fake_client.olm = object()
|
||||
|
||||
fake_client.restore_login = MagicMock(side_effect=_restore_login)
|
||||
|
||||
fake_nio = MagicMock()
|
||||
fake_nio.AsyncClient = MagicMock(return_value=fake_client)
|
||||
fake_nio.WhoamiResponse = FakeWhoamiResponse
|
||||
fake_nio.SyncResponse = FakeSyncResponse
|
||||
fake_nio.LoginResponse = type("LoginResponse", (), {})
|
||||
fake_nio.RoomMessageText = type("RoomMessageText", (), {})
|
||||
fake_nio.RoomMessageImage = type("RoomMessageImage", (), {})
|
||||
fake_nio.RoomMessageAudio = type("RoomMessageAudio", (), {})
|
||||
fake_nio.RoomMessageVideo = type("RoomMessageVideo", (), {})
|
||||
fake_nio.RoomMessageFile = type("RoomMessageFile", (), {})
|
||||
fake_nio.InviteMemberEvent = type("InviteMemberEvent", (), {})
|
||||
fake_nio.MegolmEvent = type("MegolmEvent", (), {})
|
||||
|
||||
with patch.dict("sys.modules", {"nio": fake_nio}):
|
||||
with patch.object(adapter, "_refresh_dm_cache", AsyncMock()):
|
||||
with patch.object(adapter, "_sync_loop", AsyncMock(return_value=None)):
|
||||
assert await adapter.connect() is True
|
||||
|
||||
fake_client.restore_login.assert_called_once_with(
|
||||
"@bot:example.org", "DEV123", "syt_test_access_token"
|
||||
)
|
||||
assert fake_client.access_token == "syt_test_access_token"
|
||||
assert fake_client.user_id == "@bot:example.org"
|
||||
assert fake_client.device_id == "DEV123"
|
||||
fake_client.whoami.assert_awaited_once()
|
||||
|
||||
await adapter.disconnect()
|
||||
|
||||
|
||||
class TestMatrixE2EEMaintenance:
|
||||
@pytest.mark.asyncio
|
||||
async def test_sync_loop_runs_e2ee_maintenance_requests(self):
|
||||
adapter = _make_adapter()
|
||||
adapter._encryption = True
|
||||
adapter._closing = False
|
||||
|
||||
class FakeSyncError:
|
||||
pass
|
||||
|
||||
async def _sync_once(timeout=30000):
|
||||
adapter._closing = True
|
||||
return MagicMock()
|
||||
|
||||
fake_client = MagicMock()
|
||||
fake_client.sync = AsyncMock(side_effect=_sync_once)
|
||||
fake_client.send_to_device_messages = AsyncMock(return_value=[])
|
||||
fake_client.keys_upload = AsyncMock()
|
||||
fake_client.keys_query = AsyncMock()
|
||||
fake_client.get_users_for_key_claiming = MagicMock(
|
||||
return_value={"@alice:example.org": ["DEVICE1"]}
|
||||
)
|
||||
fake_client.keys_claim = AsyncMock()
|
||||
fake_client.olm = object()
|
||||
fake_client.should_upload_keys = True
|
||||
fake_client.should_query_keys = True
|
||||
fake_client.should_claim_keys = True
|
||||
|
||||
adapter._client = fake_client
|
||||
|
||||
fake_nio = MagicMock()
|
||||
fake_nio.SyncError = FakeSyncError
|
||||
|
||||
with patch.dict("sys.modules", {"nio": fake_nio}):
|
||||
await adapter._sync_loop()
|
||||
|
||||
fake_client.sync.assert_awaited_once_with(timeout=30000)
|
||||
fake_client.send_to_device_messages.assert_awaited_once()
|
||||
fake_client.keys_upload.assert_awaited_once()
|
||||
fake_client.keys_query.assert_awaited_once()
|
||||
fake_client.keys_claim.assert_awaited_once_with(
|
||||
{"@alice:example.org": ["DEVICE1"]}
|
||||
)
|
||||
|
||||
|
||||
class TestMatrixEncryptedSendFallback:
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_retries_with_ignored_unverified_devices(self):
|
||||
adapter = _make_adapter()
|
||||
adapter._encryption = True
|
||||
|
||||
class FakeRoomSendResponse:
|
||||
def __init__(self, event_id):
|
||||
self.event_id = event_id
|
||||
|
||||
class FakeOlmUnverifiedDeviceError(Exception):
|
||||
pass
|
||||
|
||||
fake_client = MagicMock()
|
||||
fake_client.room_send = AsyncMock(side_effect=[
|
||||
FakeOlmUnverifiedDeviceError("unverified"),
|
||||
FakeRoomSendResponse("$event123"),
|
||||
])
|
||||
adapter._client = fake_client
|
||||
adapter._run_e2ee_maintenance = AsyncMock()
|
||||
|
||||
fake_nio = MagicMock()
|
||||
fake_nio.RoomSendResponse = FakeRoomSendResponse
|
||||
fake_nio.OlmUnverifiedDeviceError = FakeOlmUnverifiedDeviceError
|
||||
|
||||
with patch.dict("sys.modules", {"nio": fake_nio}):
|
||||
result = await adapter.send("!room:example.org", "hello")
|
||||
|
||||
assert result.success is True
|
||||
assert result.message_id == "$event123"
|
||||
adapter._run_e2ee_maintenance.assert_awaited_once()
|
||||
assert fake_client.room_send.await_count == 2
|
||||
first_call = fake_client.room_send.await_args_list[0]
|
||||
second_call = fake_client.room_send.await_args_list[1]
|
||||
assert first_call.kwargs.get("ignore_unverified_devices") is False
|
||||
assert second_call.kwargs.get("ignore_unverified_devices") is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_retries_after_timeout_in_encrypted_room(self):
|
||||
adapter = _make_adapter()
|
||||
adapter._encryption = True
|
||||
|
||||
class FakeRoomSendResponse:
|
||||
def __init__(self, event_id):
|
||||
self.event_id = event_id
|
||||
|
||||
fake_client = MagicMock()
|
||||
fake_client.room_send = AsyncMock(side_effect=[
|
||||
asyncio.TimeoutError(),
|
||||
FakeRoomSendResponse("$event456"),
|
||||
])
|
||||
adapter._client = fake_client
|
||||
adapter._run_e2ee_maintenance = AsyncMock()
|
||||
|
||||
fake_nio = MagicMock()
|
||||
fake_nio.RoomSendResponse = FakeRoomSendResponse
|
||||
|
||||
with patch.dict("sys.modules", {"nio": fake_nio}):
|
||||
result = await adapter.send("!room:example.org", "hello")
|
||||
|
||||
assert result.success is True
|
||||
assert result.message_id == "$event456"
|
||||
adapter._run_e2ee_maintenance.assert_awaited_once()
|
||||
assert fake_client.room_send.await_count == 2
|
||||
second_call = fake_client.room_send.await_args_list[1]
|
||||
assert second_call.kwargs.get("ignore_unverified_devices") is True
|
||||
|
||||
Reference in New Issue
Block a user