diff --git a/tests/test_telegram_thread_routing.py b/tests/test_telegram_thread_routing.py new file mode 100644 index 000000000..56e1ccd48 --- /dev/null +++ b/tests/test_telegram_thread_routing.py @@ -0,0 +1,162 @@ +"""Tests for Telegram thread-aware session routing. + +Verifies that messages in different threads/topics get independent +conversation histories. +""" + +import pytest +import sys +import os + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..')) + +from gateway.session import build_session_key, SessionSource +from gateway.platforms.base import Platform + + +class TestThreadSessionKey: + """Verify session keys include thread_id for isolation.""" + + def test_dm_with_thread_gets_unique_key(self): + source = SessionSource( + platform=Platform.TELEGRAM, + chat_id="123456", + chat_type="dm", + thread_id="100", + ) + key = build_session_key(source) + assert "123456" in key + assert "100" in key + assert key == "agent:main:telegram:dm:123456:100" + + def test_dm_without_thread_uses_chat_only(self): + source = SessionSource( + platform=Platform.TELEGRAM, + chat_id="123456", + chat_type="dm", + ) + key = build_session_key(source) + assert key == "agent:main:telegram:dm:123456" + assert ":100" not in key + + def test_different_threads_different_keys(self): + source_a = SessionSource( + platform=Platform.TELEGRAM, + chat_id="123456", + chat_type="dm", + thread_id="100", + ) + source_b = SessionSource( + platform=Platform.TELEGRAM, + chat_id="123456", + chat_type="dm", + thread_id="200", + ) + key_a = build_session_key(source_a) + key_b = build_session_key(source_b) + assert key_a != key_b + + def test_same_thread_same_key(self): + source_a = SessionSource( + platform=Platform.TELEGRAM, + chat_id="123456", + chat_type="dm", + thread_id="100", + ) + source_b = SessionSource( + platform=Platform.TELEGRAM, + chat_id="123456", + chat_type="dm", + thread_id="100", + ) + assert build_session_key(source_a) == build_session_key(source_b) + + def test_group_with_thread_includes_thread(self): + source = SessionSource( + platform=Platform.TELEGRAM, + chat_id="789", + chat_type="group", + thread_id="50", + user_id="user1", + ) + key = build_session_key(source) + assert "789" in key + assert "50" in key + + def test_group_without_thread_isolates_by_user(self): + source = SessionSource( + platform=Platform.TELEGRAM, + chat_id="789", + chat_type="group", + user_id="user1", + ) + key = build_session_key(source, group_sessions_per_user=True) + assert "789" in key + assert "user1" in key + + def test_group_thread_shared_across_users(self): + """In threads, all participants share the same session by default.""" + source_a = SessionSource( + platform=Platform.TELEGRAM, + chat_id="789", + chat_type="group", + thread_id="50", + user_id="user1", + ) + source_b = SessionSource( + platform=Platform.TELEGRAM, + chat_id="789", + chat_type="group", + thread_id="50", + user_id="user2", + ) + key_a = build_session_key(source_a, thread_sessions_per_user=False) + key_b = build_session_key(source_b, thread_sessions_per_user=False) + assert key_a == key_b # Shared session in thread + + def test_group_thread_per_user_when_enabled(self): + """With thread_sessions_per_user=True, users get isolated sessions.""" + source_a = SessionSource( + platform=Platform.TELEGRAM, + chat_id="789", + chat_type="group", + thread_id="50", + user_id="user1", + ) + source_b = SessionSource( + platform=Platform.TELEGRAM, + chat_id="789", + chat_type="group", + thread_id="50", + user_id="user2", + ) + key_a = build_session_key(source_a, thread_sessions_per_user=True) + key_b = build_session_key(source_b, thread_sessions_per_user=True) + assert key_a != key_b + + +class TestSessionSourceSerialization: + """Verify SessionSource round-trips correctly with thread_id.""" + + def test_thread_id_preserved_in_dict(self): + source = SessionSource( + platform=Platform.TELEGRAM, + chat_id="123", + chat_type="dm", + thread_id="456", + ) + d = source.to_dict() + assert d["thread_id"] == "456" + restored = SessionSource.from_dict(d) + assert restored.thread_id == "456" + + def test_none_thread_id_preserved(self): + source = SessionSource( + platform=Platform.TELEGRAM, + chat_id="123", + chat_type="dm", + ) + d = source.to_dict() + assert d.get("thread_id") is None + restored = SessionSource.from_dict(d) + assert restored.thread_id is None