diff --git a/gateway/run.py b/gateway/run.py index 18170abf1..eec3f2694 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -118,6 +118,7 @@ from gateway.session import ( SessionContext, build_session_context, build_session_context_prompt, + build_session_key, ) from gateway.delivery import DeliveryRouter, DeliveryTarget from gateway.platforms.base import BasePlatformAdapter, MessageEvent, MessageType @@ -637,12 +638,7 @@ class GatewayRunner: # PRIORITY: If an agent is already running for this session, interrupt it # immediately. This is before command parsing to minimize latency -- the # user's "stop" message reaches the agent as fast as possible. - if source.chat_type != "dm": - _quick_key = f"agent:main:{source.platform.value}:{source.chat_type}:{source.chat_id}" - elif source.platform.value == "whatsapp" and source.chat_id: - _quick_key = f"agent:main:{source.platform.value}:dm:{source.chat_id}" - else: - _quick_key = f"agent:main:{source.platform.value}:dm" + _quick_key = build_session_key(source) if _quick_key in self._running_agents: running_agent = self._running_agents[_quick_key] logger.debug("PRIORITY interrupt for session %s", _quick_key[:20]) @@ -720,12 +716,7 @@ class GatewayRunner: logger.debug("Skill command check failed (non-fatal): %s", e) # Check for pending exec approval responses - if source.chat_type != "dm": - session_key_preview = f"agent:main:{source.platform.value}:{source.chat_type}:{source.chat_id}" - elif source.platform and source.platform.value == "whatsapp" and source.chat_id: - session_key_preview = f"agent:main:{source.platform.value}:dm:{source.chat_id}" - else: - session_key_preview = f"agent:main:{source.platform.value}:dm" + session_key_preview = build_session_key(source) if session_key_preview in self._pending_approvals: user_text = event.text.strip().lower() if user_text in ("yes", "y", "approve", "ok", "go", "do it"): @@ -1362,12 +1353,7 @@ class GatewayRunner: async def _handle_usage_command(self, event: MessageEvent) -> str: """Handle /usage command -- show token usage for the session's last agent run.""" source = event.source - if source.chat_type != "dm": - session_key = f"agent:main:{source.platform.value}:{source.chat_type}:{source.chat_id}" - elif source.platform.value == "whatsapp" and source.chat_id: - session_key = f"agent:main:{source.platform.value}:dm:{source.chat_id}" - else: - session_key = f"agent:main:{source.platform.value}:dm" + session_key = build_session_key(source) agent = self._running_agents.get(session_key) if agent and hasattr(agent, "session_total_tokens") and agent.session_api_calls > 0: diff --git a/gateway/session.py b/gateway/session.py index b59196b81..a337384d5 100644 --- a/gateway/session.py +++ b/gateway/session.py @@ -281,6 +281,20 @@ class SessionEntry: ) +def build_session_key(source: SessionSource) -> str: + """Build a deterministic session key from a message source. + + This is the single source of truth for session key construction. + WhatsApp DMs include chat_id (multi-user), other DMs do not (single owner). + """ + platform = source.platform.value + if source.chat_type == "dm": + if platform == "whatsapp" and source.chat_id: + return f"agent:main:{platform}:dm:{source.chat_id}" + return f"agent:main:{platform}:dm" + return f"agent:main:{platform}:{source.chat_type}:{source.chat_id}" + + class SessionStore: """ Manages session storage and retrieval. @@ -337,16 +351,7 @@ class SessionStore: def _generate_session_key(self, source: SessionSource) -> str: """Generate a session key from a source.""" - platform = source.platform.value - - if source.chat_type == "dm": - # WhatsApp DMs come from different people, each needs its own session. - # Other platforms (Telegram, Discord) have a single DM with the bot owner. - if platform == "whatsapp" and source.chat_id: - return f"agent:main:{platform}:dm:{source.chat_id}" - return f"agent:main:{platform}:dm" - else: - return f"agent:main:{platform}:{source.chat_type}:{source.chat_id}" + return build_session_key(source) def _should_reset(self, entry: SessionEntry, source: SessionSource) -> bool: """ diff --git a/tests/gateway/test_session.py b/tests/gateway/test_session.py index 776e785fb..f4a0af6ea 100644 --- a/tests/gateway/test_session.py +++ b/tests/gateway/test_session.py @@ -10,6 +10,7 @@ from gateway.session import ( SessionStore, build_session_context, build_session_context_prompt, + build_session_key, ) @@ -315,8 +316,8 @@ class TestSessionStoreRewriteTranscript: class TestWhatsAppDMSessionKeyConsistency: - """Regression: inline session-key construction in handle_message must match - _generate_session_key for WhatsApp DMs, which include chat_id.""" + """Regression: all session-key construction must go through build_session_key + so WhatsApp DMs include chat_id while other DMs do not.""" @pytest.fixture() def store(self, tmp_path): @@ -327,73 +328,45 @@ class TestWhatsAppDMSessionKeyConsistency: s._loaded = True return s - def _build_quick_key(self, source: SessionSource) -> str: - """Reproduce the _quick_key logic from gateway/run.py handle_message.""" - if source.chat_type != "dm": - return f"agent:main:{source.platform.value}:{source.chat_type}:{source.chat_id}" - elif source.platform.value == "whatsapp" and source.chat_id: - return f"agent:main:{source.platform.value}:dm:{source.chat_id}" - else: - return f"agent:main:{source.platform.value}:dm" - - def _build_usage_key(self, source: SessionSource) -> str: - """Reproduce the session_key logic from _handle_usage_command.""" - if source.chat_type != "dm": - return f"agent:main:{source.platform.value}:{source.chat_type}:{source.chat_id}" - elif source.platform.value == "whatsapp" and source.chat_id: - return f"agent:main:{source.platform.value}:dm:{source.chat_id}" - else: - return f"agent:main:{source.platform.value}:dm" - - def test_whatsapp_dm_quick_key_includes_chat_id(self, store): + def test_whatsapp_dm_includes_chat_id(self): source = SessionSource( platform=Platform.WHATSAPP, chat_id="15551234567@s.whatsapp.net", chat_type="dm", user_name="Phone User", ) - real_key = store._generate_session_key(source) - quick_key = self._build_quick_key(source) - assert quick_key == real_key - assert "15551234567@s.whatsapp.net" in quick_key + key = build_session_key(source) + assert key == "agent:main:whatsapp:dm:15551234567@s.whatsapp.net" - def test_whatsapp_dm_usage_key_includes_chat_id(self, store): + def test_store_delegates_to_build_session_key(self, store): + """SessionStore._generate_session_key must produce the same result.""" source = SessionSource( platform=Platform.WHATSAPP, chat_id="15551234567@s.whatsapp.net", chat_type="dm", user_name="Phone User", ) - real_key = store._generate_session_key(source) - usage_key = self._build_usage_key(source) - assert usage_key == real_key - assert "15551234567@s.whatsapp.net" in usage_key + assert store._generate_session_key(source) == build_session_key(source) - def test_telegram_dm_key_unchanged(self, store): + def test_telegram_dm_omits_chat_id(self): """Non-WhatsApp DMs should still omit chat_id (single owner DM).""" source = SessionSource( platform=Platform.TELEGRAM, chat_id="99", chat_type="dm", ) - real_key = store._generate_session_key(source) - quick_key = self._build_quick_key(source) - usage_key = self._build_usage_key(source) - assert quick_key == real_key == "agent:main:telegram:dm" - assert usage_key == real_key + key = build_session_key(source) + assert key == "agent:main:telegram:dm" - def test_discord_group_key_unchanged(self, store): - """Group/channel keys should be unaffected by the fix.""" + def test_discord_group_includes_chat_id(self): + """Group/channel keys include chat_type and chat_id.""" source = SessionSource( platform=Platform.DISCORD, chat_id="guild-123", chat_type="group", ) - real_key = store._generate_session_key(source) - quick_key = self._build_quick_key(source) - usage_key = self._build_usage_key(source) - assert quick_key == real_key == "agent:main:discord:group:guild-123" - assert usage_key == real_key + key = build_session_key(source) + assert key == "agent:main:discord:group:guild-123" class TestSessionStoreEntriesAttribute: