test: add comprehensive tests for Mattermost and Matrix adapters

77 tests covering:

Mattermost (37 tests):
- Platform enum and config loading
- Message formatting (image markdown stripping)
- Message chunking at 4000 chars
- Send with mocked aiohttp (payload, threading, errors)
- WebSocket event parsing (double-encoded JSON!)
- File upload flow
- Post dedup cache (TTL, pruning)
- Requirements check

Matrix (40 tests):
- Platform enum and config loading (token + password auth, E2EE)
- mxc:// to HTTP URL conversion (authenticated v1.11+ endpoint)
- DM detection via m.direct cache
- Reply fallback stripping
- Thread detection from m.relates_to
- Message formatting and markdown to HTML
- Display name resolution
- Requirements check
This commit is contained in:
teknium1
2026-03-17 03:07:13 -07:00
parent cd67f60e01
commit c3ce6108e3
2 changed files with 1022 additions and 0 deletions

View File

@@ -0,0 +1,448 @@
"""Tests for Matrix platform adapter."""
import json
import re
import pytest
from unittest.mock import MagicMock, patch, AsyncMock
from gateway.config import Platform, PlatformConfig
# ---------------------------------------------------------------------------
# Platform & Config
# ---------------------------------------------------------------------------
class TestMatrixPlatformEnum:
def test_matrix_enum_exists(self):
assert Platform.MATRIX.value == "matrix"
def test_matrix_in_platform_list(self):
platforms = [p.value for p in Platform]
assert "matrix" in platforms
class TestMatrixConfigLoading:
def test_apply_env_overrides_with_access_token(self, monkeypatch):
monkeypatch.setenv("MATRIX_ACCESS_TOKEN", "syt_abc123")
monkeypatch.setenv("MATRIX_HOMESERVER", "https://matrix.example.org")
from gateway.config import GatewayConfig, _apply_env_overrides
config = GatewayConfig()
_apply_env_overrides(config)
assert Platform.MATRIX in config.platforms
mc = config.platforms[Platform.MATRIX]
assert mc.enabled is True
assert mc.token == "syt_abc123"
assert mc.extra.get("homeserver") == "https://matrix.example.org"
def test_apply_env_overrides_with_password(self, monkeypatch):
monkeypatch.delenv("MATRIX_ACCESS_TOKEN", raising=False)
monkeypatch.setenv("MATRIX_PASSWORD", "secret123")
monkeypatch.setenv("MATRIX_HOMESERVER", "https://matrix.example.org")
monkeypatch.setenv("MATRIX_USER_ID", "@bot:example.org")
from gateway.config import GatewayConfig, _apply_env_overrides
config = GatewayConfig()
_apply_env_overrides(config)
assert Platform.MATRIX in config.platforms
mc = config.platforms[Platform.MATRIX]
assert mc.enabled is True
assert mc.extra.get("password") == "secret123"
assert mc.extra.get("user_id") == "@bot:example.org"
def test_matrix_not_loaded_without_creds(self, monkeypatch):
monkeypatch.delenv("MATRIX_ACCESS_TOKEN", raising=False)
monkeypatch.delenv("MATRIX_PASSWORD", raising=False)
monkeypatch.delenv("MATRIX_HOMESERVER", raising=False)
from gateway.config import GatewayConfig, _apply_env_overrides
config = GatewayConfig()
_apply_env_overrides(config)
assert Platform.MATRIX not in config.platforms
def test_matrix_encryption_flag(self, monkeypatch):
monkeypatch.setenv("MATRIX_ACCESS_TOKEN", "syt_abc123")
monkeypatch.setenv("MATRIX_HOMESERVER", "https://matrix.example.org")
monkeypatch.setenv("MATRIX_ENCRYPTION", "true")
from gateway.config import GatewayConfig, _apply_env_overrides
config = GatewayConfig()
_apply_env_overrides(config)
mc = config.platforms[Platform.MATRIX]
assert mc.extra.get("encryption") is True
def test_matrix_encryption_default_off(self, monkeypatch):
monkeypatch.setenv("MATRIX_ACCESS_TOKEN", "syt_abc123")
monkeypatch.setenv("MATRIX_HOMESERVER", "https://matrix.example.org")
monkeypatch.delenv("MATRIX_ENCRYPTION", raising=False)
from gateway.config import GatewayConfig, _apply_env_overrides
config = GatewayConfig()
_apply_env_overrides(config)
mc = config.platforms[Platform.MATRIX]
assert mc.extra.get("encryption") is False
def test_matrix_home_room(self, monkeypatch):
monkeypatch.setenv("MATRIX_ACCESS_TOKEN", "syt_abc123")
monkeypatch.setenv("MATRIX_HOMESERVER", "https://matrix.example.org")
monkeypatch.setenv("MATRIX_HOME_ROOM", "!room123:example.org")
monkeypatch.setenv("MATRIX_HOME_ROOM_NAME", "Bot Room")
from gateway.config import GatewayConfig, _apply_env_overrides
config = GatewayConfig()
_apply_env_overrides(config)
home = config.get_home_channel(Platform.MATRIX)
assert home is not None
assert home.chat_id == "!room123:example.org"
assert home.name == "Bot Room"
def test_matrix_user_id_stored_in_extra(self, monkeypatch):
monkeypatch.setenv("MATRIX_ACCESS_TOKEN", "syt_abc123")
monkeypatch.setenv("MATRIX_HOMESERVER", "https://matrix.example.org")
monkeypatch.setenv("MATRIX_USER_ID", "@hermes:example.org")
from gateway.config import GatewayConfig, _apply_env_overrides
config = GatewayConfig()
_apply_env_overrides(config)
mc = config.platforms[Platform.MATRIX]
assert mc.extra.get("user_id") == "@hermes:example.org"
# ---------------------------------------------------------------------------
# Adapter helpers
# ---------------------------------------------------------------------------
def _make_adapter():
"""Create a MatrixAdapter with mocked config."""
from gateway.platforms.matrix import MatrixAdapter
config = PlatformConfig(
enabled=True,
token="syt_test_token",
extra={
"homeserver": "https://matrix.example.org",
"user_id": "@bot:example.org",
},
)
adapter = MatrixAdapter(config)
return adapter
# ---------------------------------------------------------------------------
# mxc:// URL conversion
# ---------------------------------------------------------------------------
class TestMatrixMxcToHttp:
def setup_method(self):
self.adapter = _make_adapter()
def test_basic_mxc_conversion(self):
"""mxc://server/media_id should become an authenticated HTTP URL."""
mxc = "mxc://matrix.org/abc123"
result = self.adapter._mxc_to_http(mxc)
assert result == "https://matrix.example.org/_matrix/client/v1/media/download/matrix.org/abc123"
def test_mxc_with_different_server(self):
"""mxc:// from a different server should still use our homeserver."""
mxc = "mxc://other.server/media456"
result = self.adapter._mxc_to_http(mxc)
assert result.startswith("https://matrix.example.org/")
assert "other.server/media456" in result
def test_non_mxc_url_passthrough(self):
"""Non-mxc URLs should be returned unchanged."""
url = "https://example.com/image.png"
assert self.adapter._mxc_to_http(url) == url
def test_mxc_uses_client_v1_endpoint(self):
"""Should use /_matrix/client/v1/media/download/ not the deprecated path."""
mxc = "mxc://example.com/test123"
result = self.adapter._mxc_to_http(mxc)
assert "/_matrix/client/v1/media/download/" in result
assert "/_matrix/media/v3/download/" not in result
# ---------------------------------------------------------------------------
# DM detection
# ---------------------------------------------------------------------------
class TestMatrixDmDetection:
def setup_method(self):
self.adapter = _make_adapter()
def test_room_in_m_direct_is_dm(self):
"""A room listed in m.direct should be detected as DM."""
self.adapter._joined_rooms = {"!dm_room:ex.org", "!group_room:ex.org"}
self.adapter._dm_rooms = {
"!dm_room:ex.org": True,
"!group_room:ex.org": False,
}
assert self.adapter._dm_rooms.get("!dm_room:ex.org") is True
assert self.adapter._dm_rooms.get("!group_room:ex.org") is False
def test_unknown_room_not_in_cache(self):
"""Unknown rooms should not be in the DM cache."""
self.adapter._dm_rooms = {}
assert self.adapter._dm_rooms.get("!unknown:ex.org") is None
@pytest.mark.asyncio
async def test_refresh_dm_cache_with_m_direct(self):
"""_refresh_dm_cache should populate _dm_rooms from m.direct data."""
self.adapter._joined_rooms = {"!room_a:ex.org", "!room_b:ex.org", "!room_c:ex.org"}
mock_client = MagicMock()
mock_resp = MagicMock()
mock_resp.content = {
"@alice:ex.org": ["!room_a:ex.org"],
"@bob:ex.org": ["!room_b:ex.org"],
}
mock_client.get_account_data = AsyncMock(return_value=mock_resp)
self.adapter._client = mock_client
await self.adapter._refresh_dm_cache()
assert self.adapter._dm_rooms["!room_a:ex.org"] is True
assert self.adapter._dm_rooms["!room_b:ex.org"] is True
assert self.adapter._dm_rooms["!room_c:ex.org"] is False
# ---------------------------------------------------------------------------
# Reply fallback stripping
# ---------------------------------------------------------------------------
class TestMatrixReplyFallbackStripping:
"""Test that Matrix reply fallback lines ('> ' prefix) are stripped."""
def setup_method(self):
self.adapter = _make_adapter()
self.adapter._user_id = "@bot:example.org"
self.adapter._startup_ts = 0.0
self.adapter._dm_rooms = {}
self.adapter._message_handler = AsyncMock()
def _strip_fallback(self, body: str, has_reply: bool = True) -> str:
"""Simulate the reply fallback stripping logic from _on_room_message."""
reply_to = "some_event_id" if has_reply else None
if reply_to and body.startswith("> "):
lines = body.split("\n")
stripped = []
past_fallback = False
for line in lines:
if not past_fallback:
if line.startswith("> ") or line == ">":
continue
if line == "":
past_fallback = True
continue
past_fallback = True
stripped.append(line)
body = "\n".join(stripped) if stripped else body
return body
def test_simple_reply_fallback(self):
body = "> <@alice:ex.org> Original message\n\nActual reply"
result = self._strip_fallback(body)
assert result == "Actual reply"
def test_multiline_reply_fallback(self):
body = "> <@alice:ex.org> Line 1\n> Line 2\n\nMy response"
result = self._strip_fallback(body)
assert result == "My response"
def test_no_reply_fallback_preserved(self):
body = "Just a normal message"
result = self._strip_fallback(body, has_reply=False)
assert result == "Just a normal message"
def test_quote_without_reply_preserved(self):
"""'> ' lines without a reply_to context should be preserved."""
body = "> This is a blockquote"
result = self._strip_fallback(body, has_reply=False)
assert result == "> This is a blockquote"
def test_empty_fallback_separator(self):
"""The blank line between fallback and actual content should be stripped."""
body = "> <@alice:ex.org> hi\n>\n\nResponse"
result = self._strip_fallback(body)
assert result == "Response"
def test_multiline_response_after_fallback(self):
body = "> <@alice:ex.org> Original\n\nLine 1\nLine 2\nLine 3"
result = self._strip_fallback(body)
assert result == "Line 1\nLine 2\nLine 3"
# ---------------------------------------------------------------------------
# Thread detection
# ---------------------------------------------------------------------------
class TestMatrixThreadDetection:
def test_thread_id_from_m_relates_to(self):
"""m.relates_to with rel_type=m.thread should extract the event_id."""
relates_to = {
"rel_type": "m.thread",
"event_id": "$thread_root_event",
"is_falling_back": True,
"m.in_reply_to": {"event_id": "$some_event"},
}
# Simulate the extraction logic from _on_room_message
thread_id = None
if relates_to.get("rel_type") == "m.thread":
thread_id = relates_to.get("event_id")
assert thread_id == "$thread_root_event"
def test_no_thread_for_reply(self):
"""m.in_reply_to without m.thread should not set thread_id."""
relates_to = {
"m.in_reply_to": {"event_id": "$reply_event"},
}
thread_id = None
if relates_to.get("rel_type") == "m.thread":
thread_id = relates_to.get("event_id")
assert thread_id is None
def test_no_thread_for_edit(self):
"""m.replace relation should not set thread_id."""
relates_to = {
"rel_type": "m.replace",
"event_id": "$edited_event",
}
thread_id = None
if relates_to.get("rel_type") == "m.thread":
thread_id = relates_to.get("event_id")
assert thread_id is None
def test_empty_relates_to(self):
"""Empty m.relates_to should not set thread_id."""
relates_to = {}
thread_id = None
if relates_to.get("rel_type") == "m.thread":
thread_id = relates_to.get("event_id")
assert thread_id is None
# ---------------------------------------------------------------------------
# Format message
# ---------------------------------------------------------------------------
class TestMatrixFormatMessage:
def setup_method(self):
self.adapter = _make_adapter()
def test_image_markdown_stripped(self):
"""![alt](url) should be converted to just the URL."""
result = self.adapter.format_message("![cat](https://img.example.com/cat.png)")
assert result == "https://img.example.com/cat.png"
def test_regular_markdown_preserved(self):
"""Standard markdown should be preserved (Matrix supports it)."""
content = "**bold** and *italic* and `code`"
assert self.adapter.format_message(content) == content
def test_plain_text_unchanged(self):
content = "Hello, world!"
assert self.adapter.format_message(content) == content
def test_multiple_images_stripped(self):
content = "![a](http://a.com/1.png) and ![b](http://b.com/2.png)"
result = self.adapter.format_message(content)
assert "![" not in result
assert "http://a.com/1.png" in result
assert "http://b.com/2.png" in result
# ---------------------------------------------------------------------------
# Markdown to HTML conversion
# ---------------------------------------------------------------------------
class TestMatrixMarkdownToHtml:
def setup_method(self):
self.adapter = _make_adapter()
def test_bold_conversion(self):
"""**bold** should produce <strong> tags."""
result = self.adapter._markdown_to_html("**bold**")
assert "<strong>" in result or "<b>" in result
assert "bold" in result
def test_italic_conversion(self):
"""*italic* should produce <em> tags."""
result = self.adapter._markdown_to_html("*italic*")
assert "<em>" in result or "<i>" in result
def test_inline_code(self):
"""`code` should produce <code> tags."""
result = self.adapter._markdown_to_html("`code`")
assert "<code>" in result
def test_plain_text_returns_html(self):
"""Plain text should still be returned (possibly with <br> or <p>)."""
result = self.adapter._markdown_to_html("Hello world")
assert "Hello world" in result
# ---------------------------------------------------------------------------
# Helper: display name extraction
# ---------------------------------------------------------------------------
class TestMatrixDisplayName:
def setup_method(self):
self.adapter = _make_adapter()
def test_get_display_name_from_room_users(self):
"""Should get display name from room's users dict."""
mock_room = MagicMock()
mock_user = MagicMock()
mock_user.display_name = "Alice"
mock_room.users = {"@alice:ex.org": mock_user}
name = self.adapter._get_display_name(mock_room, "@alice:ex.org")
assert name == "Alice"
def test_get_display_name_fallback_to_localpart(self):
"""Should extract localpart from @user:server format."""
mock_room = MagicMock()
mock_room.users = {}
name = self.adapter._get_display_name(mock_room, "@bob:example.org")
assert name == "bob"
def test_get_display_name_no_room(self):
"""Should handle None room gracefully."""
name = self.adapter._get_display_name(None, "@charlie:ex.org")
assert name == "charlie"
# ---------------------------------------------------------------------------
# Requirements check
# ---------------------------------------------------------------------------
class TestMatrixRequirements:
def test_check_requirements_with_token(self, monkeypatch):
monkeypatch.setenv("MATRIX_ACCESS_TOKEN", "syt_test")
monkeypatch.setenv("MATRIX_HOMESERVER", "https://matrix.example.org")
from gateway.platforms.matrix import check_matrix_requirements
try:
import nio # noqa: F401
assert check_matrix_requirements() is True
except ImportError:
assert check_matrix_requirements() is False
def test_check_requirements_without_creds(self, monkeypatch):
monkeypatch.delenv("MATRIX_ACCESS_TOKEN", raising=False)
monkeypatch.delenv("MATRIX_PASSWORD", raising=False)
monkeypatch.delenv("MATRIX_HOMESERVER", raising=False)
from gateway.platforms.matrix import check_matrix_requirements
assert check_matrix_requirements() is False
def test_check_requirements_without_homeserver(self, monkeypatch):
monkeypatch.setenv("MATRIX_ACCESS_TOKEN", "syt_test")
monkeypatch.delenv("MATRIX_HOMESERVER", raising=False)
from gateway.platforms.matrix import check_matrix_requirements
assert check_matrix_requirements() is False

