diff --git a/tests/gateway/test_matrix.py b/tests/gateway/test_matrix.py index 1a480570e..5c79e476b 100644 --- a/tests/gateway/test_matrix.py +++ b/tests/gateway/test_matrix.py @@ -1,8 +1,9 @@ -"""Tests for Matrix platform adapter.""" +"""Tests for Matrix platform adapter (mautrix-python backend).""" import asyncio import json import re import sys +import time import types import pytest from unittest.mock import MagicMock, patch, AsyncMock @@ -10,44 +11,165 @@ from unittest.mock import MagicMock, patch, AsyncMock from gateway.config import Platform, PlatformConfig -def _make_fake_nio(): - """Create a lightweight fake ``nio`` module with real response classes. +def _make_fake_mautrix(): + """Create a lightweight set of fake ``mautrix`` modules. - Tests that call production methods doing ``import nio`` / ``isinstance(resp, nio.XxxResponse)`` - need real classes (not MagicMock auto-attributes) to satisfy isinstance checks. - Use via ``patch.dict("sys.modules", {"nio": _make_fake_nio()})``. + The adapter does ``from mautrix.api import HTTPAPI``, + ``from mautrix.client import Client``, ``from mautrix.types import ...`` + at import time and inside methods. We provide just enough stubs for + tests that need to mock the mautrix import chain. + + Use via ``patch.dict("sys.modules", _make_fake_mautrix())``. """ - mod = types.ModuleType("nio") + # --- mautrix (root) --- + mautrix = types.ModuleType("mautrix") - class RoomSendResponse: - def __init__(self, event_id="$fake"): - self.event_id = event_id + # --- mautrix.api --- + mautrix_api = types.ModuleType("mautrix.api") - class RoomRedactResponse: + class HTTPAPI: + def __init__(self, base_url="", token="", **kwargs): + self.base_url = base_url + self.token = token + self.session = MagicMock() + self.session.close = AsyncMock() + + mautrix_api.HTTPAPI = HTTPAPI + mautrix.api = mautrix_api + + # --- mautrix.types --- + mautrix_types = types.ModuleType("mautrix.types") + + class EventType: + ROOM_MESSAGE = "m.room.message" + REACTION = "m.reaction" + ROOM_ENCRYPTED = "m.room.encrypted" + ROOM_NAME = "m.room.name" + + class UserID(str): pass - class RoomCreateResponse: - def __init__(self, room_id="!fake:example.org"): - self.room_id = room_id - - class RoomInviteResponse: + class RoomID(str): pass - class UploadResponse: - def __init__(self, content_uri="mxc://example.org/fake"): - self.content_uri = content_uri - - # Minimal Api stub for code that checks nio.Api.RoomPreset - class _Api: + class EventID(str): pass - mod.Api = _Api - mod.RoomSendResponse = RoomSendResponse - mod.RoomRedactResponse = RoomRedactResponse - mod.RoomCreateResponse = RoomCreateResponse - mod.RoomInviteResponse = RoomInviteResponse - mod.UploadResponse = UploadResponse - return mod + class ContentURI(str): + pass + + class SyncToken(str): + pass + + class RoomCreatePreset: + PRIVATE = "private_chat" + PUBLIC = "public_chat" + TRUSTED_PRIVATE = "trusted_private_chat" + + class PresenceState: + ONLINE = "online" + OFFLINE = "offline" + UNAVAILABLE = "unavailable" + + class TrustState: + UNVERIFIED = 0 + VERIFIED = 1 + + class PaginationDirection: + BACKWARD = "b" + FORWARD = "f" + + mautrix_types.EventType = EventType + mautrix_types.UserID = UserID + mautrix_types.RoomID = RoomID + mautrix_types.EventID = EventID + mautrix_types.ContentURI = ContentURI + mautrix_types.SyncToken = SyncToken + mautrix_types.RoomCreatePreset = RoomCreatePreset + mautrix_types.PresenceState = PresenceState + mautrix_types.TrustState = TrustState + mautrix_types.PaginationDirection = PaginationDirection + mautrix.types = mautrix_types + + # --- mautrix.client --- + mautrix_client = types.ModuleType("mautrix.client") + + class Client: + def __init__(self, mxid=None, device_id=None, api=None, + state_store=None, sync_store=None, **kwargs): + self.mxid = mxid + self.device_id = device_id + self.api = api + self.state_store = state_store + self.sync_store = sync_store + self.crypto = None + self._event_handlers = {} + + def add_event_handler(self, event_type, handler): + self._event_handlers.setdefault(event_type, []).append(handler) + + class InternalEventType: + INVITE = "internal.invite" + + mautrix_client.Client = Client + mautrix_client.InternalEventType = InternalEventType + mautrix.client = mautrix_client + + # --- mautrix.client.state_store --- + mautrix_client_state_store = types.ModuleType("mautrix.client.state_store") + + class MemoryStateStore: + async def get_member(self, room_id, user_id): + return None + + async def get_members(self, room_id): + return [] + + async def get_member_profiles(self, room_id): + return {} + + class MemorySyncStore: + pass + + mautrix_client_state_store.MemoryStateStore = MemoryStateStore + mautrix_client_state_store.MemorySyncStore = MemorySyncStore + + # --- mautrix.crypto --- + mautrix_crypto = types.ModuleType("mautrix.crypto") + + class OlmMachine: + def __init__(self, client=None, crypto_store=None, state_store=None): + self.share_keys_min_trust = None + self.send_keys_min_trust = None + + async def load(self): + pass + + async def share_keys(self): + pass + + async def decrypt_megolm_event(self, event): + return event + + mautrix_crypto.OlmMachine = OlmMachine + + # --- mautrix.crypto.store --- + mautrix_crypto_store = types.ModuleType("mautrix.crypto.store") + + class MemoryCryptoStore: + pass + + mautrix_crypto_store.MemoryCryptoStore = MemoryCryptoStore + + return { + "mautrix": mautrix, + "mautrix.api": mautrix_api, + "mautrix.types": mautrix_types, + "mautrix.client": mautrix_client, + "mautrix.client.state_store": mautrix_client_state_store, + "mautrix.crypto": mautrix_crypto, + "mautrix.crypto.store": mautrix_crypto_store, + } # --------------------------------------------------------------------------- @@ -438,27 +560,40 @@ class TestMatrixDisplayName: def setup_method(self): self.adapter = _make_adapter() - def test_get_display_name_from_room_users(self): - """Should get display name from room's users dict.""" - mock_room = MagicMock() - mock_user = MagicMock() - mock_user.display_name = "Alice" - mock_room.users = {"@alice:ex.org": mock_user} + @pytest.mark.asyncio + async def test_get_display_name_from_state_store(self): + """Should get display name from state_store.get_member().""" + mock_member = MagicMock() + mock_member.displayname = "Alice" - name = self.adapter._get_display_name(mock_room, "@alice:ex.org") + mock_state_store = MagicMock() + mock_state_store.get_member = AsyncMock(return_value=mock_member) + + mock_client = MagicMock() + mock_client.state_store = mock_state_store + self.adapter._client = mock_client + + name = await self.adapter._get_display_name("!room:ex.org", "@alice:ex.org") assert name == "Alice" - def test_get_display_name_fallback_to_localpart(self): + @pytest.mark.asyncio + async def test_get_display_name_fallback_to_localpart(self): """Should extract localpart from @user:server format.""" - mock_room = MagicMock() - mock_room.users = {} + mock_state_store = MagicMock() + mock_state_store.get_member = AsyncMock(return_value=None) - name = self.adapter._get_display_name(mock_room, "@bob:example.org") + mock_client = MagicMock() + mock_client.state_store = mock_state_store + self.adapter._client = mock_client + + name = await self.adapter._get_display_name("!room:ex.org", "@bob:example.org") assert name == "bob" - def test_get_display_name_no_room(self): - """Should handle None room gracefully.""" - name = self.adapter._get_display_name(None, "@charlie:ex.org") + @pytest.mark.asyncio + async def test_get_display_name_no_client(self): + """Should handle None client gracefully.""" + self.adapter._client = None + name = await self.adapter._get_display_name("!room:ex.org", "@charlie:ex.org") assert name == "charlie" @@ -473,7 +608,7 @@ class TestMatrixRequirements: monkeypatch.delenv("MATRIX_ENCRYPTION", raising=False) from gateway.platforms.matrix import check_matrix_requirements try: - import nio # noqa: F401 + import mautrix # noqa: F401 assert check_matrix_requirements() is True except ImportError: assert check_matrix_requirements() is False @@ -509,9 +644,9 @@ class TestMatrixRequirements: from gateway.platforms import matrix as matrix_mod with patch.object(matrix_mod, "_check_e2ee_deps", return_value=False): - # Still needs nio itself to be importable + # Still needs mautrix itself to be importable try: - import nio # noqa: F401 + import mautrix # noqa: F401 assert matrix_mod.check_matrix_requirements() is True except ImportError: assert matrix_mod.check_matrix_requirements() is False @@ -525,7 +660,7 @@ class TestMatrixRequirements: from gateway.platforms import matrix as matrix_mod with patch.object(matrix_mod, "_check_e2ee_deps", return_value=True): try: - import nio # noqa: F401 + import mautrix # noqa: F401 assert matrix_mod.check_matrix_requirements() is True except ImportError: assert matrix_mod.check_matrix_requirements() is False @@ -537,7 +672,8 @@ class TestMatrixRequirements: class TestMatrixAccessTokenAuth: @pytest.mark.asyncio - async def test_connect_fetches_device_id_from_whoami_for_access_token(self): + async def test_connect_with_access_token_and_encryption(self): + """connect() should call whoami, set user_id/device_id, set up crypto.""" from gateway.platforms.matrix import MatrixAdapter config = PlatformConfig( @@ -556,62 +692,43 @@ class TestMatrixAccessTokenAuth: self.user_id = user_id self.device_id = device_id - class FakeSyncResponse: - def __init__(self): - self.rooms = MagicMock(join={}) + fake_mautrix_mods = _make_fake_mautrix() - fake_client = MagicMock() - fake_client.whoami = AsyncMock(return_value=FakeWhoamiResponse("@bot:example.org", "DEV123")) - fake_client.sync = AsyncMock(return_value=FakeSyncResponse()) - fake_client.keys_upload = AsyncMock() - fake_client.keys_query = AsyncMock() - fake_client.keys_claim = AsyncMock() - fake_client.send_to_device_messages = AsyncMock(return_value=[]) - fake_client.get_users_for_key_claiming = MagicMock(return_value={}) - fake_client.close = AsyncMock() - fake_client.add_event_callback = MagicMock() - fake_client.rooms = {} - fake_client.account_data = {} - fake_client.olm = object() - fake_client.should_upload_keys = False - fake_client.should_query_keys = False - fake_client.should_claim_keys = False + # Create a mock client that returns from the mautrix.client.Client constructor + mock_client = MagicMock() + mock_client.mxid = "@bot:example.org" + mock_client.device_id = None + mock_client.state_store = MagicMock() + mock_client.sync_store = MagicMock() + mock_client.crypto = None + mock_client.whoami = AsyncMock(return_value=FakeWhoamiResponse("@bot:example.org", "DEV123")) + mock_client.sync = AsyncMock(return_value={"rooms": {"join": {"!room:server": {}}}}) + mock_client.add_event_handler = MagicMock() + mock_client.api = MagicMock() + mock_client.api.token = "syt_test_access_token" + mock_client.api.session = MagicMock() + mock_client.api.session.close = AsyncMock() - def _restore_login(user_id, device_id, access_token): - fake_client.user_id = user_id - fake_client.device_id = device_id - fake_client.access_token = access_token - fake_client.olm = object() + # Mock the crypto setup + mock_olm = MagicMock() + mock_olm.load = AsyncMock() + mock_olm.share_keys = AsyncMock() + mock_olm.share_keys_min_trust = None + mock_olm.send_keys_min_trust = None - fake_client.restore_login = MagicMock(side_effect=_restore_login) - - fake_nio = MagicMock() - fake_nio.AsyncClient = MagicMock(return_value=fake_client) - fake_nio.WhoamiResponse = FakeWhoamiResponse - fake_nio.SyncResponse = FakeSyncResponse - fake_nio.LoginResponse = type("LoginResponse", (), {}) - fake_nio.RoomMessageText = type("RoomMessageText", (), {}) - fake_nio.RoomMessageImage = type("RoomMessageImage", (), {}) - fake_nio.RoomMessageAudio = type("RoomMessageAudio", (), {}) - fake_nio.RoomMessageVideo = type("RoomMessageVideo", (), {}) - fake_nio.RoomMessageFile = type("RoomMessageFile", (), {}) - fake_nio.InviteMemberEvent = type("InviteMemberEvent", (), {}) - fake_nio.MegolmEvent = type("MegolmEvent", (), {}) + # Patch Client constructor to return our mock + fake_mautrix_mods["mautrix.client"].Client = MagicMock(return_value=mock_client) + fake_mautrix_mods["mautrix.crypto"].OlmMachine = MagicMock(return_value=mock_olm) from gateway.platforms import matrix as matrix_mod with patch.object(matrix_mod, "_check_e2ee_deps", return_value=True): - with patch.dict("sys.modules", {"nio": fake_nio}): + with patch.dict("sys.modules", fake_mautrix_mods): with patch.object(adapter, "_refresh_dm_cache", AsyncMock()): with patch.object(adapter, "_sync_loop", AsyncMock(return_value=None)): assert await adapter.connect() is True - fake_client.restore_login.assert_called_once_with( - "@bot:example.org", "DEV123", "syt_test_access_token" - ) - assert fake_client.access_token == "syt_test_access_token" - assert fake_client.user_id == "@bot:example.org" - assert fake_client.device_id == "DEV123" - fake_client.whoami.assert_awaited_once() + mock_client.whoami.assert_awaited_once() + assert adapter._user_id == "@bot:example.org" await adapter.disconnect() @@ -634,19 +751,30 @@ class TestMatrixE2EEHardFail: ) adapter = MatrixAdapter(config) - fake_nio = MagicMock() - fake_nio.AsyncClient = MagicMock() + fake_mautrix_mods = _make_fake_mautrix() + + mock_client = MagicMock() + mock_client.whoami = AsyncMock(return_value=MagicMock(user_id="@bot:example.org", device_id="DEV123")) + mock_client.api = MagicMock() + mock_client.api.token = "syt_test_access_token" + mock_client.api.session = MagicMock() + mock_client.api.session.close = AsyncMock() + mock_client.mxid = "@bot:example.org" + mock_client.device_id = None + mock_client.crypto = None + + fake_mautrix_mods["mautrix.client"].Client = MagicMock(return_value=mock_client) from gateway.platforms import matrix as matrix_mod with patch.object(matrix_mod, "_check_e2ee_deps", return_value=False): - with patch.dict("sys.modules", {"nio": fake_nio}): + with patch.dict("sys.modules", fake_mautrix_mods): result = await adapter.connect() assert result is False @pytest.mark.asyncio - async def test_connect_fails_when_olm_not_loaded_after_login(self): - """Even if _check_e2ee_deps passes, if olm is None after auth, hard-fail.""" + async def test_connect_fails_when_crypto_setup_raises(self): + """Even if _check_e2ee_deps passes, if OlmMachine raises, hard-fail.""" from gateway.platforms.matrix import MatrixAdapter config = PlatformConfig( @@ -660,36 +788,27 @@ class TestMatrixE2EEHardFail: ) adapter = MatrixAdapter(config) - class FakeWhoamiResponse: - def __init__(self, user_id, device_id): - self.user_id = user_id - self.device_id = device_id + fake_mautrix_mods = _make_fake_mautrix() - fake_client = MagicMock() - fake_client.whoami = AsyncMock(return_value=FakeWhoamiResponse("@bot:example.org", "DEV123")) - fake_client.close = AsyncMock() - # olm is None — crypto store not loaded - fake_client.olm = None - fake_client.should_upload_keys = False + mock_client = MagicMock() + mock_client.whoami = AsyncMock(return_value=MagicMock(user_id="@bot:example.org", device_id="DEV123")) + mock_client.api = MagicMock() + mock_client.api.token = "syt_test_access_token" + mock_client.api.session = MagicMock() + mock_client.api.session.close = AsyncMock() + mock_client.mxid = "@bot:example.org" + mock_client.device_id = None + mock_client.crypto = None - def _restore_login(user_id, device_id, access_token): - fake_client.user_id = user_id - fake_client.device_id = device_id - fake_client.access_token = access_token - - fake_client.restore_login = MagicMock(side_effect=_restore_login) - - fake_nio = MagicMock() - fake_nio.AsyncClient = MagicMock(return_value=fake_client) - fake_nio.WhoamiResponse = FakeWhoamiResponse + fake_mautrix_mods["mautrix.client"].Client = MagicMock(return_value=mock_client) + fake_mautrix_mods["mautrix.crypto"].OlmMachine = MagicMock(side_effect=Exception("olm init failed")) from gateway.platforms import matrix as matrix_mod with patch.object(matrix_mod, "_check_e2ee_deps", return_value=True): - with patch.dict("sys.modules", {"nio": fake_nio}): + with patch.dict("sys.modules", fake_mautrix_mods): result = await adapter.connect() assert result is False - fake_client.close.assert_awaited_once() class TestMatrixDeviceId: @@ -757,106 +876,50 @@ class TestMatrixDeviceId: ) adapter = MatrixAdapter(config) - class FakeWhoamiResponse: - def __init__(self, user_id, device_id): - self.user_id = user_id - self.device_id = device_id + fake_mautrix_mods = _make_fake_mautrix() - class FakeSyncResponse: - def __init__(self): - self.rooms = MagicMock(join={}) + mock_client = MagicMock() + mock_client.mxid = "@bot:example.org" + mock_client.device_id = None + mock_client.state_store = MagicMock() + mock_client.sync_store = MagicMock() + mock_client.crypto = None + mock_client.whoami = AsyncMock(return_value=MagicMock(user_id="@bot:example.org", device_id="WHOAMI_DEV")) + mock_client.sync = AsyncMock(return_value={"rooms": {"join": {"!room:server": {}}}}) + mock_client.add_event_handler = MagicMock() + mock_client.api = MagicMock() + mock_client.api.token = "syt_test_access_token" + mock_client.api.session = MagicMock() + mock_client.api.session.close = AsyncMock() - fake_client = MagicMock() - fake_client.whoami = AsyncMock(return_value=FakeWhoamiResponse("@bot:example.org", "WHOAMI_DEV")) - fake_client.sync = AsyncMock(return_value=FakeSyncResponse()) - fake_client.keys_upload = AsyncMock() - fake_client.keys_query = AsyncMock() - fake_client.keys_claim = AsyncMock() - fake_client.send_to_device_messages = AsyncMock(return_value=[]) - fake_client.get_users_for_key_claiming = MagicMock(return_value={}) - fake_client.close = AsyncMock() - fake_client.add_event_callback = MagicMock() - fake_client.rooms = {} - fake_client.account_data = {} - fake_client.olm = object() - fake_client.should_upload_keys = False - fake_client.should_query_keys = False - fake_client.should_claim_keys = False + mock_olm = MagicMock() + mock_olm.load = AsyncMock() + mock_olm.share_keys = AsyncMock() + mock_olm.share_keys_min_trust = None + mock_olm.send_keys_min_trust = None - def _restore_login(user_id, device_id, access_token): - fake_client.user_id = user_id - fake_client.device_id = device_id - fake_client.access_token = access_token - - fake_client.restore_login = MagicMock(side_effect=_restore_login) - - fake_nio = MagicMock() - fake_nio.AsyncClient = MagicMock(return_value=fake_client) - fake_nio.WhoamiResponse = FakeWhoamiResponse - fake_nio.SyncResponse = FakeSyncResponse - fake_nio.LoginResponse = type("LoginResponse", (), {}) - fake_nio.RoomMessageText = type("RoomMessageText", (), {}) - fake_nio.RoomMessageImage = type("RoomMessageImage", (), {}) - fake_nio.RoomMessageAudio = type("RoomMessageAudio", (), {}) - fake_nio.RoomMessageVideo = type("RoomMessageVideo", (), {}) - fake_nio.RoomMessageFile = type("RoomMessageFile", (), {}) - fake_nio.InviteMemberEvent = type("InviteMemberEvent", (), {}) - fake_nio.MegolmEvent = type("MegolmEvent", (), {}) + fake_mautrix_mods["mautrix.client"].Client = MagicMock(return_value=mock_client) + fake_mautrix_mods["mautrix.crypto"].OlmMachine = MagicMock(return_value=mock_olm) from gateway.platforms import matrix as matrix_mod with patch.object(matrix_mod, "_check_e2ee_deps", return_value=True): - with patch.dict("sys.modules", {"nio": fake_nio}): + with patch.dict("sys.modules", fake_mautrix_mods): with patch.object(adapter, "_refresh_dm_cache", AsyncMock()): with patch.object(adapter, "_sync_loop", AsyncMock(return_value=None)): assert await adapter.connect() is True - # The configured device_id should override the whoami device_id - fake_client.restore_login.assert_called_once_with( - "@bot:example.org", "MY_STABLE_DEVICE", "syt_test_access_token" - ) - assert fake_client.device_id == "MY_STABLE_DEVICE" - - # Verify device_id was passed to nio.AsyncClient constructor - ctor_call = fake_nio.AsyncClient.call_args - assert ctor_call.kwargs.get("device_id") == "MY_STABLE_DEVICE" + # The configured device_id should override the whoami device_id. + # In mautrix, the adapter sets client.device_id directly. + assert adapter._device_id == "MY_STABLE_DEVICE" await adapter.disconnect() -class TestMatrixE2EEClientConstructorFailure: - """connect() should hard-fail if nio.AsyncClient() raises when encryption is on.""" - - @pytest.mark.asyncio - async def test_connect_fails_when_e2ee_client_constructor_raises(self): - from gateway.platforms.matrix import MatrixAdapter - - config = PlatformConfig( - enabled=True, - token="syt_test_access_token", - extra={ - "homeserver": "https://matrix.example.org", - "user_id": "@bot:example.org", - "encryption": True, - }, - ) - adapter = MatrixAdapter(config) - - fake_nio = MagicMock() - fake_nio.AsyncClient = MagicMock(side_effect=Exception("olm init failed")) - - from gateway.platforms import matrix as matrix_mod - with patch.object(matrix_mod, "_check_e2ee_deps", return_value=True): - with patch.dict("sys.modules", {"nio": fake_nio}): - result = await adapter.connect() - - assert result is False - - class TestMatrixPasswordLoginDeviceId: - """MATRIX_DEVICE_ID should be passed to nio.AsyncClient even with password login.""" + """MATRIX_DEVICE_ID should be passed to mautrix Client even with password login.""" @pytest.mark.asyncio - async def test_password_login_passes_device_id_to_constructor(self): + async def test_password_login_uses_device_id(self): from gateway.platforms.matrix import MatrixAdapter config = PlatformConfig( @@ -870,40 +933,32 @@ class TestMatrixPasswordLoginDeviceId: ) adapter = MatrixAdapter(config) - class FakeLoginResponse: - pass + fake_mautrix_mods = _make_fake_mautrix() - class FakeSyncResponse: - def __init__(self): - self.rooms = MagicMock(join={}) + mock_client = MagicMock() + mock_client.mxid = "@bot:example.org" + mock_client.device_id = None + mock_client.state_store = MagicMock() + mock_client.sync_store = MagicMock() + mock_client.crypto = None + mock_client.login = AsyncMock(return_value=MagicMock(device_id="STABLE_PW_DEVICE", access_token="tok")) + mock_client.sync = AsyncMock(return_value={"rooms": {"join": {}}}) + mock_client.add_event_handler = MagicMock() + mock_client.api = MagicMock() + mock_client.api.token = "" + mock_client.api.session = MagicMock() + mock_client.api.session.close = AsyncMock() - fake_client = MagicMock() - fake_client.login = AsyncMock(return_value=FakeLoginResponse()) - fake_client.sync = AsyncMock(return_value=FakeSyncResponse()) - fake_client.close = AsyncMock() - fake_client.add_event_callback = MagicMock() - fake_client.rooms = {} - fake_client.account_data = {} + fake_mautrix_mods["mautrix.client"].Client = MagicMock(return_value=mock_client) - fake_nio = MagicMock() - fake_nio.AsyncClient = MagicMock(return_value=fake_client) - fake_nio.LoginResponse = FakeLoginResponse - fake_nio.SyncResponse = FakeSyncResponse - fake_nio.RoomMessageText = type("RoomMessageText", (), {}) - fake_nio.RoomMessageImage = type("RoomMessageImage", (), {}) - fake_nio.RoomMessageAudio = type("RoomMessageAudio", (), {}) - fake_nio.RoomMessageVideo = type("RoomMessageVideo", (), {}) - fake_nio.RoomMessageFile = type("RoomMessageFile", (), {}) - fake_nio.InviteMemberEvent = type("InviteMemberEvent", (), {}) - - with patch.dict("sys.modules", {"nio": fake_nio}): + from gateway.platforms import matrix as matrix_mod + with patch.dict("sys.modules", fake_mautrix_mods): with patch.object(adapter, "_refresh_dm_cache", AsyncMock()): with patch.object(adapter, "_sync_loop", AsyncMock(return_value=None)): assert await adapter.connect() is True - # Verify device_id was passed to the nio.AsyncClient constructor - ctor_call = fake_nio.AsyncClient.call_args - assert ctor_call.kwargs.get("device_id") == "STABLE_PW_DEVICE" + mock_client.login.assert_awaited_once() + assert adapter._device_id == "STABLE_PW_DEVICE" await adapter.disconnect() @@ -936,258 +991,104 @@ class TestMatrixDeviceIdConfig: assert "device_id" not in mc.extra -class TestMatrixE2EEMaintenance: +class TestMatrixSyncLoop: @pytest.mark.asyncio - async def test_sync_loop_runs_e2ee_maintenance_requests(self): + async def test_sync_loop_shares_keys_when_encryption_enabled(self): + """_sync_loop should call crypto.share_keys() after each sync.""" adapter = _make_adapter() adapter._encryption = True adapter._closing = False - class FakeSyncError: - pass + call_count = 0 - async def _sync_once(timeout=30000): - adapter._closing = True - return MagicMock() + async def _sync_once(**kwargs): + nonlocal call_count + call_count += 1 + if call_count >= 1: + adapter._closing = True + return {"rooms": {"join": {"!room:example.org": {}}}} + + mock_crypto = MagicMock() + mock_crypto.share_keys = AsyncMock() fake_client = MagicMock() fake_client.sync = AsyncMock(side_effect=_sync_once) - fake_client.send_to_device_messages = AsyncMock(return_value=[]) - fake_client.keys_upload = AsyncMock() - fake_client.keys_query = AsyncMock() - fake_client.get_users_for_key_claiming = MagicMock( - return_value={"@alice:example.org": ["DEVICE1"]} - ) - fake_client.keys_claim = AsyncMock() - fake_client.olm = object() - fake_client.should_upload_keys = True - fake_client.should_query_keys = True - fake_client.should_claim_keys = True - + fake_client.crypto = mock_crypto adapter._client = fake_client - fake_nio = MagicMock() - fake_nio.SyncError = FakeSyncError + await adapter._sync_loop() - with patch.dict("sys.modules", {"nio": fake_nio}): - await adapter._sync_loop() - - fake_client.sync.assert_awaited_once_with(timeout=30000) - fake_client.send_to_device_messages.assert_awaited_once() - fake_client.keys_upload.assert_awaited_once() - fake_client.keys_query.assert_awaited_once() - fake_client.keys_claim.assert_awaited_once_with( - {"@alice:example.org": ["DEVICE1"]} - ) + fake_client.sync.assert_awaited_once() + mock_crypto.share_keys.assert_awaited_once() class TestMatrixEncryptedSendFallback: @pytest.mark.asyncio - async def test_send_retries_with_ignored_unverified_devices(self): + async def test_send_retries_after_e2ee_error(self): + """send() should retry with crypto.share_keys() on E2EE errors.""" adapter = _make_adapter() adapter._encryption = True - class FakeRoomSendResponse: - def __init__(self, event_id): - self.event_id = event_id - - class FakeOlmUnverifiedDeviceError(Exception): - pass - fake_client = MagicMock() - fake_client.room_send = AsyncMock(side_effect=[ - FakeOlmUnverifiedDeviceError("unverified"), - FakeRoomSendResponse("$event123"), + fake_client.send_message_event = AsyncMock(side_effect=[ + Exception("encryption error"), + "$event123", # mautrix returns EventID string directly ]) + mock_crypto = MagicMock() + mock_crypto.share_keys = AsyncMock() + fake_client.crypto = mock_crypto adapter._client = fake_client - adapter._run_e2ee_maintenance = AsyncMock() - fake_nio = MagicMock() - fake_nio.RoomSendResponse = FakeRoomSendResponse - fake_nio.OlmUnverifiedDeviceError = FakeOlmUnverifiedDeviceError - - with patch.dict("sys.modules", {"nio": fake_nio}): - result = await adapter.send("!room:example.org", "hello") + result = await adapter.send("!room:example.org", "hello") assert result.success is True assert result.message_id == "$event123" - adapter._run_e2ee_maintenance.assert_awaited_once() - assert fake_client.room_send.await_count == 2 - first_call = fake_client.room_send.await_args_list[0] - second_call = fake_client.room_send.await_args_list[1] - assert first_call.kwargs.get("ignore_unverified_devices") is False - assert second_call.kwargs.get("ignore_unverified_devices") is True - - @pytest.mark.asyncio - async def test_send_retries_after_timeout_in_encrypted_room(self): - adapter = _make_adapter() - adapter._encryption = True - - class FakeRoomSendResponse: - def __init__(self, event_id): - self.event_id = event_id - - fake_client = MagicMock() - fake_client.room_send = AsyncMock(side_effect=[ - asyncio.TimeoutError(), - FakeRoomSendResponse("$event456"), - ]) - adapter._client = fake_client - adapter._run_e2ee_maintenance = AsyncMock() - - fake_nio = MagicMock() - fake_nio.RoomSendResponse = FakeRoomSendResponse - - with patch.dict("sys.modules", {"nio": fake_nio}): - result = await adapter.send("!room:example.org", "hello") - - assert result.success is True - assert result.message_id == "$event456" - adapter._run_e2ee_maintenance.assert_awaited_once() - assert fake_client.room_send.await_count == 2 - second_call = fake_client.room_send.await_args_list[1] - assert second_call.kwargs.get("ignore_unverified_devices") is True + mock_crypto.share_keys.assert_awaited_once() + assert fake_client.send_message_event.await_count == 2 # --------------------------------------------------------------------------- -# E2EE: Auto-trust devices -# --------------------------------------------------------------------------- - -class TestMatrixAutoTrustDevices: - def test_auto_trust_verifies_unverified_devices(self): - adapter = _make_adapter() - - # DeviceStore.__iter__ yields OlmDevice objects directly. - device_a = MagicMock() - device_a.device_id = "DEVICE_A" - device_a.verified = False - device_b = MagicMock() - device_b.device_id = "DEVICE_B" - device_b.verified = True # already trusted - device_c = MagicMock() - device_c.device_id = "DEVICE_C" - device_c.verified = False - - fake_client = MagicMock() - fake_client.device_id = "OWN_DEVICE" - fake_client.verify_device = MagicMock() - - # Simulate DeviceStore iteration (yields OlmDevice objects) - fake_client.device_store = MagicMock() - fake_client.device_store.__iter__ = MagicMock( - return_value=iter([device_a, device_b, device_c]) - ) - - adapter._client = fake_client - adapter._auto_trust_devices() - - # Should have verified device_a and device_c (not device_b, already verified) - assert fake_client.verify_device.call_count == 2 - verified_devices = [call.args[0] for call in fake_client.verify_device.call_args_list] - assert device_a in verified_devices - assert device_c in verified_devices - assert device_b not in verified_devices - - def test_auto_trust_skips_own_device(self): - adapter = _make_adapter() - - own_device = MagicMock() - own_device.device_id = "MY_DEVICE" - own_device.verified = False - - fake_client = MagicMock() - fake_client.device_id = "MY_DEVICE" - fake_client.verify_device = MagicMock() - - fake_client.device_store = MagicMock() - fake_client.device_store.__iter__ = MagicMock( - return_value=iter([own_device]) - ) - - adapter._client = fake_client - adapter._auto_trust_devices() - - fake_client.verify_device.assert_not_called() - - def test_auto_trust_handles_missing_device_store(self): - adapter = _make_adapter() - fake_client = MagicMock(spec=[]) # empty spec — no attributes - adapter._client = fake_client - # Should not raise - adapter._auto_trust_devices() - - -# --------------------------------------------------------------------------- -# E2EE: MegolmEvent key request + buffering +# E2EE: MegolmEvent key request + buffering via _on_encrypted_event # --------------------------------------------------------------------------- class TestMatrixMegolmEventHandling: @pytest.mark.asyncio - async def test_megolm_event_requests_room_key_and_buffers(self): + async def test_encrypted_event_buffers_for_retry(self): + """_on_encrypted_event should buffer undecrypted events for retry.""" adapter = _make_adapter() adapter._user_id = "@bot:example.org" adapter._startup_ts = 0.0 adapter._dm_rooms = {} - fake_megolm = MagicMock() - fake_megolm.sender = "@alice:example.org" - fake_megolm.event_id = "$encrypted_event" - fake_megolm.server_timestamp = 9999999999000 # future - fake_megolm.session_id = "SESSION123" + fake_event = MagicMock() + fake_event.room_id = "!room:example.org" + fake_event.event_id = "$encrypted_event" + fake_event.sender = "@alice:example.org" - fake_room = MagicMock() - fake_room.room_id = "!room:example.org" - - fake_client = MagicMock() - fake_client.request_room_key = AsyncMock(return_value=MagicMock()) - adapter._client = fake_client - - # Create a MegolmEvent class for isinstance check - fake_nio = MagicMock() - FakeMegolmEvent = type("MegolmEvent", (), {}) - fake_megolm.__class__ = FakeMegolmEvent - fake_nio.MegolmEvent = FakeMegolmEvent - - with patch.dict("sys.modules", {"nio": fake_nio}): - await adapter._on_room_message(fake_room, fake_megolm) - - # Should have requested the room key - fake_client.request_room_key.assert_awaited_once_with(fake_megolm) + await adapter._on_encrypted_event(fake_event) # Should have buffered the event assert len(adapter._pending_megolm) == 1 - room, event, ts = adapter._pending_megolm[0] - assert room is fake_room - assert event is fake_megolm + room_id, event, ts = adapter._pending_megolm[0] + assert room_id == "!room:example.org" + assert event is fake_event @pytest.mark.asyncio - async def test_megolm_buffer_capped(self): + async def test_encrypted_event_buffer_capped(self): + """Buffer should not grow past _MAX_PENDING_EVENTS.""" adapter = _make_adapter() adapter._user_id = "@bot:example.org" adapter._startup_ts = 0.0 adapter._dm_rooms = {} - fake_client = MagicMock() - fake_client.request_room_key = AsyncMock(return_value=MagicMock()) - adapter._client = fake_client - - FakeMegolmEvent = type("MegolmEvent", (), {}) - fake_nio = MagicMock() - fake_nio.MegolmEvent = FakeMegolmEvent - - # Fill the buffer past max from gateway.platforms.matrix import _MAX_PENDING_EVENTS - with patch.dict("sys.modules", {"nio": fake_nio}): - for i in range(_MAX_PENDING_EVENTS + 10): - evt = MagicMock() - evt.__class__ = FakeMegolmEvent - evt.sender = "@alice:example.org" - evt.event_id = f"$event_{i}" - evt.server_timestamp = 9999999999000 - evt.session_id = f"SESSION_{i}" - room = MagicMock() - room.room_id = "!room:example.org" - await adapter._on_room_message(room, evt) + + for i in range(_MAX_PENDING_EVENTS + 10): + evt = MagicMock() + evt.room_id = "!room:example.org" + evt.event_id = f"$event_{i}" + evt.sender = "@alice:example.org" + await adapter._on_encrypted_event(evt) assert len(adapter._pending_megolm) == _MAX_PENDING_EVENTS @@ -1198,219 +1099,91 @@ class TestMatrixMegolmEventHandling: class TestMatrixRetryPendingDecryptions: @pytest.mark.asyncio - async def test_successful_decryption_routes_to_text_handler(self): - import time as _time - + async def test_successful_decryption_routes_to_handler(self): adapter = _make_adapter() adapter._user_id = "@bot:example.org" adapter._startup_ts = 0.0 adapter._dm_rooms = {} - # Create types - FakeMegolmEvent = type("MegolmEvent", (), {}) - FakeRoomMessageText = type("RoomMessageText", (), {}) + fake_encrypted = MagicMock() + fake_encrypted.event_id = "$encrypted" decrypted_event = MagicMock() - decrypted_event.__class__ = FakeRoomMessageText - fake_megolm = MagicMock() - fake_megolm.__class__ = FakeMegolmEvent - fake_megolm.event_id = "$encrypted" - - fake_room = MagicMock() - now = _time.time() - - adapter._pending_megolm = [(fake_room, fake_megolm, now)] + mock_crypto = MagicMock() + mock_crypto.decrypt_megolm_event = AsyncMock(return_value=decrypted_event) fake_client = MagicMock() - fake_client.decrypt_event = MagicMock(return_value=decrypted_event) + fake_client.crypto = mock_crypto adapter._client = fake_client - fake_nio = MagicMock() - fake_nio.MegolmEvent = FakeMegolmEvent - fake_nio.RoomMessageText = FakeRoomMessageText - fake_nio.RoomMessageImage = type("RoomMessageImage", (), {}) - fake_nio.RoomMessageAudio = type("RoomMessageAudio", (), {}) - fake_nio.RoomMessageVideo = type("RoomMessageVideo", (), {}) - fake_nio.RoomMessageFile = type("RoomMessageFile", (), {}) + now = time.time() + adapter._pending_megolm = [("!room:ex.org", fake_encrypted, now)] - with patch.dict("sys.modules", {"nio": fake_nio}): - with patch.object(adapter, "_on_room_message", AsyncMock()) as mock_handler: - await adapter._retry_pending_decryptions() - mock_handler.assert_awaited_once_with(fake_room, decrypted_event) + with patch.object(adapter, "_on_room_message", AsyncMock()) as mock_handler: + await adapter._retry_pending_decryptions() + mock_handler.assert_awaited_once_with(decrypted_event) # Buffer should be empty now assert len(adapter._pending_megolm) == 0 @pytest.mark.asyncio async def test_still_undecryptable_stays_in_buffer(self): - import time as _time - adapter = _make_adapter() - FakeMegolmEvent = type("MegolmEvent", (), {}) + fake_encrypted = MagicMock() + fake_encrypted.event_id = "$still_encrypted" - fake_megolm = MagicMock() - fake_megolm.__class__ = FakeMegolmEvent - fake_megolm.event_id = "$still_encrypted" - - now = _time.time() - adapter._pending_megolm = [(MagicMock(), fake_megolm, now)] + mock_crypto = MagicMock() + mock_crypto.decrypt_megolm_event = AsyncMock(side_effect=Exception("missing key")) fake_client = MagicMock() - # decrypt_event raises when key is still missing - fake_client.decrypt_event = MagicMock(side_effect=Exception("missing key")) + fake_client.crypto = mock_crypto adapter._client = fake_client - fake_nio = MagicMock() - fake_nio.MegolmEvent = FakeMegolmEvent + now = time.time() + adapter._pending_megolm = [("!room:ex.org", fake_encrypted, now)] - with patch.dict("sys.modules", {"nio": fake_nio}): - await adapter._retry_pending_decryptions() + await adapter._retry_pending_decryptions() assert len(adapter._pending_megolm) == 1 @pytest.mark.asyncio async def test_expired_events_dropped(self): - import time as _time - adapter = _make_adapter() from gateway.platforms.matrix import _PENDING_EVENT_TTL - fake_megolm = MagicMock() - fake_megolm.event_id = "$old_event" - fake_megolm.__class__ = type("MegolmEvent", (), {}) - - # Timestamp well past TTL - old_ts = _time.time() - _PENDING_EVENT_TTL - 60 - adapter._pending_megolm = [(MagicMock(), fake_megolm, old_ts)] + fake_event = MagicMock() + fake_event.event_id = "$old_event" + mock_crypto = MagicMock() fake_client = MagicMock() + fake_client.crypto = mock_crypto adapter._client = fake_client - fake_nio = MagicMock() - fake_nio.MegolmEvent = type("MegolmEvent", (), {}) + # Timestamp well past TTL + old_ts = time.time() - _PENDING_EVENT_TTL - 60 + adapter._pending_megolm = [("!room:ex.org", fake_event, old_ts)] - with patch.dict("sys.modules", {"nio": fake_nio}): - await adapter._retry_pending_decryptions() + await adapter._retry_pending_decryptions() # Should have been dropped assert len(adapter._pending_megolm) == 0 - # Should NOT have tried to decrypt - fake_client.decrypt_event.assert_not_called() - - @pytest.mark.asyncio - async def test_media_event_routes_to_media_handler(self): - import time as _time - - adapter = _make_adapter() - adapter._user_id = "@bot:example.org" - adapter._startup_ts = 0.0 - - FakeMegolmEvent = type("MegolmEvent", (), {}) - FakeRoomMessageImage = type("RoomMessageImage", (), {}) - - decrypted_image = MagicMock() - decrypted_image.__class__ = FakeRoomMessageImage - - fake_megolm = MagicMock() - fake_megolm.__class__ = FakeMegolmEvent - fake_megolm.event_id = "$encrypted_image" - - fake_room = MagicMock() - now = _time.time() - adapter._pending_megolm = [(fake_room, fake_megolm, now)] - - fake_client = MagicMock() - fake_client.decrypt_event = MagicMock(return_value=decrypted_image) - adapter._client = fake_client - - fake_nio = MagicMock() - fake_nio.MegolmEvent = FakeMegolmEvent - fake_nio.RoomMessageText = type("RoomMessageText", (), {}) - fake_nio.RoomMessageImage = FakeRoomMessageImage - fake_nio.RoomMessageAudio = type("RoomMessageAudio", (), {}) - fake_nio.RoomMessageVideo = type("RoomMessageVideo", (), {}) - fake_nio.RoomMessageFile = type("RoomMessageFile", (), {}) - - with patch.dict("sys.modules", {"nio": fake_nio}): - with patch.object(adapter, "_on_room_message_media", AsyncMock()) as mock_media: - await adapter._retry_pending_decryptions() - mock_media.assert_awaited_once_with(fake_room, decrypted_image) - - assert len(adapter._pending_megolm) == 0 # --------------------------------------------------------------------------- -# E2EE: Key export / import +# E2EE: connect registers encrypted event handler # --------------------------------------------------------------------------- -class TestMatrixKeyExportImport: +class TestMatrixEncryptedEventHandler: @pytest.mark.asyncio - async def test_disconnect_exports_keys(self): - adapter = _make_adapter() - adapter._encryption = True - adapter._sync_task = None - - fake_client = MagicMock() - fake_client.olm = object() - fake_client.export_keys = AsyncMock() - fake_client.close = AsyncMock() - adapter._client = fake_client - - from gateway.platforms.matrix import _KEY_EXPORT_FILE, _KEY_EXPORT_PASSPHRASE - - await adapter.disconnect() - - fake_client.export_keys.assert_awaited_once_with( - str(_KEY_EXPORT_FILE), _KEY_EXPORT_PASSPHRASE, - ) - - @pytest.mark.asyncio - async def test_disconnect_handles_export_failure(self): - adapter = _make_adapter() - adapter._encryption = True - adapter._sync_task = None - - fake_client = MagicMock() - fake_client.olm = object() - fake_client.export_keys = AsyncMock(side_effect=Exception("export failed")) - fake_client.close = AsyncMock() - adapter._client = fake_client - - # Should not raise - await adapter.disconnect() - assert adapter._client is None # still cleaned up - - @pytest.mark.asyncio - async def test_disconnect_skips_export_when_no_encryption(self): - adapter = _make_adapter() - adapter._encryption = False - adapter._sync_task = None - - fake_client = MagicMock() - fake_client.close = AsyncMock() - adapter._client = fake_client - - await adapter.disconnect() - # Should not have tried to export - assert not hasattr(fake_client, "export_keys") or \ - not fake_client.export_keys.called - - -# --------------------------------------------------------------------------- -# E2EE: Encrypted media -# --------------------------------------------------------------------------- - -class TestMatrixEncryptedMedia: - @pytest.mark.asyncio - async def test_connect_registers_callbacks_for_encrypted_media_events(self): + async def test_connect_registers_encrypted_event_handler_when_encryption_on(self): from gateway.platforms.matrix import MatrixAdapter config = PlatformConfig( enabled=True, - token="syt_te...oken", + token="syt_test_token", extra={ "homeserver": "https://matrix.example.org", "user_id": "@bot:example.org", @@ -1419,350 +1192,104 @@ class TestMatrixEncryptedMedia: ) adapter = MatrixAdapter(config) - class FakeWhoamiResponse: - def __init__(self, user_id, device_id): - self.user_id = user_id - self.device_id = device_id + fake_mautrix_mods = _make_fake_mautrix() - class FakeSyncResponse: - def __init__(self): - self.rooms = MagicMock(join={}) + mock_client = MagicMock() + mock_client.mxid = "@bot:example.org" + mock_client.device_id = None + mock_client.state_store = MagicMock() + mock_client.sync_store = MagicMock() + mock_client.crypto = None # Will be set during connect + mock_client.whoami = AsyncMock(return_value=MagicMock(user_id="@bot:example.org", device_id="DEV123")) + mock_client.sync = AsyncMock(return_value={"rooms": {"join": {"!room:server": {}}}}) + mock_client.add_event_handler = MagicMock() + mock_client.api = MagicMock() + mock_client.api.token = "syt_test_token" + mock_client.api.session = MagicMock() + mock_client.api.session.close = AsyncMock() - class FakeRoomMessageText: ... - class FakeRoomMessageImage: ... - class FakeRoomMessageAudio: ... - class FakeRoomMessageVideo: ... - class FakeRoomMessageFile: ... - class FakeRoomEncryptedImage: ... - class FakeRoomEncryptedAudio: ... - class FakeRoomEncryptedVideo: ... - class FakeRoomEncryptedFile: ... - class FakeInviteMemberEvent: ... - class FakeMegolmEvent: ... + mock_olm = MagicMock() + mock_olm.load = AsyncMock() + mock_olm.share_keys = AsyncMock() + mock_olm.share_keys_min_trust = None + mock_olm.send_keys_min_trust = None - fake_client = MagicMock() - fake_client.whoami = AsyncMock(return_value=FakeWhoamiResponse("@bot:example.org", "DEV123")) - fake_client.sync = AsyncMock(return_value=FakeSyncResponse()) - fake_client.keys_upload = AsyncMock() - fake_client.keys_query = AsyncMock() - fake_client.keys_claim = AsyncMock() - fake_client.send_to_device_messages = AsyncMock(return_value=[]) - fake_client.get_users_for_key_claiming = MagicMock(return_value={}) - fake_client.close = AsyncMock() - fake_client.add_event_callback = MagicMock() - fake_client.rooms = {} - fake_client.account_data = {} - fake_client.olm = object() - fake_client.should_upload_keys = False - fake_client.should_query_keys = False - fake_client.should_claim_keys = False - fake_client.restore_login = MagicMock(side_effect=lambda u, d, t: None) - - fake_nio = MagicMock() - fake_nio.AsyncClient = MagicMock(return_value=fake_client) - fake_nio.WhoamiResponse = FakeWhoamiResponse - fake_nio.SyncResponse = FakeSyncResponse - fake_nio.LoginResponse = type("LoginResponse", (), {}) - fake_nio.RoomMessageText = FakeRoomMessageText - fake_nio.RoomMessageImage = FakeRoomMessageImage - fake_nio.RoomMessageAudio = FakeRoomMessageAudio - fake_nio.RoomMessageVideo = FakeRoomMessageVideo - fake_nio.RoomMessageFile = FakeRoomMessageFile - fake_nio.RoomEncryptedImage = FakeRoomEncryptedImage - fake_nio.RoomEncryptedAudio = FakeRoomEncryptedAudio - fake_nio.RoomEncryptedVideo = FakeRoomEncryptedVideo - fake_nio.RoomEncryptedFile = FakeRoomEncryptedFile - fake_nio.InviteMemberEvent = FakeInviteMemberEvent - fake_nio.MegolmEvent = FakeMegolmEvent + fake_mautrix_mods["mautrix.client"].Client = MagicMock(return_value=mock_client) + fake_mautrix_mods["mautrix.crypto"].OlmMachine = MagicMock(return_value=mock_olm) from gateway.platforms import matrix as matrix_mod with patch.object(matrix_mod, "_check_e2ee_deps", return_value=True): - with patch.dict("sys.modules", {"nio": fake_nio}): + with patch.dict("sys.modules", fake_mautrix_mods): with patch.object(adapter, "_refresh_dm_cache", AsyncMock()): with patch.object(adapter, "_sync_loop", AsyncMock(return_value=None)): assert await adapter.connect() is True - callback_classes = [call.args[1] for call in fake_client.add_event_callback.call_args_list] - assert FakeRoomEncryptedImage in callback_classes - assert FakeRoomEncryptedAudio in callback_classes - assert FakeRoomEncryptedVideo in callback_classes - assert FakeRoomEncryptedFile in callback_classes + # Verify event handlers were registered. + # In mautrix the order is: add_event_handler(EventType, callback) + handler_calls = mock_client.add_event_handler.call_args_list + registered_types = [call.args[0] for call in handler_calls] + + # Should have registered handlers for ROOM_MESSAGE, REACTION, INVITE, and ROOM_ENCRYPTED + assert len(handler_calls) >= 4 # At minimum these four await adapter.disconnect() + +# --------------------------------------------------------------------------- +# Disconnect +# --------------------------------------------------------------------------- + +class TestMatrixDisconnect: @pytest.mark.asyncio - async def test_on_room_message_media_decrypts_encrypted_image_and_passes_local_path(self): - try: - from nio.crypto.attachments import encrypt_attachment - except (ImportError, ModuleNotFoundError): - pytest.skip("matrix-nio[e2e] required for encryption tests") - + async def test_disconnect_closes_api_session(self): + """disconnect() should close client.api.session.""" adapter = _make_adapter() - adapter._user_id = "@bot:example.org" - adapter._startup_ts = 0.0 - adapter._dm_rooms = {} - adapter.handle_message = AsyncMock() + adapter._sync_task = None - plaintext = b"\x89PNG\r\n\x1a\n" + b"\x00" * 32 - ciphertext, keys = encrypt_attachment(plaintext) + mock_session = MagicMock() + mock_session.close = AsyncMock() - class FakeRoomEncryptedImage: - def __init__(self): - self.sender = "@alice:example.org" - self.event_id = "$img1" - self.server_timestamp = 0 - self.body = "screenshot.png" - self.url = "mxc://example.org/media123" - self.key = keys["key"]["k"] - self.hashes = keys["hashes"] - self.iv = keys["iv"] - self.mimetype = "image/png" - self.source = { - "content": { - "body": "screenshot.png", - "info": {"mimetype": "image/png"}, - "file": { - "url": self.url, - "key": keys["key"], - "hashes": keys["hashes"], - "iv": keys["iv"], - }, - } - } - - class FakeDownloadResponse: - def __init__(self, body): - self.body = body + mock_api = MagicMock() + mock_api.session = mock_session fake_client = MagicMock() - fake_client.download = AsyncMock(return_value=FakeDownloadResponse(ciphertext)) + fake_client.api = mock_api adapter._client = fake_client - fake_nio = MagicMock() - fake_nio.RoomMessageImage = type("RoomMessageImage", (), {}) - fake_nio.RoomMessageAudio = type("RoomMessageAudio", (), {}) - fake_nio.RoomMessageVideo = type("RoomMessageVideo", (), {}) - fake_nio.RoomMessageFile = type("RoomMessageFile", (), {}) - fake_nio.RoomEncryptedImage = FakeRoomEncryptedImage - fake_nio.RoomEncryptedAudio = type("RoomEncryptedAudio", (), {}) - fake_nio.RoomEncryptedVideo = type("RoomEncryptedVideo", (), {}) - fake_nio.RoomEncryptedFile = type("RoomEncryptedFile", (), {}) + await adapter.disconnect() - room = MagicMock(room_id="!room:example.org", member_count=2, users={}) - event = FakeRoomEncryptedImage() - - with patch.dict("sys.modules", {"nio": fake_nio}): - with patch("gateway.platforms.base.cache_image_from_bytes", return_value="/tmp/cached-image.png") as cache_mock: - await adapter._on_room_message_media(room, event) - - cache_mock.assert_called_once_with(plaintext, ext=".png") - msg_event = adapter.handle_message.await_args.args[0] - assert msg_event.message_type.name == "PHOTO" - assert msg_event.media_urls == ["/tmp/cached-image.png"] - assert msg_event.media_types == ["image/png"] + mock_session.close.assert_awaited_once() + assert adapter._client is None @pytest.mark.asyncio - async def test_on_room_message_media_decrypts_encrypted_voice_and_caches_audio(self): - try: - from nio.crypto.attachments import encrypt_attachment - except (ImportError, ModuleNotFoundError): - pytest.skip("matrix-nio[e2e] required for encryption tests") - + async def test_disconnect_handles_session_close_failure(self): + """disconnect() should not raise if session close fails.""" adapter = _make_adapter() - adapter._user_id = "@bot:example.org" - adapter._startup_ts = 0.0 - adapter._dm_rooms = {} - adapter.handle_message = AsyncMock() + adapter._sync_task = None - plaintext = b"OggS" + b"\x00" * 32 - ciphertext, keys = encrypt_attachment(plaintext) + mock_session = MagicMock() + mock_session.close = AsyncMock(side_effect=Exception("close failed")) - class FakeRoomEncryptedAudio: - def __init__(self): - self.sender = "@alice:example.org" - self.event_id = "$voice1" - self.server_timestamp = 0 - self.body = "voice.ogg" - self.url = "mxc://example.org/voice123" - self.key = keys["key"]["k"] - self.hashes = keys["hashes"] - self.iv = keys["iv"] - self.mimetype = "audio/ogg" - self.source = { - "content": { - "body": "voice.ogg", - "info": {"mimetype": "audio/ogg"}, - "org.matrix.msc3245.voice": {}, - "file": { - "url": self.url, - "key": keys["key"], - "hashes": keys["hashes"], - "iv": keys["iv"], - }, - } - } - - class FakeDownloadResponse: - def __init__(self, body): - self.body = body + mock_api = MagicMock() + mock_api.session = mock_session fake_client = MagicMock() - fake_client.download = AsyncMock(return_value=FakeDownloadResponse(ciphertext)) + fake_client.api = mock_api adapter._client = fake_client - fake_nio = MagicMock() - fake_nio.RoomMessageImage = type("RoomMessageImage", (), {}) - fake_nio.RoomMessageAudio = type("RoomMessageAudio", (), {}) - fake_nio.RoomMessageVideo = type("RoomMessageVideo", (), {}) - fake_nio.RoomMessageFile = type("RoomMessageFile", (), {}) - fake_nio.RoomEncryptedImage = type("RoomEncryptedImage", (), {}) - fake_nio.RoomEncryptedAudio = FakeRoomEncryptedAudio - fake_nio.RoomEncryptedVideo = type("RoomEncryptedVideo", (), {}) - fake_nio.RoomEncryptedFile = type("RoomEncryptedFile", (), {}) - - room = MagicMock(room_id="!room:example.org", member_count=2, users={}) - event = FakeRoomEncryptedAudio() - - with patch.dict("sys.modules", {"nio": fake_nio}): - with patch("gateway.platforms.base.cache_audio_from_bytes", return_value="/tmp/cached-voice.ogg") as cache_mock: - await adapter._on_room_message_media(room, event) - - cache_mock.assert_called_once_with(plaintext, ext=".ogg") - msg_event = adapter.handle_message.await_args.args[0] - assert msg_event.message_type.name == "VOICE" - assert msg_event.media_urls == ["/tmp/cached-voice.ogg"] - assert msg_event.media_types == ["audio/ogg"] + # Should not raise + await adapter.disconnect() + assert adapter._client is None @pytest.mark.asyncio - async def test_on_room_message_media_decrypts_encrypted_file_and_caches_document(self): - try: - from nio.crypto.attachments import encrypt_attachment - except (ImportError, ModuleNotFoundError): - pytest.skip("matrix-nio[e2e] required for encryption tests") - + async def test_disconnect_without_client(self): + """disconnect() should handle None client gracefully.""" adapter = _make_adapter() - adapter._user_id = "@bot:example.org" - adapter._startup_ts = 0.0 - adapter._dm_rooms = {} - adapter.handle_message = AsyncMock() + adapter._sync_task = None + adapter._client = None - plaintext = b"hello from encrypted document" - ciphertext, keys = encrypt_attachment(plaintext) - - class FakeRoomEncryptedFile: - def __init__(self): - self.sender = "@alice:example.org" - self.event_id = "$file1" - self.server_timestamp = 0 - self.body = "notes.txt" - self.url = "mxc://example.org/file123" - self.key = keys["key"] - self.hashes = keys["hashes"] - self.iv = keys["iv"] - self.mimetype = "text/plain" - self.source = { - "content": { - "body": "notes.txt", - "info": {"mimetype": "text/plain"}, - "file": { - "url": self.url, - "key": keys["key"], - "hashes": keys["hashes"], - "iv": keys["iv"], - }, - } - } - - class FakeDownloadResponse: - def __init__(self, body): - self.body = body - - fake_client = MagicMock() - fake_client.download = AsyncMock(return_value=FakeDownloadResponse(ciphertext)) - adapter._client = fake_client - - fake_nio = MagicMock() - fake_nio.RoomMessageImage = type("RoomMessageImage", (), {}) - fake_nio.RoomMessageAudio = type("RoomMessageAudio", (), {}) - fake_nio.RoomMessageVideo = type("RoomMessageVideo", (), {}) - fake_nio.RoomMessageFile = type("RoomMessageFile", (), {}) - fake_nio.RoomEncryptedImage = type("RoomEncryptedImage", (), {}) - fake_nio.RoomEncryptedAudio = type("RoomEncryptedAudio", (), {}) - fake_nio.RoomEncryptedVideo = type("RoomEncryptedVideo", (), {}) - fake_nio.RoomEncryptedFile = FakeRoomEncryptedFile - - room = MagicMock(room_id="!room:example.org", member_count=2, users={}) - event = FakeRoomEncryptedFile() - - with patch.dict("sys.modules", {"nio": fake_nio}): - with patch("gateway.platforms.base.cache_document_from_bytes", return_value="/tmp/cached-notes.txt") as cache_mock: - await adapter._on_room_message_media(room, event) - - cache_mock.assert_called_once_with(plaintext, "notes.txt") - msg_event = adapter.handle_message.await_args.args[0] - assert msg_event.message_type.name == "DOCUMENT" - assert msg_event.media_urls == ["/tmp/cached-notes.txt"] - assert msg_event.media_types == ["text/plain"] - - @pytest.mark.asyncio - async def test_on_room_message_media_does_not_emit_ciphertext_url_when_encrypted_media_decryption_fails(self): - adapter = _make_adapter() - adapter._user_id = "@bot:example.org" - adapter._startup_ts = 0.0 - adapter._dm_rooms = {} - adapter.handle_message = AsyncMock() - - class FakeRoomEncryptedImage: - def __init__(self): - self.sender = "@alice:example.org" - self.event_id = "$img2" - self.server_timestamp = 0 - self.body = "broken.png" - self.url = "mxc://example.org/media999" - self.key = {"k": "broken"} - self.hashes = {"sha256": "broken"} - self.iv = "broken" - self.mimetype = "image/png" - self.source = { - "content": { - "body": "broken.png", - "info": {"mimetype": "image/png"}, - "file": { - "url": self.url, - "key": self.key, - "hashes": self.hashes, - "iv": self.iv, - }, - } - } - - class FakeDownloadResponse: - def __init__(self, body): - self.body = body - - fake_client = MagicMock() - fake_client.download = AsyncMock(return_value=FakeDownloadResponse(b"ciphertext")) - adapter._client = fake_client - - fake_nio = MagicMock() - fake_nio.RoomMessageImage = type("RoomMessageImage", (), {}) - fake_nio.RoomMessageAudio = type("RoomMessageAudio", (), {}) - fake_nio.RoomMessageVideo = type("RoomMessageVideo", (), {}) - fake_nio.RoomMessageFile = type("RoomMessageFile", (), {}) - fake_nio.RoomEncryptedImage = FakeRoomEncryptedImage - fake_nio.RoomEncryptedAudio = type("RoomEncryptedAudio", (), {}) - fake_nio.RoomEncryptedVideo = type("RoomEncryptedVideo", (), {}) - fake_nio.RoomEncryptedFile = type("RoomEncryptedFile", (), {}) - - room = MagicMock(room_id="!room:example.org", member_count=2, users={}) - event = FakeRoomEncryptedImage() - - with patch.dict("sys.modules", {"nio": fake_nio}): - await adapter._on_room_message_media(room, event) - - msg_event = adapter.handle_message.await_args.args[0] - assert not msg_event.media_urls - assert not msg_event.media_types + await adapter.disconnect() + assert adapter._client is None # --------------------------------------------------------------------------- @@ -1933,34 +1460,29 @@ class TestMatrixReactions: @pytest.mark.asyncio async def test_send_reaction(self): - """_send_reaction should call room_send with m.reaction.""" - fake_nio = _make_fake_nio() + """_send_reaction should call send_message_event with m.reaction.""" mock_client = MagicMock() - mock_client.room_send = AsyncMock( - return_value=fake_nio.RoomSendResponse("$reaction1") - ) + # mautrix send_message_event returns EventID string directly + mock_client.send_message_event = AsyncMock(return_value="$reaction1") self.adapter._client = mock_client - with patch.dict("sys.modules", {"nio": fake_nio}): - result = await self.adapter._send_reaction("!room:ex", "$event1", "👍") + result = await self.adapter._send_reaction("!room:ex", "$event1", "\U0001f44d") assert result == "$reaction1" - mock_client.room_send.assert_called_once() - args = mock_client.room_send.call_args - assert args[0][1] == "m.reaction" - content = args[0][2] + mock_client.send_message_event.assert_called_once() + call_args = mock_client.send_message_event.call_args + content = call_args.args[2] if len(call_args.args) > 2 else call_args.kwargs.get("content") assert content["m.relates_to"]["rel_type"] == "m.annotation" - assert content["m.relates_to"]["key"] == "👍" + assert content["m.relates_to"]["key"] == "\U0001f44d" @pytest.mark.asyncio async def test_send_reaction_no_client(self): self.adapter._client = None - with patch.dict("sys.modules", {"nio": _make_fake_nio()}): - result = await self.adapter._send_reaction("!room:ex", "$ev", "👍") + result = await self.adapter._send_reaction("!room:ex", "$ev", "\U0001f44d") assert result is None @pytest.mark.asyncio async def test_on_processing_start_sends_eyes(self): - """on_processing_start should send 👀 reaction.""" + """on_processing_start should send eyes reaction.""" from gateway.platforms.base import MessageEvent, MessageType self.adapter._reactions_enabled = True @@ -1976,7 +1498,7 @@ class TestMatrixReactions: message_id="$msg1", ) await self.adapter.on_processing_start(event) - self.adapter._send_reaction.assert_called_once_with("!room:ex", "$msg1", "👀") + self.adapter._send_reaction.assert_called_once_with("!room:ex", "$msg1", "\U0001f440") assert self.adapter._pending_reactions == {("!room:ex", "$msg1"): "$reaction_event_123"} @pytest.mark.asyncio @@ -1999,7 +1521,7 @@ class TestMatrixReactions: ) await self.adapter.on_processing_complete(event, ProcessingOutcome.SUCCESS) self.adapter._redact_reaction.assert_called_once_with("!room:ex", "$eyes_reaction_123") - self.adapter._send_reaction.assert_called_once_with("!room:ex", "$msg1", "✅") + self.adapter._send_reaction.assert_called_once_with("!room:ex", "$msg1", "\u2705") @pytest.mark.asyncio async def test_on_processing_complete_sends_cross_on_failure(self): @@ -2021,7 +1543,7 @@ class TestMatrixReactions: ) await self.adapter.on_processing_complete(event, ProcessingOutcome.FAILURE) self.adapter._redact_reaction.assert_called_once_with("!room:ex", "$eyes_reaction_123") - self.adapter._send_reaction.assert_called_once_with("!room:ex", "$msg1", "❌") + self.adapter._send_reaction.assert_called_once_with("!room:ex", "$msg1", "\u274c") @pytest.mark.asyncio async def test_on_processing_complete_cancelled_sends_no_terminal_reaction(self): @@ -2063,7 +1585,7 @@ class TestMatrixReactions: ) await self.adapter.on_processing_complete(event, ProcessingOutcome.SUCCESS) self.adapter._redact_reaction.assert_not_called() - self.adapter._send_reaction.assert_called_once_with("!room:ex", "$msg1", "✅") + self.adapter._send_reaction.assert_called_once_with("!room:ex", "$msg1", "\u2705") @pytest.mark.asyncio async def test_reactions_disabled(self): @@ -2095,13 +1617,14 @@ class TestMatrixReadReceipts: @pytest.mark.asyncio async def test_send_read_receipt(self): + """send_read_receipt should call client.set_read_markers.""" mock_client = MagicMock() - mock_client.room_read_markers = AsyncMock(return_value=MagicMock()) + mock_client.set_read_markers = AsyncMock(return_value=None) self.adapter._client = mock_client result = await self.adapter.send_read_receipt("!room:ex", "$event1") assert result is True - mock_client.room_read_markers.assert_called_once() + mock_client.set_read_markers.assert_called_once() @pytest.mark.asyncio async def test_read_receipt_no_client(self): @@ -2120,23 +1643,20 @@ class TestMatrixRedaction: @pytest.mark.asyncio async def test_redact_message(self): - fake_nio = _make_fake_nio() + """redact_message should call client.redact().""" mock_client = MagicMock() - mock_client.room_redact = AsyncMock( - return_value=fake_nio.RoomRedactResponse() - ) + # mautrix redact() returns EventID string + mock_client.redact = AsyncMock(return_value="$redact_event") self.adapter._client = mock_client - with patch.dict("sys.modules", {"nio": fake_nio}): - result = await self.adapter.redact_message("!room:ex", "$ev1", "oops") + result = await self.adapter.redact_message("!room:ex", "$ev1", "oops") assert result is True - mock_client.room_redact.assert_called_once() + mock_client.redact.assert_called_once() @pytest.mark.asyncio async def test_redact_no_client(self): self.adapter._client = None - with patch.dict("sys.modules", {"nio": _make_fake_nio()}): - result = await self.adapter.redact_message("!room:ex", "$ev1") + result = await self.adapter.redact_message("!room:ex", "$ev1") assert result is False @@ -2150,35 +1670,30 @@ class TestMatrixRoomManagement: @pytest.mark.asyncio async def test_create_room(self): - fake_nio = _make_fake_nio() - mock_resp = fake_nio.RoomCreateResponse(room_id="!new:example.org") + """create_room should call client.create_room() returning RoomID string.""" mock_client = MagicMock() - mock_client.room_create = AsyncMock(return_value=mock_resp) + # mautrix create_room returns RoomID string directly + mock_client.create_room = AsyncMock(return_value="!new:example.org") self.adapter._client = mock_client - with patch.dict("sys.modules", {"nio": fake_nio}): - room_id = await self.adapter.create_room(name="Test Room", topic="A test") + room_id = await self.adapter.create_room(name="Test Room", topic="A test") assert room_id == "!new:example.org" assert "!new:example.org" in self.adapter._joined_rooms @pytest.mark.asyncio async def test_invite_user(self): - fake_nio = _make_fake_nio() + """invite_user should call client.invite_user().""" mock_client = MagicMock() - mock_client.room_invite = AsyncMock( - return_value=fake_nio.RoomInviteResponse() - ) + mock_client.invite_user = AsyncMock(return_value=None) self.adapter._client = mock_client - with patch.dict("sys.modules", {"nio": fake_nio}): - result = await self.adapter.invite_user("!room:ex", "@user:ex") + result = await self.adapter.invite_user("!room:ex", "@user:ex") assert result is True @pytest.mark.asyncio async def test_create_room_no_client(self): self.adapter._client = None - with patch.dict("sys.modules", {"nio": _make_fake_nio()}): - result = await self.adapter.create_room() + result = await self.adapter.create_room() assert result is None @@ -2224,35 +1739,35 @@ class TestMatrixMessageTypes: @pytest.mark.asyncio async def test_send_emote(self): - fake_nio = _make_fake_nio() + """send_emote should call send_message_event with m.emote.""" mock_client = MagicMock() - mock_resp = fake_nio.RoomSendResponse(event_id="$emote1") - mock_client.room_send = AsyncMock(return_value=mock_resp) + # mautrix returns EventID string directly + mock_client.send_message_event = AsyncMock(return_value="$emote1") self.adapter._client = mock_client - with patch.dict("sys.modules", {"nio": fake_nio}): - result = await self.adapter.send_emote("!room:ex", "waves hello") + result = await self.adapter.send_emote("!room:ex", "waves hello") assert result.success is True - call_args = mock_client.room_send.call_args[0] - assert call_args[2]["msgtype"] == "m.emote" + assert result.message_id == "$emote1" + call_args = mock_client.send_message_event.call_args + content = call_args.args[2] if len(call_args.args) > 2 else call_args.kwargs.get("content") + assert content["msgtype"] == "m.emote" @pytest.mark.asyncio async def test_send_notice(self): - fake_nio = _make_fake_nio() + """send_notice should call send_message_event with m.notice.""" mock_client = MagicMock() - mock_resp = fake_nio.RoomSendResponse(event_id="$notice1") - mock_client.room_send = AsyncMock(return_value=mock_resp) + mock_client.send_message_event = AsyncMock(return_value="$notice1") self.adapter._client = mock_client - with patch.dict("sys.modules", {"nio": fake_nio}): - result = await self.adapter.send_notice("!room:ex", "System message") + result = await self.adapter.send_notice("!room:ex", "System message") assert result.success is True - call_args = mock_client.room_send.call_args[0] - assert call_args[2]["msgtype"] == "m.notice" + assert result.message_id == "$notice1" + call_args = mock_client.send_message_event.call_args + content = call_args.args[2] if len(call_args.args) > 2 else call_args.kwargs.get("content") + assert content["msgtype"] == "m.notice" @pytest.mark.asyncio async def test_send_emote_empty_text(self): self.adapter._client = MagicMock() - with patch.dict("sys.modules", {"nio": _make_fake_nio()}): - result = await self.adapter.send_emote("!room:ex", "") + result = await self.adapter.send_emote("!room:ex", "") assert result.success is False diff --git a/tests/gateway/test_matrix_mention.py b/tests/gateway/test_matrix_mention.py index 215d8ab52..c0533741a 100644 --- a/tests/gateway/test_matrix_mention.py +++ b/tests/gateway/test_matrix_mention.py @@ -11,24 +11,59 @@ import pytest from gateway.config import PlatformConfig -def _ensure_nio_mock(): - """Install a mock nio module when matrix-nio isn't available.""" - if "nio" in sys.modules and hasattr(sys.modules["nio"], "__file__"): +def _ensure_mautrix_mock(): + """Install mock mautrix modules when mautrix-python isn't available.""" + if "mautrix" in sys.modules and hasattr(sys.modules["mautrix"], "__file__"): return - nio_mod = MagicMock() - nio_mod.MegolmEvent = type("MegolmEvent", (), {}) - nio_mod.RoomMessageText = type("RoomMessageText", (), {}) - nio_mod.RoomMessageImage = type("RoomMessageImage", (), {}) - nio_mod.RoomMessageAudio = type("RoomMessageAudio", (), {}) - nio_mod.RoomMessageVideo = type("RoomMessageVideo", (), {}) - nio_mod.RoomMessageFile = type("RoomMessageFile", (), {}) - nio_mod.DownloadResponse = type("DownloadResponse", (), {}) - nio_mod.MemoryDownloadResponse = type("MemoryDownloadResponse", (), {}) - nio_mod.InviteMemberEvent = type("InviteMemberEvent", (), {}) - sys.modules.setdefault("nio", nio_mod) + + # Root module + mautrix_mod = MagicMock() + + # mautrix.types — commonly imported types + types_mod = MagicMock() + types_mod.EventType = MagicMock() + types_mod.RoomID = str + types_mod.UserID = str + types_mod.EventID = str + types_mod.ContentURI = str + types_mod.RoomCreatePreset = MagicMock() + types_mod.PresenceState = MagicMock() + types_mod.PaginationDirection = MagicMock() + types_mod.SyncToken = str + types_mod.TrustState = MagicMock() + + # mautrix.client + client_mod = MagicMock() + client_mod.Client = MagicMock() + client_mod.InternalEventType = MagicMock() + + # mautrix.client.state_store + state_store_mod = MagicMock() + state_store_mod.MemoryStateStore = MagicMock() + state_store_mod.MemorySyncStore = MagicMock() + + # mautrix.api + api_mod = MagicMock() + api_mod.HTTPAPI = MagicMock() + + # mautrix.crypto + crypto_mod = MagicMock() + crypto_mod.OlmMachine = MagicMock() + crypto_store_mod = MagicMock() + crypto_store_mod.MemoryCryptoStore = MagicMock() + crypto_attachments_mod = MagicMock() + + sys.modules.setdefault("mautrix", mautrix_mod) + sys.modules.setdefault("mautrix.types", types_mod) + sys.modules.setdefault("mautrix.client", client_mod) + sys.modules.setdefault("mautrix.client.state_store", state_store_mod) + sys.modules.setdefault("mautrix.api", api_mod) + sys.modules.setdefault("mautrix.crypto", crypto_mod) + sys.modules.setdefault("mautrix.crypto.store", crypto_store_mod) + sys.modules.setdefault("mautrix.crypto.attachments", crypto_attachments_mod) -_ensure_nio_mock() +_ensure_mautrix_mock() def _make_adapter(tmp_path=None): @@ -50,24 +85,25 @@ def _make_adapter(tmp_path=None): return adapter -def _make_room(room_id="!room1:example.org", member_count=5, is_dm=False): - """Create a fake Matrix room.""" - room = SimpleNamespace( - room_id=room_id, - member_count=member_count, - users={}, - ) - return room +def _set_dm(adapter, room_id="!room1:example.org", is_dm=True): + """Mark a room as DM (or not) in the adapter's cache.""" + adapter._dm_rooms[room_id] = is_dm def _make_event( body, sender="@alice:example.org", event_id="$evt1", + room_id="!room1:example.org", formatted_body=None, thread_id=None, ): - """Create a fake RoomMessageText event.""" + """Create a fake room message event. + + The mautrix adapter reads ``event.room_id``, ``event.sender``, + ``event.event_id``, ``event.timestamp``, and ``event.content`` + (a dict with ``msgtype``, ``body``, etc.). + """ content = {"body": body, "msgtype": "m.text"} if formatted_body: content["formatted_body"] = formatted_body @@ -83,9 +119,9 @@ def _make_event( return SimpleNamespace( sender=sender, event_id=event_id, - server_timestamp=int(time.time() * 1000), - body=body, - source={"content": content}, + room_id=room_id, + timestamp=int(time.time() * 1000), + content=content, ) @@ -152,10 +188,9 @@ async def test_require_mention_default_ignores_unmentioned(monkeypatch): monkeypatch.delenv("MATRIX_AUTO_THREAD", raising=False) adapter = _make_adapter() - room = _make_room() event = _make_event("hello everyone") - await adapter._on_room_message(room, event) + await adapter._on_room_message(event) adapter.handle_message.assert_not_awaited() @@ -167,10 +202,9 @@ async def test_require_mention_default_processes_mentioned(monkeypatch): monkeypatch.setenv("MATRIX_AUTO_THREAD", "false") adapter = _make_adapter() - room = _make_room() event = _make_event("@hermes:example.org help me") - await adapter._on_room_message(room, event) + await adapter._on_room_message(event) adapter.handle_message.assert_awaited_once() msg = adapter.handle_message.await_args.args[0] assert msg.text == "help me" @@ -184,11 +218,10 @@ async def test_require_mention_html_pill(monkeypatch): monkeypatch.setenv("MATRIX_AUTO_THREAD", "false") adapter = _make_adapter() - room = _make_room() formatted = 'Hermes help' event = _make_event("Hermes help", formatted_body=formatted) - await adapter._on_room_message(room, event) + await adapter._on_room_message(event) adapter.handle_message.assert_awaited_once() @@ -200,11 +233,11 @@ async def test_require_mention_dm_always_responds(monkeypatch): monkeypatch.setenv("MATRIX_AUTO_THREAD", "false") adapter = _make_adapter() - # member_count=2 triggers DM detection - room = _make_room(member_count=2) + # Mark the room as a DM via the adapter's cache. + _set_dm(adapter) event = _make_event("hello without mention") - await adapter._on_room_message(room, event) + await adapter._on_room_message(event) adapter.handle_message.assert_awaited_once() @@ -216,10 +249,10 @@ async def test_dm_strips_mention(monkeypatch): monkeypatch.setenv("MATRIX_AUTO_THREAD", "false") adapter = _make_adapter() - room = _make_room(member_count=2) + _set_dm(adapter) event = _make_event("@hermes:example.org help me") - await adapter._on_room_message(room, event) + await adapter._on_room_message(event) adapter.handle_message.assert_awaited_once() msg = adapter.handle_message.await_args.args[0] assert msg.text == "help me" @@ -233,10 +266,9 @@ async def test_bare_mention_passes_empty_string(monkeypatch): monkeypatch.setenv("MATRIX_AUTO_THREAD", "false") adapter = _make_adapter() - room = _make_room() event = _make_event("@hermes:example.org") - await adapter._on_room_message(room, event) + await adapter._on_room_message(event) adapter.handle_message.assert_awaited_once() msg = adapter.handle_message.await_args.args[0] assert msg.text == "" @@ -250,10 +282,9 @@ async def test_require_mention_free_response_room(monkeypatch): monkeypatch.setenv("MATRIX_AUTO_THREAD", "false") adapter = _make_adapter() - room = _make_room(room_id="!room1:example.org") - event = _make_event("hello without mention") + event = _make_event("hello without mention", room_id="!room1:example.org") - await adapter._on_room_message(room, event) + await adapter._on_room_message(event) adapter.handle_message.assert_awaited_once() @@ -267,10 +298,9 @@ async def test_require_mention_bot_participated_thread(monkeypatch): adapter = _make_adapter() adapter._bot_participated_threads.add("$thread1") - room = _make_room() event = _make_event("hello without mention", thread_id="$thread1") - await adapter._on_room_message(room, event) + await adapter._on_room_message(event) adapter.handle_message.assert_awaited_once() @@ -282,10 +312,9 @@ async def test_require_mention_disabled(monkeypatch): monkeypatch.setenv("MATRIX_AUTO_THREAD", "false") adapter = _make_adapter() - room = _make_room() event = _make_event("hello without mention") - await adapter._on_room_message(room, event) + await adapter._on_room_message(event) adapter.handle_message.assert_awaited_once() msg = adapter.handle_message.await_args.args[0] assert msg.text == "hello without mention" @@ -303,10 +332,9 @@ async def test_auto_thread_default_creates_thread(monkeypatch): monkeypatch.delenv("MATRIX_AUTO_THREAD", raising=False) adapter = _make_adapter() - room = _make_room() event = _make_event("hello", event_id="$msg1") - await adapter._on_room_message(room, event) + await adapter._on_room_message(event) adapter.handle_message.assert_awaited_once() msg = adapter.handle_message.await_args.args[0] assert msg.source.thread_id == "$msg1" @@ -320,10 +348,9 @@ async def test_auto_thread_preserves_existing_thread(monkeypatch): adapter = _make_adapter() adapter._bot_participated_threads.add("$thread_root") - room = _make_room() event = _make_event("reply in thread", thread_id="$thread_root") - await adapter._on_room_message(room, event) + await adapter._on_room_message(event) adapter.handle_message.assert_awaited_once() msg = adapter.handle_message.await_args.args[0] assert msg.source.thread_id == "$thread_root" @@ -336,10 +363,10 @@ async def test_auto_thread_skips_dm(monkeypatch): monkeypatch.delenv("MATRIX_AUTO_THREAD", raising=False) adapter = _make_adapter() - room = _make_room(member_count=2) + _set_dm(adapter) event = _make_event("hello dm", event_id="$dm1") - await adapter._on_room_message(room, event) + await adapter._on_room_message(event) adapter.handle_message.assert_awaited_once() msg = adapter.handle_message.await_args.args[0] assert msg.source.thread_id is None @@ -352,10 +379,9 @@ async def test_auto_thread_disabled(monkeypatch): monkeypatch.setenv("MATRIX_AUTO_THREAD", "false") adapter = _make_adapter() - room = _make_room() event = _make_event("hello", event_id="$msg1") - await adapter._on_room_message(room, event) + await adapter._on_room_message(event) adapter.handle_message.assert_awaited_once() msg = adapter.handle_message.await_args.args[0] assert msg.source.thread_id is None @@ -368,11 +394,10 @@ async def test_auto_thread_tracks_participation(monkeypatch): monkeypatch.delenv("MATRIX_AUTO_THREAD", raising=False) adapter = _make_adapter() - room = _make_room() event = _make_event("hello", event_id="$msg1") with patch.object(adapter, "_save_participated_threads"): - await adapter._on_room_message(room, event) + await adapter._on_room_message(event) assert "$msg1" in adapter._bot_participated_threads @@ -448,10 +473,10 @@ async def test_dm_mention_thread_disabled_by_default(monkeypatch): monkeypatch.setenv("MATRIX_AUTO_THREAD", "false") adapter = _make_adapter() - room = _make_room(member_count=2) + _set_dm(adapter) event = _make_event("@hermes:example.org help me", event_id="$dm1") - await adapter._on_room_message(room, event) + await adapter._on_room_message(event) adapter.handle_message.assert_awaited_once() msg = adapter.handle_message.await_args.args[0] assert msg.source.thread_id is None @@ -464,11 +489,11 @@ async def test_dm_mention_thread_creates_thread(monkeypatch): monkeypatch.setenv("MATRIX_AUTO_THREAD", "false") adapter = _make_adapter() - room = _make_room(member_count=2) + _set_dm(adapter) event = _make_event("@hermes:example.org help me", event_id="$dm1") with patch.object(adapter, "_save_participated_threads"): - await adapter._on_room_message(room, event) + await adapter._on_room_message(event) adapter.handle_message.assert_awaited_once() msg = adapter.handle_message.await_args.args[0] @@ -483,10 +508,10 @@ async def test_dm_mention_thread_no_mention_no_thread(monkeypatch): monkeypatch.setenv("MATRIX_AUTO_THREAD", "false") adapter = _make_adapter() - room = _make_room(member_count=2) + _set_dm(adapter) event = _make_event("hello without mention", event_id="$dm1") - await adapter._on_room_message(room, event) + await adapter._on_room_message(event) adapter.handle_message.assert_awaited_once() msg = adapter.handle_message.await_args.args[0] assert msg.source.thread_id is None @@ -499,11 +524,11 @@ async def test_dm_mention_thread_preserves_existing_thread(monkeypatch): monkeypatch.setenv("MATRIX_AUTO_THREAD", "false") adapter = _make_adapter() + _set_dm(adapter) adapter._bot_participated_threads.add("$existing_thread") - room = _make_room(member_count=2) event = _make_event("@hermes:example.org help me", thread_id="$existing_thread") - await adapter._on_room_message(room, event) + await adapter._on_room_message(event) adapter.handle_message.assert_awaited_once() msg = adapter.handle_message.await_args.args[0] assert msg.source.thread_id == "$existing_thread" @@ -516,11 +541,11 @@ async def test_dm_mention_thread_tracks_participation(monkeypatch): monkeypatch.setenv("MATRIX_AUTO_THREAD", "false") adapter = _make_adapter() - room = _make_room(member_count=2) + _set_dm(adapter) event = _make_event("@hermes:example.org help", event_id="$dm1") with patch.object(adapter, "_save_participated_threads"): - await adapter._on_room_message(room, event) + await adapter._on_room_message(event) assert "$dm1" in adapter._bot_participated_threads diff --git a/tests/gateway/test_matrix_voice.py b/tests/gateway/test_matrix_voice.py index 93d56caf1..dab113c5d 100644 --- a/tests/gateway/test_matrix_voice.py +++ b/tests/gateway/test_matrix_voice.py @@ -1,18 +1,23 @@ -"""Tests for Matrix voice message support (MSC3245).""" +"""Tests for Matrix voice message support (MSC3245). + +Updated for the mautrix-python SDK (no more matrix-nio / nio imports). +""" import io +import os +import tempfile import types +from types import SimpleNamespace import pytest from unittest.mock import AsyncMock, MagicMock, patch -# Try importing real nio; skip entire file if not available. -# A MagicMock in sys.modules (from another test) is not the real package. +# Try importing mautrix; skip entire file if not available. try: - import nio as _nio_probe - if not isinstance(_nio_probe, types.ModuleType) or not hasattr(_nio_probe, "__file__"): - pytest.skip("nio in sys.modules is a mock, not the real package", allow_module_level=True) + import mautrix as _mautrix_probe + if not isinstance(_mautrix_probe, types.ModuleType) or not hasattr(_mautrix_probe, "__file__"): + pytest.skip("mautrix in sys.modules is a mock, not the real package", allow_module_level=True) except ImportError: - pytest.skip("matrix-nio not installed", allow_module_level=True) + pytest.skip("mautrix not installed", allow_module_level=True) from gateway.platforms.base import MessageType @@ -25,7 +30,7 @@ def _make_adapter(): """Create a MatrixAdapter with mocked config.""" from gateway.platforms.matrix import MatrixAdapter from gateway.config import PlatformConfig - + config = PlatformConfig( enabled=True, token="***", @@ -38,32 +43,26 @@ def _make_adapter(): return adapter -def _make_room(room_id: str = "!test:example.org", member_count: int = 2): - """Create a mock Matrix room.""" - room = MagicMock() - room.room_id = room_id - room.member_count = member_count - return room - - def _make_audio_event( event_id: str = "$audio_event", sender: str = "@alice:example.org", + room_id: str = "!test:example.org", body: str = "Voice message", url: str = "mxc://example.org/abc123", is_voice: bool = False, mimetype: str = "audio/ogg", - timestamp: float = 9999999999000, # ms + timestamp: int = 9999999999000, # ms ): """ - Create a mock RoomMessageAudio event that passes isinstance checks. - + Create a mock mautrix room message event. + + In mautrix, the handler receives a single event object with attributes + ``room_id``, ``sender``, ``event_id``, ``timestamp``, and ``content`` + (a dict-like or serializable object). + Args: - is_voice: If True, adds org.matrix.msc3245.voice field to content + is_voice: If True, adds org.matrix.msc3245.voice field to content. """ - import nio - - # Build the source dict that nio events expose via .source content = { "msgtype": "m.audio", "body": body, @@ -72,39 +71,35 @@ def _make_audio_event( "mimetype": mimetype, }, } - + if is_voice: content["org.matrix.msc3245.voice"] = {} - - # Create a real nio RoomMessageAudio-like object - # We use MagicMock but configure __class__ to pass isinstance check - event = MagicMock(spec=nio.RoomMessageAudio) - event.event_id = event_id - event.sender = sender - event.body = body - event.url = url - event.server_timestamp = timestamp - event.source = { - "type": "m.room.message", - "content": content, - } - # For MIME type extraction - needs to be a dict - event.content = content - + + event = SimpleNamespace( + event_id=event_id, + sender=sender, + room_id=room_id, + timestamp=timestamp, + content=content, + ) return event -def _make_download_response(body: bytes = b"fake audio data"): - """Create a mock nio.MemoryDownloadResponse.""" - import nio - resp = MagicMock() - resp.body = body - resp.__class__ = nio.MemoryDownloadResponse - return resp +def _make_state_store(member_count: int = 2): + """Create a mock state store with get_members/get_member support.""" + store = MagicMock() + # get_members returns a list of member user IDs + members = [MagicMock() for _ in range(member_count)] + store.get_members = AsyncMock(return_value=members) + # get_member returns a single member info object + member = MagicMock() + member.displayname = "Alice" + store.get_member = AsyncMock(return_value=member) + return store # --------------------------------------------------------------------------- -# Tests: MSC3245 Voice Detection (RED -> GREEN) +# Tests: MSC3245 Voice Detection # --------------------------------------------------------------------------- class TestMatrixVoiceMessageDetection: @@ -118,27 +113,28 @@ class TestMatrixVoiceMessageDetection: self.adapter._message_handler = AsyncMock() # Mock _mxc_to_http to return a fake HTTP URL self.adapter._mxc_to_http = lambda url: f"https://matrix.example.org/_matrix/media/v3/download/{url[6:]}" - # Mock client for authenticated download + # Mock client for authenticated download — download_media returns bytes directly self.adapter._client = MagicMock() - self.adapter._client.download = AsyncMock(return_value=_make_download_response()) + self.adapter._client.download_media = AsyncMock(return_value=b"fake audio data") + # State store for DM detection + self.adapter._client.state_store = _make_state_store() @pytest.mark.asyncio async def test_voice_message_has_type_voice(self): """Voice messages (with MSC3245 field) should be MessageType.VOICE.""" - room = _make_room() event = _make_audio_event(is_voice=True) - + # Capture the MessageEvent passed to handle_message captured_event = None - + async def capture(msg_event): nonlocal captured_event captured_event = msg_event - + self.adapter.handle_message = capture - - await self.adapter._on_room_message_media(room, event) - + + await self.adapter._on_room_message(event) + assert captured_event is not None, "No event was captured" assert captured_event.message_type == MessageType.VOICE, \ f"Expected MessageType.VOICE, got {captured_event.message_type}" @@ -146,44 +142,43 @@ class TestMatrixVoiceMessageDetection: @pytest.mark.asyncio async def test_voice_message_has_local_path(self): """Voice messages should have a local cached path in media_urls.""" - room = _make_room() event = _make_audio_event(is_voice=True) - + captured_event = None - + async def capture(msg_event): nonlocal captured_event captured_event = msg_event - + self.adapter.handle_message = capture - - await self.adapter._on_room_message_media(room, event) - + + await self.adapter._on_room_message(event) + assert captured_event is not None assert captured_event.media_urls is not None assert len(captured_event.media_urls) > 0 # Should be a local path, not an HTTP URL assert not captured_event.media_urls[0].startswith("http"), \ f"media_urls should contain local path, got {captured_event.media_urls[0]}" - self.adapter._client.download.assert_awaited_once_with(mxc=event.url) + # download_media is called with a ContentURI wrapping the mxc URL + self.adapter._client.download_media.assert_awaited_once() assert captured_event.media_types == ["audio/ogg"] @pytest.mark.asyncio async def test_audio_without_msc3245_stays_audio_type(self): """Regular audio uploads (no MSC3245 field) should remain MessageType.AUDIO.""" - room = _make_room() event = _make_audio_event(is_voice=False) # NOT a voice message - + captured_event = None - + async def capture(msg_event): nonlocal captured_event captured_event = msg_event - + self.adapter.handle_message = capture - - await self.adapter._on_room_message_media(room, event) - + + await self.adapter._on_room_message(event) + assert captured_event is not None assert captured_event.message_type == MessageType.AUDIO, \ f"Expected MessageType.AUDIO for non-voice, got {captured_event.message_type}" @@ -191,25 +186,24 @@ class TestMatrixVoiceMessageDetection: @pytest.mark.asyncio async def test_regular_audio_has_http_url(self): """Regular audio uploads should keep HTTP URL (not cached locally).""" - room = _make_room() event = _make_audio_event(is_voice=False) - + captured_event = None - + async def capture(msg_event): nonlocal captured_event captured_event = msg_event - + self.adapter.handle_message = capture - - await self.adapter._on_room_message_media(room, event) - + + await self.adapter._on_room_message(event) + assert captured_event is not None assert captured_event.media_urls is not None # Should be HTTP URL, not local path assert captured_event.media_urls[0].startswith("http"), \ f"Non-voice audio should have HTTP URL, got {captured_event.media_urls[0]}" - self.adapter._client.download.assert_not_awaited() + self.adapter._client.download_media.assert_not_awaited() assert captured_event.media_types == ["audio/ogg"] @@ -224,29 +218,26 @@ class TestMatrixVoiceCacheFallback: self.adapter._message_handler = AsyncMock() self.adapter._mxc_to_http = lambda url: f"https://matrix.example.org/_matrix/media/v3/download/{url[6:]}" self.adapter._client = MagicMock() + self.adapter._client.state_store = _make_state_store() @pytest.mark.asyncio async def test_voice_cache_failure_falls_back_to_http_url(self): - """If caching fails, voice message should still be delivered with HTTP URL.""" - room = _make_room() + """If caching fails (download returns None), voice message should still be delivered with HTTP URL.""" event = _make_audio_event(is_voice=True) - - # Make download fail - import nio - error_resp = MagicMock() - error_resp.__class__ = nio.DownloadError - self.adapter._client.download = AsyncMock(return_value=error_resp) - + + # download_media returns None on failure + self.adapter._client.download_media = AsyncMock(return_value=None) + captured_event = None - + async def capture(msg_event): nonlocal captured_event captured_event = msg_event - + self.adapter.handle_message = capture - - await self.adapter._on_room_message_media(room, event) - + + await self.adapter._on_room_message(event) + assert captured_event is not None assert captured_event.media_urls is not None # Should fall back to HTTP URL @@ -256,10 +247,9 @@ class TestMatrixVoiceCacheFallback: @pytest.mark.asyncio async def test_voice_cache_exception_falls_back_to_http_url(self): """Unexpected download exceptions should also fall back to HTTP URL.""" - room = _make_room() event = _make_audio_event(is_voice=True) - self.adapter._client.download = AsyncMock(side_effect=RuntimeError("boom")) + self.adapter._client.download_media = AsyncMock(side_effect=RuntimeError("boom")) captured_event = None @@ -269,7 +259,7 @@ class TestMatrixVoiceCacheFallback: self.adapter.handle_message = capture - await self.adapter._on_room_message_media(room, event) + await self.adapter._on_room_message(event) assert captured_event is not None assert captured_event.media_urls is not None @@ -278,7 +268,7 @@ class TestMatrixVoiceCacheFallback: # --------------------------------------------------------------------------- -# Tests: send_voice includes MSC3245 field (RED -> GREEN) +# Tests: send_voice includes MSC3245 field # --------------------------------------------------------------------------- class TestMatrixSendVoiceMSC3245: @@ -287,62 +277,52 @@ class TestMatrixSendVoiceMSC3245: def setup_method(self): self.adapter = _make_adapter() self.adapter._user_id = "@bot:example.org" - # Mock client with successful upload + # Mock client — upload_media returns a ContentURI string self.adapter._client = MagicMock() self.upload_call = None - async def mock_upload(*args, **kwargs): - self.upload_call = (args, kwargs) - import nio - resp = MagicMock() - resp.content_uri = "mxc://example.org/uploaded" - resp.__class__ = nio.UploadResponse - return resp, None + async def mock_upload_media(data, mime_type=None, filename=None, **kwargs): + self.upload_call = {"data": data, "mime_type": mime_type, "filename": filename} + return "mxc://example.org/uploaded" - self.adapter._client.upload = mock_upload + self.adapter._client.upload_media = mock_upload_media @pytest.mark.asyncio - async def test_send_voice_includes_msc3245_field(self): + @patch("mimetypes.guess_type", return_value=("audio/ogg", None)) + async def test_send_voice_includes_msc3245_field(self, _mock_guess): """send_voice should include org.matrix.msc3245.voice in message content.""" - import tempfile - import os - # Create a temp audio file with tempfile.NamedTemporaryFile(suffix=".ogg", delete=False) as f: f.write(b"fake audio data") temp_path = f.name - + try: - # Capture the message content sent to room_send + # Capture the message content sent via send_message_event sent_content = None - - async def mock_room_send(room_id, event_type, content): + + async def mock_send_message_event(room_id, event_type, content): nonlocal sent_content sent_content = content - resp = MagicMock() - resp.event_id = "$sent_event" - import nio - resp.__class__ = nio.RoomSendResponse - return resp - - self.adapter._client.room_send = mock_room_send - + # send_message_event returns an EventID string + return "$sent_event" + + self.adapter._client.send_message_event = mock_send_message_event + await self.adapter.send_voice( chat_id="!room:example.org", audio_path=temp_path, caption="Test voice", ) - + assert sent_content is not None, "No message was sent" assert "org.matrix.msc3245.voice" in sent_content, \ f"MSC3245 voice field missing from content: {sent_content.keys()}" assert sent_content["msgtype"] == "m.audio" assert sent_content["info"]["mimetype"] == "audio/ogg" - assert self.upload_call is not None, "Expected upload() to be called" - args, kwargs = self.upload_call - assert isinstance(args[0], io.BytesIO) - assert kwargs["content_type"] == "audio/ogg" - assert kwargs["filename"].endswith(".ogg") + assert self.upload_call is not None, "Expected upload_media() to be called" + assert isinstance(self.upload_call["data"], bytes) + assert self.upload_call["mime_type"] == "audio/ogg" + assert self.upload_call["filename"].endswith(".ogg") finally: os.unlink(temp_path)