diff --git a/gateway/config.py b/gateway/config.py index c7eb4adf1..e7794b751 100644 --- a/gateway/config.py +++ b/gateway/config.py @@ -563,6 +563,14 @@ def load_gateway_config() -> GatewayConfig: if isinstance(frc, list): frc = ",".join(str(v) for v in frc) os.environ["TELEGRAM_FREE_RESPONSE_CHATS"] = str(frc) + + whatsapp_cfg = yaml_cfg.get("whatsapp", {}) + if isinstance(whatsapp_cfg, dict): + if "require_mention" in whatsapp_cfg and not os.getenv("WHATSAPP_REQUIRE_MENTION"): + os.environ["WHATSAPP_REQUIRE_MENTION"] = str(whatsapp_cfg["require_mention"]).lower() + if "mention_patterns" in whatsapp_cfg and not os.getenv("WHATSAPP_MENTION_PATTERNS"): + import json as _json + os.environ["WHATSAPP_MENTION_PATTERNS"] = _json.dumps(whatsapp_cfg["mention_patterns"]) except Exception as e: logger.warning( "Failed to process config.yaml — falling back to .env / gateway.json values. " diff --git a/gateway/platforms/whatsapp.py b/gateway/platforms/whatsapp.py index 02448a6dd..fb5d1b2dc 100644 --- a/gateway/platforms/whatsapp.py +++ b/gateway/platforms/whatsapp.py @@ -16,9 +16,11 @@ with different backends via a bridge pattern. """ import asyncio +import json import logging import os import platform +import re import subprocess _IS_WINDOWS = platform.system() == "Windows" @@ -138,12 +140,113 @@ class WhatsAppAdapter(BasePlatformAdapter): get_hermes_dir("platforms/whatsapp/session", "whatsapp/session") )) self._reply_prefix: Optional[str] = config.extra.get("reply_prefix") + self._mention_patterns = self._compile_mention_patterns() self._message_queue: asyncio.Queue = asyncio.Queue() self._bridge_log_fh = None self._bridge_log: Optional[Path] = None self._poll_task: Optional[asyncio.Task] = None self._http_session: Optional["aiohttp.ClientSession"] = None self._session_lock_identity: Optional[str] = None + + def _whatsapp_require_mention(self) -> bool: + configured = self.config.extra.get("require_mention") + if configured is not None: + if isinstance(configured, str): + return configured.lower() in ("true", "1", "yes", "on") + return bool(configured) + return os.getenv("WHATSAPP_REQUIRE_MENTION", "false").lower() in ("true", "1", "yes", "on") + + def _compile_mention_patterns(self): + patterns = self.config.extra.get("mention_patterns") + if patterns is None: + raw = os.getenv("WHATSAPP_MENTION_PATTERNS", "").strip() + if raw: + try: + patterns = json.loads(raw) + except Exception: + patterns = [part.strip() for part in raw.splitlines() if part.strip()] + if not patterns: + patterns = [part.strip() for part in raw.split(",") if part.strip()] + if patterns is None: + return [] + if isinstance(patterns, str): + patterns = [patterns] + if not isinstance(patterns, list): + logger.warning("[%s] whatsapp mention_patterns must be a list or string; got %s", self.name, type(patterns).__name__) + return [] + + compiled = [] + for pattern in patterns: + if not isinstance(pattern, str) or not pattern.strip(): + continue + try: + compiled.append(re.compile(pattern, re.IGNORECASE)) + except re.error as exc: + logger.warning("[%s] Invalid WhatsApp mention pattern %r: %s", self.name, pattern, exc) + return compiled + + @staticmethod + def _normalize_whatsapp_id(value: Optional[str]) -> str: + if not value: + return "" + normalized = str(value).strip() + if ":" in normalized and "@" in normalized: + normalized = normalized.replace(":", "@", 1) + return normalized + + def _bot_ids_from_message(self, data: Dict[str, Any]) -> set[str]: + bot_ids = set() + for candidate in data.get("botIds") or []: + normalized = self._normalize_whatsapp_id(candidate) + if normalized: + bot_ids.add(normalized) + return bot_ids + + def _message_is_reply_to_bot(self, data: Dict[str, Any]) -> bool: + quoted_participant = self._normalize_whatsapp_id(data.get("quotedParticipant")) + if not quoted_participant: + return False + return quoted_participant in self._bot_ids_from_message(data) + + def _message_mentions_bot(self, data: Dict[str, Any]) -> bool: + bot_ids = self._bot_ids_from_message(data) + if not bot_ids: + return False + mentioned_ids = { + self._normalize_whatsapp_id(candidate) + for candidate in (data.get("mentionedIds") or []) + if self._normalize_whatsapp_id(candidate) + } + if mentioned_ids & bot_ids: + return True + + body = str(data.get("body") or "") + lower_body = body.lower() + for bot_id in bot_ids: + bare_id = bot_id.split("@", 1)[0].lower() + if bare_id and (f"@{bare_id}" in lower_body or bare_id in lower_body): + return True + return False + + def _message_matches_mention_patterns(self, data: Dict[str, Any]) -> bool: + if not self._mention_patterns: + return False + body = str(data.get("body") or "") + return any(pattern.search(body) for pattern in self._mention_patterns) + + def _should_process_message(self, data: Dict[str, Any]) -> bool: + if not data.get("isGroup"): + return True + if not self._whatsapp_require_mention(): + return True + body = str(data.get("body") or "").strip() + if body.startswith("/"): + return True + if self._message_is_reply_to_bot(data): + return True + if self._message_mentions_bot(data): + return True + return self._message_matches_mention_patterns(data) async def connect(self) -> bool: """ @@ -687,6 +790,9 @@ class WhatsAppAdapter(BasePlatformAdapter): async def _build_message_event(self, data: Dict[str, Any]) -> Optional[MessageEvent]: """Build a MessageEvent from bridge message data, downloading images to cache.""" try: + if not self._should_process_message(data): + return None + # Determine message type msg_type = MessageType.TEXT if data.get("hasMedia"): diff --git a/scripts/whatsapp-bridge/bridge.js b/scripts/whatsapp-bridge/bridge.js index 5f0cb729f..c4d6891c1 100644 --- a/scripts/whatsapp-bridge/bridge.js +++ b/scripts/whatsapp-bridge/bridge.js @@ -62,6 +62,30 @@ function formatOutgoingMessage(message) { return REPLY_PREFIX ? `${REPLY_PREFIX}${message}` : message; } +function normalizeWhatsAppId(value) { + if (!value) return ''; + return String(value).replace(':', '@'); +} + +function getMessageContent(msg) { + const content = msg?.message || {}; + if (content.ephemeralMessage?.message) return content.ephemeralMessage.message; + if (content.viewOnceMessage?.message) return content.viewOnceMessage.message; + if (content.viewOnceMessageV2?.message) return content.viewOnceMessageV2.message; + if (content.documentWithCaptionMessage?.message) return content.documentWithCaptionMessage.message; + return content; +} + +function getContextInfo(messageContent) { + if (!messageContent || typeof messageContent !== 'object') return {}; + for (const value of Object.values(messageContent)) { + if (value && typeof value === 'object' && value.contextInfo) { + return value.contextInfo; + } + } + return {}; +} + mkdirSync(SESSION_DIR, { recursive: true }); // Build LID → phone reverse map from session files (lid-mapping-{phone}.json) @@ -200,23 +224,32 @@ async function startSocket() { continue; } + const messageContent = getMessageContent(msg); + const contextInfo = getContextInfo(messageContent); + const mentionedIds = Array.from(new Set((contextInfo?.mentionedJid || []).map(normalizeWhatsAppId).filter(Boolean))); + const quotedParticipant = normalizeWhatsAppId(contextInfo?.participant || contextInfo?.remoteJid || ''); + const botIds = Array.from(new Set([ + normalizeWhatsAppId(sock.user?.id), + normalizeWhatsAppId(sock.user?.lid), + ].filter(Boolean))); + // Extract message body let body = ''; let hasMedia = false; let mediaType = ''; const mediaUrls = []; - if (msg.message.conversation) { - body = msg.message.conversation; - } else if (msg.message.extendedTextMessage?.text) { - body = msg.message.extendedTextMessage.text; - } else if (msg.message.imageMessage) { - body = msg.message.imageMessage.caption || ''; + if (messageContent.conversation) { + body = messageContent.conversation; + } else if (messageContent.extendedTextMessage?.text) { + body = messageContent.extendedTextMessage.text; + } else if (messageContent.imageMessage) { + body = messageContent.imageMessage.caption || ''; hasMedia = true; mediaType = 'image'; try { const buf = await downloadMediaMessage(msg, 'buffer', {}, { logger, reuploadRequest: sock.updateMediaMessage }); - const mime = msg.message.imageMessage.mimetype || 'image/jpeg'; + const mime = messageContent.imageMessage.mimetype || 'image/jpeg'; const extMap = { 'image/jpeg': '.jpg', 'image/png': '.png', 'image/webp': '.webp', 'image/gif': '.gif' }; const ext = extMap[mime] || '.jpg'; mkdirSync(IMAGE_CACHE_DIR, { recursive: true }); @@ -226,13 +259,13 @@ async function startSocket() { } catch (err) { console.error('[bridge] Failed to download image:', err.message); } - } else if (msg.message.videoMessage) { - body = msg.message.videoMessage.caption || ''; + } else if (messageContent.videoMessage) { + body = messageContent.videoMessage.caption || ''; hasMedia = true; mediaType = 'video'; try { const buf = await downloadMediaMessage(msg, 'buffer', {}, { logger, reuploadRequest: sock.updateMediaMessage }); - const mime = msg.message.videoMessage.mimetype || 'video/mp4'; + const mime = messageContent.videoMessage.mimetype || 'video/mp4'; const ext = mime.includes('mp4') ? '.mp4' : '.mkv'; mkdirSync(DOCUMENT_CACHE_DIR, { recursive: true }); const filePath = path.join(DOCUMENT_CACHE_DIR, `vid_${randomBytes(6).toString('hex')}${ext}`); @@ -241,11 +274,11 @@ async function startSocket() { } catch (err) { console.error('[bridge] Failed to download video:', err.message); } - } else if (msg.message.audioMessage || msg.message.pttMessage) { + } else if (messageContent.audioMessage || messageContent.pttMessage) { hasMedia = true; - mediaType = msg.message.pttMessage ? 'ptt' : 'audio'; + mediaType = messageContent.pttMessage ? 'ptt' : 'audio'; try { - const audioMsg = msg.message.pttMessage || msg.message.audioMessage; + const audioMsg = messageContent.pttMessage || messageContent.audioMessage; const buf = await downloadMediaMessage(msg, 'buffer', {}, { logger, reuploadRequest: sock.updateMediaMessage }); const mime = audioMsg.mimetype || 'audio/ogg'; const ext = mime.includes('ogg') ? '.ogg' : mime.includes('mp4') ? '.m4a' : '.ogg'; @@ -256,11 +289,11 @@ async function startSocket() { } catch (err) { console.error('[bridge] Failed to download audio:', err.message); } - } else if (msg.message.documentMessage) { - body = msg.message.documentMessage.caption || ''; + } else if (messageContent.documentMessage) { + body = messageContent.documentMessage.caption || ''; hasMedia = true; mediaType = 'document'; - const fileName = msg.message.documentMessage.fileName || 'document'; + const fileName = messageContent.documentMessage.fileName || 'document'; try { const buf = await downloadMediaMessage(msg, 'buffer', {}, { logger, reuploadRequest: sock.updateMediaMessage }); mkdirSync(DOCUMENT_CACHE_DIR, { recursive: true }); @@ -309,6 +342,9 @@ async function startSocket() { hasMedia, mediaType, mediaUrls, + mentionedIds, + quotedParticipant, + botIds, timestamp: msg.messageTimestamp, }; diff --git a/tests/gateway/test_whatsapp_group_gating.py b/tests/gateway/test_whatsapp_group_gating.py new file mode 100644 index 000000000..8d1c3d6dc --- /dev/null +++ b/tests/gateway/test_whatsapp_group_gating.py @@ -0,0 +1,97 @@ +import json +from unittest.mock import AsyncMock + +from gateway.config import Platform, PlatformConfig, load_gateway_config + + +def _make_adapter(require_mention=None, mention_patterns=None): + from gateway.platforms.whatsapp import WhatsAppAdapter + + extra = {} + if require_mention is not None: + extra["require_mention"] = require_mention + if mention_patterns is not None: + extra["mention_patterns"] = mention_patterns + + adapter = object.__new__(WhatsAppAdapter) + adapter.platform = Platform.WHATSAPP + adapter.config = PlatformConfig(enabled=True, extra=extra) + adapter._message_handler = AsyncMock() + adapter._mention_patterns = adapter._compile_mention_patterns() + return adapter + + +def _group_message(body="hello", **overrides): + data = { + "isGroup": True, + "body": body, + "mentionedIds": [], + "botIds": ["15551230000@s.whatsapp.net", "15551230000@lid"], + "quotedParticipant": "", + } + data.update(overrides) + return data + + +def test_group_messages_can_be_opened_via_config(): + adapter = _make_adapter(require_mention=False) + + assert adapter._should_process_message(_group_message("hello everyone")) is True + + +def test_group_messages_can_require_direct_trigger_via_config(): + adapter = _make_adapter(require_mention=True) + + assert adapter._should_process_message(_group_message("hello everyone")) is False + assert adapter._should_process_message( + _group_message( + "hi there", + mentionedIds=["15551230000@s.whatsapp.net"], + ) + ) is True + assert adapter._should_process_message( + _group_message( + "replying", + quotedParticipant="15551230000@lid", + ) + ) is True + assert adapter._should_process_message(_group_message("/status")) is True + + +def test_regex_mention_patterns_allow_custom_wake_words(): + adapter = _make_adapter(require_mention=True, mention_patterns=[r"^\s*chompy\b"]) + + assert adapter._should_process_message(_group_message("chompy status")) is True + assert adapter._should_process_message(_group_message(" chompy help")) is True + assert adapter._should_process_message(_group_message("hey chompy")) is False + + +def test_invalid_regex_patterns_are_ignored(): + adapter = _make_adapter(require_mention=True, mention_patterns=[r"(", r"^\s*chompy\b"]) + + assert adapter._should_process_message(_group_message("chompy status")) is True + assert adapter._should_process_message(_group_message("hello everyone")) is False + + +def test_config_bridges_whatsapp_group_settings(monkeypatch, tmp_path): + hermes_home = tmp_path / ".hermes" + hermes_home.mkdir() + (hermes_home / "config.yaml").write_text( + "whatsapp:\n" + " require_mention: true\n" + " mention_patterns:\n" + " - \"^\\\\s*chompy\\\\b\"\n", + encoding="utf-8", + ) + + monkeypatch.setenv("HERMES_HOME", str(hermes_home)) + monkeypatch.delenv("WHATSAPP_REQUIRE_MENTION", raising=False) + monkeypatch.delenv("WHATSAPP_MENTION_PATTERNS", raising=False) + + config = load_gateway_config() + + assert config is not None + assert config.platforms[Platform.WHATSAPP].extra["require_mention"] is True + assert config.platforms[Platform.WHATSAPP].extra["mention_patterns"] == [r"^\s*chompy\b"] + assert __import__("os").environ["WHATSAPP_REQUIRE_MENTION"] == "true" + assert json.loads(__import__("os").environ["WHATSAPP_MENTION_PATTERNS"]) == [r"^\s*chompy\b"]