View File

@@ -0,0 +1,574 @@
"""Tests for Mattermost platform adapter."""
import json
import time
import pytest
from unittest.mock import MagicMock, patch, AsyncMock
from gateway.config import Platform, PlatformConfig
# ---------------------------------------------------------------------------
# Platform & Config
# ---------------------------------------------------------------------------
class TestMattermostPlatformEnum:
def test_mattermost_enum_exists(self):
assert Platform.MATTERMOST.value == "mattermost"
def test_mattermost_in_platform_list(self):
platforms = [p.value for p in Platform]
assert "mattermost" in platforms
class TestMattermostConfigLoading:
def test_apply_env_overrides_mattermost(self, monkeypatch):
monkeypatch.setenv("MATTERMOST_TOKEN", "mm-tok-abc123")
monkeypatch.setenv("MATTERMOST_URL", "https://mm.example.com")
from gateway.config import GatewayConfig, _apply_env_overrides
config = GatewayConfig()
_apply_env_overrides(config)
assert Platform.MATTERMOST in config.platforms
mc = config.platforms[Platform.MATTERMOST]
assert mc.enabled is True
assert mc.token == "mm-tok-abc123"
assert mc.extra.get("url") == "https://mm.example.com"
def test_mattermost_not_loaded_without_token(self, monkeypatch):
monkeypatch.delenv("MATTERMOST_TOKEN", raising=False)
monkeypatch.delenv("MATTERMOST_URL", raising=False)
from gateway.config import GatewayConfig, _apply_env_overrides
config = GatewayConfig()
_apply_env_overrides(config)
assert Platform.MATTERMOST not in config.platforms
def test_connected_platforms_includes_mattermost(self, monkeypatch):
monkeypatch.setenv("MATTERMOST_TOKEN", "mm-tok-abc123")
monkeypatch.setenv("MATTERMOST_URL", "https://mm.example.com")
from gateway.config import GatewayConfig, _apply_env_overrides
config = GatewayConfig()
_apply_env_overrides(config)
connected = config.get_connected_platforms()
assert Platform.MATTERMOST in connected
def test_mattermost_home_channel(self, monkeypatch):
monkeypatch.setenv("MATTERMOST_TOKEN", "mm-tok-abc123")
monkeypatch.setenv("MATTERMOST_URL", "https://mm.example.com")
monkeypatch.setenv("MATTERMOST_HOME_CHANNEL", "ch_abc123")
monkeypatch.setenv("MATTERMOST_HOME_CHANNEL_NAME", "General")
from gateway.config import GatewayConfig, _apply_env_overrides
config = GatewayConfig()
_apply_env_overrides(config)
home = config.get_home_channel(Platform.MATTERMOST)
assert home is not None
assert home.chat_id == "ch_abc123"
assert home.name == "General"
def test_mattermost_url_warning_without_url(self, monkeypatch):
"""MATTERMOST_TOKEN set but MATTERMOST_URL missing should still load."""
monkeypatch.setenv("MATTERMOST_TOKEN", "mm-tok-abc123")
monkeypatch.delenv("MATTERMOST_URL", raising=False)
from gateway.config import GatewayConfig, _apply_env_overrides
config = GatewayConfig()
_apply_env_overrides(config)
assert Platform.MATTERMOST in config.platforms
assert config.platforms[Platform.MATTERMOST].extra.get("url") == ""
# ---------------------------------------------------------------------------
# Adapter format / truncate
# ---------------------------------------------------------------------------
def _make_adapter():
"""Create a MattermostAdapter with mocked config."""
from gateway.platforms.mattermost import MattermostAdapter
config = PlatformConfig(
enabled=True,
token="test-token",
extra={"url": "https://mm.example.com"},
)
adapter = MattermostAdapter(config)
return adapter
class TestMattermostFormatMessage:
def setup_method(self):
self.adapter = _make_adapter()
def test_image_markdown_to_url(self):
"""![alt](url) should be converted to just the URL."""
result = self.adapter.format_message("![cat](https://img.example.com/cat.png)")
assert result == "https://img.example.com/cat.png"
def test_image_markdown_strips_alt_text(self):
result = self.adapter.format_message("Here: ![my image](https://x.com/a.jpg) done")
assert "![" not in result
assert "https://x.com/a.jpg" in result
def test_regular_markdown_preserved(self):
"""Regular markdown (bold, italic, code) should be kept as-is."""
content = "**bold** and *italic* and `code`"
assert self.adapter.format_message(content) == content
def test_regular_links_preserved(self):
"""Non-image links should be preserved."""
content = "[click](https://example.com)"
assert self.adapter.format_message(content) == content
def test_plain_text_unchanged(self):
content = "Hello, world!"
assert self.adapter.format_message(content) == content
def test_multiple_images(self):
content = "![a](http://a.com/1.png) text ![b](http://b.com/2.png)"
result = self.adapter.format_message(content)
assert "![" not in result
assert "http://a.com/1.png" in result
assert "http://b.com/2.png" in result
class TestMattermostTruncateMessage:
def setup_method(self):
self.adapter = _make_adapter()
def test_short_message_single_chunk(self):
msg = "Hello, world!"
chunks = self.adapter.truncate_message(msg, 4000)
assert len(chunks) == 1
assert chunks[0] == msg
def test_long_message_splits(self):
msg = "a " * 2500 # 5000 chars
chunks = self.adapter.truncate_message(msg, 4000)
assert len(chunks) >= 2
for chunk in chunks:
assert len(chunk) <= 4000
def test_custom_max_length(self):
msg = "Hello " * 20
chunks = self.adapter.truncate_message(msg, max_length=50)
assert all(len(c) <= 50 for c in chunks)
def test_exactly_at_limit(self):
msg = "x" * 4000
chunks = self.adapter.truncate_message(msg, 4000)
assert len(chunks) == 1
# ---------------------------------------------------------------------------
# Send
# ---------------------------------------------------------------------------
class TestMattermostSend:
def setup_method(self):
self.adapter = _make_adapter()
self.adapter._session = MagicMock()
@pytest.mark.asyncio
async def test_send_calls_api_post(self):
"""send() should POST to /api/v4/posts with channel_id and message."""
mock_resp = AsyncMock()
mock_resp.status = 200
mock_resp.json = AsyncMock(return_value={"id": "post123"})
mock_resp.text = AsyncMock(return_value="")
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
mock_resp.__aexit__ = AsyncMock(return_value=False)
self.adapter._session.post = MagicMock(return_value=mock_resp)
result = await self.adapter.send("channel_1", "Hello!")
assert result.success is True
assert result.message_id == "post123"
# Verify post was called with correct URL
call_args = self.adapter._session.post.call_args
assert "/api/v4/posts" in call_args[0][0]
# Verify payload
payload = call_args[1]["json"]
assert payload["channel_id"] == "channel_1"
assert payload["message"] == "Hello!"
@pytest.mark.asyncio
async def test_send_empty_content_succeeds(self):
"""Empty content should return success without calling the API."""
result = await self.adapter.send("channel_1", "")
assert result.success is True
@pytest.mark.asyncio
async def test_send_with_thread_reply(self):
"""When reply_mode is 'thread', reply_to should become root_id."""
self.adapter._reply_mode = "thread"
mock_resp = AsyncMock()
mock_resp.status = 200
mock_resp.json = AsyncMock(return_value={"id": "post456"})
mock_resp.text = AsyncMock(return_value="")
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
mock_resp.__aexit__ = AsyncMock(return_value=False)
self.adapter._session.post = MagicMock(return_value=mock_resp)
result = await self.adapter.send("channel_1", "Reply!", reply_to="root_post")
assert result.success is True
payload = self.adapter._session.post.call_args[1]["json"]
assert payload["root_id"] == "root_post"
@pytest.mark.asyncio
async def test_send_without_thread_no_root_id(self):
"""When reply_mode is 'off', reply_to should NOT set root_id."""
self.adapter._reply_mode = "off"
mock_resp = AsyncMock()
mock_resp.status = 200
mock_resp.json = AsyncMock(return_value={"id": "post789"})
mock_resp.text = AsyncMock(return_value="")
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
mock_resp.__aexit__ = AsyncMock(return_value=False)
self.adapter._session.post = MagicMock(return_value=mock_resp)
result = await self.adapter.send("channel_1", "Reply!", reply_to="root_post")
assert result.success is True
payload = self.adapter._session.post.call_args[1]["json"]
assert "root_id" not in payload
@pytest.mark.asyncio
async def test_send_api_failure(self):
"""When API returns error, send should return failure."""
mock_resp = AsyncMock()
mock_resp.status = 500
mock_resp.json = AsyncMock(return_value={})
mock_resp.text = AsyncMock(return_value="Internal Server Error")
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
mock_resp.__aexit__ = AsyncMock(return_value=False)
self.adapter._session.post = MagicMock(return_value=mock_resp)
result = await self.adapter.send("channel_1", "Hello!")
assert result.success is False
# ---------------------------------------------------------------------------
# WebSocket event parsing
# ---------------------------------------------------------------------------
class TestMattermostWebSocketParsing:
def setup_method(self):
self.adapter = _make_adapter()
self.adapter._bot_user_id = "bot_user_id"
# Mock handle_message to capture the MessageEvent without processing
self.adapter.handle_message = AsyncMock()
@pytest.mark.asyncio
async def test_parse_posted_event(self):
"""'posted' events should extract message from double-encoded post JSON."""
post_data = {
"id": "post_abc",
"user_id": "user_123",
"channel_id": "chan_456",
"message": "Hello from Matrix!",
}
event = {
"event": "posted",
"data": {
"post": json.dumps(post_data), # double-encoded JSON string
"channel_type": "O",
"sender_name": "@alice",
},
}
await self.adapter._handle_ws_event(event)
assert self.adapter.handle_message.called
msg_event = self.adapter.handle_message.call_args[0][0]
assert msg_event.text == "Hello from Matrix!"
assert msg_event.message_id == "post_abc"
@pytest.mark.asyncio
async def test_ignore_own_messages(self):
"""Messages from the bot's own user_id should be ignored."""
post_data = {
"id": "post_self",
"user_id": "bot_user_id", # same as bot
"channel_id": "chan_456",
"message": "Bot echo",
}
event = {
"event": "posted",
"data": {
"post": json.dumps(post_data),
"channel_type": "O",
},
}
await self.adapter._handle_ws_event(event)
assert not self.adapter.handle_message.called
@pytest.mark.asyncio
async def test_ignore_non_posted_events(self):
"""Non-'posted' events should be ignored."""
event = {
"event": "typing",
"data": {"user_id": "user_123"},
}
await self.adapter._handle_ws_event(event)
assert not self.adapter.handle_message.called
@pytest.mark.asyncio
async def test_ignore_system_posts(self):
"""Posts with a 'type' field (system messages) should be ignored."""
post_data = {
"id": "sys_post",
"user_id": "user_123",
"channel_id": "chan_456",
"message": "user joined",
"type": "system_join_channel",
}
event = {
"event": "posted",
"data": {
"post": json.dumps(post_data),
"channel_type": "O",
},
}
await self.adapter._handle_ws_event(event)
assert not self.adapter.handle_message.called
@pytest.mark.asyncio
async def test_channel_type_mapping(self):
"""channel_type 'D' should map to 'dm'."""
post_data = {
"id": "post_dm",
"user_id": "user_123",
"channel_id": "chan_dm",
"message": "DM message",
}
event = {
"event": "posted",
"data": {
"post": json.dumps(post_data),
"channel_type": "D",
"sender_name": "@bob",
},
}
await self.adapter._handle_ws_event(event)
assert self.adapter.handle_message.called
msg_event = self.adapter.handle_message.call_args[0][0]
assert msg_event.source.chat_type == "dm"
@pytest.mark.asyncio
async def test_thread_id_from_root_id(self):
"""Post with root_id should have thread_id set."""
post_data = {
"id": "post_reply",
"user_id": "user_123",
"channel_id": "chan_456",
"message": "Thread reply",
"root_id": "root_post_123",
}
event = {
"event": "posted",
"data": {
"post": json.dumps(post_data),
"channel_type": "O",
"sender_name": "@alice",
},
}
await self.adapter._handle_ws_event(event)
assert self.adapter.handle_message.called
msg_event = self.adapter.handle_message.call_args[0][0]
assert msg_event.source.thread_id == "root_post_123"
@pytest.mark.asyncio
async def test_invalid_post_json_ignored(self):
"""Invalid JSON in data.post should be silently ignored."""
event = {
"event": "posted",
"data": {
"post": "not-valid-json{{{",
"channel_type": "O",
},
}
await self.adapter._handle_ws_event(event)
assert not self.adapter.handle_message.called
# ---------------------------------------------------------------------------
# File upload (send_image)
# ---------------------------------------------------------------------------
class TestMattermostFileUpload:
def setup_method(self):
self.adapter = _make_adapter()
self.adapter._session = MagicMock()
@pytest.mark.asyncio
async def test_send_image_downloads_and_uploads(self):
"""send_image should download the URL, upload via /api/v4/files, then post."""
# Mock the download (GET)
mock_dl_resp = AsyncMock()
mock_dl_resp.status = 200
mock_dl_resp.read = AsyncMock(return_value=b"\x89PNG\x00fake-image-data")
mock_dl_resp.content_type = "image/png"
mock_dl_resp.__aenter__ = AsyncMock(return_value=mock_dl_resp)
mock_dl_resp.__aexit__ = AsyncMock(return_value=False)
# Mock the upload (POST to /files)
mock_upload_resp = AsyncMock()
mock_upload_resp.status = 200
mock_upload_resp.json = AsyncMock(return_value={
"file_infos": [{"id": "file_abc123"}]
})
mock_upload_resp.text = AsyncMock(return_value="")
mock_upload_resp.__aenter__ = AsyncMock(return_value=mock_upload_resp)
mock_upload_resp.__aexit__ = AsyncMock(return_value=False)
# Mock the post (POST to /posts)
mock_post_resp = AsyncMock()
mock_post_resp.status = 200
mock_post_resp.json = AsyncMock(return_value={"id": "post_with_file"})
mock_post_resp.text = AsyncMock(return_value="")
mock_post_resp.__aenter__ = AsyncMock(return_value=mock_post_resp)
mock_post_resp.__aexit__ = AsyncMock(return_value=False)
# Route calls: first GET (download), then POST (upload), then POST (create post)
self.adapter._session.get = MagicMock(return_value=mock_dl_resp)
post_call_count = 0
original_post_returns = [mock_upload_resp, mock_post_resp]
def post_side_effect(*args, **kwargs):
nonlocal post_call_count
resp = original_post_returns[min(post_call_count, len(original_post_returns) - 1)]
post_call_count += 1
return resp
self.adapter._session.post = MagicMock(side_effect=post_side_effect)
result = await self.adapter.send_image(
"channel_1", "https://img.example.com/cat.png", caption="A cat"
)
assert result.success is True
assert result.message_id == "post_with_file"
# ---------------------------------------------------------------------------
# Dedup cache
# ---------------------------------------------------------------------------
class TestMattermostDedup:
def setup_method(self):
self.adapter = _make_adapter()
self.adapter._bot_user_id = "bot_user_id"
# Mock handle_message to capture calls without processing
self.adapter.handle_message = AsyncMock()
@pytest.mark.asyncio
async def test_duplicate_post_ignored(self):
"""The same post_id within the TTL window should be ignored."""
post_data = {
"id": "post_dup",
"user_id": "user_123",
"channel_id": "chan_456",
"message": "Hello!",
}
event = {
"event": "posted",
"data": {
"post": json.dumps(post_data),
"channel_type": "O",
"sender_name": "@alice",
},
}
# First time: should process
await self.adapter._handle_ws_event(event)
assert self.adapter.handle_message.call_count == 1
# Second time (same post_id): should be deduped
await self.adapter._handle_ws_event(event)
assert self.adapter.handle_message.call_count == 1 # still 1
@pytest.mark.asyncio
async def test_different_post_ids_both_processed(self):
"""Different post IDs should both be processed."""
for i, pid in enumerate(["post_a", "post_b"]):
post_data = {
"id": pid,
"user_id": "user_123",
"channel_id": "chan_456",
"message": f"Message {i}",
}
event = {
"event": "posted",
"data": {
"post": json.dumps(post_data),
"channel_type": "O",
"sender_name": "@alice",
},
}
await self.adapter._handle_ws_event(event)
assert self.adapter.handle_message.call_count == 2
def test_prune_seen_clears_expired(self):
"""_prune_seen should remove entries older than _SEEN_TTL."""
now = time.time()
# Fill with enough expired entries to trigger pruning
for i in range(self.adapter._SEEN_MAX + 10):
self.adapter._seen_posts[f"old_{i}"] = now - 600 # 10 min ago
# Add a fresh one
self.adapter._seen_posts["fresh"] = now
self.adapter._prune_seen()
# Old entries should be pruned, fresh one kept
assert "fresh" in self.adapter._seen_posts
assert len(self.adapter._seen_posts) < self.adapter._SEEN_MAX
def test_seen_cache_tracks_post_ids(self):
"""Posts are tracked in _seen_posts dict."""
self.adapter._seen_posts["test_post"] = time.time()
assert "test_post" in self.adapter._seen_posts
# ---------------------------------------------------------------------------
# Requirements check
# ---------------------------------------------------------------------------
class TestMattermostRequirements:
def test_check_requirements_with_token_and_url(self, monkeypatch):
monkeypatch.setenv("MATTERMOST_TOKEN", "test-token")
monkeypatch.setenv("MATTERMOST_URL", "https://mm.example.com")
from gateway.platforms.mattermost import check_mattermost_requirements
assert check_mattermost_requirements() is True
def test_check_requirements_without_token(self, monkeypatch):
monkeypatch.delenv("MATTERMOST_TOKEN", raising=False)
monkeypatch.delenv("MATTERMOST_URL", raising=False)
from gateway.platforms.mattermost import check_mattermost_requirements
assert check_mattermost_requirements() is False
def test_check_requirements_without_url(self, monkeypatch):
monkeypatch.setenv("MATTERMOST_TOKEN", "test-token")
monkeypatch.delenv("MATTERMOST_URL", raising=False)
from gateway.platforms.mattermost import check_mattermost_requirements
assert check_mattermost_requirements() is False